Unverified Commit a208e886 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4680)



* [Misc] Black auto fix.

* fix pylint disable
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 29434e65
......@@ -5,33 +5,41 @@ from __future__ import absolute_import
import ctypes
import traceback
from numbers import Number, Integral
from numbers import Integral, Number
from ..base import _LIB, check_call
from ..base import c_str, string_types
from ..object_generic import convert_to_object, ObjectGeneric
from ..runtime_ctypes import DGLDataType, DGLByteArray, DGLContext
from ..base import _LIB, c_str, check_call, string_types
from ..object_generic import ObjectGeneric, convert_to_object
from ..runtime_ctypes import DGLByteArray, DGLContext, DGLDataType
from . import ndarray as _nd
from . import object as _object
from .ndarray import NDArrayBase, _make_array
from .types import DGLValue, TypeCode
from .types import DGLPackedCFunc, DGLCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
from .object import ObjectBase
from . import object as _object
from .types import (
C_TO_PY_ARG_SWITCH,
RETURN_SWITCH,
DGLCFuncFinalizer,
DGLPackedCFunc,
DGLValue,
TypeCode,
_wrap_arg_func,
)
FunctionHandle = ctypes.c_void_p
ModuleHandle = ctypes.c_void_p
DGLRetValueHandle = ctypes.c_void_p
def _ctypes_free_resource(rhandle):
"""callback to free resources when it it not needed."""
pyobj = ctypes.cast(rhandle, ctypes.py_object)
ctypes.pythonapi.Py_DecRef(pyobj)
# Global callback that is always alive
DGL_FREE_PYOBJ = DGLCFuncFinalizer(_ctypes_free_resource)
ctypes.pythonapi.Py_IncRef(ctypes.py_object(DGL_FREE_PYOBJ))
def convert_to_dgl_func(pyfunc):
"""Convert a python function to DGL function
......@@ -46,10 +54,15 @@ def convert_to_dgl_func(pyfunc):
The converted dgl function.
"""
local_pyfunc = pyfunc
def cfun(args, type_codes, num_args, ret, _):
""" ctypes function """
num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args
pyargs = (C_TO_PY_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args))
"""ctypes function"""
num_args = (
num_args.value if isinstance(num_args, ctypes.c_int) else num_args
)
pyargs = (
C_TO_PY_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args)
)
# pylint: disable=broad-except
try:
rv = local_pyfunc(*pyargs)
......@@ -60,12 +73,16 @@ def convert_to_dgl_func(pyfunc):
if rv is not None:
if isinstance(rv, tuple):
raise ValueError("PackedFunction can only support one return value")
raise ValueError(
"PackedFunction can only support one return value"
)
temp_args = []
values, tcodes, _ = _make_dgl_args((rv,), temp_args)
if not isinstance(ret, DGLRetValueHandle):
ret = DGLRetValueHandle(ret)
check_call(_LIB.DGLCFuncSetReturn(ret, values, tcodes, ctypes.c_int(1)))
check_call(
_LIB.DGLCFuncSetReturn(ret, values, tcodes, ctypes.c_int(1))
)
_ = temp_args
_ = rv
return 0
......@@ -76,8 +93,11 @@ def convert_to_dgl_func(pyfunc):
# DGL_FREE_PYOBJ will be called after it is no longer needed.
pyobj = ctypes.py_object(f)
ctypes.pythonapi.Py_IncRef(pyobj)
check_call(_LIB.DGLFuncCreateFromCFunc(
f, pyobj, DGL_FREE_PYOBJ, ctypes.byref(handle)))
check_call(
_LIB.DGLFuncCreateFromCFunc(
f, pyobj, DGL_FREE_PYOBJ, ctypes.byref(handle)
)
)
return _CLASS_FUNCTION(handle, False)
......@@ -104,8 +124,11 @@ def _make_dgl_args(args, temp_args):
temp_args.append(arg)
elif isinstance(arg, NDArrayBase):
values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
type_codes[i] = (TypeCode.NDARRAY_CONTAINER
if not arg.is_view else TypeCode.ARRAY_HANDLE)
type_codes[i] = (
TypeCode.NDARRAY_CONTAINER
if not arg.is_view
else TypeCode.ARRAY_HANDLE
)
elif isinstance(arg, _nd._DGL_COMPATS):
values[i].v_handle = ctypes.c_void_p(arg._dgl_handle)
type_codes[i] = arg.__class__._dgl_tcode
......@@ -125,7 +148,8 @@ def _make_dgl_args(args, temp_args):
arr = DGLByteArray()
arr.data = ctypes.cast(
(ctypes.c_byte * len(arg)).from_buffer(arg),
ctypes.POINTER(ctypes.c_byte))
ctypes.POINTER(ctypes.c_byte),
)
arr.size = len(arg)
values[i].v_handle = ctypes.c_void_p(ctypes.addressof(arr))
temp_args.append(arr)
......@@ -134,7 +158,7 @@ def _make_dgl_args(args, temp_args):
values[i].v_str = c_str(arg)
type_codes[i] = TypeCode.STR
# NOTE(minjie): module is not used in DGL
#elif isinstance(arg, _CLASS_MODULE):
# elif isinstance(arg, _CLASS_MODULE):
# values[i].v_handle = arg.handle
# type_codes[i] = TypeCode.MODULE_HANDLE
elif isinstance(arg, FunctionBase):
......@@ -155,6 +179,7 @@ def _make_dgl_args(args, temp_args):
class FunctionBase(object):
"""Function base."""
__slots__ = ["handle", "is_global"]
# pylint: disable=no-member
def __init__(self, handle, is_global):
......@@ -185,9 +210,16 @@ class FunctionBase(object):
values, tcodes, num_args = _make_dgl_args(args, temp_args)
ret_val = DGLValue()
ret_tcode = ctypes.c_int()
check_call(_LIB.DGLFuncCall(
self.handle, values, tcodes, ctypes.c_int(num_args),
ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
check_call(
_LIB.DGLFuncCall(
self.handle,
values,
tcodes,
ctypes.c_int(num_args),
ctypes.byref(ret_val),
ctypes.byref(ret_tcode),
)
)
_ = temp_args
_ = args
return RETURN_SWITCH[ret_tcode.value](ret_val)
......@@ -199,9 +231,16 @@ def __init_handle_by_constructor__(fconstructor, args):
values, tcodes, num_args = _make_dgl_args(args, temp_args)
ret_val = DGLValue()
ret_tcode = ctypes.c_int()
check_call(_LIB.DGLFuncCall(
fconstructor.handle, values, tcodes, ctypes.c_int(num_args),
ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
check_call(
_LIB.DGLFuncCall(
fconstructor.handle,
values,
tcodes,
ctypes.c_int(num_args),
ctypes.byref(ret_val),
ctypes.byref(ret_tcode),
)
)
_ = temp_args
_ = args
assert ret_tcode.value == TypeCode.OBJECT_HANDLE
......@@ -216,6 +255,7 @@ def _return_module(x):
handle = ModuleHandle(handle)
return _CLASS_MODULE(handle)
def _handle_return_func(x):
"""Return function"""
handle = x.v_handle
......@@ -228,22 +268,32 @@ def _handle_return_func(x):
_object.__init_by_constructor__ = __init_handle_by_constructor__
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False)
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(
x.v_handle, False
)
C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func(
_handle_return_func, TypeCode.FUNC_HANDLE)
_handle_return_func, TypeCode.FUNC_HANDLE
)
C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
_return_module, TypeCode.MODULE_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.ARRAY_HANDLE] = lambda x: _make_array(x.v_handle, True)
C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False)
_return_module, TypeCode.MODULE_HANDLE
)
C_TO_PY_ARG_SWITCH[TypeCode.ARRAY_HANDLE] = lambda x: _make_array(
x.v_handle, True
)
C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(
x.v_handle, False
)
_CLASS_MODULE = None
_CLASS_FUNCTION = None
def _set_class_module(module_class):
"""Initialize the module."""
global _CLASS_MODULE
_CLASS_MODULE = module_class
def _set_class_function(func_class):
global _CLASS_FUNCTION
_CLASS_FUNCTION = func_class
......@@ -3,18 +3,23 @@
from __future__ import absolute_import
import ctypes
from ..base import _LIB, check_call, c_str
from ..runtime_ctypes import DGLArrayHandle
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _return_handle
from ..base import _LIB, c_str, check_call
from ..runtime_ctypes import DGLArrayHandle
from .types import (
C_TO_PY_ARG_SWITCH,
RETURN_SWITCH,
_return_handle,
_wrap_arg_func,
)
DGLPyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
_c_str_dltensor = c_str('dltensor')
_c_str_used_dltensor = c_str('used_dltensor')
_c_str_dltensor = c_str("dltensor")
_c_str_used_dltensor = c_str("used_dltensor")
# used for PyCapsule manipulation
if hasattr(ctypes, 'pythonapi'):
if hasattr(ctypes, "pythonapi"):
ctypes.pythonapi.PyCapsule_GetName.restype = ctypes.c_char_p
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
......@@ -31,9 +36,13 @@ def _from_dlpack(dltensor):
handle = DGLArrayHandle()
check_call(_LIB.DGLArrayFromDLPack(ptr, ctypes.byref(handle)))
ctypes.pythonapi.PyCapsule_SetName(dltensor, _c_str_used_dltensor)
ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, DGLPyCapsuleDestructor(0))
ctypes.pythonapi.PyCapsule_SetDestructor(
dltensor, DGLPyCapsuleDestructor(0)
)
return _make_array(handle, False)
raise ValueError("Expect a dltensor field, PyCapsule can only be consumed once")
raise ValueError(
"Expect a dltensor field, PyCapsule can only be consumed once"
)
def _dlpack_deleter(pycapsule):
......@@ -45,13 +54,17 @@ def _dlpack_deleter(pycapsule):
# work out always.
ptr = ctypes.cast(ptr, ctypes.c_void_p)
_LIB.DGLDLManagedTensorCallDeleter(ptr)
ctypes.pythonapi.PyCapsule_SetDestructor(pycapsule, DGLPyCapsuleDestructor(0))
ctypes.pythonapi.PyCapsule_SetDestructor(
pycapsule, DGLPyCapsuleDestructor(0)
)
_c_dlpack_deleter = DGLPyCapsuleDestructor(_dlpack_deleter)
class NDArrayBase(object):
"""A simple Device/CPU Array object in runtime."""
__slots__ = ["handle", "is_view"]
# pylint: disable=no-member
def __init__(self, handle, is_view=False):
......@@ -89,27 +102,36 @@ class NDArrayBase(object):
dlpack : DLPack tensor view of the array data
"""
ptr = ctypes.c_void_p()
check_call(_LIB.DGLArrayToDLPack(self.handle, ctypes.byref(ptr), alignment))
return ctypes.pythonapi.PyCapsule_New(ptr, _c_str_dltensor, _c_dlpack_deleter)
check_call(
_LIB.DGLArrayToDLPack(self.handle, ctypes.byref(ptr), alignment)
)
return ctypes.pythonapi.PyCapsule_New(
ptr, _c_str_dltensor, _c_dlpack_deleter
)
def _make_array(handle, is_view):
handle = ctypes.cast(handle, DGLArrayHandle)
return _CLASS_NDARRAY(handle, is_view)
_DGL_COMPATS = ()
def _reg_extension(cls, fcreate):
global _DGL_COMPATS
_DGL_COMPATS += (cls,)
if fcreate:
fret = lambda x: fcreate(_return_handle(x))
RETURN_SWITCH[cls._dgl_tcode] = fret
C_TO_PY_ARG_SWITCH[cls._dgl_tcode] = _wrap_arg_func(fret, cls._dgl_tcode)
C_TO_PY_ARG_SWITCH[cls._dgl_tcode] = _wrap_arg_func(
fret, cls._dgl_tcode
)
_CLASS_NDARRAY = None
def _set_class_ndarray(cls):
global _CLASS_NDARRAY
_CLASS_NDARRAY = cls
......@@ -2,10 +2,16 @@
from __future__ import absolute_import
import ctypes
from ..base import _LIB, check_call, c_str
from ..base import _LIB, c_str, check_call
from ..object_generic import _set_class_object_base
from .types import DGLValue, TypeCode
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
from .types import (
C_TO_PY_ARG_SWITCH,
RETURN_SWITCH,
DGLValue,
TypeCode,
_wrap_arg_func,
)
ObjectHandle = ctypes.c_void_p
__init_by_constructor__ = None
......@@ -13,10 +19,12 @@ __init_by_constructor__ = None
"""Maps object type to its constructor"""
OBJECT_TYPE = {}
def _register_object(index, cls):
"""register object class in python"""
OBJECT_TYPE[index] = cls
def _return_object(x):
"""Construct a object object from the given DGLValue object"""
handle = x.v_handle
......@@ -34,32 +42,41 @@ def _return_object(x):
RETURN_SWITCH[TypeCode.OBJECT_HANDLE] = _return_object
C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_HANDLE] = _wrap_arg_func(
_return_object, TypeCode.OBJECT_HANDLE)
_return_object, TypeCode.OBJECT_HANDLE
)
class ObjectBase(object):
"""Object base class"""
__slots__ = ["handle"]
# pylint: disable=no-member
def __del__(self):
if _LIB is not None and hasattr(self, 'handle'):
if _LIB is not None and hasattr(self, "handle"):
check_call(_LIB.DGLObjectFree(self.handle))
def __getattr__(self, name):
if name == 'handle':
raise AttributeError("'handle' is a reserved attribute name that should not be used")
if name == "handle":
raise AttributeError(
"'handle' is a reserved attribute name that should not be used"
)
ret_val = DGLValue()
ret_type_code = ctypes.c_int()
ret_success = ctypes.c_int()
check_call(_LIB.DGLObjectGetAttr(
self.handle, c_str(name),
ctypes.byref(ret_val),
ctypes.byref(ret_type_code),
ctypes.byref(ret_success)))
check_call(
_LIB.DGLObjectGetAttr(
self.handle,
c_str(name),
ctypes.byref(ret_val),
ctypes.byref(ret_type_code),
ctypes.byref(ret_success),
)
)
if not ret_success.value:
raise AttributeError(
"'%s' object has no attribute '%s'" % (str(type(self)), name))
"'%s' object has no attribute '%s'" % (str(type(self)), name)
)
return RETURN_SWITCH[ret_type_code.value](ret_val)
def __init_handle_by_constructor__(self, fconstructor, *args):
......@@ -81,9 +98,12 @@ class ObjectBase(object):
"""
# assign handle first to avoid error raising
self.handle = None
handle = __init_by_constructor__(fconstructor, args) # pylint: disable=not-callable
handle = __init_by_constructor__(
fconstructor, args
) # pylint: disable=not-callable
if not isinstance(handle, ObjectHandle):
handle = ObjectHandle(handle)
self.handle = handle
_set_class_object_base(ObjectBase)
......@@ -3,17 +3,22 @@
from __future__ import absolute_import as _abs
import ctypes
from ..base import py_str, check_call, _LIB
from ..runtime_ctypes import DGLByteArray, TypeCode, DGLDataType, DGLContext
from ..base import _LIB, check_call, py_str
from ..runtime_ctypes import DGLByteArray, DGLContext, DGLDataType, TypeCode
class DGLValue(ctypes.Union):
"""DGLValue in C API"""
_fields_ = [("v_int64", ctypes.c_int64),
("v_float64", ctypes.c_double),
("v_handle", ctypes.c_void_p),
("v_str", ctypes.c_char_p),
("v_type", DGLDataType),
("v_ctx", DGLContext)]
_fields_ = [
("v_int64", ctypes.c_int64),
("v_float64", ctypes.c_double),
("v_handle", ctypes.c_void_p),
("v_str", ctypes.c_char_p),
("v_type", DGLDataType),
("v_ctx", DGLContext),
]
DGLPackedCFunc = ctypes.CFUNCTYPE(
......@@ -22,12 +27,11 @@ DGLPackedCFunc = ctypes.CFUNCTYPE(
ctypes.POINTER(ctypes.c_int),
ctypes.c_int,
ctypes.c_void_p,
ctypes.c_void_p)
ctypes.c_void_p,
)
DGLCFuncFinalizer = ctypes.CFUNCTYPE(
None,
ctypes.c_void_p)
DGLCFuncFinalizer = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
def _return_handle(x):
......@@ -37,6 +41,7 @@ def _return_handle(x):
handle = ctypes.c_void_p(handle)
return handle
def _return_bytes(x):
"""return handle"""
handle = x.v_handle
......@@ -47,16 +52,20 @@ def _return_bytes(x):
res = bytearray(size)
rptr = (ctypes.c_byte * size).from_buffer(res)
if not ctypes.memmove(rptr, arr.data, size):
raise RuntimeError('memmove failed')
raise RuntimeError("memmove failed")
return res
def _wrap_arg_func(return_f, type_code):
tcode = ctypes.c_int(type_code)
def _wrap_func(x):
check_call(_LIB.DGLCbArgToReturn(ctypes.byref(x), tcode))
return return_f(x)
return _wrap_func
RETURN_SWITCH = {
TypeCode.INT: lambda x: x.v_int64,
TypeCode.FLOAT: lambda x: x.v_float64,
......@@ -64,7 +73,9 @@ RETURN_SWITCH = {
TypeCode.NULL: lambda x: None,
TypeCode.STR: lambda x: py_str(x.v_str),
TypeCode.BYTES: _return_bytes,
TypeCode.DGL_CONTEXT: lambda x: DGLContext(x.v_ctx.device_type, x.v_ctx.device_id),
TypeCode.DGL_CONTEXT: lambda x: DGLContext(
x.v_ctx.device_type, x.v_ctx.device_id
),
}
C_TO_PY_ARG_SWITCH = {
......@@ -74,5 +85,7 @@ C_TO_PY_ARG_SWITCH = {
TypeCode.NULL: lambda x: None,
TypeCode.STR: lambda x: py_str(x.v_str),
TypeCode.BYTES: _return_bytes,
TypeCode.DGL_CONTEXT: lambda x: DGLContext(x.v_ctx.device_type, x.v_ctx.device_id),
TypeCode.DGL_CONTEXT: lambda x: DGLContext(
x.v_ctx.device_type, x.v_ctx.device_id
),
}
......@@ -3,22 +3,24 @@
"""ctypes library and helper functions """
from __future__ import absolute_import
import sys
import os
import ctypes
import logging
import os
import sys
import numpy as np
from . import libinfo
#----------------------------
# ----------------------------
# library loading
#----------------------------
# ----------------------------
if sys.version_info[0] == 3:
string_types = (str,)
numeric_types = (float, int, np.float32, np.int32)
# this function is needed for python3
# to convert ctypes.char_p .value back to python str
py_str = lambda x: x.decode('utf-8')
py_str = lambda x: x.decode("utf-8")
else:
string_types = (basestring,)
numeric_types = (float, int, long, np.float32, np.int32)
......@@ -27,8 +29,10 @@ else:
class DGLError(Exception):
"""Error thrown by DGL function"""
pass # pylint: disable=unnecessary-pass
def _load_lib():
"""Load libary by searching possible path."""
lib_path = libinfo.find_lib_path()
......@@ -39,6 +43,7 @@ def _load_lib():
lib.DGLGetLastError.restype = ctypes.c_char_p
return lib, basename, dirname
# version number
__version__ = libinfo.__version__
# library instance of nnvm
......@@ -47,9 +52,9 @@ _LIB, _LIB_NAME, _DIR_NAME = _load_lib()
# The FFI mode of DGL
_FFI_MODE = os.environ.get("DGL_FFI", "auto")
#----------------------------
# ----------------------------
# helper function in ctypes.
#----------------------------
# ----------------------------
def check_call(ret):
"""Check the return value of C API call
......@@ -77,7 +82,7 @@ def c_str(string):
str : c_char_p
A char pointer that can be passed to C API
"""
return ctypes.c_char_p(string.encode('utf-8'))
return ctypes.c_char_p(string.encode("utf-8"))
def c_array(ctype, values):
......@@ -111,10 +116,13 @@ def decorate(func, fwrapped):
The wrapped function
"""
import decorator
return decorator.decorate(func, fwrapped)
tensor_adapter_loaded = False
def load_tensor_adapter(backend, version):
"""Tell DGL to load a tensoradapter library for given backend and version.
......@@ -126,17 +134,17 @@ def load_tensor_adapter(backend, version):
The version number of the backend.
"""
global tensor_adapter_loaded
version = version.split('+')[0]
if sys.platform.startswith('linux'):
basename = 'libtensoradapter_%s_%s.so' % (backend, version)
elif sys.platform.startswith('darwin'):
basename = 'libtensoradapter_%s_%s.dylib' % (backend, version)
elif sys.platform.startswith('win'):
basename = 'tensoradapter_%s_%s.dll' % (backend, version)
version = version.split("+")[0]
if sys.platform.startswith("linux"):
basename = "libtensoradapter_%s_%s.so" % (backend, version)
elif sys.platform.startswith("darwin"):
basename = "libtensoradapter_%s_%s.dylib" % (backend, version)
elif sys.platform.startswith("win"):
basename = "tensoradapter_%s_%s.dll" % (backend, version)
else:
raise NotImplementedError('Unsupported system: %s' % sys.platform)
path = os.path.join(_DIR_NAME, 'tensoradapter', backend, basename)
tensor_adapter_loaded = (_LIB.DGLLoadTensorAdapter(path.encode('utf-8')) == 0)
raise NotImplementedError("Unsupported system: %s" % sys.platform)
path = os.path.join(_DIR_NAME, "tensoradapter", backend, basename)
tensor_adapter_loaded = _LIB.DGLLoadTensorAdapter(path.encode("utf-8")) == 0
if not tensor_adapter_loaded:
logger = logging.getLogger("dgl-core")
logger.debug("Memory optimization with PyTorch is not enabled.")
......@@ -2,9 +2,10 @@
"""Function namespace."""
from __future__ import absolute_import
import sys
import ctypes
from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE
import sys
from .base import _FFI_MODE, _LIB, c_str, check_call, py_str, string_types
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
......@@ -13,21 +14,31 @@ try:
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_function, _set_class_module
from ._cy3.core import FunctionBase as _FunctionBase
from ._cy3.core import convert_to_dgl_func
from ._cy3.core import (
_set_class_function,
_set_class_module,
convert_to_dgl_func,
)
else:
from ._cy2.core import _set_class_function, _set_class_module
from ._cy2.core import FunctionBase as _FunctionBase
from ._cy2.core import convert_to_dgl_func
from ._cy2.core import (
_set_class_function,
_set_class_module,
convert_to_dgl_func,
)
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_function, _set_class_module
from ._ctypes.function import FunctionBase as _FunctionBase
from ._ctypes.function import convert_to_dgl_func
from ._ctypes.function import (
_set_class_function,
_set_class_module,
convert_to_dgl_func,
)
FunctionHandle = ctypes.c_void_p
class Function(_FunctionBase):
"""The PackedFunc object.
......@@ -51,11 +62,13 @@ class Function(_FunctionBase):
dgl.register_func: How to register global function.
dgl.get_global_func: How to get global function.
"""
pass # pylint: disable=unnecessary-pass
class ModuleBase(object):
"""Base class for module"""
__slots__ = ["handle", "_entry", "entry_name"]
def __init__(self, handle):
......@@ -97,13 +110,16 @@ class ModuleBase(object):
The result function.
"""
ret_handle = FunctionHandle()
check_call(_LIB.DGLModGetFunction(
self.handle, c_str(name),
ctypes.c_int(query_imports),
ctypes.byref(ret_handle)))
check_call(
_LIB.DGLModGetFunction(
self.handle,
c_str(name),
ctypes.c_int(query_imports),
ctypes.byref(ret_handle),
)
)
if not ret_handle.value:
raise AttributeError(
"Module has no function '%s'" % name)
raise AttributeError("Module has no function '%s'" % name)
return Function(ret_handle, False)
def import_module(self, module):
......@@ -175,13 +191,16 @@ def register_func(func_name, f=None, override=False):
raise ValueError("expect string function name")
ioverride = ctypes.c_int(override)
def register(myf):
"""internal register function"""
if not isinstance(myf, Function):
myf = convert_to_dgl_func(myf)
check_call(_LIB.DGLFuncRegisterGlobal(
c_str(func_name), myf.handle, ioverride))
check_call(
_LIB.DGLFuncRegisterGlobal(c_str(func_name), myf.handle, ioverride)
)
return myf
if f:
return register(f)
return register
......@@ -214,7 +233,6 @@ def get_global_func(name, allow_missing=False):
raise ValueError("Cannot find global function %s" % name)
def list_global_func_names():
"""Get list of global functions registered.
......@@ -226,8 +244,9 @@ def list_global_func_names():
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()
check_call(_LIB.DGLFuncListGlobalNames(ctypes.byref(size),
ctypes.byref(plist)))
check_call(
_LIB.DGLFuncListGlobalNames(ctypes.byref(size), ctypes.byref(plist))
)
fnames = []
for i in range(size.value):
fnames.append(py_str(plist[i]))
......@@ -249,8 +268,10 @@ def extract_ext_funcs(finit):
The extracted functions
"""
fdict = {}
def _list(name, func):
fdict[name] = func
myf = convert_to_dgl_func(_list)
ret = finit(myf.handle)
_ = myf
......@@ -258,11 +279,13 @@ def extract_ext_funcs(finit):
raise RuntimeError("cannot initialize with %s" % finit)
return fdict
def _get_api(f):
flocal = f
flocal.is_global = True
return flocal
def _init_api(namespace, target_module_name=None):
"""Initialize api for a given module name
......@@ -272,8 +295,7 @@ def _init_api(namespace, target_module_name=None):
target_module_name : str
The target module name if different from namespace
"""
target_module_name = (
target_module_name if target_module_name else namespace)
target_module_name = target_module_name if target_module_name else namespace
if namespace.startswith("dgl."):
_init_api_prefix(target_module_name, namespace[4:])
else:
......@@ -284,10 +306,10 @@ def _init_api_prefix(module_name, prefix):
module = sys.modules[module_name]
for name in list_global_func_names():
if name.startswith("_") and not name.startswith('_deprecate'):
if name.startswith("_") and not name.startswith("_deprecate"):
# internal APIs are ignored
continue
name_split = name.rsplit('.', 1)
name_split = name.rsplit(".", 1)
if name_split[0] != prefix:
continue
......@@ -300,12 +322,13 @@ def _init_api_prefix(module_name, prefix):
f = get_global_func(name)
ff = _get_api(f)
ff.__name__ = fname
ff.__doc__ = ("DGL PackedFunc %s. " % fname)
ff.__doc__ = "DGL PackedFunc %s. " % fname
setattr(target_module, ff.__name__, ff)
def _init_internal_api():
for name in list_global_func_names():
if not name.startswith("_") or name.startswith('_deprecate'):
if not name.startswith("_") or name.startswith("_deprecate"):
# normal APIs are ignored
continue
target_module = sys.modules["dgl._api_internal"]
......@@ -316,7 +339,8 @@ def _init_internal_api():
f = get_global_func(name)
ff = _get_api(f)
ff.__name__ = fname
ff.__doc__ = ("DGL PackedFunc %s. " % fname)
ff.__doc__ = "DGL PackedFunc %s. " % fname
setattr(target_module, ff.__name__, ff)
_set_class_function(Function)
"""Library information."""
from __future__ import absolute_import
import sys
import os
import pathlib
import sys
def find_lib_path(name=None, search_path=None, optional=False):
......@@ -30,13 +31,21 @@ def find_lib_path(name=None, search_path=None, optional=False):
dll_path = []
if os.environ.get('DGL_LIBRARY_PATH', None):
dll_path.append(os.environ['DGL_LIBRARY_PATH'])
if sys.platform.startswith('linux') and os.environ.get('LD_LIBRARY_PATH', None):
dll_path.extend([p.strip() for p in os.environ['LD_LIBRARY_PATH'].split(":")])
elif sys.platform.startswith('darwin') and os.environ.get('DYLD_LIBRARY_PATH', None):
dll_path.extend([p.strip() for p in os.environ['DYLD_LIBRARY_PATH'].split(":")])
if os.environ.get("DGL_LIBRARY_PATH", None):
dll_path.append(os.environ["DGL_LIBRARY_PATH"])
if sys.platform.startswith("linux") and os.environ.get(
"LD_LIBRARY_PATH", None
):
dll_path.extend(
[p.strip() for p in os.environ["LD_LIBRARY_PATH"].split(":")]
)
elif sys.platform.startswith("darwin") and os.environ.get(
"DYLD_LIBRARY_PATH", None
):
dll_path.extend(
[p.strip() for p in os.environ["DYLD_LIBRARY_PATH"].split(":")]
)
# Pip lib directory
dll_path.append(os.path.join(ffi_dir, ".."))
......@@ -54,17 +63,21 @@ def find_lib_path(name=None, search_path=None, optional=False):
elif isinstance(search_path, str):
dll_path.append(search_path)
else:
raise ValueError("type(search_path)={} is invalid".format(type(search_path)))
dll_path = [str(x.absolute()) if isinstance(x, pathlib.Path)
else os.path.abspath(x) for x in dll_path]
raise ValueError(
"type(search_path)={} is invalid".format(type(search_path))
)
dll_path = [
str(x.absolute()) if isinstance(x, pathlib.Path) else os.path.abspath(x)
for x in dll_path
]
if name is None:
if sys.platform.startswith('win32'):
name = ['libdgl.dll', 'dgl.dll']
elif sys.platform.startswith('darwin'):
name = 'libdgl.dylib'
if sys.platform.startswith("win32"):
name = ["libdgl.dll", "dgl.dll"]
elif sys.platform.startswith("darwin"):
name = "libdgl.dylib"
else:
name = 'libdgl.so'
name = "libdgl.so"
if isinstance(name, str):
name = [name]
......@@ -76,9 +89,11 @@ def find_lib_path(name=None, search_path=None, optional=False):
lib_found = [p for p in lib_dll_path if os.path.isfile(p)]
if not lib_found:
message = ('Cannot find the files.\n' +
'List of candidates:\n' +
str('\n'.join(lib_dll_path)))
message = (
"Cannot find the files.\n"
+ "List of candidates:\n"
+ str("\n".join(lib_dll_path))
)
if not optional:
raise RuntimeError(message)
return None
......
......@@ -2,13 +2,20 @@
"""Runtime NDArray api"""
from __future__ import absolute_import
import sys
import ctypes
import sys
import numpy as np
from .base import _LIB, check_call, c_array, string_types, _FFI_MODE, c_str
from .runtime_ctypes import DGLDataType, DGLContext, DGLArray, DGLArrayHandle
from .runtime_ctypes import TypeCode, dgl_shape_index_t
from .base import _FFI_MODE, _LIB, c_array, c_str, check_call, string_types
from .runtime_ctypes import (
DGLArray,
DGLArrayHandle,
DGLContext,
DGLDataType,
TypeCode,
dgl_shape_index_t,
)
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
......@@ -17,15 +24,31 @@ try:
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
from ._cy3.core import NDArrayBase as _NDArrayBase
from ._cy3.core import (
_from_dlpack,
_make_array,
_reg_extension,
_set_class_ndarray,
)
else:
from ._cy2.core import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
from ._cy2.core import NDArrayBase as _NDArrayBase
from ._cy2.core import (
_from_dlpack,
_make_array,
_reg_extension,
_set_class_ndarray,
)
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.ndarray import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
from ._ctypes.ndarray import (
_from_dlpack,
_make_array,
_reg_extension,
_set_class_ndarray,
)
def context(dev_type, dev_id=0):
"""Construct a DGL context with given device type and id.
......@@ -63,10 +86,9 @@ def context(dev_type, dev_id=0):
def numpyasarray(np_data):
"""Return a DGLArray representation of a numpy array.
"""
"""Return a DGLArray representation of a numpy array."""
data = np_data
assert data.flags['C_CONTIGUOUS']
assert data.flags["C_CONTIGUOUS"]
arr = DGLArray()
shape = c_array(dgl_shape_index_t, data.shape)
arr.data = data.ctypes.data_as(ctypes.c_void_p)
......@@ -102,14 +124,18 @@ def empty(shape, dtype="float32", ctx=context(1, 0)):
ndim = ctypes.c_int(len(shape))
handle = DGLArrayHandle()
dtype = DGLDataType(dtype)
check_call(_LIB.DGLArrayAlloc(
shape, ndim,
ctypes.c_int(dtype.type_code),
ctypes.c_int(dtype.bits),
ctypes.c_int(dtype.lanes),
ctx.device_type,
ctx.device_id,
ctypes.byref(handle)))
check_call(
_LIB.DGLArrayAlloc(
shape,
ndim,
ctypes.c_int(dtype.type_code),
ctypes.c_int(dtype.bits),
ctypes.c_int(dtype.lanes),
ctx.device_type,
ctx.device_id,
ctypes.byref(handle),
)
)
return _make_array(handle, False)
......@@ -135,18 +161,23 @@ def empty_shared_mem(name, is_create, shape, dtype="float32"):
arr : dgl.nd.NDArray
The array dgl supported.
"""
name = ctypes.c_char_p(name.encode('utf-8'))
name = ctypes.c_char_p(name.encode("utf-8"))
shape = c_array(dgl_shape_index_t, shape)
ndim = ctypes.c_int(len(shape))
handle = DGLArrayHandle()
dtype = DGLDataType(dtype)
check_call(_LIB.DGLArrayAllocSharedMem(
name, shape, ndim,
ctypes.c_int(dtype.type_code),
ctypes.c_int(dtype.bits),
ctypes.c_int(dtype.lanes),
is_create,
ctypes.byref(handle)))
check_call(
_LIB.DGLArrayAllocSharedMem(
name,
shape,
ndim,
ctypes.c_int(dtype.type_code),
ctypes.c_int(dtype.bits),
ctypes.c_int(dtype.lanes),
is_create,
ctypes.byref(handle),
)
)
return _make_array(handle, False)
......@@ -171,10 +202,14 @@ def from_dlpack(dltensor):
class NDArrayBase(_NDArrayBase):
"""A simple Device/CPU Array object in runtime."""
@property
def shape(self):
"""Shape of this array"""
return tuple(self.handle.contents.shape[i] for i in range(self.handle.contents.ndim))
return tuple(
self.handle.contents.shape[i]
for i in range(self.handle.contents.ndim)
)
@property
def dtype(self):
......@@ -219,17 +254,19 @@ class NDArrayBase(_NDArrayBase):
def __setitem__(self, in_slice, value):
"""Set ndarray value"""
if (not isinstance(in_slice, slice) or
in_slice.start is not None
or in_slice.stop is not None):
raise ValueError('Array only support set from numpy array')
if (
not isinstance(in_slice, slice)
or in_slice.start is not None
or in_slice.stop is not None
):
raise ValueError("Array only support set from numpy array")
if isinstance(value, NDArrayBase):
if value.handle is not self.handle:
value.copyto(self)
elif isinstance(value, (np.ndarray, np.generic)):
self.copyfrom(value)
else:
raise TypeError('type %s not supported' % str(type(value)))
raise TypeError("type %s not supported" % str(type(value)))
def copyfrom(self, source_array):
"""Perform a synchronized copy from the array.
......@@ -252,8 +289,10 @@ class NDArrayBase(_NDArrayBase):
try:
source_array = np.asarray(source_array, dtype=self.dtype)
except:
raise TypeError('array must be an array_like data,' +
'type %s is not supported' % str(type(source_array)))
raise TypeError(
"array must be an array_like data,"
+ "type %s is not supported" % str(type(source_array))
)
t = DGLDataType(self.dtype)
shape, dtype = self.shape, self.dtype
if t.lanes > 1:
......@@ -262,12 +301,17 @@ class NDArrayBase(_NDArrayBase):
dtype = str(t)
if source_array.shape != shape:
raise ValueError("array shape do not match the shape of NDArray {0} vs {1}".format(
source_array.shape, shape))
raise ValueError(
"array shape do not match the shape of NDArray {0} vs {1}".format(
source_array.shape, shape
)
)
source_array = np.ascontiguousarray(source_array, dtype=dtype)
assert source_array.flags['C_CONTIGUOUS']
assert source_array.flags["C_CONTIGUOUS"]
data = source_array.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize)
nbytes = ctypes.c_size_t(
source_array.size * source_array.dtype.itemsize
)
check_call(_LIB.DGLArrayCopyFromBytes(self.handle, data, nbytes))
return self
......@@ -293,7 +337,7 @@ class NDArrayBase(_NDArrayBase):
t.lanes = 1
dtype = str(t)
np_arr = np.empty(shape, dtype=dtype)
assert np_arr.flags['C_CONTIGUOUS']
assert np_arr.flags["C_CONTIGUOUS"]
data = np_arr.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize)
check_call(_LIB.DGLArrayCopyToBytes(self.handle, data, nbytes))
......@@ -310,20 +354,17 @@ class NDArrayBase(_NDArrayBase):
if isinstance(target, DGLContext):
target = empty(self.shape, self.dtype, target)
if isinstance(target, NDArrayBase):
check_call(_LIB.DGLArrayCopyFromTo(
self.handle, target.handle))
check_call(_LIB.DGLArrayCopyFromTo(self.handle, target.handle))
else:
raise ValueError("Unsupported target type %s" % str(type(target)))
return target
def pin_memory_(self):
"""Pin host memory and map into GPU address space (in-place)
"""
"""Pin host memory and map into GPU address space (in-place)"""
check_call(_LIB.DGLArrayPinData(self.handle))
def unpin_memory_(self):
"""Unpin host memory pinned by pin_memory_()
"""
"""Unpin host memory pinned by pin_memory_()"""
check_call(_LIB.DGLArrayUnpinData(self.handle))
def record_stream(self, stream):
......@@ -340,6 +381,7 @@ class NDArrayBase(_NDArrayBase):
"""
check_call(_LIB.DGLArrayRecordStream(self.handle, stream))
def free_extension_handle(handle, type_code):
"""Free c++ extension type handle
......@@ -353,6 +395,7 @@ def free_extension_handle(handle, type_code):
"""
check_call(_LIB.DGLExtTypeFree(handle, ctypes.c_int(type_code)))
def register_extension(cls, fcreate=None):
"""Register a extension class to DGL.
......@@ -398,6 +441,8 @@ def register_extension(cls, fcreate=None):
return self.handle.value
"""
if fcreate and cls._dgl_tcode < TypeCode.EXT_BEGIN:
raise ValueError("Cannot register create when extension tcode is same as buildin")
raise ValueError(
"Cannot register create when extension tcode is same as buildin"
)
_reg_extension(cls, fcreate)
return cls
......@@ -4,23 +4,27 @@ from __future__ import absolute_import
import ctypes
import sys
from .. import _api_internal
from .base import _FFI_MODE, _LIB, c_str, check_call, py_str
from .object_generic import ObjectGeneric, convert_to_object
from .base import _LIB, check_call, c_str, py_str, _FFI_MODE
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" \
else ImportError # pylint: disable=invalid-name
# pylint: disable=invalid-name
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
try:
# pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _register_object, ObjectBase as _ObjectBase
from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import _register_object
else:
from ._cy2.core import _register_object, ObjectBase as _ObjectBase
from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import _register_object
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.object import _register_object, ObjectBase as _ObjectBase
from ._ctypes.object import ObjectBase as _ObjectBase
from ._ctypes.object import _register_object
def _new_object(cls):
......@@ -36,11 +40,15 @@ class ObjectBase(_ObjectBase):
Note that the same handle **CANNOT** be shared across multiple ObjectBase instances.
"""
def __dir__(self):
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()
check_call(_LIB.DGLObjectListAttrNames(
self.handle, ctypes.byref(size), ctypes.byref(plist)))
check_call(
_LIB.DGLObjectListAttrNames(
self.handle, ctypes.byref(size), ctypes.byref(plist)
)
)
names = []
for i in range(size.value):
names.append(py_str(plist[i]))
......@@ -57,7 +65,7 @@ class ObjectBase(_ObjectBase):
def __reduce__(self):
cls = type(self)
return (_new_object, (cls, ), self.__getstate__())
return (_new_object, (cls,), self.__getstate__())
def __getstate__(self):
# TODO(minjie): TVM assumes that a Node (Object in DGL) can be serialized
......@@ -100,7 +108,9 @@ def register_object(type_key=None):
def register(cls):
"""internal register function"""
tindex = ctypes.c_int()
ret = _LIB.DGLObjectTypeKey2Index(c_str(object_name), ctypes.byref(tindex))
ret = _LIB.DGLObjectTypeKey2Index(
c_str(object_name), ctypes.byref(tindex)
)
if ret == 0:
_register_object(tindex.value, cls)
return cls
......
......@@ -2,23 +2,28 @@
# pylint: disable=unused-import
from __future__ import absolute_import
from numbers import Number, Integral
from numbers import Integral, Number
from .. import _api_internal
from .base import string_types
# Object base class
_CLASS_OBJECT_BASE = None
def _set_class_object_base(cls):
global _CLASS_OBJECT_BASE
_CLASS_OBJECT_BASE = cls
class ObjectGeneric(object):
"""Base class for all classes that can be converted to object."""
def asobject(self):
"""Convert value to object"""
raise NotImplementedError()
def convert_to_object(value):
"""Convert a python value to corresponding object type.
......@@ -40,9 +45,12 @@ def convert_to_object(value):
if isinstance(value, dict):
vlist = []
for item in value.items():
if (not isinstance(item[0], _CLASS_OBJECT_BASE) and
not isinstance(item[0], string_types)):
raise ValueError("key of map must already been a container type")
if not isinstance(item[0], _CLASS_OBJECT_BASE) and not isinstance(
item[0], string_types
):
raise ValueError(
"key of map must already been a container type"
)
vlist.append(item[0])
vlist.append(convert_to_object(item[1]))
return _api_internal._Map(*vlist)
......
......@@ -4,14 +4,18 @@ from __future__ import absolute_import
import ctypes
import json
import numpy as np
from .base import _LIB, check_call
from .. import _api_internal
from .base import _LIB, check_call
dgl_shape_index_t = ctypes.c_int64
class TypeCode(object):
"""Type code used in API calls"""
INT = 0
UINT = 1
FLOAT = 2
......@@ -28,22 +32,25 @@ class TypeCode(object):
NDARRAY_CONTAINER = 13
EXT_BEGIN = 15
class DGLByteArray(ctypes.Structure):
"""Temp data structure for byte array."""
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)]
_fields_ = [
("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t),
]
class DGLDataType(ctypes.Structure):
"""DGL datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8),
("lanes", ctypes.c_uint16)]
CODE2STR = {
0 : 'int',
1 : 'uint',
2 : 'float',
4 : 'handle'
}
_fields_ = [
("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8),
("lanes", ctypes.c_uint16),
]
CODE2STR = {0: "int", 1: "uint", 2: "float", 4: "handle"}
_cache = {}
def __new__(cls, type_str):
......@@ -90,50 +97,54 @@ class DGLDataType(ctypes.Structure):
return x
def __eq__(self, other):
return (self.bits == other.bits and
self.type_code == other.type_code and
self.lanes == other.lanes)
return (
self.bits == other.bits
and self.type_code == other.type_code
and self.lanes == other.lanes
)
def __ne__(self, other):
return not self.__eq__(other)
RPC_SESS_MASK = 128
class DGLContext(ctypes.Structure):
"""DGL context strucure."""
_fields_ = [("device_type", ctypes.c_int),
("device_id", ctypes.c_int)]
_fields_ = [("device_type", ctypes.c_int), ("device_id", ctypes.c_int)]
MASK2STR = {
1 : 'cpu',
2 : 'gpu',
4 : 'opencl',
5 : 'aocl',
6 : 'sdaccel',
7 : 'vulkan',
8 : 'metal',
9 : 'vpi',
10: 'rocm',
11: 'opengl',
12: 'ext_dev',
1: "cpu",
2: "gpu",
4: "opencl",
5: "aocl",
6: "sdaccel",
7: "vulkan",
8: "metal",
9: "vpi",
10: "rocm",
11: "opengl",
12: "ext_dev",
}
STR2MASK = {
'llvm': 1,
'stackvm': 1,
'cpu': 1,
'gpu': 2,
'cuda': 2,
'nvptx': 2,
'cl': 4,
'opencl': 4,
'aocl' : 5,
'aocl_sw_emu' : 5,
'sdaccel': 6,
'vulkan': 7,
'metal': 8,
'vpi': 9,
'rocm': 10,
'opengl': 11,
'ext_dev': 12,
"llvm": 1,
"stackvm": 1,
"cpu": 1,
"gpu": 2,
"cuda": 2,
"nvptx": 2,
"cl": 4,
"opencl": 4,
"aocl": 5,
"aocl_sw_emu": 5,
"sdaccel": 6,
"vulkan": 7,
"metal": 8,
"vpi": 9,
"rocm": 10,
"opengl": 11,
"ext_dev": 12,
}
_cache = {}
......@@ -155,26 +166,25 @@ class DGLContext(ctypes.Structure):
@property
def exist(self):
"""Whether this device exist."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 0) != 0
return (
_api_internal._GetDeviceAttr(self.device_type, self.device_id, 0)
!= 0
)
@property
def max_threads_per_block(self):
"""Maximum number of threads on each block."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 1)
return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 1)
@property
def warp_size(self):
"""Number of threads that executes in concurrent."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 2)
return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 2)
@property
def max_shared_memory_per_block(self):
"""Total amount of shared memory per block in bytes."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 3)
return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 3)
@property
def compute_version(self):
......@@ -187,26 +197,22 @@ class DGLContext(ctypes.Structure):
version : str
The version string in `major.minor` format.
"""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 4)
return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 4)
@property
def device_name(self):
"""Return the string name of device."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 5)
return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 5)
@property
def max_clock_rate(self):
"""Return the max clock frequency of device."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 6)
return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 6)
@property
def multi_processor_count(self):
"""Return the number of compute units of device."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 7)
return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 7)
@property
def max_thread_dimensions(self):
......@@ -217,17 +223,20 @@ class DGLContext(ctypes.Structure):
dims: List of int
The maximum length of threadIdx.x, threadIdx.y, threadIdx.z
"""
return json.loads(_api_internal._GetDeviceAttr(
self.device_type, self.device_id, 8))
return json.loads(
_api_internal._GetDeviceAttr(self.device_type, self.device_id, 8)
)
def sync(self):
"""Synchronize until jobs finished at the context."""
check_call(_LIB.DGLSynchronize(self.device_type, self.device_id, None))
def __eq__(self, other):
return (isinstance(other, DGLContext) and
self.device_id == other.device_id and
self.device_type == other.device_type)
return (
isinstance(other, DGLContext)
and self.device_id == other.device_id
and self.device_type == other.device_type
)
def __ne__(self, other):
return not self.__eq__(other)
......@@ -237,9 +246,14 @@ class DGLContext(ctypes.Structure):
tbl_id = self.device_type / RPC_SESS_MASK - 1
dev_type = self.device_type % RPC_SESS_MASK
return "remote[%d]:%s(%d)" % (
tbl_id, DGLContext.MASK2STR[dev_type], self.device_id)
tbl_id,
DGLContext.MASK2STR[dev_type],
self.device_id,
)
return "%s(%d)" % (
DGLContext.MASK2STR[self.device_type], self.device_id)
DGLContext.MASK2STR[self.device_type],
self.device_id,
)
def __hash__(self):
return hash((self.device_type, self.device_id))
......@@ -247,13 +261,17 @@ class DGLContext(ctypes.Structure):
class DGLArray(ctypes.Structure):
"""DGLValue in C API"""
_fields_ = [("data", ctypes.c_void_p),
("ctx", DGLContext),
("ndim", ctypes.c_int),
("dtype", DGLDataType),
("shape", ctypes.POINTER(dgl_shape_index_t)),
("strides", ctypes.POINTER(dgl_shape_index_t)),
("byte_offset", ctypes.c_uint64)]
_fields_ = [
("data", ctypes.c_void_p),
("ctx", DGLContext),
("ndim", ctypes.c_int),
("dtype", DGLDataType),
("shape", ctypes.POINTER(dgl_shape_index_t)),
("strides", ctypes.POINTER(dgl_shape_index_t)),
("byte_offset", ctypes.c_uint64),
]
DGLArrayHandle = ctypes.POINTER(DGLArray)
......
......@@ -5,13 +5,13 @@ For applications, please use PyTorch's stream management, of which DGL is aware.
from __future__ import absolute_import
import ctypes
from .base import _LIB, check_call, _FFI_MODE
from .base import _FFI_MODE, _LIB, check_call
from .runtime_ctypes import DGLStreamHandle
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
def to_dgl_stream_handle(cuda_stream):
""" Convert torch.cuda.Stream to DGL stream handle
"""Convert torch.cuda.Stream to DGL stream handle
Parameters
----------
......@@ -24,6 +24,7 @@ def to_dgl_stream_handle(cuda_stream):
"""
return ctypes.c_void_p(cuda_stream.cuda_stream)
def _dgl_get_stream(ctx):
"""Get the current CUDA stream of the given DGL context.
......@@ -37,6 +38,9 @@ def _dgl_get_stream(ctx):
DGLStreamHandle of the current CUDA stream.
"""
current_cuda_stream = DGLStreamHandle()
check_call(_LIB.DGLGetStream(
ctx.device_type, ctx.device_id, ctypes.byref(current_cuda_stream)))
check_call(
_LIB.DGLGetStream(
ctx.device_type, ctx.device_id, ctypes.byref(current_cuda_stream)
)
)
return current_cuda_stream
from __future__ import absolute_import
import sys
import os
import json
import importlib
import json
import logging
import os
import sys
from . import backend
from .set_default_backend import set_default_backend
......@@ -13,13 +13,18 @@ _enabled_apis = set()
logger = logging.getLogger("dgl-core")
def _gen_missing_api(api, mod_name):
def _missing_api(*args, **kwargs):
raise ImportError('API "%s" is not supported by backend "%s".'
' You can switch to other backends by setting'
' the DGLBACKEND environment.' % (api, mod_name))
raise ImportError(
'API "%s" is not supported by backend "%s".'
" You can switch to other backends by setting"
" the DGLBACKEND environment." % (api, mod_name)
)
return _missing_api
def load_backend(mod_name):
# Load backend does four things:
# (1) Import backend framework (PyTorch, MXNet, Tensorflow, etc.)
......@@ -28,40 +33,46 @@ def load_backend(mod_name):
# (3) Sets up the tensoradapter library path.
# (4) Import the Python wrappers of the backend framework. DGL does this last because
# it already depends on both the backend framework and the DGL C library.
if mod_name == 'pytorch':
if mod_name == "pytorch":
import torch
mod = torch
elif mod_name == 'mxnet':
elif mod_name == "mxnet":
import mxnet
mod = mxnet
elif mod_name == 'tensorflow':
elif mod_name == "tensorflow":
import tensorflow
mod = tensorflow
else:
raise NotImplementedError('Unsupported backend: %s' % mod_name)
raise NotImplementedError("Unsupported backend: %s" % mod_name)
from .._ffi.base import load_tensor_adapter # imports DGL C library
from .._ffi.base import load_tensor_adapter # imports DGL C library
version = mod.__version__
load_tensor_adapter(mod_name, version)
logger.debug('Using backend: %s' % mod_name)
mod = importlib.import_module('.%s' % mod_name, __name__)
logger.debug("Using backend: %s" % mod_name)
mod = importlib.import_module(".%s" % mod_name, __name__)
thismod = sys.modules[__name__]
for api in backend.__dict__.keys():
if api.startswith('__'):
if api.startswith("__"):
# ignore python builtin attributes
continue
if api == 'data_type_dict':
if api == "data_type_dict":
# load data type
if api not in mod.__dict__:
raise ImportError('API "data_type_dict" is required but missing for'
' backend "%s".' % (mod_name))
raise ImportError(
'API "data_type_dict" is required but missing for'
' backend "%s".' % (mod_name)
)
data_type_dict = mod.__dict__[api]()
for name, dtype in data_type_dict.items():
setattr(thismod, name, dtype)
# override data type dict function
setattr(thismod, 'data_type_dict', data_type_dict)
setattr(thismod, "data_type_dict", data_type_dict)
# for data types with aliases, treat the first listed type as
# the true one
......@@ -69,11 +80,9 @@ def load_backend(mod_name):
for k, v in data_type_dict.items():
if not v in rev_data_type_dict.keys():
rev_data_type_dict[v] = k
setattr(thismod,
'reverse_data_type_dict',
rev_data_type_dict)
setattr(thismod, "reverse_data_type_dict", rev_data_type_dict)
# log backend name
setattr(thismod, 'backend_name', mod_name)
setattr(thismod, "backend_name", mod_name)
else:
# load functions
if api in mod.__dict__:
......@@ -82,28 +91,32 @@ def load_backend(mod_name):
else:
setattr(thismod, api, _gen_missing_api(api, mod_name))
def get_preferred_backend():
default_dir = None
if "DGLDEFAULTDIR" in os.environ:
default_dir = os.getenv('DGLDEFAULTDIR')
default_dir = os.getenv("DGLDEFAULTDIR")
else:
default_dir = os.path.join(os.path.expanduser('~'), '.dgl')
config_path = os.path.join(default_dir, 'config.json')
default_dir = os.path.join(os.path.expanduser("~"), ".dgl")
config_path = os.path.join(default_dir, "config.json")
backend_name = None
if "DGLBACKEND" in os.environ:
backend_name = os.getenv('DGLBACKEND')
backend_name = os.getenv("DGLBACKEND")
elif os.path.exists(config_path):
with open(config_path, "r") as config_file:
config_dict = json.load(config_file)
backend_name = config_dict.get('backend', '').lower()
backend_name = config_dict.get("backend", "").lower()
if (backend_name in ['tensorflow', 'mxnet', 'pytorch']):
if backend_name in ["tensorflow", "mxnet", "pytorch"]:
return backend_name
else:
print("DGL backend not selected or invalid. "
"Assuming PyTorch for now.", file=sys.stderr)
set_default_backend(default_dir, 'pytorch')
return 'pytorch'
print(
"DGL backend not selected or invalid. "
"Assuming PyTorch for now.",
file=sys.stderr,
)
set_default_backend(default_dir, "pytorch")
return "pytorch"
load_backend(get_preferred_backend())
......@@ -124,8 +137,10 @@ def is_enabled(api):
"""
return api in _enabled_apis
def to_dgl_nd(data):
return zerocopy_to_dgl_ndarray(data)
def from_dgl_nd(data):
return zerocopy_from_dgl_ndarray(data)
......@@ -16,6 +16,7 @@ that returns whether the interface is supported by the framework or not.
###############################################################################
# Tensor, data type and context interfaces
def data_type_dict():
"""Returns a dictionary from data type string to the data type.
......@@ -52,10 +53,12 @@ def data_type_dict():
"""
pass
def cpu():
"""Return a context object for CPU device."""
pass
def tensor(data, dtype=None):
"""Create a tensor given the data and data type.
......@@ -81,6 +84,7 @@ def tensor(data, dtype=None):
"""
pass
def as_scalar(data):
"""Returns a scalar whose value is copied from this array.
......@@ -96,6 +100,7 @@ def as_scalar(data):
"""
pass
def get_preferred_sparse_format():
"""Get the preferred sparse matrix format supported by the backend.
......@@ -109,6 +114,7 @@ def get_preferred_sparse_format():
"""
pass
def sparse_matrix(data, index, shape, force_format=False):
"""Create a sparse matrix.
......@@ -146,6 +152,7 @@ def sparse_matrix(data, index, shape, force_format=False):
"""
pass
def sparse_matrix_indices(spmat):
"""Return the indices of the given sparse matrix.
......@@ -169,10 +176,12 @@ def sparse_matrix_indices(spmat):
"""
pass
def is_tensor(obj):
"""Returns true if the given object is a framework-specific tensor."""
pass
def shape(input):
"""Return the shape of the tensor.
......@@ -188,6 +197,7 @@ def shape(input):
"""
pass
def dtype(input):
"""Return the data type of the tensor.
......@@ -203,6 +213,7 @@ def dtype(input):
"""
pass
def ndim(input):
"""Return the number of dimensions of the tensor.
......@@ -218,6 +229,7 @@ def ndim(input):
"""
pass
def context(input):
"""Return the context/device of the input tensor.
......@@ -233,6 +245,7 @@ def context(input):
"""
pass
def device_type(ctx):
"""Return a str representing device type.
......@@ -247,6 +260,7 @@ def device_type(ctx):
"""
pass
def device_id(ctx):
"""Return device index.
......@@ -265,6 +279,7 @@ def device_id(ctx):
"""
pass
def to_backend_ctx(dglctx):
"""Convert a DGL context object to a backend context.
......@@ -279,6 +294,7 @@ def to_backend_ctx(dglctx):
"""
pass
def astype(input, ty):
"""Convert the input tensor to the given data type.
......@@ -296,6 +312,7 @@ def astype(input, ty):
"""
pass
def asnumpy(input):
"""Convert the input tensor to numpy array.
......@@ -313,6 +330,7 @@ def asnumpy(input):
"""
pass
def copy_to(input, ctx, **kwargs):
"""Copy the given tensor to the context.
......@@ -330,6 +348,7 @@ def copy_to(input, ctx, **kwargs):
"""
pass
def is_pinned(input):
"""Check whether the tensor is in pinned memory.
......@@ -345,12 +364,14 @@ def is_pinned(input):
"""
pass
###############################################################################
# Tensor functions on feature data
# --------------------------------
# These functions are performance critical, so it's better to have efficient
# implementation in each framework.
def sum(input, dim, keepdims=False):
"""Reduce sum the input tensor along the given dim.
......@@ -370,6 +391,7 @@ def sum(input, dim, keepdims=False):
"""
pass
def floor_div(in1, in2):
"""Element-wise integer division and rounds each quotient towards zero.
......@@ -386,6 +408,7 @@ def floor_div(in1, in2):
A framework-specific tensor.
"""
def reduce_sum(input):
"""Returns the sum of all elements in the input tensor.
......@@ -401,6 +424,7 @@ def reduce_sum(input):
"""
pass
def cumsum(input, dim):
"""Return the cumulative sum of the elements along a given axis.
......@@ -418,6 +442,7 @@ def cumsum(input, dim):
"""
pass
def mean(input, dim):
"""Reduce average the input tensor along the given dim.
......@@ -435,6 +460,7 @@ def mean(input, dim):
"""
pass
def reduce_mean(input):
"""Returns the average of all elements in the input tensor.
......@@ -450,6 +476,7 @@ def reduce_mean(input):
"""
pass
def max(input, dim):
"""Reduce max the input tensor along the given dim.
......@@ -467,6 +494,7 @@ def max(input, dim):
"""
pass
def reduce_max(input):
"""Returns the max of all elements in the input tensor.
......@@ -482,6 +510,7 @@ def reduce_max(input):
"""
pass
def min(input, dim):
"""Reduce min the input tensor along the given dim.
......@@ -499,6 +528,7 @@ def min(input, dim):
"""
pass
def reduce_min(input):
"""Returns the min of all elements in the input tensor.
......@@ -533,6 +563,7 @@ def argsort(input, dim, descending):
A framework-specific tensor.
"""
def topk(input, k, dim, descending=True):
"""Return the k largest elements of the given input tensor along the given dimension.
......@@ -551,6 +582,7 @@ def topk(input, k, dim, descending=True):
"""
pass
def argtopk(input, k, dim, descending=True):
"""Return the indices of the k largest elements of the given input tensor
along the given dimension.
......@@ -570,6 +602,7 @@ def argtopk(input, k, dim, descending=True):
"""
pass
def exp(input):
"""Returns a new tensor with the exponential of the elements of the input tensor `input`.
......@@ -585,6 +618,7 @@ def exp(input):
"""
pass
def inverse(input):
"""Returns the inverse matrix of a square matrix if it exists.
......@@ -600,6 +634,7 @@ def inverse(input):
"""
pass
def sqrt(input):
"""Returns a new tensor with the square root of the elements of the input tensor `input`.
......@@ -615,6 +650,7 @@ def sqrt(input):
"""
pass
def softmax(input, dim=-1):
"""Apply the softmax function on given dimension.
......@@ -650,6 +686,7 @@ def cat(seq, dim):
"""
pass
def stack(seq, dim):
"""Stack the sequence of tensors along the given dimension.
......@@ -667,6 +704,7 @@ def stack(seq, dim):
"""
pass
def split(input, sizes_or_sections, dim):
"""Split the input tensor into chunks.
......@@ -692,6 +730,7 @@ def split(input, sizes_or_sections, dim):
"""
pass
def repeat(input, repeats, dim):
"""Repeats elements of an array.
......@@ -711,6 +750,7 @@ def repeat(input, repeats, dim):
"""
pass
def gather_row(data, row_index):
"""Slice out the data given the row index.
......@@ -728,6 +768,7 @@ def gather_row(data, row_index):
"""
pass
def slice_axis(data, axis, begin, end):
"""Slice along a given axis.
Returns an array slice along a given axis starting from :attr:`begin` index to :attr:`end` index.
......@@ -749,6 +790,7 @@ def slice_axis(data, axis, begin, end):
"""
pass
def take(data, indices, dim):
"""Takes elements from an input array along the given dim.
......@@ -763,6 +805,7 @@ def take(data, indices, dim):
"""
pass
def narrow_row(x, start, stop):
"""Narrow down the tensor along the first dimension.
......@@ -786,6 +829,7 @@ def narrow_row(x, start, stop):
"""
pass
def scatter_row(data, row_index, value):
"""Write the value into the data tensor using the row index.
......@@ -807,6 +851,7 @@ def scatter_row(data, row_index, value):
"""
pass
def index_add_inplace(data, row_idx, value):
"""Add the values into the data tensor using the row index inplace.
......@@ -832,6 +877,7 @@ def index_add_inplace(data, row_idx, value):
"""
pass
def scatter_row_inplace(data, row_index, value):
"""Write the value into the data tensor using the row index inplace.
......@@ -848,6 +894,7 @@ def scatter_row_inplace(data, row_index, value):
"""
pass
def squeeze(input, dim):
"""Remove the given dimension of size 1.
......@@ -865,6 +912,7 @@ def squeeze(input, dim):
"""
pass
def unsqueeze(input, dim):
"""Add the given dimension of size 1.
......@@ -882,6 +930,7 @@ def unsqueeze(input, dim):
"""
pass
def reshape(input, shape):
"""Reshape the tensor.
......@@ -899,6 +948,7 @@ def reshape(input, shape):
"""
pass
def swapaxes(input, axis1, axis2):
"""Interchange the two given axes of a tensor.
......@@ -916,6 +966,7 @@ def swapaxes(input, axis1, axis2):
"""
pass
def zeros(shape, dtype, ctx):
"""Create a zero tensor.
......@@ -935,6 +986,7 @@ def zeros(shape, dtype, ctx):
"""
pass
def zeros_like(input):
"""Create a zero tensor with the same shape, dtype and context of the
given tensor.
......@@ -951,6 +1003,7 @@ def zeros_like(input):
"""
pass
def ones(shape, dtype, ctx):
"""Create a one tensor.
......@@ -970,6 +1023,7 @@ def ones(shape, dtype, ctx):
"""
pass
def uniform(shape, dtype, ctx, low, high):
"""Create a tensor with random value in a uniform
distribution between low (inclusive) and high (exclusive).
......@@ -990,6 +1044,7 @@ def uniform(shape, dtype, ctx, low, high):
"""
pass
def randint(shape, dtype, ctx, low, high):
"""Create a tensor with random value in a uniform integer
distribution between low (inclusive) and high (exclusive)
......@@ -1010,6 +1065,7 @@ def randint(shape, dtype, ctx, low, high):
"""
pass
def pad_packed_tensor(input, lengths, value, l_min=None):
r"""Pads a packed batch of variable length tensors with given value.
......@@ -1034,6 +1090,7 @@ def pad_packed_tensor(input, lengths, value, l_min=None):
"""
pass
def pack_padded_tensor(input, lengths):
r"""Packs a tensor containing padded sequence of variable length.
......@@ -1054,6 +1111,7 @@ def pack_padded_tensor(input, lengths):
"""
pass
def boolean_mask(input, mask):
"""Selects elements in x according to the given mask from the first
dimension.
......@@ -1072,6 +1130,7 @@ def boolean_mask(input, mask):
"""
pass
def equal(x, y):
"""Compares whether the elements are equal.
......@@ -1087,6 +1146,7 @@ def equal(x, y):
"""
pass
def allclose(x, y, rtol=1e-4, atol=1e-4):
"""Compares whether all elements are close.
......@@ -1102,6 +1162,7 @@ def allclose(x, y, rtol=1e-4, atol=1e-4):
Absolute tolerance
"""
def logical_not(input):
"""Perform a logical not operation. Equivalent to np.logical_not
......@@ -1117,9 +1178,11 @@ def logical_not(input):
"""
pass
def logical_and(input1, input2):
pass
def clone(input):
"""Return a clone of the input tensor.
......@@ -1135,6 +1198,7 @@ def clone(input):
"""
pass
def clamp(data, min_val, max_val):
"""Clamp all elements in :attr:`input` into the range [min_val, max_val]
and return a resulting tensor.
......@@ -1155,6 +1219,7 @@ def clamp(data, min_val, max_val):
"""
pass
def replace_inf_with_zero(x):
"""Returns a new tensor replacing infinity and negative infinity with zeros.
......@@ -1170,6 +1235,7 @@ def replace_inf_with_zero(x):
"""
pass
def count_nonzero(input):
"""Return the count of non-zero values in the tensor input.
......@@ -1185,6 +1251,7 @@ def count_nonzero(input):
"""
pass
###############################################################################
# Tensor functions used *only* on index tensor
# ----------------
......@@ -1193,6 +1260,7 @@ def count_nonzero(input):
# DGL should contain all the operations on index, so this set of operators
# should be gradually removed.
def unique(input, return_inverse=False, return_counts=False):
"""Returns the unique scalar elements in a tensor.
......@@ -1219,6 +1287,7 @@ def unique(input, return_inverse=False, return_counts=False):
"""
pass
def full_1d(length, fill_value, dtype, ctx):
"""Create a 1D tensor full of the fill_value.
......@@ -1240,6 +1309,7 @@ def full_1d(length, fill_value, dtype, ctx):
"""
pass
def nonzero_1d(input):
"""Return the nonzero index of the given 1D input.
......@@ -1255,6 +1325,7 @@ def nonzero_1d(input):
"""
pass
def sort_1d(input):
"""Sort a 1D tensor (in ascending order) and also return the original index.
......@@ -1272,6 +1343,7 @@ def sort_1d(input):
"""
pass
def arange(start, stop, dtype, ctx):
"""Create a 1D range int64 tensor.
......@@ -1293,6 +1365,7 @@ def arange(start, stop, dtype, ctx):
"""
pass
def rand_shuffle(arr):
"""Random shuffle the data in the first dimension of the array.
......@@ -1310,6 +1383,7 @@ def rand_shuffle(arr):
"""
pass
def zerocopy_to_dlpack(input):
"""Create a dlpack tensor that shares the input memory.
......@@ -1325,6 +1399,7 @@ def zerocopy_to_dlpack(input):
"""
pass
def zerocopy_from_dlpack(dlpack_tensor):
"""Create a tensor that shares the dlpack_tensor.
......@@ -1340,6 +1415,7 @@ def zerocopy_from_dlpack(dlpack_tensor):
"""
pass
def zerocopy_to_numpy(input):
"""Create a numpy ndarray that shares the input memory.
......@@ -1355,6 +1431,7 @@ def zerocopy_to_numpy(input):
"""
pass
def zerocopy_from_numpy(np_array):
"""Create a tensor that shares the numpy array.
......@@ -1370,6 +1447,7 @@ def zerocopy_from_numpy(np_array):
"""
pass
def zerocopy_to_dgl_ndarray(input):
"""Zerocopy a framework-specific Tensor to dgl.ndarray.NDArray
......@@ -1383,6 +1461,7 @@ def zerocopy_to_dgl_ndarray(input):
"""
pass
def zerocopy_to_dgl_ndarray_for_write(input):
"""Zerocopy a framework-specific Tensor to dgl.ndarray.NDArray
that is ready for write (required in MXNet).
......@@ -1412,7 +1491,6 @@ def zerocopy_from_dgl_ndarray(input):
pass
###############################################################################
# Custom Operators for graph level computations.
......@@ -1420,8 +1498,20 @@ def zerocopy_from_dgl_ndarray(input):
# kernels (see kernel.py), and plug into tensor framework using custom op
# extensions.
def binary_reduce(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data,
out_size, lhs_map, rhs_map, out_map):
def binary_reduce(
reducer,
binary_op,
graph,
lhs,
rhs,
lhs_data,
rhs_data,
out_size,
lhs_map,
rhs_map,
out_map,
):
"""Perform binary operation between given data and reduce based on graph
structure.
......@@ -1458,6 +1548,7 @@ def binary_reduce(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data,
"""
pass
def copy_reduce(reducer, graph, target, in_data, out_size, in_map, out_map):
"""Copy target data and perform reduce based on graph structure.
......@@ -1486,8 +1577,9 @@ def copy_reduce(reducer, graph, target, in_data, out_size, in_map, out_map):
"""
pass
def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
r""" Generalized Sparse Matrix Multiplication interface.
r"""Generalized Sparse Matrix Multiplication interface.
It fuses two steps into one kernel.
(1) Computes messages by :attr:`op` source node and edge features.
(2) Aggregate the messages by :attr:`reduce_op` as the features on destination nodes.
......@@ -1523,8 +1615,9 @@ def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
"""
pass
def gspmm_hetero(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple):
r""" Generalized Sparse Matrix Multiplication interface on heterogenenous graph.
r"""Generalized Sparse Matrix Multiplication interface on heterogenenous graph.
All the relation types of the heterogeneous graph will be processed together.
It fuses two steps into one kernel.
(1) Computes messages by :attr:`op` source node and edge features.
......@@ -1564,8 +1657,9 @@ def gspmm_hetero(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple):
"""
pass
def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface.
def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target="u", rhs_target="v"):
r"""Generalized Sampled-Dense-Dense Matrix Multiplication interface.
It computes edge features by :attr:`op` lhs features and rhs features.
.. math::
......@@ -1599,8 +1693,11 @@ def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
"""
pass
def gsddmm_hetero(g, op, lhs_len, lhs_target='u', rhs_target='v', *lhs_and_rhs_tuple):
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface on
def gsddmm_hetero(
g, op, lhs_len, lhs_target="u", rhs_target="v", *lhs_and_rhs_tuple
):
r"""Generalized Sampled-Dense-Dense Matrix Multiplication interface on
heterogenenous graph. All the relation types of the heterogeneous graph
will be processed together.
It computes edge features by :attr:`op` lhs features and rhs features.
......@@ -1639,6 +1736,7 @@ def gsddmm_hetero(g, op, lhs_len, lhs_target='u', rhs_target='v', *lhs_and_rhs_t
"""
pass
def edge_softmax(gidx, logits, eids, norm_by):
r"""Compute edge softmax.
......@@ -1676,6 +1774,7 @@ def edge_softmax(gidx, logits, eids, norm_by):
"""
pass
def edge_softmax_hetero(gidx, eids, norm_by, *logits):
r"""Compute edge softmax.
......@@ -1713,6 +1812,7 @@ def edge_softmax_hetero(gidx, eids, norm_by, *logits):
"""
pass
def segment_reduce(op, x, offsets):
"""Segment reduction operator.
......@@ -1741,6 +1841,7 @@ def segment_reduce(op, x, offsets):
"""
pass
def scatter_add(x, idx, m):
"""Scatter add (on first dimension) operator.
......@@ -1763,6 +1864,7 @@ def scatter_add(x, idx, m):
"""
pass
def csrmm(A, A_weights, B, B_weights, num_vtypes):
"""Compute weighted adjacency matrix multiplication.
......@@ -1795,6 +1897,7 @@ def csrmm(A, A_weights, B, B_weights, num_vtypes):
"""
pass
def csrsum(gidxs, weights):
"""Compute weighted adjacency matrix summation.
......@@ -1821,6 +1924,7 @@ def csrsum(gidxs, weights):
"""
pass
def csrmask(A, A_weights, B):
"""Retrieve the values in the weighted adjacency matrix of graph :attr:`A` at the
non-zero positions of graph :attr:`B`'s adjacency matrix.
......@@ -1848,8 +1952,9 @@ def csrmask(A, A_weights, B):
"""
pass
def gather_mm(A, B, idx_a, idx_b):
r""" Dense Matrix Multiplication interface. It multiplies 2D dense tensor A
r"""Dense Matrix Multiplication interface. It multiplies 2D dense tensor A
and 3D dense tensor B according to their relation types. A is unsorted and
the relation type is fetched from idx_b.
......@@ -1871,8 +1976,9 @@ def gather_mm(A, B, idx_a, idx_b):
"""
pass
def segment_mm(A, B, seglen_A):
r""" Dense Matrix Multiplication interface. It multiplies dense tensor A
r"""Dense Matrix Multiplication interface. It multiplies dense tensor A
and dense tensor B according to relation types. A is sorted and concatenated
according to relation types.
......@@ -1900,6 +2006,7 @@ def segment_mm(A, B, seglen_A):
# These are not related to tensors. Some of them are temporary workarounds that
# should be included in DGL in the future.
def sync():
"""Synchronize computation.
......@@ -1909,33 +2016,35 @@ def sync():
"""
pass
def attach_grad(tensor):
""" Attach gradients to the input tensor
"""
"""Attach gradients to the input tensor"""
pass
def backward(x, head_gradient=None):
"""Invoke backward computation with an optional head gradient.
"""
"""Invoke backward computation with an optional head gradient."""
pass
def grad(x):
"""Fetches the gradient from the tensor after backward computation.
"""
"""Fetches the gradient from the tensor after backward computation."""
pass
def is_no_grad(x):
""" Test if the input tensor has gradient
"""
"""Test if the input tensor has gradient"""
pass
def is_recording():
""" Test if the execution is recording gradients.
"""
"""Test if the execution is recording gradients."""
pass
class record_grad(object):
"""Context manager that records the gradients"""
def __init__(self):
pass
......@@ -1948,6 +2057,7 @@ class record_grad(object):
class no_grad(object):
"""Context manager that explicitly disables gradient computation"""
def __init__(self):
pass
......@@ -1957,8 +2067,10 @@ class no_grad(object):
def __exit__(self, exc_type, exc_value, exc_traceback):
pass
class NodeEmbedding(object):
"""Sparse node embeddings"""
def __init__(self):
pass
......
from .tensor import *
from .sparse import *
from .tensor import *
import mxnet as mx
import numpy as np
from mxnet import nd
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _scatter_add
from ...sparse import _csrmm, _csrsum, _csrmask
from ...base import dgl_warning, is_all, ALL
from .tensor import asnumpy, copy_to, zerocopy_from_numpy, context, to_backend_ctx
from ...heterograph_index import create_unitgraph_from_csr
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce', 'scatter_add',
'csrmm', 'csrsum', 'csrmask']
from ...base import ALL, dgl_warning, is_all
from ...heterograph_index import create_unitgraph_from_csr
from ...sparse import (
_bwd_segment_cmp,
_csrmask,
_csrmm,
_csrsum,
_gsddmm,
_gspmm,
_scatter_add,
_segment_reduce,
)
from .tensor import (
asnumpy,
context,
copy_to,
to_backend_ctx,
zerocopy_from_numpy,
)
__all__ = [
"gspmm",
"gsddmm",
"edge_softmax",
"segment_reduce",
"scatter_add",
"csrmm",
"csrsum",
"csrmask",
]
def _scatter_nd(index, src, n_rows):
......@@ -26,7 +49,10 @@ def _scatter_nd(index, src, n_rows):
di = shp[i]
offset_i = np.arange(di, dtype=index.dtype)
offsets.append(
(stride * offset_i).reshape((1,) * i + (di,) + (1,) * (ndim - 1 - i)))
(stride * offset_i).reshape(
(1,) * i + (di,) + (1,) * (ndim - 1 - i)
)
)
stride *= di
if ndim > 1:
new_idx = index * stride + sum(offsets)
......@@ -52,7 +78,10 @@ def _gather_nd(index, src):
di = shp[i]
offset_i = nd.arange(di, dtype=index.dtype)
offsets.append(
(stride * offset_i).reshape((1,) * i + (di,) + (1,) * (ndim - 1 - i)))
(stride * offset_i).reshape(
(1,) * i + (di,) + (1,) * (ndim - 1 - i)
)
)
stride *= di
if ndim > 1:
new_idx = index * stride + copy_to(sum(offsets), ctx)
......@@ -107,11 +136,11 @@ def _need_reduce_last_dim(ufeat, efeat):
def _muldiv(op, x):
return 1. / x if op == 'div' else x
return 1.0 / x if op == "div" else x
def _addsub(op, x):
return -x if op == 'sub' else x
return -x if op == "sub" else x
def _expand(x, shape):
......@@ -134,45 +163,48 @@ class GSpMM(mx.autograd.Function):
ctx = context(dZ)
X, Y, argX, argY = self.saved_tensors
gidx, op, reduce_op = self.gidx, self.op, self.reduce_op
if op != 'copy_rhs':
if op != "copy_rhs":
g_rev = gidx.reverse()
if reduce_op == 'sum':
if op in ['mul', 'div']:
dX = _gspmm(g_rev, 'mul', 'sum', dZ, _muldiv(op, Y))[0]
elif op in ['add', 'sub']:
dX = _gspmm(g_rev, 'copy_lhs', 'sum', dZ, Y)[0]
elif op == 'copy_lhs':
dX = _gspmm(g_rev, 'copy_lhs', 'sum', dZ, None)[0]
if reduce_op == "sum":
if op in ["mul", "div"]:
dX = _gspmm(g_rev, "mul", "sum", dZ, _muldiv(op, Y))[0]
elif op in ["add", "sub"]:
dX = _gspmm(g_rev, "copy_lhs", "sum", dZ, Y)[0]
elif op == "copy_lhs":
dX = _gspmm(g_rev, "copy_lhs", "sum", dZ, None)[0]
else:
if op in ['mul', 'div']:
if op in ["mul", "div"]:
dX = _scatter_nd(
argX,
_muldiv(op, _gather_nd(argY, _expand(Y, dZ.shape[1:]))) * dZ,
X.shape[0])
elif op in ['add', 'sub', 'copy_lhs']:
_muldiv(op, _gather_nd(argY, _expand(Y, dZ.shape[1:])))
* dZ,
X.shape[0],
)
elif op in ["add", "sub", "copy_lhs"]:
dX = _scatter_nd(argX, dZ, X.shape[0])
dX = _reduce_grad(dX, X.shape)
else:
dX = nd.zeros_like(X)
if op != 'copy_lhs':
if reduce_op == 'sum':
if op == 'mul' and _need_reduce_last_dim(X, Y):
dY = _gsddmm(gidx, 'dot', X, dZ)
elif op in ['mul', 'div']:
dY = _gsddmm(gidx, 'mul', X, dZ)
if op == 'div':
dY = -dY / (Y ** 2)
elif op in ['add', 'sub', 'copy_rhs']:
dY = _gsddmm(gidx, 'copy_rhs', X, _addsub(op, dZ))
if op != "copy_lhs":
if reduce_op == "sum":
if op == "mul" and _need_reduce_last_dim(X, Y):
dY = _gsddmm(gidx, "dot", X, dZ)
elif op in ["mul", "div"]:
dY = _gsddmm(gidx, "mul", X, dZ)
if op == "div":
dY = -dY / (Y**2)
elif op in ["add", "sub", "copy_rhs"]:
dY = _gsddmm(gidx, "copy_rhs", X, _addsub(op, dZ))
else:
if op in ['mul', 'div']:
if op in ["mul", "div"]:
dY = _scatter_nd(
argY,
_gather_nd(argX, _expand(X, dZ.shape[1:])) * dZ,
Y.shape[0])
if op == 'div':
dY = -dY / (Y ** 2)
elif op in ['add', 'sub', 'copy_rhs']:
Y.shape[0],
)
if op == "div":
dY = -dY / (Y**2)
elif op in ["add", "sub", "copy_rhs"]:
dY = _scatter_nd(argY, _addsub(op, dZ), Y.shape[0])
dY = _reduce_grad(dY, Y.shape)
else:
......@@ -207,7 +239,9 @@ class GSDDMM(mx.autograd.Function):
self.rhs_target = rhs_target
def forward(self, X, Y):
out = _gsddmm(self.gidx, self.op, X, Y, self.lhs_target, self.rhs_target)
out = _gsddmm(
self.gidx, self.op, X, Y, self.lhs_target, self.rhs_target
)
self.save_for_backward(X, Y)
return out
......@@ -216,47 +250,55 @@ class GSDDMM(mx.autograd.Function):
X, Y = self.saved_tensors
gidx, op = self.gidx, self.op
lhs_target, rhs_target = self.lhs_target, self.rhs_target
if op != 'copy_rhs':
if lhs_target in ['u', 'v']:
_gidx = gidx if self.lhs_target == 'v' else gidx.reverse()
if op in ['add', 'sub', 'copy_lhs']:
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0]
if op != "copy_rhs":
if lhs_target in ["u", "v"]:
_gidx = gidx if self.lhs_target == "v" else gidx.reverse()
if op in ["add", "sub", "copy_lhs"]:
dX = _gspmm(_gidx, "copy_rhs", "sum", None, dZ)[0]
else: # mul, div, dot
if rhs_target == lhs_target:
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] * _muldiv(op, Y)
elif self.rhs_target == 'e':
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * _muldiv(op, Y))[0]
dX = _gspmm(_gidx, "copy_rhs", "sum", None, dZ)[
0
] * _muldiv(op, Y)
elif self.rhs_target == "e":
dX = _gspmm(
_gidx, "copy_rhs", "sum", None, dZ * _muldiv(op, Y)
)[0]
else: # rhs_target = !lhs_target
dX = _gspmm(_gidx, 'mul', 'sum', _muldiv(op, Y), dZ)[0]
dX = _gspmm(_gidx, "mul", "sum", _muldiv(op, Y), dZ)[0]
else: # lhs_target == 'e'
if op in ['add', 'sub', 'copy_lhs']:
if op in ["add", "sub", "copy_lhs"]:
dX = dZ
else: # mul, div, dot
dX = _gsddmm(gidx, 'mul', dZ, _muldiv(op, Y), 'e', rhs_target)
dX = _gsddmm(
gidx, "mul", dZ, _muldiv(op, Y), "e", rhs_target
)
dX = _reduce_grad(dX, X.shape)
else:
dX = nd.zeros_like(X)
if op != 'copy_lhs':
if self.rhs_target in ['u', 'v']:
_gidx = gidx if rhs_target == 'v' else gidx.reverse()
if op in ['add', 'sub', 'copy_rhs']:
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, _addsub(op, dZ))[0]
if op != "copy_lhs":
if self.rhs_target in ["u", "v"]:
_gidx = gidx if rhs_target == "v" else gidx.reverse()
if op in ["add", "sub", "copy_rhs"]:
dY = _gspmm(
_gidx, "copy_rhs", "sum", None, _addsub(op, dZ)
)[0]
else: # mul, div, dot
if lhs_target == rhs_target:
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] * X
elif self.lhs_target == 'e':
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * X)[0]
dY = _gspmm(_gidx, "copy_rhs", "sum", None, dZ)[0] * X
elif self.lhs_target == "e":
dY = _gspmm(_gidx, "copy_rhs", "sum", None, dZ * X)[0]
else: # rhs_target = !lhs_target
dY = _gspmm(_gidx, 'mul', 'sum', X, dZ)[0]
if op == 'div':
dY = -dY / (Y ** 2)
dY = _gspmm(_gidx, "mul", "sum", X, dZ)[0]
if op == "div":
dY = -dY / (Y**2)
else:
if op in ['add', 'sub', 'copy_rhs']:
if op in ["add", "sub", "copy_rhs"]:
dY = _addsub(op, dZ)
else: # mul, div, dot
dY = _gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target)
if op == 'div':
dY = -dY / (Y ** 2)
dY = _gsddmm(gidx, "mul", dZ, X, "e", lhs_target)
if op == "div":
dY = -dY / (Y**2)
dY = _reduce_grad(dY, Y.shape)
else:
dY = nd.zeros_like(Y)
......@@ -264,7 +306,7 @@ class GSDDMM(mx.autograd.Function):
return dX, dY
def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target="u", rhs_target="v"):
func = GSDDMM(gidx, op, lhs_target, rhs_target)
ctx = to_backend_ctx(gidx.ctx)
if lhs_data is None:
......@@ -279,7 +321,7 @@ class EdgeSoftmax(mx.autograd.Function):
super(EdgeSoftmax, self).__init__()
if not is_all(eids):
gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == 'src':
if norm_by == "src":
gidx = gidx.reverse()
self.gidx = gidx
......@@ -298,10 +340,10 @@ class EdgeSoftmax(mx.autograd.Function):
return out.data
"""
gidx = self.gidx
score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0]
score = mx.nd.exp(_gsddmm(gidx, 'sub', score, score_max, 'e', 'v'))
score_sum = _gspmm(gidx, 'copy_rhs', 'sum', None, score)[0]
out = _gsddmm(gidx, 'div', score, score_sum, 'e', 'v')
score_max = _gspmm(gidx, "copy_rhs", "max", None, score)[0]
score = mx.nd.exp(_gsddmm(gidx, "sub", score, score_max, "e", "v"))
score_sum = _gspmm(gidx, "copy_rhs", "sum", None, score)[0]
out = _gsddmm(gidx, "div", score, score_sum, "e", "v")
self.save_for_backward(out)
return out
......@@ -319,16 +361,16 @@ class EdgeSoftmax(mx.autograd.Function):
sds_sum = sds.dst_sum() # type dgl.NData
grad_score = sds - sds * sds_sum # multiple expressions
"""
out, = self.saved_tensors
(out,) = self.saved_tensors
gidx = self.gidx
sds = out * grad_out
accum = gspmm(gidx, 'copy_rhs', 'sum', None, sds)
grad_score = sds - gsddmm(gidx, 'mul', out, accum, 'e', 'v')
accum = gspmm(gidx, "copy_rhs", "sum", None, sds)
grad_score = sds - gsddmm(gidx, "mul", out, accum, "e", "v")
self.save_tensors = None
return grad_score
def edge_softmax(gidx, logits, eids=ALL, norm_by='dst'):
def edge_softmax(gidx, logits, eids=ALL, norm_by="dst"):
softmax_op = EdgeSoftmax(gidx, eids, norm_by)
return softmax_op(logits)
......@@ -345,10 +387,10 @@ class SegmentReduce(mx.autograd.Function):
return y
def backward(self, dy):
arg, = self.saved_tensors
(arg,) = self.saved_tensors
offsets = self.offsets
m = offsets[-1].asscalar()
if self.op == 'sum':
if self.op == "sum":
offsets_np = asnumpy(offsets[1:])
indices_np = np.zeros((m + 1,), dtype=offsets_np.dtype)
np.add.at(indices_np, offsets_np, np.ones_like(offsets_np))
......@@ -374,7 +416,7 @@ class ScatterAdd(mx.autograd.Function):
def forward(self, x):
y = _scatter_add(x, self.idx, self.m)
return y
def backward(self, dy):
return dy[self.idx]
......@@ -392,36 +434,66 @@ class CSRMM(mx.autograd.Function):
self.num_vtypes = num_vtypes
def forward(self, A_weights, B_weights):
gidxC, C_weights = _csrmm(self.gidxA, A_weights, self.gidxB, B_weights, self.num_vtypes)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(0, False, 'csr')
gidxC, C_weights = _csrmm(
self.gidxA, A_weights, self.gidxB, B_weights, self.num_vtypes
)
(
nrows,
ncols,
C_indptr,
C_indices,
C_eids,
) = gidxC.adjacency_matrix_tensors(0, False, "csr")
# Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same
# as the underlying tensors of the created graph gidxC.
self.backward_cache = gidxC
self.save_for_backward(A_weights, B_weights)
nrows = nd.array([nrows], dtype='int64')
ncols = nd.array([ncols], dtype='int64')
nrows = nd.array([nrows], dtype="int64")
ncols = nd.array([ncols], dtype="int64")
return nrows, ncols, C_indptr, C_indices, C_eids, C_weights
def backward(self, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):
def backward(
self, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights
):
# Only the last argument is meaningful.
gidxC = self.backward_cache
A_weights, B_weights = self.saved_tensors
dgidxA, dA_weights = _csrmm(
gidxC, dC_weights, self.gidxB.reverse(), B_weights, self.gidxA.number_of_ntypes())
gidxC,
dC_weights,
self.gidxB.reverse(),
B_weights,
self.gidxA.number_of_ntypes(),
)
dgidxB, dB_weights = _csrmm(
self.gidxA.reverse(), A_weights, gidxC, dC_weights, self.gidxB.number_of_ntypes())
self.gidxA.reverse(),
A_weights,
gidxC,
dC_weights,
self.gidxB.number_of_ntypes(),
)
dA_weights = _csrmask(dgidxA, dA_weights, self.gidxA)
dB_weights = _csrmask(dgidxB, dB_weights, self.gidxB)
return dA_weights, dB_weights
def csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes):
op = CSRMM(gidxA, gidxB, num_vtypes)
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = op(A_weights, B_weights)
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = op(
A_weights, B_weights
)
gidxC = create_unitgraph_from_csr(
num_vtypes, nrows.asscalar(), ncols.asscalar(), C_indptr, C_indices, C_eids,
["coo", "csr", "csc"])
num_vtypes,
nrows.asscalar(),
ncols.asscalar(),
C_indptr,
C_indices,
C_eids,
["coo", "csr", "csc"],
)
return gidxC, C_weights
class CSRSum(mx.autograd.Function):
def __init__(self, gidxs):
super().__init__()
......@@ -429,29 +501,44 @@ class CSRSum(mx.autograd.Function):
def forward(self, *weights):
gidxC, C_weights = _csrsum(self.gidxs, weights)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(
0, False, 'csr')
(
nrows,
ncols,
C_indptr,
C_indices,
C_eids,
) = gidxC.adjacency_matrix_tensors(0, False, "csr")
# Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same
# as the underlying tensors of the created graph gidxC.
self.backward_cache = gidxC
nrows = nd.array([nrows], dtype='int64')
ncols = nd.array([ncols], dtype='int64')
nrows = nd.array([nrows], dtype="int64")
ncols = nd.array([ncols], dtype="int64")
return nrows, ncols, C_indptr, C_indices, C_eids, C_weights
def backward(self, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):
def backward(
self, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights
):
# Only the last argument is meaningful.
gidxC = self.backward_cache
return tuple(csrmask(gidxC, dC_weights, gidx) for gidx in self.gidxs)
def csrsum(gidxs, weights):
op = CSRSum(gidxs)
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = op(*weights)
num_vtypes = gidxs[0].number_of_ntypes()
gidxC = create_unitgraph_from_csr(
num_vtypes, nrows.asscalar(), ncols.asscalar(), C_indptr, C_indices, C_eids,
["coo", "csr", "csc"])
num_vtypes,
nrows.asscalar(),
ncols.asscalar(),
C_indptr,
C_indices,
C_eids,
["coo", "csr", "csc"],
)
return gidxC, C_weights
class CSRMask(mx.autograd.Function):
def __init__(self, gidxA, gidxB):
super().__init__()
......@@ -464,6 +551,7 @@ class CSRMask(mx.autograd.Function):
def backward(self, dB_weights):
return _csrmask(self.gidxB, dB_weights, self.gidxA)
def csrmask(gidxA, A_weights, gidxB):
op = CSRMask(gidxA, gidxB)
return op(A_weights)
"""Sparse optimizer is not supported for mxnet"""
\ No newline at end of file
"""Sparse optimizer is not supported for mxnet"""
from __future__ import absolute_import
import builtins
import numbers
import os
from distutils.version import LooseVersion
import os
import numpy as np
import mxnet as mx
import mxnet.ndarray as nd
import numbers
import builtins
import numpy as np
from ... import ndarray as dglnd
from ..._deprecate import kernel as K
from ...function.base import TargetCode
......@@ -17,26 +18,31 @@ if LooseVersion(mx.__version__) < LooseVersion("1.6.0"):
# After MXNet 1.5, empty tensors aren't supprted by default.
# After we turn on the numpy compatible flag, MXNet supports empty NDArray.
mx.set_np_shape(bool(os.environ.get('DGL_MXNET_SET_NP_SHAPE', True)))
mx.set_np_shape(bool(os.environ.get("DGL_MXNET_SET_NP_SHAPE", True)))
def data_type_dict():
return {'float16' : np.float16,
'float32' : np.float32,
'float64' : np.float64,
'uint8' : np.uint8,
'int8' : np.int8,
'int16' : np.int16,
'int32' : np.int32,
'int64' : np.int64,
'bool' : np.bool} # mxnet does not support bool
return {
"float16": np.float16,
"float32": np.float32,
"float64": np.float64,
"uint8": np.uint8,
"int8": np.int8,
"int16": np.int16,
"int32": np.int32,
"int64": np.int64,
"bool": np.bool,
} # mxnet does not support bool
def cpu():
return mx.cpu()
def tensor(data, dtype=None):
if dtype == np.bool:
# mxnet doesn't support bool
dtype = np.int32
dtype = np.int32
if isinstance(data, nd.NDArray):
if dtype is None or data.dtype == dtype:
return data
......@@ -51,9 +57,14 @@ def tensor(data, dtype=None):
elif len(data) == 0:
dtype = np.int64
else:
dtype = np.int64 if isinstance(data[0], numbers.Integral) else np.float32
dtype = (
np.int64
if isinstance(data[0], numbers.Integral)
else np.float32
)
return nd.array(data, dtype=dtype)
def as_scalar(data):
if data.size != 1:
raise ValueError("The current array is not a scalar")
......@@ -61,6 +72,7 @@ def as_scalar(data):
data = data.expand_dims(axis=0)
return data.asscalar()
def get_preferred_sparse_format():
"""Get the preferred sparse matrix format supported by the backend.
......@@ -69,60 +81,79 @@ def get_preferred_sparse_format():
"""
return "csr"
def sparse_matrix(data, index, shape, force_format=False):
fmt = index[0]
if fmt == 'coo':
if fmt == "coo":
if force_format:
raise TypeError('MXNet backend only supports CSR format,'
' but COO format is forced.')
raise TypeError(
"MXNet backend only supports CSR format,"
" but COO format is forced."
)
coord = index[1]
# generate convert idx
# FIXME: cannot use int64
tmp_data = nd.arange(len(coord[0]), dtype=data.dtype, ctx=coord[0].context)
tmp_spmat = nd.sparse.csr_matrix((tmp_data, (coord[0], coord[1])),
tuple(shape), ctx=data.context)
convert_idx = nd.cast(tmp_spmat.data, dtype='int64')
tmp_data = nd.arange(
len(coord[0]), dtype=data.dtype, ctx=coord[0].context
)
tmp_spmat = nd.sparse.csr_matrix(
(tmp_data, (coord[0], coord[1])), tuple(shape), ctx=data.context
)
convert_idx = nd.cast(tmp_spmat.data, dtype="int64")
# shuffle the data
data = data[convert_idx]
spmat = nd.sparse.csr_matrix((data, tmp_spmat.indices, tmp_spmat.indptr),
tuple(shape), ctx=data.context)
spmat = nd.sparse.csr_matrix(
(data, tmp_spmat.indices, tmp_spmat.indptr),
tuple(shape),
ctx=data.context,
)
return spmat, convert_idx
elif fmt == 'csr':
elif fmt == "csr":
indices = index[1]
indptr = index[2]
spmat = nd.sparse.csr_matrix((data, indices, indptr),
tuple(shape), ctx=data.context)
spmat = nd.sparse.csr_matrix(
(data, indices, indptr), tuple(shape), ctx=data.context
)
# No conversion is required.
return spmat, None
else:
raise TypeError('Invalid format: %s.' % fmt)
raise TypeError("Invalid format: %s." % fmt)
def sparse_matrix_indices(spmat):
return ('csr', spmat.indices, spmat.indptr)
return ("csr", spmat.indices, spmat.indptr)
def is_tensor(obj):
return isinstance(obj, nd.NDArray)
def shape(input):
# NOTE: the input cannot be a symbol
return input.shape
def dtype(input):
# NOTE: the input cannot be a symbol
return input.dtype
def ndim(input):
return input.ndim
def context(input):
return input.context
def device_type(ctx):
return ctx.device_type
def device_id(ctx):
return ctx.device_id
def to_backend_ctx(dglctx):
dev_type = dglctx.device_type
if dev_type == 1:
......@@ -130,84 +161,110 @@ def to_backend_ctx(dglctx):
elif dev_type == 2:
return mx.gpu(dglctx.device_id)
else:
raise ValueError('Unsupported DGL device context:', dglctx)
raise ValueError("Unsupported DGL device context:", dglctx)
def astype(input, ty):
if ty == np.bool:
ty = np.int32
return input.astype(ty)
def asnumpy(input):
return input.asnumpy()
def copy_to(input, ctx, **kwargs):
return input.as_in_context(ctx)
def is_pinned(input):
return input.context == mx.cpu_pinned()
def sum(input, dim, keepdims=False):
if len(input) == 0:
return nd.array([0.], dtype=input.dtype, ctx=input.context)
return nd.array([0.0], dtype=input.dtype, ctx=input.context)
return nd.sum(input, axis=dim, keepdims=keepdims)
def floor_div(in1, in2):
return in1 / in2
def reduce_sum(input):
return input.sum()
def cumsum(input, dim):
return nd.cumsum(input, axis=dim)
def mean(input, dim):
return nd.mean(input, axis=dim)
def reduce_mean(input):
return input.mean()
def max(input, dim):
return nd.max(input, axis=dim)
def reduce_max(input):
return input.max()
def min(input, dim):
return nd.min(input, axis=dim)
def reduce_min(input):
return input.min()
def topk(input, k, dim, descending=True):
return nd.topk(input, axis=dim, k=k, ret_typ='value', is_ascend=not descending)
return nd.topk(
input, axis=dim, k=k, ret_typ="value", is_ascend=not descending
)
def argtopk(input, k, dim, descending=True):
idx = nd.argsort(input, dim, is_ascend=not descending)
return nd.slice_axis(input, dim, 0, k)
def argsort(input, dim, descending):
idx = nd.argsort(input, dim, is_ascend=not descending)
idx = nd.cast(idx, dtype='int64')
idx = nd.cast(idx, dtype="int64")
return idx
def exp(input):
return nd.exp(input)
def inverse(input):
return nd.linalg_inverse(input)
def sqrt(input):
return nd.sqrt(input)
def softmax(input, dim=-1):
return nd.softmax(input, axis=dim)
def cat(seq, dim):
return nd.concat(*seq, dim=dim)
def stack(seq, dim):
return nd.stack(*seq, axis=dim)
def split(x, sizes_or_sections, dim):
if isinstance(sizes_or_sections, list) and len(sizes_or_sections) == 1:
assert len(x) == sizes_or_sections[0]
......@@ -217,13 +274,18 @@ def split(x, sizes_or_sections, dim):
sizes_or_sections1 = tuple(np.cumsum(sizes_or_sections)[:-1])
return nd.split_v2(x, sizes_or_sections1, axis=dim)
def repeat(input, repeats, dim):
if isinstance(repeats, nd.NDArray):
return nd.array(np.repeat(input.asnumpy(), repeats.asnumpy(), axis=dim),
ctx=input.context, dtype=input.dtype)
return nd.array(
np.repeat(input.asnumpy(), repeats.asnumpy(), axis=dim),
ctx=input.context,
dtype=input.dtype,
)
else:
return nd.repeat(input, repeats, axis=dim)
def gather_row(data, row_index):
# MXNet workaround for empty row index
if len(row_index) == 0:
......@@ -235,7 +297,10 @@ def gather_row(data, row_index):
if isinstance(row_index, nd.NDArray):
return nd.take(data, row_index)
else:
return data[row_index,]
return data[
row_index,
]
def slice_axis(data, axis, begin, end):
dim = data.shape[axis]
......@@ -245,49 +310,64 @@ def slice_axis(data, axis, begin, end):
end += dim
return nd.slice_axis(data, axis, begin, end)
def take(data, indices, dim):
return nd.take(data, indices, dim)
def narrow_row(data, start, stop):
return data[start:stop]
def index_add_inplace(data, row_idx, value):
raise NotImplementedError("MXNet doesn't support inplace index_add")
def scatter_row(data, row_index, value):
return mx.nd.contrib.index_copy(data, row_index, value)
def scatter_row_inplace(data, row_index, value):
data[row_index] = value
def squeeze(input, dim):
return nd.squeeze(input, axis=dim)
def unsqueeze(input, dim):
return nd.expand_dims(input, axis=dim)
def reshape(input, shape):
# NOTE: the input cannot be a symbol
return nd.reshape(input ,shape)
return nd.reshape(input, shape)
def swapaxes(input, axis1, axis2):
return nd.swapaxes(input, axis1, axis2)
def zeros(shape, dtype, ctx):
return nd.zeros(shape, dtype=dtype, ctx=ctx)
def zeros_like(input):
return nd.zeros_like(input)
def ones(shape, dtype, ctx):
return nd.ones(shape, dtype=dtype, ctx=ctx)
def uniform(shape, dtype, ctx, low, high):
return nd.random.uniform(low, high, ctx=ctx, dtype=dtype, shape=shape)
def randint(shape, dtype, ctx, low, high):
return nd.random.randint(low, high, ctx=ctx, dtype=dtype, shape=shape)
def pad_packed_tensor(input, lengths, value, l_min=None):
old_shape = input.shape
if isinstance(lengths, nd.NDArray):
......@@ -300,12 +380,17 @@ def pad_packed_tensor(input, lengths, value, l_min=None):
batch_size = len(lengths)
ctx = input.context
dtype = input.dtype
x = nd.full((batch_size * max_len, *old_shape[1:]), value, ctx=ctx, dtype=dtype)
x = nd.full(
(batch_size * max_len, *old_shape[1:]), value, ctx=ctx, dtype=dtype
)
index = []
for i, l in enumerate(lengths):
index.extend(range(i * max_len, i * max_len + l))
index = nd.array(index, ctx=ctx)
return scatter_row(x, index, input).reshape(batch_size, max_len, *old_shape[1:])
return scatter_row(x, index, input).reshape(
batch_size, max_len, *old_shape[1:]
)
def pack_padded_tensor(input, lengths):
batch_size, max_len = input.shape[:2]
......@@ -316,46 +401,60 @@ def pack_padded_tensor(input, lengths):
index = nd.array(index, ctx=ctx)
return gather_row(input.reshape(batch_size * max_len, -1), index)
def boolean_mask(input, mask):
return mx.contrib.nd.boolean_mask(input, mask)
def equal(x, y):
return x == y
def allclose(x, y, rtol=1e-4, atol=1e-4):
return np.allclose(x.asnumpy(), y.asnumpy(), rtol=rtol, atol=atol)
def logical_not(input):
return nd.logical_not(input)
def logical_and(input1, input2):
return nd.logical_and(input1, input2)
def clone(input):
return input.copy()
def clamp(data, min_val, max_val):
return nd.clip(data, min_val, max_val)
def replace_inf_with_zero(x):
return nd.where(nd.abs(x) == np.inf, nd.zeros_like(x), x)
def count_nonzero(input):
# TODO: fallback to numpy is unfortunate
tmp = input.asnumpy()
return np.count_nonzero(tmp)
def unique(input, return_inverse=False, return_counts=False):
# TODO: fallback to numpy is unfortunate
tmp = input.asnumpy()
if return_inverse and return_counts:
tmp, inv, count = np.unique(tmp, return_inverse=True, return_counts=True)
tmp, inv, count = np.unique(
tmp, return_inverse=True, return_counts=True
)
tmp = nd.array(tmp, ctx=input.context, dtype=input.dtype)
inv = nd.array(inv, ctx=input.context)
count = nd.array(count, ctx=input.context)
return tmp, inv, count
elif return_inverse or return_counts:
tmp, tmp2 = np.unique(tmp, return_inverse=return_inverse, return_counts=return_counts)
tmp, tmp2 = np.unique(
tmp, return_inverse=return_inverse, return_counts=return_counts
)
tmp = nd.array(tmp, ctx=input.context, dtype=input.dtype)
tmp2 = nd.array(tmp2, ctx=input.context)
return tmp, tmp2
......@@ -363,9 +462,11 @@ def unique(input, return_inverse=False, return_counts=False):
tmp = np.unique(tmp)
return nd.array(tmp, ctx=input.context, dtype=input.dtype)
def full_1d(length, fill_value, dtype, ctx):
return nd.full((length,), fill_value, dtype=dtype, ctx=ctx)
def nonzero_1d(input):
# TODO: fallback to numpy is unfortunate
tmp = input.asnumpy()
......@@ -373,50 +474,70 @@ def nonzero_1d(input):
r = nd.array(tmp, ctx=input.context, dtype=tmp.dtype)
return r
def sort_1d(input):
# TODO: this isn't an ideal implementation.
val = nd.sort(input, axis=None, is_ascend=True)
idx = nd.argsort(input, is_ascend=True)
idx = nd.cast(idx, dtype='int64')
idx = nd.cast(idx, dtype="int64")
return val, idx
def arange(start, stop, dtype=np.int64, ctx=None):
if start >= stop:
return nd.array([], dtype=dtype, ctx=ctx)
else:
return nd.arange(start, stop, dtype=dtype, ctx=ctx)
def rand_shuffle(arr):
return mx.nd.random.shuffle(arr)
def zerocopy_to_dlpack(arr):
return arr.to_dlpack_for_read()
def zerocopy_from_dlpack(dlpack_arr):
return nd.from_dlpack(dlpack_arr)
def zerocopy_to_numpy(arr):
# NOTE: not zerocopy
return arr.asnumpy()
def zerocopy_from_numpy(np_data):
np_data = np.asarray(np_data, order='C')
np_data = np.asarray(np_data, order="C")
return mx.nd.from_numpy(np_data, zero_copy=True)
def zerocopy_to_dgl_ndarray(arr):
arr.to_dlpack_for_read()
return dglnd.from_dlpack(arr.to_dlpack_for_read())
def zerocopy_to_dgl_ndarray_for_write(arr):
return dglnd.from_dlpack(arr.to_dlpack_for_write())
def zerocopy_from_dgl_ndarray(arr):
return nd.from_dlpack(arr.to_dlpack())
class BinaryReduce(mx.autograd.Function):
def __init__(self, reducer, binary_op, graph, lhs, rhs, out_size, lhs_map,
rhs_map, out_map):
def __init__(
self,
reducer,
binary_op,
graph,
lhs,
rhs,
out_size,
lhs_map,
rhs_map,
out_map,
):
super(BinaryReduce, self).__init__()
self.reducer = reducer
self.binary_op = binary_op
......@@ -431,23 +552,37 @@ class BinaryReduce(mx.autograd.Function):
def forward(self, lhs_data, rhs_data):
lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data)
rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data)
feat_shape = K.infer_binary_feature_shape(self.binary_op, lhs_data_nd, rhs_data_nd)
feat_shape = K.infer_binary_feature_shape(
self.binary_op, lhs_data_nd, rhs_data_nd
)
out_shape = feat_shape
if self.binary_op == 'dot':
if self.binary_op == "dot":
out_shape = feat_shape[:-1]
out_data = nd.empty((self.out_size,) + out_shape,
ctx=lhs_data.context, dtype=lhs_data.dtype)
out_data = nd.empty(
(self.out_size,) + out_shape,
ctx=lhs_data.context,
dtype=lhs_data.dtype,
)
out_data_nd = zerocopy_to_dgl_ndarray_for_write(out_data)
K.binary_op_reduce(
self.reducer if self.reducer != 'mean' else 'sum',
self.binary_op, self.graph, self.lhs, self.rhs,
lhs_data_nd, rhs_data_nd, out_data_nd, self.lhs_map[0],
self.rhs_map[0], self.out_map[0])
self.reducer if self.reducer != "mean" else "sum",
self.binary_op,
self.graph,
self.lhs,
self.rhs,
lhs_data_nd,
rhs_data_nd,
out_data_nd,
self.lhs_map[0],
self.rhs_map[0],
self.out_map[0],
)
# normalize if mean reducer
# NOTE(zihao): this is a temporary hack and we should have better solution in the future.
if self.reducer == 'mean':
degs = nd.empty((out_data.shape[0],),
ctx=out_data.context, dtype=out_data.dtype)
if self.reducer == "mean":
degs = nd.empty(
(out_data.shape[0],), ctx=out_data.context, dtype=out_data.dtype
)
degs_nd = zerocopy_to_dgl_ndarray(degs)
if self.lhs != TargetCode.DST:
target = self.lhs
......@@ -460,49 +595,100 @@ class BinaryReduce(mx.autograd.Function):
in_ones = nd.ones((n,), ctx=lhs_data.context, dtype=lhs_data.dtype)
in_ones_nd = zerocopy_to_dgl_ndarray(in_ones)
K.copy_reduce(
'sum', self.graph, target, in_ones_nd, degs_nd,
in_map, self.out_map[0])
"sum",
self.graph,
target,
in_ones_nd,
degs_nd,
in_map,
self.out_map[0],
)
# reshape
degs = degs.reshape((out_data.shape[0],) + (1,) * (out_data.ndim - 1)).clip(1, float('inf'))
degs = degs.reshape(
(out_data.shape[0],) + (1,) * (out_data.ndim - 1)
).clip(1, float("inf"))
out_data = out_data / degs
else:
degs = None
self.save_for_backward(lhs_data_nd, rhs_data_nd, out_data_nd,
feat_shape, degs)
self.save_for_backward(
lhs_data_nd, rhs_data_nd, out_data_nd, feat_shape, degs
)
return out_data
def backward(self, grad_out):
lhs_data_nd, rhs_data_nd, out_data_nd, feat_shape, degs = self.saved_tensors
if self.reducer == 'mean':
(
lhs_data_nd,
rhs_data_nd,
out_data_nd,
feat_shape,
degs,
) = self.saved_tensors
if self.reducer == "mean":
grad_out = grad_out / degs
grad_out_nd = zerocopy_to_dgl_ndarray(grad_out)
grad_lhs = nd.empty((lhs_data_nd.shape[0],) + feat_shape,
ctx=grad_out.context, dtype=grad_out.dtype)
grad_lhs = nd.empty(
(lhs_data_nd.shape[0],) + feat_shape,
ctx=grad_out.context,
dtype=grad_out.dtype,
)
K.backward_lhs_binary_op_reduce(
self.reducer if self.reducer != 'mean' else 'sum',
self.binary_op, self.graph, self.lhs, self.rhs,
lhs_data_nd, rhs_data_nd, out_data_nd, grad_out_nd,
zerocopy_to_dgl_ndarray_for_write(grad_lhs), self.lhs_map[1],
self.rhs_map[1], self.out_map[1])
self.reducer if self.reducer != "mean" else "sum",
self.binary_op,
self.graph,
self.lhs,
self.rhs,
lhs_data_nd,
rhs_data_nd,
out_data_nd,
grad_out_nd,
zerocopy_to_dgl_ndarray_for_write(grad_lhs),
self.lhs_map[1],
self.rhs_map[1],
self.out_map[1],
)
grad_lhs = _reduce_grad(grad_lhs, lhs_data_nd.shape)
grad_rhs = nd.empty((rhs_data_nd.shape[0],) + feat_shape,
ctx=grad_out.context, dtype=grad_out.dtype)
grad_rhs = nd.empty(
(rhs_data_nd.shape[0],) + feat_shape,
ctx=grad_out.context,
dtype=grad_out.dtype,
)
K.backward_rhs_binary_op_reduce(
self.reducer if self.reducer != 'mean' else 'sum',
self.binary_op, self.graph, self.lhs, self.rhs,
lhs_data_nd, rhs_data_nd, out_data_nd, grad_out_nd,
zerocopy_to_dgl_ndarray_for_write(grad_rhs), self.lhs_map[1],
self.rhs_map[1], self.out_map[1])
self.reducer if self.reducer != "mean" else "sum",
self.binary_op,
self.graph,
self.lhs,
self.rhs,
lhs_data_nd,
rhs_data_nd,
out_data_nd,
grad_out_nd,
zerocopy_to_dgl_ndarray_for_write(grad_rhs),
self.lhs_map[1],
self.rhs_map[1],
self.out_map[1],
)
grad_rhs = _reduce_grad(grad_rhs, rhs_data_nd.shape)
# clear saved tensors explicitly
self.saved_tensors = None
return grad_lhs, grad_rhs
def binary_reduce(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data,
out_size, lhs_map=(None, None), rhs_map=(None, None), out_map=(None, None)):
func = BinaryReduce(reducer, binary_op, graph, lhs, rhs, out_size, lhs_map,
rhs_map, out_map)
def binary_reduce(
reducer,
binary_op,
graph,
lhs,
rhs,
lhs_data,
rhs_data,
out_size,
lhs_map=(None, None),
rhs_map=(None, None),
out_map=(None, None),
):
func = BinaryReduce(
reducer, binary_op, graph, lhs, rhs, out_size, lhs_map, rhs_map, out_map
)
return func(lhs_data, rhs_data)
......@@ -518,28 +704,46 @@ class CopyReduce(mx.autograd.Function):
def forward(self, in_data):
feat_shape = in_data.shape[1:]
out_data = nd.empty((self.out_size,) + feat_shape,
ctx=in_data.context, dtype=in_data.dtype)
out_data = nd.empty(
(self.out_size,) + feat_shape,
ctx=in_data.context,
dtype=in_data.dtype,
)
in_data_nd = zerocopy_to_dgl_ndarray(in_data)
out_data_nd = zerocopy_to_dgl_ndarray_for_write(out_data)
K.copy_reduce(
self.reducer if self.reducer != 'mean' else 'sum',
self.graph, self.target, in_data_nd, out_data_nd,
self.in_map[0], self.out_map[0])
self.reducer if self.reducer != "mean" else "sum",
self.graph,
self.target,
in_data_nd,
out_data_nd,
self.in_map[0],
self.out_map[0],
)
# normalize if mean reducer
# NOTE(zihao): this is a temporary hack and we should have better solution in the future.
if self.reducer == 'mean':
in_ones = nd.ones((in_data.shape[0],),
ctx=in_data.context, dtype=in_data.dtype)
degs = nd.empty((out_data.shape[0],),
ctx=out_data.context, dtype=out_data.dtype)
if self.reducer == "mean":
in_ones = nd.ones(
(in_data.shape[0],), ctx=in_data.context, dtype=in_data.dtype
)
degs = nd.empty(
(out_data.shape[0],), ctx=out_data.context, dtype=out_data.dtype
)
in_ones_nd = zerocopy_to_dgl_ndarray(in_ones)
degs_nd = zerocopy_to_dgl_ndarray(degs)
K.copy_reduce(
'sum', self.graph, self.target, in_ones_nd, degs_nd,
self.in_map[0], self.out_map[0])
"sum",
self.graph,
self.target,
in_ones_nd,
degs_nd,
self.in_map[0],
self.out_map[0],
)
# reshape
degs = degs.reshape((out_data.shape[0],) + (1,) * (out_data.ndim - 1)).clip(1, float('inf'))
degs = degs.reshape(
(out_data.shape[0],) + (1,) * (out_data.ndim - 1)
).clip(1, float("inf"))
out_data = out_data / degs
else:
degs = None
......@@ -548,23 +752,37 @@ class CopyReduce(mx.autograd.Function):
def backward(self, grad_out):
in_data_nd, out_data_nd, degs = self.saved_tensors
grad_in = nd.empty(in_data_nd.shape, ctx=grad_out.context,
dtype=grad_out.dtype)
if self.reducer == 'mean':
grad_in = nd.empty(
in_data_nd.shape, ctx=grad_out.context, dtype=grad_out.dtype
)
if self.reducer == "mean":
grad_out = grad_out / degs
grad_out_nd = zerocopy_to_dgl_ndarray(grad_out)
K.backward_copy_reduce(
self.reducer if self.reducer != 'mean' else 'sum',
self.graph, self.target, in_data_nd, out_data_nd,
grad_out_nd, zerocopy_to_dgl_ndarray_for_write(grad_in),
self.in_map[1], self.out_map[1])
self.reducer if self.reducer != "mean" else "sum",
self.graph,
self.target,
in_data_nd,
out_data_nd,
grad_out_nd,
zerocopy_to_dgl_ndarray_for_write(grad_in),
self.in_map[1],
self.out_map[1],
)
# clear saved tensors explicitly
self.saved_tensors = None
return grad_in
def copy_reduce(reducer, graph, target, in_data, out_size, in_map=(None, None),
out_map=(None, None)):
def copy_reduce(
reducer,
graph,
target,
in_data,
out_size,
in_map=(None, None),
out_map=(None, None),
):
func = CopyReduce(reducer, graph, target, out_size, in_map, out_map)
return func(in_data)
......@@ -600,6 +818,7 @@ def _reduce_grad(grad, shape):
grad = grad.sum(axis=tuple(reduce_idx), keepdims=True)
return grad.reshape(shape)
def sync():
"""Synchronize computation.
......@@ -609,24 +828,31 @@ def sync():
"""
mx.nd.waitall()
def attach_grad(tensor):
tensor.attach_grad()
return tensor
def backward(x, head_gradient=None):
x.backward(head_gradient)
def grad(x):
return x.grad
def is_no_grad(x):
return (x != 0).sum() == 0
def is_recording():
return mx.autograd.is_recording()
record_grad = mx.autograd.record
class no_grad(object):
def __init__(self):
pass
......
from .tensor import *
from .sparse import *
from .tensor import *
import torch as th
from torch.cuda.amp import custom_fwd, custom_bwd
from ...base import is_all, ALL
from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp
from ...sparse import _csrmm, _csrsum, _csrmask, _scatter_add, _update_grad_minmax_hetero
from ...sparse import _gather_mm, _gather_mm_scatter, _segment_mm, _segment_mm_backward_B
from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp, _edge_softmax_forward, _edge_softmax_backward
from ...sparse import _csrmm, _csrsum, _csrmask, _scatter_add, _update_grad_minmax_hetero
from ...heterograph_index import create_unitgraph_from_csr
from torch.cuda.amp import custom_bwd, custom_fwd
__all__ = ['gspmm', 'gsddmm', 'gspmm_hetero', 'gsddmm_hetero', 'edge_softmax', 'edge_softmax_hetero',
'segment_reduce', 'scatter_add', 'csrmm', 'csrsum', 'csrmask', 'gather_mm', 'segment_mm']
from ...base import ALL, is_all
from ...heterograph_index import create_unitgraph_from_csr
from ...sparse import (
_bwd_segment_cmp,
_csrmask,
_csrmm,
_csrsum,
_edge_softmax_backward,
_edge_softmax_forward,
_gather_mm,
_gather_mm_scatter,
_gsddmm,
_gsddmm_hetero,
_gspmm,
_gspmm_hetero,
_scatter_add,
_segment_mm,
_segment_mm_backward_B,
_segment_reduce,
_update_grad_minmax_hetero,
)
__all__ = [
"gspmm",
"gsddmm",
"gspmm_hetero",
"gsddmm_hetero",
"edge_softmax",
"edge_softmax_hetero",
"segment_reduce",
"scatter_add",
"csrmm",
"csrsum",
"csrmask",
"gather_mm",
"segment_mm",
]
def _reduce_grad(grad, shape):
......@@ -38,7 +65,9 @@ def _reduce_grad(grad, shape):
num_to_squeeze = len(grad_shape) - len(in_shape)
# pad inshape
in_shape = (1,) * num_to_squeeze + in_shape
reduce_idx = th.nonzero(th.tensor(grad_shape) - th.tensor(in_shape), as_tuple=False)
reduce_idx = th.nonzero(
th.tensor(grad_shape) - th.tensor(in_shape), as_tuple=False
)
reduce_idx += 1 # skip batch dim
if len(reduce_idx) > 0:
grad = grad.sum(dim=tuple(reduce_idx), keepdim=True)
......@@ -62,23 +91,23 @@ def _expand(x, shape):
def spmm_cache_X(binary_op, reduce_op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache X in SpMM forward stage."""
if binary_op != 'copy_lhs' and req_grad_Y:
if reduce_op == 'sum':
if binary_op != "copy_lhs" and req_grad_Y:
if reduce_op == "sum":
return True
else:
if binary_op == 'mul':
if binary_op == "mul":
return True
return False
def spmm_cache_Y(binary_op, reduce_op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache Y in SpMM forward stage."""
if binary_op != 'copy_rhs' and req_grad_X:
if reduce_op == 'sum':
if binary_op in ['mul', 'add']:
if binary_op != "copy_rhs" and req_grad_X:
if reduce_op == "sum":
if binary_op in ["mul", "add"]:
return True
else:
if binary_op == 'mul':
if binary_op == "mul":
return True
return False
......@@ -86,7 +115,7 @@ def spmm_cache_Y(binary_op, reduce_op, req_grad_X, req_grad_Y):
def spmm_cache_argX(binary_op, reduce_op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache argX in SpMM forward stage."""
if req_grad_X or req_grad_Y:
if reduce_op in ['min', 'max']:
if reduce_op in ["min", "max"]:
return True
return False
......@@ -94,7 +123,7 @@ def spmm_cache_argX(binary_op, reduce_op, req_grad_X, req_grad_Y):
def spmm_cache_argY(binary_op, reduce_op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache argY in SpMM forward stage."""
if req_grad_X or req_grad_Y:
if reduce_op in ['min', 'max']:
if reduce_op in ["min", "max"]:
return True
return False
......@@ -109,7 +138,16 @@ class GSpMM(th.autograd.Function):
Y_shape = Y.shape if Y is not None else None
dtype = X.dtype if X is not None else Y.dtype
device = X.device if X is not None else Y.device
ctx.backward_cache = gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last
ctx.backward_cache = (
gidx,
op,
reduce_op,
X_shape,
Y_shape,
dtype,
device,
reduce_last,
)
req_grad_X = X.requires_grad if X is not None else False
req_grad_Y = Y.requires_grad if Y is not None else False
if not spmm_cache_X(op, reduce_op, req_grad_X, req_grad_Y):
......@@ -126,45 +164,54 @@ class GSpMM(th.autograd.Function):
@staticmethod
@custom_bwd
def backward(ctx, dZ):
gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last = ctx.backward_cache
(
gidx,
op,
reduce_op,
X_shape,
Y_shape,
dtype,
device,
reduce_last,
) = ctx.backward_cache
X, Y, argX, argY = ctx.saved_tensors
if op != 'copy_rhs' and ctx.needs_input_grad[3]:
if op != "copy_rhs" and ctx.needs_input_grad[3]:
g_rev = gidx.reverse()
if reduce_op == 'sum':
if op == 'mul':
dX = gspmm(g_rev, 'mul', 'sum', dZ, Y)
elif op == 'add':
dX = gspmm(g_rev, 'copy_lhs', 'sum', dZ, Y)
elif op == 'copy_lhs':
dX = gspmm(g_rev, 'copy_lhs', 'sum', dZ, None)
if reduce_op == "sum":
if op == "mul":
dX = gspmm(g_rev, "mul", "sum", dZ, Y)
elif op == "add":
dX = gspmm(g_rev, "copy_lhs", "sum", dZ, Y)
elif op == "copy_lhs":
dX = gspmm(g_rev, "copy_lhs", "sum", dZ, None)
else: # max/min
dX = th.zeros((X_shape[0],) + dZ.shape[1:],
dtype=dtype, device=device)
if op == 'mul':
grad = _expand(Y, dZ.shape[1:]).gather(
0, argY.long()) * dZ
dX = th.zeros(
(X_shape[0],) + dZ.shape[1:], dtype=dtype, device=device
)
if op == "mul":
grad = _expand(Y, dZ.shape[1:]).gather(0, argY.long()) * dZ
dX.scatter_add_(0, argX.long(), grad)
elif op in ['add', 'copy_lhs']:
elif op in ["add", "copy_lhs"]:
dX.scatter_add_(0, argX.long(), dZ)
dX = _reduce_grad(dX, X_shape)
else: # X has not gradient
dX = None
if op != 'copy_lhs' and ctx.needs_input_grad[4]:
if reduce_op == 'sum':
if op == 'mul' and reduce_last:
dY = gsddmm(gidx, 'dot', X, dZ)
elif op == 'mul':
dY = gsddmm(gidx, 'mul', X, dZ)
elif op in ['add', 'copy_rhs']:
dY = gsddmm(gidx, 'copy_rhs', X, dZ)
if op != "copy_lhs" and ctx.needs_input_grad[4]:
if reduce_op == "sum":
if op == "mul" and reduce_last:
dY = gsddmm(gidx, "dot", X, dZ)
elif op == "mul":
dY = gsddmm(gidx, "mul", X, dZ)
elif op in ["add", "copy_rhs"]:
dY = gsddmm(gidx, "copy_rhs", X, dZ)
else: # max/min
dY = th.zeros((Y_shape[0],) + dZ.shape[1:],
dtype=dtype, device=device)
if op == 'mul':
grad = _expand(X, dZ.shape[1:]).gather(
0, argX.long()) * dZ
dY = th.zeros(
(Y_shape[0],) + dZ.shape[1:], dtype=dtype, device=device
)
if op == "mul":
grad = _expand(X, dZ.shape[1:]).gather(0, argX.long()) * dZ
dY.scatter_add_(0, argY.long(), grad)
elif op in ['add', 'copy_rhs']:
elif op in ["add", "copy_rhs"]:
dY.scatter_add_(0, argY.long(), dZ)
dY = _reduce_grad(dY, Y_shape)
else: # Y has no gradient
......@@ -175,94 +222,178 @@ class GSpMM(th.autograd.Function):
class GSpMM_hetero(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, gidx, op, reduce_op, X_len, *feats): # feats = lhs_data + rhs_data
out, (argX, argY, argX_ntype, argY_etype) = _gspmm_hetero(gidx, op, reduce_op, X_len, feats)
def forward(
ctx, gidx, op, reduce_op, X_len, *feats
): # feats = lhs_data + rhs_data
out, (argX, argY, argX_ntype, argY_etype) = _gspmm_hetero(
gidx, op, reduce_op, X_len, feats
)
X, Y = feats[:X_len], feats[X_len:]
# TODO (Israt): check target to decide src_id/dst_id?
src_id, dst_id = gidx.metagraph.find_edge(0)
reduce_last = _need_reduce_last_dim(X[src_id], Y[dst_id])
X_shape = tuple([X[i].shape if X[i] is not None else None
for i in range(X_len)])
Y_shape = tuple([Y[i].shape if Y[i] is not None else None
for i in range(len(Y))])
X_shape = tuple(
[X[i].shape if X[i] is not None else None for i in range(X_len)]
)
Y_shape = tuple(
[Y[i].shape if Y[i] is not None else None for i in range(len(Y))]
)
dtype = X[src_id].dtype if X[src_id] is not None else Y[dst_id].dtype
device = X[src_id].device if X[src_id] is not None else Y[dst_id].device
ctx.backward_cache = gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last, X_len
req_grad_X = tuple([X[i].requires_grad if X[i] is not None else False
for i in range(X_len)])
req_grad_Y = tuple([Y[i].requires_grad if Y[i] is not None else False
for i in range(len(Y))])
ctx.backward_cache = (
gidx,
op,
reduce_op,
X_shape,
Y_shape,
dtype,
device,
reduce_last,
X_len,
)
req_grad_X = tuple(
[
X[i].requires_grad if X[i] is not None else False
for i in range(X_len)
]
)
req_grad_Y = tuple(
[
Y[i].requires_grad if Y[i] is not None else False
for i in range(len(Y))
]
)
# checking the first relation to decide for all the relations
if not spmm_cache_argX(op, reduce_op, req_grad_X[src_id], req_grad_Y[dst_id]):
if not spmm_cache_argX(
op, reduce_op, req_grad_X[src_id], req_grad_Y[dst_id]
):
argX = tuple([None] * len(X))
if not spmm_cache_argY(op, reduce_op, req_grad_X[src_id], req_grad_Y[dst_id]):
if not spmm_cache_argY(
op, reduce_op, req_grad_X[src_id], req_grad_Y[dst_id]
):
argY = tuple([None] * len(X))
ctx.save_for_backward(*feats, *argX, *argX_ntype, *argY, *argY_etype )
ctx.save_for_backward(*feats, *argX, *argX_ntype, *argY, *argY_etype)
return out
@staticmethod
@custom_bwd
def backward(ctx, *dZ):
gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last, X_len = ctx.backward_cache
(
gidx,
op,
reduce_op,
X_shape,
Y_shape,
dtype,
device,
reduce_last,
X_len,
) = ctx.backward_cache
num_ntypes = gidx.number_of_ntypes()
feats = ctx.saved_tensors[:-(4 * num_ntypes)]
argX = ctx.saved_tensors[-(4 * num_ntypes):-(3 * num_ntypes)]
argX_ntype = ctx.saved_tensors[-(3 * num_ntypes):-(2 * num_ntypes)]
argY = ctx.saved_tensors[-(2 * num_ntypes):- num_ntypes]
feats = ctx.saved_tensors[: -(4 * num_ntypes)]
argX = ctx.saved_tensors[-(4 * num_ntypes) : -(3 * num_ntypes)]
argX_ntype = ctx.saved_tensors[-(3 * num_ntypes) : -(2 * num_ntypes)]
argY = ctx.saved_tensors[-(2 * num_ntypes) : -num_ntypes]
argY_etype = ctx.saved_tensors[-num_ntypes:]
X, Y = feats[:X_len], feats[X_len:]
if op != 'copy_rhs' and any([x is not None for x in X]):
if op != "copy_rhs" and any([x is not None for x in X]):
g_rev = gidx.reverse()
if reduce_op == 'sum':
if op == 'mul':
dX = gspmm_hetero(g_rev, 'mul', 'sum', len(X), *tuple(dZ + Y))
elif op == 'add':
dX = gspmm_hetero(g_rev, 'copy_lhs', 'sum', len(X), *tuple(dZ + Y))
elif op == 'copy_lhs':
if reduce_op == "sum":
if op == "mul":
dX = gspmm_hetero(
g_rev, "mul", "sum", len(X), *tuple(dZ + Y)
)
elif op == "add":
dX = gspmm_hetero(
g_rev, "copy_lhs", "sum", len(X), *tuple(dZ + Y)
)
elif op == "copy_lhs":
tpl_None = tuple([None] * len(Y))
dX = gspmm_hetero(g_rev, 'copy_lhs', 'sum', len(X), *tuple(dZ + tpl_None))
dX = gspmm_hetero(
g_rev, "copy_lhs", "sum", len(X), *tuple(dZ + tpl_None)
)
else: # max/min
# Assuming that the features are of the same dimension (enforced by the forward function)
src_id, dst_id = gidx.metagraph.find_edge(0)
dX = tuple([th.zeros((X_shape[i][0],) + dZ[dst_id].shape[1:], dtype=dtype, device=device)
if X[i] is not None else None for i in range(len(X))])
if op == 'mul':
grad = _expand(Y, dZ.shape[1:]).gather(
0, argY.long()) * dZ
dX = tuple(
[
th.zeros(
(X_shape[i][0],) + dZ[dst_id].shape[1:],
dtype=dtype,
device=device,
)
if X[i] is not None
else None
for i in range(len(X))
]
)
if op == "mul":
grad = _expand(Y, dZ.shape[1:]).gather(0, argY.long()) * dZ
dX.scatter_add_(0, argX.long(), grad)
elif op in ['add', 'copy_lhs']:
dX = _update_grad_minmax_hetero(g_rev, op, dZ, argX, argX_ntype, dX)
dX = tuple([_reduce_grad(dX[i], X_shape[i]) if X[i] is not None else None
for i in range(len(X))])
elif op in ["add", "copy_lhs"]:
dX = _update_grad_minmax_hetero(
g_rev, op, dZ, argX, argX_ntype, dX
)
dX = tuple(
[
_reduce_grad(dX[i], X_shape[i])
if X[i] is not None
else None
for i in range(len(X))
]
)
else: # X has not gradient
dX = tuple([None] * len(X))
if op != 'copy_lhs' and any([y is not None for y in Y]):
if op != "copy_lhs" and any([y is not None for y in Y]):
# TODO(Israt): implement other combinations of reduce functions
if reduce_op == 'sum':
tpl_dZ = tuple([dZ[i] if dZ[i] is not None else None
for i in range(len(dZ))])
if reduce_op == "sum":
tpl_dZ = tuple(
[
dZ[i] if dZ[i] is not None else None
for i in range(len(dZ))
]
)
tpl_X_dZ = tuple(X + tpl_dZ)
if op == 'mul' and reduce_last:
dY = gsddmm_hetero(gidx, 'dot', X_len, 'u', 'v', *tpl_X_dZ)
elif op == 'mul':
dY = gsddmm_hetero(gidx, 'mul', X_len, 'u', 'v', *tpl_X_dZ)
elif op in ['add', 'copy_rhs']:
dY = gsddmm_hetero(gidx, 'copy_rhs', X_len, 'u', 'v', *tpl_X_dZ)
if op == "mul" and reduce_last:
dY = gsddmm_hetero(gidx, "dot", X_len, "u", "v", *tpl_X_dZ)
elif op == "mul":
dY = gsddmm_hetero(gidx, "mul", X_len, "u", "v", *tpl_X_dZ)
elif op in ["add", "copy_rhs"]:
dY = gsddmm_hetero(
gidx, "copy_rhs", X_len, "u", "v", *tpl_X_dZ
)
else: # max/min
src_id, dst_id = gidx.metagraph.find_edge(0)
dY = tuple([th.zeros((Y_shape[i][0],) + dZ[dst_id].shape[1:], dtype=dtype, device=device)
if Y[i] is not None else None for i in range(len(Y))])
if op == 'mul':
grad = _expand(X, dZ.shape[1:]).gather(
0, argX.long()) * dZ
dY = tuple(
[
th.zeros(
(Y_shape[i][0],) + dZ[dst_id].shape[1:],
dtype=dtype,
device=device,
)
if Y[i] is not None
else None
for i in range(len(Y))
]
)
if op == "mul":
grad = _expand(X, dZ.shape[1:]).gather(0, argX.long()) * dZ
dY.scatter_add_(0, argY.long(), grad)
elif op in ['add', 'copy_rhs']:
dY = _update_grad_minmax_hetero(gidx.reverse(), op, dZ, argY, argY_etype, dY)
dY = tuple([_reduce_grad(dY[i], Y_shape[i]) if dY[i] is not None else None
for i in range(len(dY))])
elif op in ["add", "copy_rhs"]:
dY = _update_grad_minmax_hetero(
gidx.reverse(), op, dZ, argY, argY_etype, dY
)
dY = tuple(
[
_reduce_grad(dY[i], Y_shape[i])
if dY[i] is not None
else None
for i in range(len(dY))
]
)
else: # Y has no gradient
dY = tuple([None] * len(Y))
return (None, None, None, None) + dX + dY
......@@ -270,14 +401,14 @@ class GSpMM_hetero(th.autograd.Function):
def sddmm_cache_X(op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache X in SDDMM forward stage."""
if op in ['mul', 'dot'] and req_grad_Y:
if op in ["mul", "dot"] and req_grad_Y:
return True
return False
def sddmm_cache_Y(op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache Y in SDDMM forward stage."""
if op in ['mul', 'dot'] and req_grad_X:
if op in ["mul", "dot"] and req_grad_X:
return True
return False
......@@ -304,43 +435,43 @@ class GSDDMM(th.autograd.Function):
def backward(ctx, dZ):
gidx, op, lhs_target, rhs_target, X_shape, Y_shape = ctx.backward_cache
X, Y = ctx.saved_tensors
if op != 'copy_rhs' and ctx.needs_input_grad[2]:
if lhs_target in ['u', 'v']:
_gidx = gidx if lhs_target == 'v' else gidx.reverse()
if op in ['add', 'copy_lhs']:
dX = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)
if op != "copy_rhs" and ctx.needs_input_grad[2]:
if lhs_target in ["u", "v"]:
_gidx = gidx if lhs_target == "v" else gidx.reverse()
if op in ["add", "copy_lhs"]:
dX = gspmm(_gidx, "copy_rhs", "sum", None, dZ)
else: # mul, dot
if rhs_target == lhs_target:
dX = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ) * Y
elif rhs_target == 'e':
dX = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * Y)
dX = gspmm(_gidx, "copy_rhs", "sum", None, dZ) * Y
elif rhs_target == "e":
dX = gspmm(_gidx, "copy_rhs", "sum", None, dZ * Y)
else: # rhs_target = !lhs_target
dX = gspmm(_gidx, 'mul', 'sum', Y, dZ)
dX = gspmm(_gidx, "mul", "sum", Y, dZ)
else: # lhs_target == 'e'
if op in ['add', 'copy_lhs']:
if op in ["add", "copy_lhs"]:
dX = dZ
else: # mul, dot
dX = gsddmm(gidx, 'mul', dZ, Y, 'e', rhs_target)
dX = gsddmm(gidx, "mul", dZ, Y, "e", rhs_target)
dX = _reduce_grad(dX, X_shape)
else:
dX = None
if op != 'copy_lhs' and ctx.needs_input_grad[3]:
if rhs_target in ['u', 'v']:
_gidx = gidx if rhs_target == 'v' else gidx.reverse()
if op in ['add', 'copy_rhs']:
dY = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)
if op != "copy_lhs" and ctx.needs_input_grad[3]:
if rhs_target in ["u", "v"]:
_gidx = gidx if rhs_target == "v" else gidx.reverse()
if op in ["add", "copy_rhs"]:
dY = gspmm(_gidx, "copy_rhs", "sum", None, dZ)
else: # mul, dot
if lhs_target == rhs_target:
dY = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ) * X
elif lhs_target == 'e':
dY = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * X)
dY = gspmm(_gidx, "copy_rhs", "sum", None, dZ) * X
elif lhs_target == "e":
dY = gspmm(_gidx, "copy_rhs", "sum", None, dZ * X)
else: # rhs_target = !lhs_target
dY = gspmm(_gidx, 'mul', 'sum', X, dZ)
dY = gspmm(_gidx, "mul", "sum", X, dZ)
else:
if op in ['add', 'copy_rhs']:
if op in ["add", "copy_rhs"]:
dY = dZ
else: # mul, dot
dY = gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target)
dY = gsddmm(gidx, "mul", dZ, X, "e", lhs_target)
dY = _reduce_grad(dY, Y_shape)
else:
dY = None
......@@ -350,18 +481,38 @@ class GSDDMM(th.autograd.Function):
class GSDDMM_hetero(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, gidx, op, X_len, lhs_target, rhs_target, *feats): # feats = X+Y
def forward(
ctx, gidx, op, X_len, lhs_target, rhs_target, *feats
): # feats = X+Y
out = _gsddmm_hetero(gidx, op, X_len, lhs_target, rhs_target, feats)
X, Y = feats[:X_len], feats[X_len:]
X_shape = tuple([X[i].shape if X[i] is not None else None
for i in range(len(X))])
Y_shape = tuple([Y[i].shape if Y[i] is not None else None
for i in range(len(Y))])
ctx.backward_cache = gidx, op, lhs_target, rhs_target, X_shape, Y_shape, X_len
req_grad_X = tuple([X[i].requires_grad if X[i] is not None else False
for i in range(len(X))])
req_grad_Y = tuple([Y[i].requires_grad if Y[i] is not None else False
for i in range(len(Y))])
X_shape = tuple(
[X[i].shape if X[i] is not None else None for i in range(len(X))]
)
Y_shape = tuple(
[Y[i].shape if Y[i] is not None else None for i in range(len(Y))]
)
ctx.backward_cache = (
gidx,
op,
lhs_target,
rhs_target,
X_shape,
Y_shape,
X_len,
)
req_grad_X = tuple(
[
X[i].requires_grad if X[i] is not None else False
for i in range(len(X))
]
)
req_grad_Y = tuple(
[
Y[i].requires_grad if Y[i] is not None else False
for i in range(len(Y))
]
)
ctx.save_for_backward(*feats)
return out
......@@ -369,58 +520,140 @@ class GSDDMM_hetero(th.autograd.Function):
@custom_bwd
# TODO(Israt): Implement the complete backward operator
def backward(ctx, *dZ):
gidx, op, lhs_target, rhs_target, X_shape, Y_shape, X_len = ctx.backward_cache
(
gidx,
op,
lhs_target,
rhs_target,
X_shape,
Y_shape,
X_len,
) = ctx.backward_cache
feats = ctx.saved_tensors
X, Y = feats[:X_len], feats[X_len:]
if op != 'copy_rhs' and any([x is not None for x in X]):
if lhs_target in ['u', 'v']:
_gidx = gidx if lhs_target == 'v' else gidx.reverse()
if op != "copy_rhs" and any([x is not None for x in X]):
if lhs_target in ["u", "v"]:
_gidx = gidx if lhs_target == "v" else gidx.reverse()
tpl_of_None = tuple([None] * len(X))
if op in ['add', 'copy_lhs']:
dX = gspmm_hetero(_gidx, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ)))
if op in ["add", "copy_lhs"]:
dX = gspmm_hetero(
_gidx,
"copy_rhs",
"sum",
len(X),
*(tuple(tpl_of_None + dZ))
)
else: # mul, dot
if rhs_target == lhs_target:
dX = gspmm_hetero(_gidx, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ))) * Y
elif rhs_target == 'e':
dZ_mul_Y = tuple([dZ[i] * Y[i] if dZ[i] is not None else None
for i in range(len(Y))])
dX = gspmm_hetero(_gidx, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ_mul_Y)))
dX = (
gspmm_hetero(
_gidx,
"copy_rhs",
"sum",
len(X),
*(tuple(tpl_of_None + dZ))
)
* Y
)
elif rhs_target == "e":
dZ_mul_Y = tuple(
[
dZ[i] * Y[i] if dZ[i] is not None else None
for i in range(len(Y))
]
)
dX = gspmm_hetero(
_gidx,
"copy_rhs",
"sum",
len(X),
*(tuple(tpl_of_None + dZ_mul_Y))
)
else: # rhs_target = !lhs_target
dX = gspmm_hetero(_gidx, 'mul', 'sum', len(X), *tuple(Y + dZ))
dX = gspmm_hetero(
_gidx, "mul", "sum", len(X), *tuple(Y + dZ)
)
else: # lhs_target == 'e'
if op in ['add', 'copy_lhs']:
if op in ["add", "copy_lhs"]:
dX = dZ
else: # mul, dot
num_etype = gidx.number_of_etypes()
dX = gsddmm_hetero(gidx, 'mul', num_etype, 'e', rhs_target, *tuple(dZ + Y))
dX = tuple([_reduce_grad(dX[i], X_shape[i]) if X[i] is not None else None
for i in range(len(X))])
dX = gsddmm_hetero(
gidx, "mul", num_etype, "e", rhs_target, *tuple(dZ + Y)
)
dX = tuple(
[
_reduce_grad(dX[i], X_shape[i])
if X[i] is not None
else None
for i in range(len(X))
]
)
else:
dX = tuple([None] * len(X))
if op != 'copy_lhs' and any([y is not None for y in Y]):
if rhs_target in ['u', 'v']:
_gidx = gidx if rhs_target == 'v' else gidx.reverse()
if op != "copy_lhs" and any([y is not None for y in Y]):
if rhs_target in ["u", "v"]:
_gidx = gidx if rhs_target == "v" else gidx.reverse()
tpl_of_None = tuple([None] * len(X))
if op in ['add', 'copy_rhs']:
dY = gspmm_hetero(_gidx, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ)))
if op in ["add", "copy_rhs"]:
dY = gspmm_hetero(
_gidx,
"copy_rhs",
"sum",
len(X),
*(tuple(tpl_of_None + dZ))
)
else: # mul, dot
if lhs_target == rhs_target:
dY = gspmm_hetero(_gidx, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ))) * X
elif lhs_target == 'e':
dZ_mul_X = tuple([dZ[i] * X[i] if dZ[i] is not None else None
for i in range(len(X))])
dY = gspmm_hetero(_gidx, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ_mul_X)))
dY = (
gspmm_hetero(
_gidx,
"copy_rhs",
"sum",
len(X),
*(tuple(tpl_of_None + dZ))
)
* X
)
elif lhs_target == "e":
dZ_mul_X = tuple(
[
dZ[i] * X[i] if dZ[i] is not None else None
for i in range(len(X))
]
)
dY = gspmm_hetero(
_gidx,
"copy_rhs",
"sum",
len(X),
*(tuple(tpl_of_None + dZ_mul_X))
)
else: # rhs_target = !lhs_target
dY = gspmm_hetero(_gidx, 'mul', 'sum', len(X), *tuple(X + dZ))
dY = gspmm_hetero(
_gidx, "mul", "sum", len(X), *tuple(X + dZ)
)
else:
if op in ['add', 'copy_rhs']:
dY = tuple([dZ[i] if dZ[i] is not None else None
for i in range(len(dZ))])
if op in ["add", "copy_rhs"]:
dY = tuple(
[
dZ[i] if dZ[i] is not None else None
for i in range(len(dZ))
]
)
else: # mul, dot
num_etype = gidx.number_of_etypes()
dY = gsddmm_hetero(gidx, 'mul', num_etype, 'e', lhs_target, *tuple(dZ + X))
dY = tuple([_reduce_grad(dY[i], Y_shape[i]) if Y[i] is not None else None
for i in range(len(Y))])
dY = gsddmm_hetero(
gidx, "mul", num_etype, "e", lhs_target, *tuple(dZ + X)
)
dY = tuple(
[
_reduce_grad(dY[i], Y_shape[i])
if Y[i] is not None
else None
for i in range(len(Y))
]
)
else:
dY = tuple([None] * len(Y))
return (None, None, None, None, None) + dX + dY
......@@ -447,17 +680,17 @@ class EdgeSoftmax(th.autograd.Function):
# a local variable
if not is_all(eids):
gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == 'src':
if norm_by == "src":
gidx = gidx.reverse()
#Note: Now _edge_softmax_forward op only supports CPU
#TODO(Zhejiang): We will support GPU in the future
if(score.is_cuda):
score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0]
score = th.exp(_gsddmm(gidx, 'sub', score, score_max, 'e', 'v'))
score_sum = _gspmm(gidx, 'copy_rhs', 'sum', None, score)[0]
out = _gsddmm(gidx, 'div', score, score_sum, 'e', 'v')
# Note: Now _edge_softmax_forward op only supports CPU
# TODO(Zhejiang): We will support GPU in the future
if score.is_cuda:
score_max = _gspmm(gidx, "copy_rhs", "max", None, score)[0]
score = th.exp(_gsddmm(gidx, "sub", score, score_max, "e", "v"))
score_sum = _gspmm(gidx, "copy_rhs", "sum", None, score)[0]
out = _gsddmm(gidx, "div", score, score_sum, "e", "v")
else:
out = _edge_softmax_forward(gidx, score, 'copy_rhs')
out = _edge_softmax_forward(gidx, score, "copy_rhs")
ctx.backward_cache = gidx
ctx.save_for_backward(out)
return out
......@@ -480,14 +713,14 @@ class EdgeSoftmax(th.autograd.Function):
return grad_score.data
"""
gidx = ctx.backward_cache
out, = ctx.saved_tensors
(out,) = ctx.saved_tensors
sds = out * grad_out
#Note: Now _edge_softmax_backward op only supports CPU
#TODO(Zhejiang): We will support GPU in the future
if(out.is_cuda):
accum = gspmm(gidx, 'copy_rhs', 'sum', None, sds)
# Note: Now _edge_softmax_backward op only supports CPU
# TODO(Zhejiang): We will support GPU in the future
if out.is_cuda:
accum = gspmm(gidx, "copy_rhs", "sum", None, sds)
grad_score = sds - gsddmm(gidx, 'mul', out, accum, 'e', 'v')
grad_score = sds - gsddmm(gidx, "mul", out, accum, "e", "v")
else:
grad_score = _edge_softmax_backward(gidx, out, sds)
return None, grad_score, None, None
......@@ -514,18 +747,28 @@ class EdgeSoftmax_hetero(th.autograd.Function):
# a local variable
if not is_all(eids):
gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == 'src':
if norm_by == "src":
gidx = gidx.reverse()
u_len = gidx.number_of_ntypes()
e_len = gidx.number_of_etypes()
lhs = [None] * u_len
feats = tuple(lhs + list(score))
score_max = _gspmm_hetero(gidx, 'copy_rhs', 'max', u_len, feats)[0]
out_tmp = _gsddmm_hetero(gidx, 'sub', e_len, 'e', 'v', tuple(list(score) + list(score_max)))
score = tuple([th.exp(out_tmp[i]) if out_tmp[i] is not None else None
for i in range(len(out_tmp))])
score_sum = _gspmm_hetero(gidx, 'copy_rhs', 'sum', u_len, tuple(lhs + list(score)))[0]
out = _gsddmm_hetero(gidx, 'div', e_len, 'e', 'v', tuple(list(score) + list(score_sum)))
feats = tuple(lhs + list(score))
score_max = _gspmm_hetero(gidx, "copy_rhs", "max", u_len, feats)[0]
out_tmp = _gsddmm_hetero(
gidx, "sub", e_len, "e", "v", tuple(list(score) + list(score_max))
)
score = tuple(
[
th.exp(out_tmp[i]) if out_tmp[i] is not None else None
for i in range(len(out_tmp))
]
)
score_sum = _gspmm_hetero(
gidx, "copy_rhs", "sum", u_len, tuple(lhs + list(score))
)[0]
out = _gsddmm_hetero(
gidx, "div", e_len, "e", "v", tuple(list(score) + list(score_sum))
)
ctx.backward_cache = gidx
ctx.save_for_backward(*out)
return out
......@@ -552,12 +795,14 @@ class EdgeSoftmax_hetero(th.autograd.Function):
e_len = gidx.number_of_etypes()
lhs = [None] * u_len
out = ctx.saved_tensors
sds = tuple([out[i] * grad_out[i]
for i in range(len(out))])
accum = _gspmm_hetero(gidx, 'copy_rhs', 'sum', u_len, tuple(lhs + list(sds)))[0]
out_sddmm = _gsddmm_hetero(gidx, 'mul', e_len, 'e', 'v', tuple(list(out) + list(accum)))
grad_score = tuple([sds[i] - out_sddmm[i]
for i in range(len(sds))])
sds = tuple([out[i] * grad_out[i] for i in range(len(out))])
accum = _gspmm_hetero(
gidx, "copy_rhs", "sum", u_len, tuple(lhs + list(sds))
)[0]
out_sddmm = _gsddmm_hetero(
gidx, "mul", e_len, "e", "v", tuple(list(out) + list(accum))
)
grad_score = tuple([sds[i] - out_sddmm[i] for i in range(len(sds))])
return (None, None, None) + grad_score
......@@ -576,12 +821,13 @@ class SegmentReduce(th.autograd.Function):
op = ctx.backward_cache
arg, offsets = ctx.saved_tensors
m = offsets[-1].item()
if op == 'sum':
if op == "sum":
offsets = offsets[1:]
# To address the issue of trailing zeros, related issue:
# https://github.com/dmlc/dgl/pull/2610
indices = th.zeros(
(m + 1,), device=offsets.device, dtype=offsets.dtype)
(m + 1,), device=offsets.device, dtype=offsets.dtype
)
indices.scatter_add_(0, offsets, th.ones_like(offsets))
indices = th.cumsum(indices, -1)[:-1]
dx = dy[indices]
......@@ -608,23 +854,50 @@ class ScatterAdd(th.autograd.Function):
class CSRMM(th.autograd.Function):
@staticmethod
def forward(ctx, gidxA, A_weights, gidxB, B_weights, num_vtypes):
gidxC, C_weights = _csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(0, False, 'csr')
gidxC, C_weights = _csrmm(
gidxA, A_weights, gidxB, B_weights, num_vtypes
)
(
nrows,
ncols,
C_indptr,
C_indices,
C_eids,
) = gidxC.adjacency_matrix_tensors(0, False, "csr")
# Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same
# as the underlying tensors of the created graph gidxC.
ctx.backward_cache = gidxA, gidxB, gidxC
ctx.save_for_backward(A_weights, B_weights)
return th.tensor(nrows), th.tensor(ncols), C_indptr, C_indices, C_eids, C_weights
return (
th.tensor(nrows),
th.tensor(ncols),
C_indptr,
C_indices,
C_eids,
C_weights,
)
@staticmethod
def backward(ctx, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):
def backward(
ctx, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights
):
# Only the last argument is meaningful.
gidxA, gidxB, gidxC = ctx.backward_cache
A_weights, B_weights = ctx.saved_tensors
dgidxA, dA_weights = csrmm(
gidxC, dC_weights, gidxB.reverse(), B_weights, gidxA.number_of_ntypes())
gidxC,
dC_weights,
gidxB.reverse(),
B_weights,
gidxA.number_of_ntypes(),
)
dgidxB, dB_weights = csrmm(
gidxA.reverse(), A_weights, gidxC, dC_weights, gidxB.number_of_ntypes())
gidxA.reverse(),
A_weights,
gidxC,
dC_weights,
gidxB.number_of_ntypes(),
)
dA_weights = csrmask(dgidxA, dA_weights, gidxA)
dB_weights = csrmask(dgidxB, dB_weights, gidxB)
return None, dA_weights, None, dB_weights, None
......@@ -635,18 +908,34 @@ class CSRSum(th.autograd.Function):
def forward(ctx, gidxs, *weights):
# PyTorch tensors must be explicit arguments of the forward function
gidxC, C_weights = _csrsum(gidxs, weights)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(
0, False, 'csr')
(
nrows,
ncols,
C_indptr,
C_indices,
C_eids,
) = gidxC.adjacency_matrix_tensors(0, False, "csr")
# Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same
# as the underlying tensors of the created graph gidxC.
ctx.backward_cache = gidxs, gidxC
return th.tensor(nrows), th.tensor(ncols), C_indptr, C_indices, C_eids, C_weights
return (
th.tensor(nrows),
th.tensor(ncols),
C_indptr,
C_indices,
C_eids,
C_weights,
)
@staticmethod
def backward(ctx, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):
def backward(
ctx, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights
):
# Only the last argument is meaningful.
gidxs, gidxC = ctx.backward_cache
return (None,) + tuple(csrmask(gidxC, dC_weights, gidx) for gidx in gidxs)
return (None,) + tuple(
csrmask(gidxC, dC_weights, gidx) for gidx in gidxs
)
class CSRMask(th.autograd.Function):
......@@ -692,7 +981,9 @@ class GATHERMM(th.autograd.Function):
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, A, B, idx_a, idx_b):
if B.dim() != 3:
raise ValueError("Expected dimension of B is 3. Got " + str(B.dim()))
raise ValueError(
"Expected dimension of B is 3. Got " + str(B.dim())
)
N = len(idx_b) if idx_a is None else len(idx_a)
C = th.zeros((N, B.shape[2]), device=A.device, dtype=A.dtype)
C = _gather_mm(A, B, C, idx_a, idx_b)
......@@ -706,103 +997,158 @@ class GATHERMM(th.autograd.Function):
if ctx.needs_input_grad[0]:
# Compute A_grad = Out_grad * B^T
A_grad = th.zeros(A.shape, device=A.device, dtype=A.dtype)
A_grad = _gather_mm_scatter(dZ, B.transpose(1, 2), A_grad,
idx_b=idx_b, idx_c=idx_a)
A_grad = _gather_mm_scatter(
dZ, B.transpose(1, 2), A_grad, idx_b=idx_b, idx_c=idx_a
)
if ctx.needs_input_grad[1]:
# Compute B_grad = A^T * Out_grad
B_grad = th.zeros(B.shape, device=B.device, dtype=B.dtype)
B_grad = _gather_mm_scatter(A, dZ, B_grad, idx_a=idx_a, idx_c=idx_b)
return A_grad, B_grad, None, None
def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
if op == 'sub':
op = 'add'
if op == "sub":
op = "add"
rhs_data = -rhs_data
if op == 'div':
op = 'mul'
rhs_data = 1. / rhs_data
if op == "div":
op = "mul"
rhs_data = 1.0 / rhs_data
return GSpMM.apply(gidx, op, reduce_op, lhs_data, rhs_data)
def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
if op == 'sub':
op = 'add'
def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target="u", rhs_target="v"):
if op == "sub":
op = "add"
rhs_data = -rhs_data
if op == 'div':
op = 'mul'
rhs_data = 1. / rhs_data
if op == "div":
op = "mul"
rhs_data = 1.0 / rhs_data
return GSDDMM.apply(gidx, op, lhs_data, rhs_data, lhs_target, rhs_target)
def gspmm_hetero(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple):
lhs_tuple, rhs_tuple = lhs_and_rhs_tuple[:lhs_len], lhs_and_rhs_tuple[lhs_len:]
if op == 'sub':
op = 'add'
rhs_tuple = tuple([-rhs_tuple[i] if rhs_tuple[i] is not None else None
for i in range(len(rhs_tuple))])
if op == 'div':
op = 'mul'
rhs_tuple = tuple([(1. / rhs_tuple[i]) if rhs_tuple[i] is not None else None
for i in range(len(rhs_tuple))])
if op in ['add', 'mul']:
lhs_tuple, rhs_tuple = (
lhs_and_rhs_tuple[:lhs_len],
lhs_and_rhs_tuple[lhs_len:],
)
if op == "sub":
op = "add"
rhs_tuple = tuple(
[
-rhs_tuple[i] if rhs_tuple[i] is not None else None
for i in range(len(rhs_tuple))
]
)
if op == "div":
op = "mul"
rhs_tuple = tuple(
[
(1.0 / rhs_tuple[i]) if rhs_tuple[i] is not None else None
for i in range(len(rhs_tuple))
]
)
if op in ["add", "mul"]:
lhs_and_rhs_tuple = tuple(list(lhs_tuple) + list(rhs_tuple))
return GSpMM_hetero.apply(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple)
def gsddmm_hetero(g, op, lhs_len, lhs_target='u', rhs_target='v', *lhs_and_rhs_tuple):
lhs_tuple, rhs_tuple = lhs_and_rhs_tuple[:lhs_len], lhs_and_rhs_tuple[lhs_len:]
if op == 'sub':
op = 'add'
rhs_tuple = tuple([-rhs_tuple[i] if rhs_tuple[i] is not None else None
for i in range(len(rhs_tuple))])
if op == 'div':
op = 'mul'
rhs_tuple = tuple([(1. / rhs_tuple[i]) if rhs_tuple[i] is not None else None
for i in range(len(rhs_tuple))])
if op in ['add', 'mul']:
def gsddmm_hetero(
g, op, lhs_len, lhs_target="u", rhs_target="v", *lhs_and_rhs_tuple
):
lhs_tuple, rhs_tuple = (
lhs_and_rhs_tuple[:lhs_len],
lhs_and_rhs_tuple[lhs_len:],
)
if op == "sub":
op = "add"
rhs_tuple = tuple(
[
-rhs_tuple[i] if rhs_tuple[i] is not None else None
for i in range(len(rhs_tuple))
]
)
if op == "div":
op = "mul"
rhs_tuple = tuple(
[
(1.0 / rhs_tuple[i]) if rhs_tuple[i] is not None else None
for i in range(len(rhs_tuple))
]
)
if op in ["add", "mul"]:
lhs_and_rhs_tuple = tuple(list(lhs_tuple) + list(rhs_tuple))
return GSDDMM_hetero.apply(g, op, lhs_len, lhs_target, rhs_target, *lhs_and_rhs_tuple)
return GSDDMM_hetero.apply(
g, op, lhs_len, lhs_target, rhs_target, *lhs_and_rhs_tuple
)
def edge_softmax(gidx, logits, eids=ALL, norm_by='dst'):
def edge_softmax(gidx, logits, eids=ALL, norm_by="dst"):
return EdgeSoftmax.apply(gidx, logits, eids, norm_by)
def edge_softmax_hetero(gidx, eids=ALL, norm_by='dst', *logits):
def edge_softmax_hetero(gidx, eids=ALL, norm_by="dst", *logits):
return EdgeSoftmax_hetero.apply(gidx, eids, norm_by, *logits)
def segment_reduce(op, x, offsets):
return SegmentReduce.apply(op, x, offsets)
def scatter_add(x, idx, m):
return ScatterAdd.apply(x, idx, m)
def csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes):
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = \
CSRMM.apply(gidxA, A_weights, gidxB, B_weights, num_vtypes)
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = CSRMM.apply(
gidxA, A_weights, gidxB, B_weights, num_vtypes
)
gidxC = create_unitgraph_from_csr(
num_vtypes, nrows.item(), ncols.item(), C_indptr, C_indices, C_eids,
["coo", "csr", "csc"])
num_vtypes,
nrows.item(),
ncols.item(),
C_indptr,
C_indices,
C_eids,
["coo", "csr", "csc"],
)
return gidxC, C_weights
def csrsum(gidxs, weights):
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = CSRSum.apply(gidxs, *weights)
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = CSRSum.apply(
gidxs, *weights
)
gidxC = create_unitgraph_from_csr(
gidxs[0].number_of_ntypes(), nrows.item(), ncols.item(), C_indptr, C_indices, C_eids,
["coo", "csr", "csc"])
gidxs[0].number_of_ntypes(),
nrows.item(),
ncols.item(),
C_indptr,
C_indices,
C_eids,
["coo", "csr", "csc"],
)
return gidxC, C_weights
def csrmask(gidxA, A_weights, gidxB):
return CSRMask.apply(gidxA, A_weights, gidxB)
def segment_mm(A, B, seglen_A):
if A.device.type == 'cpu':
if A.device.type == "cpu":
C = []
off = 0
for i in range(B.shape[0]):
C.append(A[off:off+seglen_A[i]] @ B[i])
C.append(A[off : off + seglen_A[i]] @ B[i])
off += seglen_A[i]
return th.cat(C)
else:
return SEGMENTMM.apply(A, B, seglen_A)
def gather_mm(A, B, idx_A=None, idx_B=None):
if A.device.type == 'cpu':
if A.device.type == "cpu":
A = A[idx_A] if idx_A is not None else A
B = B[idx_B] if idx_B is not None else B
return th.bmm(A.unsqueeze(1), B).squeeze(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment