Unverified Commit a208e886 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4680)



* [Misc] Black auto fix.

* fix pylint disable
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 29434e65
...@@ -5,33 +5,41 @@ from __future__ import absolute_import ...@@ -5,33 +5,41 @@ from __future__ import absolute_import
import ctypes import ctypes
import traceback import traceback
from numbers import Number, Integral from numbers import Integral, Number
from ..base import _LIB, check_call from ..base import _LIB, c_str, check_call, string_types
from ..base import c_str, string_types from ..object_generic import ObjectGeneric, convert_to_object
from ..object_generic import convert_to_object, ObjectGeneric from ..runtime_ctypes import DGLByteArray, DGLContext, DGLDataType
from ..runtime_ctypes import DGLDataType, DGLByteArray, DGLContext
from . import ndarray as _nd from . import ndarray as _nd
from . import object as _object
from .ndarray import NDArrayBase, _make_array from .ndarray import NDArrayBase, _make_array
from .types import DGLValue, TypeCode
from .types import DGLPackedCFunc, DGLCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
from .object import ObjectBase from .object import ObjectBase
from . import object as _object from .types import (
C_TO_PY_ARG_SWITCH,
RETURN_SWITCH,
DGLCFuncFinalizer,
DGLPackedCFunc,
DGLValue,
TypeCode,
_wrap_arg_func,
)
FunctionHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p
ModuleHandle = ctypes.c_void_p ModuleHandle = ctypes.c_void_p
DGLRetValueHandle = ctypes.c_void_p DGLRetValueHandle = ctypes.c_void_p
def _ctypes_free_resource(rhandle): def _ctypes_free_resource(rhandle):
"""callback to free resources when it it not needed.""" """callback to free resources when it it not needed."""
pyobj = ctypes.cast(rhandle, ctypes.py_object) pyobj = ctypes.cast(rhandle, ctypes.py_object)
ctypes.pythonapi.Py_DecRef(pyobj) ctypes.pythonapi.Py_DecRef(pyobj)
# Global callback that is always alive # Global callback that is always alive
DGL_FREE_PYOBJ = DGLCFuncFinalizer(_ctypes_free_resource) DGL_FREE_PYOBJ = DGLCFuncFinalizer(_ctypes_free_resource)
ctypes.pythonapi.Py_IncRef(ctypes.py_object(DGL_FREE_PYOBJ)) ctypes.pythonapi.Py_IncRef(ctypes.py_object(DGL_FREE_PYOBJ))
def convert_to_dgl_func(pyfunc): def convert_to_dgl_func(pyfunc):
"""Convert a python function to DGL function """Convert a python function to DGL function
...@@ -46,10 +54,15 @@ def convert_to_dgl_func(pyfunc): ...@@ -46,10 +54,15 @@ def convert_to_dgl_func(pyfunc):
The converted dgl function. The converted dgl function.
""" """
local_pyfunc = pyfunc local_pyfunc = pyfunc
def cfun(args, type_codes, num_args, ret, _): def cfun(args, type_codes, num_args, ret, _):
""" ctypes function """ """ctypes function"""
num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args num_args = (
pyargs = (C_TO_PY_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args)) num_args.value if isinstance(num_args, ctypes.c_int) else num_args
)
pyargs = (
C_TO_PY_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args)
)
# pylint: disable=broad-except # pylint: disable=broad-except
try: try:
rv = local_pyfunc(*pyargs) rv = local_pyfunc(*pyargs)
...@@ -60,12 +73,16 @@ def convert_to_dgl_func(pyfunc): ...@@ -60,12 +73,16 @@ def convert_to_dgl_func(pyfunc):
if rv is not None: if rv is not None:
if isinstance(rv, tuple): if isinstance(rv, tuple):
raise ValueError("PackedFunction can only support one return value") raise ValueError(
"PackedFunction can only support one return value"
)
temp_args = [] temp_args = []
values, tcodes, _ = _make_dgl_args((rv,), temp_args) values, tcodes, _ = _make_dgl_args((rv,), temp_args)
if not isinstance(ret, DGLRetValueHandle): if not isinstance(ret, DGLRetValueHandle):
ret = DGLRetValueHandle(ret) ret = DGLRetValueHandle(ret)
check_call(_LIB.DGLCFuncSetReturn(ret, values, tcodes, ctypes.c_int(1))) check_call(
_LIB.DGLCFuncSetReturn(ret, values, tcodes, ctypes.c_int(1))
)
_ = temp_args _ = temp_args
_ = rv _ = rv
return 0 return 0
...@@ -76,8 +93,11 @@ def convert_to_dgl_func(pyfunc): ...@@ -76,8 +93,11 @@ def convert_to_dgl_func(pyfunc):
# DGL_FREE_PYOBJ will be called after it is no longer needed. # DGL_FREE_PYOBJ will be called after it is no longer needed.
pyobj = ctypes.py_object(f) pyobj = ctypes.py_object(f)
ctypes.pythonapi.Py_IncRef(pyobj) ctypes.pythonapi.Py_IncRef(pyobj)
check_call(_LIB.DGLFuncCreateFromCFunc( check_call(
f, pyobj, DGL_FREE_PYOBJ, ctypes.byref(handle))) _LIB.DGLFuncCreateFromCFunc(
f, pyobj, DGL_FREE_PYOBJ, ctypes.byref(handle)
)
)
return _CLASS_FUNCTION(handle, False) return _CLASS_FUNCTION(handle, False)
...@@ -104,8 +124,11 @@ def _make_dgl_args(args, temp_args): ...@@ -104,8 +124,11 @@ def _make_dgl_args(args, temp_args):
temp_args.append(arg) temp_args.append(arg)
elif isinstance(arg, NDArrayBase): elif isinstance(arg, NDArrayBase):
values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p) values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
type_codes[i] = (TypeCode.NDARRAY_CONTAINER type_codes[i] = (
if not arg.is_view else TypeCode.ARRAY_HANDLE) TypeCode.NDARRAY_CONTAINER
if not arg.is_view
else TypeCode.ARRAY_HANDLE
)
elif isinstance(arg, _nd._DGL_COMPATS): elif isinstance(arg, _nd._DGL_COMPATS):
values[i].v_handle = ctypes.c_void_p(arg._dgl_handle) values[i].v_handle = ctypes.c_void_p(arg._dgl_handle)
type_codes[i] = arg.__class__._dgl_tcode type_codes[i] = arg.__class__._dgl_tcode
...@@ -125,7 +148,8 @@ def _make_dgl_args(args, temp_args): ...@@ -125,7 +148,8 @@ def _make_dgl_args(args, temp_args):
arr = DGLByteArray() arr = DGLByteArray()
arr.data = ctypes.cast( arr.data = ctypes.cast(
(ctypes.c_byte * len(arg)).from_buffer(arg), (ctypes.c_byte * len(arg)).from_buffer(arg),
ctypes.POINTER(ctypes.c_byte)) ctypes.POINTER(ctypes.c_byte),
)
arr.size = len(arg) arr.size = len(arg)
values[i].v_handle = ctypes.c_void_p(ctypes.addressof(arr)) values[i].v_handle = ctypes.c_void_p(ctypes.addressof(arr))
temp_args.append(arr) temp_args.append(arr)
...@@ -134,7 +158,7 @@ def _make_dgl_args(args, temp_args): ...@@ -134,7 +158,7 @@ def _make_dgl_args(args, temp_args):
values[i].v_str = c_str(arg) values[i].v_str = c_str(arg)
type_codes[i] = TypeCode.STR type_codes[i] = TypeCode.STR
# NOTE(minjie): module is not used in DGL # NOTE(minjie): module is not used in DGL
#elif isinstance(arg, _CLASS_MODULE): # elif isinstance(arg, _CLASS_MODULE):
# values[i].v_handle = arg.handle # values[i].v_handle = arg.handle
# type_codes[i] = TypeCode.MODULE_HANDLE # type_codes[i] = TypeCode.MODULE_HANDLE
elif isinstance(arg, FunctionBase): elif isinstance(arg, FunctionBase):
...@@ -155,6 +179,7 @@ def _make_dgl_args(args, temp_args): ...@@ -155,6 +179,7 @@ def _make_dgl_args(args, temp_args):
class FunctionBase(object): class FunctionBase(object):
"""Function base.""" """Function base."""
__slots__ = ["handle", "is_global"] __slots__ = ["handle", "is_global"]
# pylint: disable=no-member # pylint: disable=no-member
def __init__(self, handle, is_global): def __init__(self, handle, is_global):
...@@ -185,9 +210,16 @@ class FunctionBase(object): ...@@ -185,9 +210,16 @@ class FunctionBase(object):
values, tcodes, num_args = _make_dgl_args(args, temp_args) values, tcodes, num_args = _make_dgl_args(args, temp_args)
ret_val = DGLValue() ret_val = DGLValue()
ret_tcode = ctypes.c_int() ret_tcode = ctypes.c_int()
check_call(_LIB.DGLFuncCall( check_call(
self.handle, values, tcodes, ctypes.c_int(num_args), _LIB.DGLFuncCall(
ctypes.byref(ret_val), ctypes.byref(ret_tcode))) self.handle,
values,
tcodes,
ctypes.c_int(num_args),
ctypes.byref(ret_val),
ctypes.byref(ret_tcode),
)
)
_ = temp_args _ = temp_args
_ = args _ = args
return RETURN_SWITCH[ret_tcode.value](ret_val) return RETURN_SWITCH[ret_tcode.value](ret_val)
...@@ -199,9 +231,16 @@ def __init_handle_by_constructor__(fconstructor, args): ...@@ -199,9 +231,16 @@ def __init_handle_by_constructor__(fconstructor, args):
values, tcodes, num_args = _make_dgl_args(args, temp_args) values, tcodes, num_args = _make_dgl_args(args, temp_args)
ret_val = DGLValue() ret_val = DGLValue()
ret_tcode = ctypes.c_int() ret_tcode = ctypes.c_int()
check_call(_LIB.DGLFuncCall( check_call(
fconstructor.handle, values, tcodes, ctypes.c_int(num_args), _LIB.DGLFuncCall(
ctypes.byref(ret_val), ctypes.byref(ret_tcode))) fconstructor.handle,
values,
tcodes,
ctypes.c_int(num_args),
ctypes.byref(ret_val),
ctypes.byref(ret_tcode),
)
)
_ = temp_args _ = temp_args
_ = args _ = args
assert ret_tcode.value == TypeCode.OBJECT_HANDLE assert ret_tcode.value == TypeCode.OBJECT_HANDLE
...@@ -216,6 +255,7 @@ def _return_module(x): ...@@ -216,6 +255,7 @@ def _return_module(x):
handle = ModuleHandle(handle) handle = ModuleHandle(handle)
return _CLASS_MODULE(handle) return _CLASS_MODULE(handle)
def _handle_return_func(x): def _handle_return_func(x):
"""Return function""" """Return function"""
handle = x.v_handle handle = x.v_handle
...@@ -228,22 +268,32 @@ def _handle_return_func(x): ...@@ -228,22 +268,32 @@ def _handle_return_func(x):
_object.__init_by_constructor__ = __init_handle_by_constructor__ _object.__init_by_constructor__ = __init_handle_by_constructor__
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False) RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(
x.v_handle, False
)
C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func( C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func(
_handle_return_func, TypeCode.FUNC_HANDLE) _handle_return_func, TypeCode.FUNC_HANDLE
)
C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func( C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
_return_module, TypeCode.MODULE_HANDLE) _return_module, TypeCode.MODULE_HANDLE
C_TO_PY_ARG_SWITCH[TypeCode.ARRAY_HANDLE] = lambda x: _make_array(x.v_handle, True) )
C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False) C_TO_PY_ARG_SWITCH[TypeCode.ARRAY_HANDLE] = lambda x: _make_array(
x.v_handle, True
)
C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(
x.v_handle, False
)
_CLASS_MODULE = None _CLASS_MODULE = None
_CLASS_FUNCTION = None _CLASS_FUNCTION = None
def _set_class_module(module_class): def _set_class_module(module_class):
"""Initialize the module.""" """Initialize the module."""
global _CLASS_MODULE global _CLASS_MODULE
_CLASS_MODULE = module_class _CLASS_MODULE = module_class
def _set_class_function(func_class): def _set_class_function(func_class):
global _CLASS_FUNCTION global _CLASS_FUNCTION
_CLASS_FUNCTION = func_class _CLASS_FUNCTION = func_class
...@@ -3,18 +3,23 @@ ...@@ -3,18 +3,23 @@
from __future__ import absolute_import from __future__ import absolute_import
import ctypes import ctypes
from ..base import _LIB, check_call, c_str
from ..runtime_ctypes import DGLArrayHandle
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _return_handle
from ..base import _LIB, c_str, check_call
from ..runtime_ctypes import DGLArrayHandle
from .types import (
C_TO_PY_ARG_SWITCH,
RETURN_SWITCH,
_return_handle,
_wrap_arg_func,
)
DGLPyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p) DGLPyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
_c_str_dltensor = c_str('dltensor') _c_str_dltensor = c_str("dltensor")
_c_str_used_dltensor = c_str('used_dltensor') _c_str_used_dltensor = c_str("used_dltensor")
# used for PyCapsule manipulation # used for PyCapsule manipulation
if hasattr(ctypes, 'pythonapi'): if hasattr(ctypes, "pythonapi"):
ctypes.pythonapi.PyCapsule_GetName.restype = ctypes.c_char_p ctypes.pythonapi.PyCapsule_GetName.restype = ctypes.c_char_p
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
...@@ -31,9 +36,13 @@ def _from_dlpack(dltensor): ...@@ -31,9 +36,13 @@ def _from_dlpack(dltensor):
handle = DGLArrayHandle() handle = DGLArrayHandle()
check_call(_LIB.DGLArrayFromDLPack(ptr, ctypes.byref(handle))) check_call(_LIB.DGLArrayFromDLPack(ptr, ctypes.byref(handle)))
ctypes.pythonapi.PyCapsule_SetName(dltensor, _c_str_used_dltensor) ctypes.pythonapi.PyCapsule_SetName(dltensor, _c_str_used_dltensor)
ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, DGLPyCapsuleDestructor(0)) ctypes.pythonapi.PyCapsule_SetDestructor(
dltensor, DGLPyCapsuleDestructor(0)
)
return _make_array(handle, False) return _make_array(handle, False)
raise ValueError("Expect a dltensor field, PyCapsule can only be consumed once") raise ValueError(
"Expect a dltensor field, PyCapsule can only be consumed once"
)
def _dlpack_deleter(pycapsule): def _dlpack_deleter(pycapsule):
...@@ -45,13 +54,17 @@ def _dlpack_deleter(pycapsule): ...@@ -45,13 +54,17 @@ def _dlpack_deleter(pycapsule):
# work out always. # work out always.
ptr = ctypes.cast(ptr, ctypes.c_void_p) ptr = ctypes.cast(ptr, ctypes.c_void_p)
_LIB.DGLDLManagedTensorCallDeleter(ptr) _LIB.DGLDLManagedTensorCallDeleter(ptr)
ctypes.pythonapi.PyCapsule_SetDestructor(pycapsule, DGLPyCapsuleDestructor(0)) ctypes.pythonapi.PyCapsule_SetDestructor(
pycapsule, DGLPyCapsuleDestructor(0)
)
_c_dlpack_deleter = DGLPyCapsuleDestructor(_dlpack_deleter) _c_dlpack_deleter = DGLPyCapsuleDestructor(_dlpack_deleter)
class NDArrayBase(object): class NDArrayBase(object):
"""A simple Device/CPU Array object in runtime.""" """A simple Device/CPU Array object in runtime."""
__slots__ = ["handle", "is_view"] __slots__ = ["handle", "is_view"]
# pylint: disable=no-member # pylint: disable=no-member
def __init__(self, handle, is_view=False): def __init__(self, handle, is_view=False):
...@@ -89,27 +102,36 @@ class NDArrayBase(object): ...@@ -89,27 +102,36 @@ class NDArrayBase(object):
dlpack : DLPack tensor view of the array data dlpack : DLPack tensor view of the array data
""" """
ptr = ctypes.c_void_p() ptr = ctypes.c_void_p()
check_call(_LIB.DGLArrayToDLPack(self.handle, ctypes.byref(ptr), alignment)) check_call(
return ctypes.pythonapi.PyCapsule_New(ptr, _c_str_dltensor, _c_dlpack_deleter) _LIB.DGLArrayToDLPack(self.handle, ctypes.byref(ptr), alignment)
)
return ctypes.pythonapi.PyCapsule_New(
ptr, _c_str_dltensor, _c_dlpack_deleter
)
def _make_array(handle, is_view): def _make_array(handle, is_view):
handle = ctypes.cast(handle, DGLArrayHandle) handle = ctypes.cast(handle, DGLArrayHandle)
return _CLASS_NDARRAY(handle, is_view) return _CLASS_NDARRAY(handle, is_view)
_DGL_COMPATS = () _DGL_COMPATS = ()
def _reg_extension(cls, fcreate): def _reg_extension(cls, fcreate):
global _DGL_COMPATS global _DGL_COMPATS
_DGL_COMPATS += (cls,) _DGL_COMPATS += (cls,)
if fcreate: if fcreate:
fret = lambda x: fcreate(_return_handle(x)) fret = lambda x: fcreate(_return_handle(x))
RETURN_SWITCH[cls._dgl_tcode] = fret RETURN_SWITCH[cls._dgl_tcode] = fret
C_TO_PY_ARG_SWITCH[cls._dgl_tcode] = _wrap_arg_func(fret, cls._dgl_tcode) C_TO_PY_ARG_SWITCH[cls._dgl_tcode] = _wrap_arg_func(
fret, cls._dgl_tcode
)
_CLASS_NDARRAY = None _CLASS_NDARRAY = None
def _set_class_ndarray(cls): def _set_class_ndarray(cls):
global _CLASS_NDARRAY global _CLASS_NDARRAY
_CLASS_NDARRAY = cls _CLASS_NDARRAY = cls
...@@ -2,10 +2,16 @@ ...@@ -2,10 +2,16 @@
from __future__ import absolute_import from __future__ import absolute_import
import ctypes import ctypes
from ..base import _LIB, check_call, c_str
from ..base import _LIB, c_str, check_call
from ..object_generic import _set_class_object_base from ..object_generic import _set_class_object_base
from .types import DGLValue, TypeCode from .types import (
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func C_TO_PY_ARG_SWITCH,
RETURN_SWITCH,
DGLValue,
TypeCode,
_wrap_arg_func,
)
ObjectHandle = ctypes.c_void_p ObjectHandle = ctypes.c_void_p
__init_by_constructor__ = None __init_by_constructor__ = None
...@@ -13,10 +19,12 @@ __init_by_constructor__ = None ...@@ -13,10 +19,12 @@ __init_by_constructor__ = None
"""Maps object type to its constructor""" """Maps object type to its constructor"""
OBJECT_TYPE = {} OBJECT_TYPE = {}
def _register_object(index, cls): def _register_object(index, cls):
"""register object class in python""" """register object class in python"""
OBJECT_TYPE[index] = cls OBJECT_TYPE[index] = cls
def _return_object(x): def _return_object(x):
"""Construct a object object from the given DGLValue object""" """Construct a object object from the given DGLValue object"""
handle = x.v_handle handle = x.v_handle
...@@ -34,32 +42,41 @@ def _return_object(x): ...@@ -34,32 +42,41 @@ def _return_object(x):
RETURN_SWITCH[TypeCode.OBJECT_HANDLE] = _return_object RETURN_SWITCH[TypeCode.OBJECT_HANDLE] = _return_object
C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_HANDLE] = _wrap_arg_func( C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_HANDLE] = _wrap_arg_func(
_return_object, TypeCode.OBJECT_HANDLE) _return_object, TypeCode.OBJECT_HANDLE
)
class ObjectBase(object): class ObjectBase(object):
"""Object base class""" """Object base class"""
__slots__ = ["handle"] __slots__ = ["handle"]
# pylint: disable=no-member # pylint: disable=no-member
def __del__(self): def __del__(self):
if _LIB is not None and hasattr(self, 'handle'): if _LIB is not None and hasattr(self, "handle"):
check_call(_LIB.DGLObjectFree(self.handle)) check_call(_LIB.DGLObjectFree(self.handle))
def __getattr__(self, name): def __getattr__(self, name):
if name == 'handle': if name == "handle":
raise AttributeError("'handle' is a reserved attribute name that should not be used") raise AttributeError(
"'handle' is a reserved attribute name that should not be used"
)
ret_val = DGLValue() ret_val = DGLValue()
ret_type_code = ctypes.c_int() ret_type_code = ctypes.c_int()
ret_success = ctypes.c_int() ret_success = ctypes.c_int()
check_call(_LIB.DGLObjectGetAttr( check_call(
self.handle, c_str(name), _LIB.DGLObjectGetAttr(
ctypes.byref(ret_val), self.handle,
ctypes.byref(ret_type_code), c_str(name),
ctypes.byref(ret_success))) ctypes.byref(ret_val),
ctypes.byref(ret_type_code),
ctypes.byref(ret_success),
)
)
if not ret_success.value: if not ret_success.value:
raise AttributeError( raise AttributeError(
"'%s' object has no attribute '%s'" % (str(type(self)), name)) "'%s' object has no attribute '%s'" % (str(type(self)), name)
)
return RETURN_SWITCH[ret_type_code.value](ret_val) return RETURN_SWITCH[ret_type_code.value](ret_val)
def __init_handle_by_constructor__(self, fconstructor, *args): def __init_handle_by_constructor__(self, fconstructor, *args):
...@@ -81,9 +98,12 @@ class ObjectBase(object): ...@@ -81,9 +98,12 @@ class ObjectBase(object):
""" """
# assign handle first to avoid error raising # assign handle first to avoid error raising
self.handle = None self.handle = None
handle = __init_by_constructor__(fconstructor, args) # pylint: disable=not-callable handle = __init_by_constructor__(
fconstructor, args
) # pylint: disable=not-callable
if not isinstance(handle, ObjectHandle): if not isinstance(handle, ObjectHandle):
handle = ObjectHandle(handle) handle = ObjectHandle(handle)
self.handle = handle self.handle = handle
_set_class_object_base(ObjectBase) _set_class_object_base(ObjectBase)
...@@ -3,17 +3,22 @@ ...@@ -3,17 +3,22 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import ctypes import ctypes
from ..base import py_str, check_call, _LIB
from ..runtime_ctypes import DGLByteArray, TypeCode, DGLDataType, DGLContext from ..base import _LIB, check_call, py_str
from ..runtime_ctypes import DGLByteArray, DGLContext, DGLDataType, TypeCode
class DGLValue(ctypes.Union): class DGLValue(ctypes.Union):
"""DGLValue in C API""" """DGLValue in C API"""
_fields_ = [("v_int64", ctypes.c_int64),
("v_float64", ctypes.c_double), _fields_ = [
("v_handle", ctypes.c_void_p), ("v_int64", ctypes.c_int64),
("v_str", ctypes.c_char_p), ("v_float64", ctypes.c_double),
("v_type", DGLDataType), ("v_handle", ctypes.c_void_p),
("v_ctx", DGLContext)] ("v_str", ctypes.c_char_p),
("v_type", DGLDataType),
("v_ctx", DGLContext),
]
DGLPackedCFunc = ctypes.CFUNCTYPE( DGLPackedCFunc = ctypes.CFUNCTYPE(
...@@ -22,12 +27,11 @@ DGLPackedCFunc = ctypes.CFUNCTYPE( ...@@ -22,12 +27,11 @@ DGLPackedCFunc = ctypes.CFUNCTYPE(
ctypes.POINTER(ctypes.c_int), ctypes.POINTER(ctypes.c_int),
ctypes.c_int, ctypes.c_int,
ctypes.c_void_p, ctypes.c_void_p,
ctypes.c_void_p) ctypes.c_void_p,
)
DGLCFuncFinalizer = ctypes.CFUNCTYPE( DGLCFuncFinalizer = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
None,
ctypes.c_void_p)
def _return_handle(x): def _return_handle(x):
...@@ -37,6 +41,7 @@ def _return_handle(x): ...@@ -37,6 +41,7 @@ def _return_handle(x):
handle = ctypes.c_void_p(handle) handle = ctypes.c_void_p(handle)
return handle return handle
def _return_bytes(x): def _return_bytes(x):
"""return handle""" """return handle"""
handle = x.v_handle handle = x.v_handle
...@@ -47,16 +52,20 @@ def _return_bytes(x): ...@@ -47,16 +52,20 @@ def _return_bytes(x):
res = bytearray(size) res = bytearray(size)
rptr = (ctypes.c_byte * size).from_buffer(res) rptr = (ctypes.c_byte * size).from_buffer(res)
if not ctypes.memmove(rptr, arr.data, size): if not ctypes.memmove(rptr, arr.data, size):
raise RuntimeError('memmove failed') raise RuntimeError("memmove failed")
return res return res
def _wrap_arg_func(return_f, type_code): def _wrap_arg_func(return_f, type_code):
tcode = ctypes.c_int(type_code) tcode = ctypes.c_int(type_code)
def _wrap_func(x): def _wrap_func(x):
check_call(_LIB.DGLCbArgToReturn(ctypes.byref(x), tcode)) check_call(_LIB.DGLCbArgToReturn(ctypes.byref(x), tcode))
return return_f(x) return return_f(x)
return _wrap_func return _wrap_func
RETURN_SWITCH = { RETURN_SWITCH = {
TypeCode.INT: lambda x: x.v_int64, TypeCode.INT: lambda x: x.v_int64,
TypeCode.FLOAT: lambda x: x.v_float64, TypeCode.FLOAT: lambda x: x.v_float64,
...@@ -64,7 +73,9 @@ RETURN_SWITCH = { ...@@ -64,7 +73,9 @@ RETURN_SWITCH = {
TypeCode.NULL: lambda x: None, TypeCode.NULL: lambda x: None,
TypeCode.STR: lambda x: py_str(x.v_str), TypeCode.STR: lambda x: py_str(x.v_str),
TypeCode.BYTES: _return_bytes, TypeCode.BYTES: _return_bytes,
TypeCode.DGL_CONTEXT: lambda x: DGLContext(x.v_ctx.device_type, x.v_ctx.device_id), TypeCode.DGL_CONTEXT: lambda x: DGLContext(
x.v_ctx.device_type, x.v_ctx.device_id
),
} }
C_TO_PY_ARG_SWITCH = { C_TO_PY_ARG_SWITCH = {
...@@ -74,5 +85,7 @@ C_TO_PY_ARG_SWITCH = { ...@@ -74,5 +85,7 @@ C_TO_PY_ARG_SWITCH = {
TypeCode.NULL: lambda x: None, TypeCode.NULL: lambda x: None,
TypeCode.STR: lambda x: py_str(x.v_str), TypeCode.STR: lambda x: py_str(x.v_str),
TypeCode.BYTES: _return_bytes, TypeCode.BYTES: _return_bytes,
TypeCode.DGL_CONTEXT: lambda x: DGLContext(x.v_ctx.device_type, x.v_ctx.device_id), TypeCode.DGL_CONTEXT: lambda x: DGLContext(
x.v_ctx.device_type, x.v_ctx.device_id
),
} }
...@@ -3,22 +3,24 @@ ...@@ -3,22 +3,24 @@
"""ctypes library and helper functions """ """ctypes library and helper functions """
from __future__ import absolute_import from __future__ import absolute_import
import sys
import os
import ctypes import ctypes
import logging import logging
import os
import sys
import numpy as np import numpy as np
from . import libinfo from . import libinfo
#---------------------------- # ----------------------------
# library loading # library loading
#---------------------------- # ----------------------------
if sys.version_info[0] == 3: if sys.version_info[0] == 3:
string_types = (str,) string_types = (str,)
numeric_types = (float, int, np.float32, np.int32) numeric_types = (float, int, np.float32, np.int32)
# this function is needed for python3 # this function is needed for python3
# to convert ctypes.char_p .value back to python str # to convert ctypes.char_p .value back to python str
py_str = lambda x: x.decode('utf-8') py_str = lambda x: x.decode("utf-8")
else: else:
string_types = (basestring,) string_types = (basestring,)
numeric_types = (float, int, long, np.float32, np.int32) numeric_types = (float, int, long, np.float32, np.int32)
...@@ -27,8 +29,10 @@ else: ...@@ -27,8 +29,10 @@ else:
class DGLError(Exception): class DGLError(Exception):
"""Error thrown by DGL function""" """Error thrown by DGL function"""
pass # pylint: disable=unnecessary-pass pass # pylint: disable=unnecessary-pass
def _load_lib(): def _load_lib():
"""Load libary by searching possible path.""" """Load libary by searching possible path."""
lib_path = libinfo.find_lib_path() lib_path = libinfo.find_lib_path()
...@@ -39,6 +43,7 @@ def _load_lib(): ...@@ -39,6 +43,7 @@ def _load_lib():
lib.DGLGetLastError.restype = ctypes.c_char_p lib.DGLGetLastError.restype = ctypes.c_char_p
return lib, basename, dirname return lib, basename, dirname
# version number # version number
__version__ = libinfo.__version__ __version__ = libinfo.__version__
# library instance of nnvm # library instance of nnvm
...@@ -47,9 +52,9 @@ _LIB, _LIB_NAME, _DIR_NAME = _load_lib() ...@@ -47,9 +52,9 @@ _LIB, _LIB_NAME, _DIR_NAME = _load_lib()
# The FFI mode of DGL # The FFI mode of DGL
_FFI_MODE = os.environ.get("DGL_FFI", "auto") _FFI_MODE = os.environ.get("DGL_FFI", "auto")
#---------------------------- # ----------------------------
# helper function in ctypes. # helper function in ctypes.
#---------------------------- # ----------------------------
def check_call(ret): def check_call(ret):
"""Check the return value of C API call """Check the return value of C API call
...@@ -77,7 +82,7 @@ def c_str(string): ...@@ -77,7 +82,7 @@ def c_str(string):
str : c_char_p str : c_char_p
A char pointer that can be passed to C API A char pointer that can be passed to C API
""" """
return ctypes.c_char_p(string.encode('utf-8')) return ctypes.c_char_p(string.encode("utf-8"))
def c_array(ctype, values): def c_array(ctype, values):
...@@ -111,10 +116,13 @@ def decorate(func, fwrapped): ...@@ -111,10 +116,13 @@ def decorate(func, fwrapped):
The wrapped function The wrapped function
""" """
import decorator import decorator
return decorator.decorate(func, fwrapped) return decorator.decorate(func, fwrapped)
tensor_adapter_loaded = False tensor_adapter_loaded = False
def load_tensor_adapter(backend, version): def load_tensor_adapter(backend, version):
"""Tell DGL to load a tensoradapter library for given backend and version. """Tell DGL to load a tensoradapter library for given backend and version.
...@@ -126,17 +134,17 @@ def load_tensor_adapter(backend, version): ...@@ -126,17 +134,17 @@ def load_tensor_adapter(backend, version):
The version number of the backend. The version number of the backend.
""" """
global tensor_adapter_loaded global tensor_adapter_loaded
version = version.split('+')[0] version = version.split("+")[0]
if sys.platform.startswith('linux'): if sys.platform.startswith("linux"):
basename = 'libtensoradapter_%s_%s.so' % (backend, version) basename = "libtensoradapter_%s_%s.so" % (backend, version)
elif sys.platform.startswith('darwin'): elif sys.platform.startswith("darwin"):
basename = 'libtensoradapter_%s_%s.dylib' % (backend, version) basename = "libtensoradapter_%s_%s.dylib" % (backend, version)
elif sys.platform.startswith('win'): elif sys.platform.startswith("win"):
basename = 'tensoradapter_%s_%s.dll' % (backend, version) basename = "tensoradapter_%s_%s.dll" % (backend, version)
else: else:
raise NotImplementedError('Unsupported system: %s' % sys.platform) raise NotImplementedError("Unsupported system: %s" % sys.platform)
path = os.path.join(_DIR_NAME, 'tensoradapter', backend, basename) path = os.path.join(_DIR_NAME, "tensoradapter", backend, basename)
tensor_adapter_loaded = (_LIB.DGLLoadTensorAdapter(path.encode('utf-8')) == 0) tensor_adapter_loaded = _LIB.DGLLoadTensorAdapter(path.encode("utf-8")) == 0
if not tensor_adapter_loaded: if not tensor_adapter_loaded:
logger = logging.getLogger("dgl-core") logger = logging.getLogger("dgl-core")
logger.debug("Memory optimization with PyTorch is not enabled.") logger.debug("Memory optimization with PyTorch is not enabled.")
...@@ -2,9 +2,10 @@ ...@@ -2,9 +2,10 @@
"""Function namespace.""" """Function namespace."""
from __future__ import absolute_import from __future__ import absolute_import
import sys
import ctypes import ctypes
from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE import sys
from .base import _FFI_MODE, _LIB, c_str, check_call, py_str, string_types
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
...@@ -13,21 +14,31 @@ try: ...@@ -13,21 +14,31 @@ try:
if _FFI_MODE == "ctypes": if _FFI_MODE == "ctypes":
raise ImportError() raise ImportError()
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_function, _set_class_module
from ._cy3.core import FunctionBase as _FunctionBase from ._cy3.core import FunctionBase as _FunctionBase
from ._cy3.core import convert_to_dgl_func from ._cy3.core import (
_set_class_function,
_set_class_module,
convert_to_dgl_func,
)
else: else:
from ._cy2.core import _set_class_function, _set_class_module
from ._cy2.core import FunctionBase as _FunctionBase from ._cy2.core import FunctionBase as _FunctionBase
from ._cy2.core import convert_to_dgl_func from ._cy2.core import (
_set_class_function,
_set_class_module,
convert_to_dgl_func,
)
except IMPORT_EXCEPT: except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_function, _set_class_module
from ._ctypes.function import FunctionBase as _FunctionBase from ._ctypes.function import FunctionBase as _FunctionBase
from ._ctypes.function import convert_to_dgl_func from ._ctypes.function import (
_set_class_function,
_set_class_module,
convert_to_dgl_func,
)
FunctionHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p
class Function(_FunctionBase): class Function(_FunctionBase):
"""The PackedFunc object. """The PackedFunc object.
...@@ -51,11 +62,13 @@ class Function(_FunctionBase): ...@@ -51,11 +62,13 @@ class Function(_FunctionBase):
dgl.register_func: How to register global function. dgl.register_func: How to register global function.
dgl.get_global_func: How to get global function. dgl.get_global_func: How to get global function.
""" """
pass # pylint: disable=unnecessary-pass pass # pylint: disable=unnecessary-pass
class ModuleBase(object): class ModuleBase(object):
"""Base class for module""" """Base class for module"""
__slots__ = ["handle", "_entry", "entry_name"] __slots__ = ["handle", "_entry", "entry_name"]
def __init__(self, handle): def __init__(self, handle):
...@@ -97,13 +110,16 @@ class ModuleBase(object): ...@@ -97,13 +110,16 @@ class ModuleBase(object):
The result function. The result function.
""" """
ret_handle = FunctionHandle() ret_handle = FunctionHandle()
check_call(_LIB.DGLModGetFunction( check_call(
self.handle, c_str(name), _LIB.DGLModGetFunction(
ctypes.c_int(query_imports), self.handle,
ctypes.byref(ret_handle))) c_str(name),
ctypes.c_int(query_imports),
ctypes.byref(ret_handle),
)
)
if not ret_handle.value: if not ret_handle.value:
raise AttributeError( raise AttributeError("Module has no function '%s'" % name)
"Module has no function '%s'" % name)
return Function(ret_handle, False) return Function(ret_handle, False)
def import_module(self, module): def import_module(self, module):
...@@ -175,13 +191,16 @@ def register_func(func_name, f=None, override=False): ...@@ -175,13 +191,16 @@ def register_func(func_name, f=None, override=False):
raise ValueError("expect string function name") raise ValueError("expect string function name")
ioverride = ctypes.c_int(override) ioverride = ctypes.c_int(override)
def register(myf): def register(myf):
"""internal register function""" """internal register function"""
if not isinstance(myf, Function): if not isinstance(myf, Function):
myf = convert_to_dgl_func(myf) myf = convert_to_dgl_func(myf)
check_call(_LIB.DGLFuncRegisterGlobal( check_call(
c_str(func_name), myf.handle, ioverride)) _LIB.DGLFuncRegisterGlobal(c_str(func_name), myf.handle, ioverride)
)
return myf return myf
if f: if f:
return register(f) return register(f)
return register return register
...@@ -214,7 +233,6 @@ def get_global_func(name, allow_missing=False): ...@@ -214,7 +233,6 @@ def get_global_func(name, allow_missing=False):
raise ValueError("Cannot find global function %s" % name) raise ValueError("Cannot find global function %s" % name)
def list_global_func_names(): def list_global_func_names():
"""Get list of global functions registered. """Get list of global functions registered.
...@@ -226,8 +244,9 @@ def list_global_func_names(): ...@@ -226,8 +244,9 @@ def list_global_func_names():
plist = ctypes.POINTER(ctypes.c_char_p)() plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint() size = ctypes.c_uint()
check_call(_LIB.DGLFuncListGlobalNames(ctypes.byref(size), check_call(
ctypes.byref(plist))) _LIB.DGLFuncListGlobalNames(ctypes.byref(size), ctypes.byref(plist))
)
fnames = [] fnames = []
for i in range(size.value): for i in range(size.value):
fnames.append(py_str(plist[i])) fnames.append(py_str(plist[i]))
...@@ -249,8 +268,10 @@ def extract_ext_funcs(finit): ...@@ -249,8 +268,10 @@ def extract_ext_funcs(finit):
The extracted functions The extracted functions
""" """
fdict = {} fdict = {}
def _list(name, func): def _list(name, func):
fdict[name] = func fdict[name] = func
myf = convert_to_dgl_func(_list) myf = convert_to_dgl_func(_list)
ret = finit(myf.handle) ret = finit(myf.handle)
_ = myf _ = myf
...@@ -258,11 +279,13 @@ def extract_ext_funcs(finit): ...@@ -258,11 +279,13 @@ def extract_ext_funcs(finit):
raise RuntimeError("cannot initialize with %s" % finit) raise RuntimeError("cannot initialize with %s" % finit)
return fdict return fdict
def _get_api(f): def _get_api(f):
flocal = f flocal = f
flocal.is_global = True flocal.is_global = True
return flocal return flocal
def _init_api(namespace, target_module_name=None): def _init_api(namespace, target_module_name=None):
"""Initialize api for a given module name """Initialize api for a given module name
...@@ -272,8 +295,7 @@ def _init_api(namespace, target_module_name=None): ...@@ -272,8 +295,7 @@ def _init_api(namespace, target_module_name=None):
target_module_name : str target_module_name : str
The target module name if different from namespace The target module name if different from namespace
""" """
target_module_name = ( target_module_name = target_module_name if target_module_name else namespace
target_module_name if target_module_name else namespace)
if namespace.startswith("dgl."): if namespace.startswith("dgl."):
_init_api_prefix(target_module_name, namespace[4:]) _init_api_prefix(target_module_name, namespace[4:])
else: else:
...@@ -284,10 +306,10 @@ def _init_api_prefix(module_name, prefix): ...@@ -284,10 +306,10 @@ def _init_api_prefix(module_name, prefix):
module = sys.modules[module_name] module = sys.modules[module_name]
for name in list_global_func_names(): for name in list_global_func_names():
if name.startswith("_") and not name.startswith('_deprecate'): if name.startswith("_") and not name.startswith("_deprecate"):
# internal APIs are ignored # internal APIs are ignored
continue continue
name_split = name.rsplit('.', 1) name_split = name.rsplit(".", 1)
if name_split[0] != prefix: if name_split[0] != prefix:
continue continue
...@@ -300,12 +322,13 @@ def _init_api_prefix(module_name, prefix): ...@@ -300,12 +322,13 @@ def _init_api_prefix(module_name, prefix):
f = get_global_func(name) f = get_global_func(name)
ff = _get_api(f) ff = _get_api(f)
ff.__name__ = fname ff.__name__ = fname
ff.__doc__ = ("DGL PackedFunc %s. " % fname) ff.__doc__ = "DGL PackedFunc %s. " % fname
setattr(target_module, ff.__name__, ff) setattr(target_module, ff.__name__, ff)
def _init_internal_api(): def _init_internal_api():
for name in list_global_func_names(): for name in list_global_func_names():
if not name.startswith("_") or name.startswith('_deprecate'): if not name.startswith("_") or name.startswith("_deprecate"):
# normal APIs are ignored # normal APIs are ignored
continue continue
target_module = sys.modules["dgl._api_internal"] target_module = sys.modules["dgl._api_internal"]
...@@ -316,7 +339,8 @@ def _init_internal_api(): ...@@ -316,7 +339,8 @@ def _init_internal_api():
f = get_global_func(name) f = get_global_func(name)
ff = _get_api(f) ff = _get_api(f)
ff.__name__ = fname ff.__name__ = fname
ff.__doc__ = ("DGL PackedFunc %s. " % fname) ff.__doc__ = "DGL PackedFunc %s. " % fname
setattr(target_module, ff.__name__, ff) setattr(target_module, ff.__name__, ff)
_set_class_function(Function) _set_class_function(Function)
"""Library information.""" """Library information."""
from __future__ import absolute_import from __future__ import absolute_import
import sys
import os import os
import pathlib import pathlib
import sys
def find_lib_path(name=None, search_path=None, optional=False): def find_lib_path(name=None, search_path=None, optional=False):
...@@ -30,13 +31,21 @@ def find_lib_path(name=None, search_path=None, optional=False): ...@@ -30,13 +31,21 @@ def find_lib_path(name=None, search_path=None, optional=False):
dll_path = [] dll_path = []
if os.environ.get('DGL_LIBRARY_PATH', None): if os.environ.get("DGL_LIBRARY_PATH", None):
dll_path.append(os.environ['DGL_LIBRARY_PATH']) dll_path.append(os.environ["DGL_LIBRARY_PATH"])
if sys.platform.startswith('linux') and os.environ.get('LD_LIBRARY_PATH', None): if sys.platform.startswith("linux") and os.environ.get(
dll_path.extend([p.strip() for p in os.environ['LD_LIBRARY_PATH'].split(":")]) "LD_LIBRARY_PATH", None
elif sys.platform.startswith('darwin') and os.environ.get('DYLD_LIBRARY_PATH', None): ):
dll_path.extend([p.strip() for p in os.environ['DYLD_LIBRARY_PATH'].split(":")]) dll_path.extend(
[p.strip() for p in os.environ["LD_LIBRARY_PATH"].split(":")]
)
elif sys.platform.startswith("darwin") and os.environ.get(
"DYLD_LIBRARY_PATH", None
):
dll_path.extend(
[p.strip() for p in os.environ["DYLD_LIBRARY_PATH"].split(":")]
)
# Pip lib directory # Pip lib directory
dll_path.append(os.path.join(ffi_dir, "..")) dll_path.append(os.path.join(ffi_dir, ".."))
...@@ -54,17 +63,21 @@ def find_lib_path(name=None, search_path=None, optional=False): ...@@ -54,17 +63,21 @@ def find_lib_path(name=None, search_path=None, optional=False):
elif isinstance(search_path, str): elif isinstance(search_path, str):
dll_path.append(search_path) dll_path.append(search_path)
else: else:
raise ValueError("type(search_path)={} is invalid".format(type(search_path))) raise ValueError(
dll_path = [str(x.absolute()) if isinstance(x, pathlib.Path) "type(search_path)={} is invalid".format(type(search_path))
else os.path.abspath(x) for x in dll_path] )
dll_path = [
str(x.absolute()) if isinstance(x, pathlib.Path) else os.path.abspath(x)
for x in dll_path
]
if name is None: if name is None:
if sys.platform.startswith('win32'): if sys.platform.startswith("win32"):
name = ['libdgl.dll', 'dgl.dll'] name = ["libdgl.dll", "dgl.dll"]
elif sys.platform.startswith('darwin'): elif sys.platform.startswith("darwin"):
name = 'libdgl.dylib' name = "libdgl.dylib"
else: else:
name = 'libdgl.so' name = "libdgl.so"
if isinstance(name, str): if isinstance(name, str):
name = [name] name = [name]
...@@ -76,9 +89,11 @@ def find_lib_path(name=None, search_path=None, optional=False): ...@@ -76,9 +89,11 @@ def find_lib_path(name=None, search_path=None, optional=False):
lib_found = [p for p in lib_dll_path if os.path.isfile(p)] lib_found = [p for p in lib_dll_path if os.path.isfile(p)]
if not lib_found: if not lib_found:
message = ('Cannot find the files.\n' + message = (
'List of candidates:\n' + "Cannot find the files.\n"
str('\n'.join(lib_dll_path))) + "List of candidates:\n"
+ str("\n".join(lib_dll_path))
)
if not optional: if not optional:
raise RuntimeError(message) raise RuntimeError(message)
return None return None
......
...@@ -2,13 +2,20 @@ ...@@ -2,13 +2,20 @@
"""Runtime NDArray api""" """Runtime NDArray api"""
from __future__ import absolute_import from __future__ import absolute_import
import sys
import ctypes import ctypes
import sys
import numpy as np import numpy as np
from .base import _LIB, check_call, c_array, string_types, _FFI_MODE, c_str
from .runtime_ctypes import DGLDataType, DGLContext, DGLArray, DGLArrayHandle
from .runtime_ctypes import TypeCode, dgl_shape_index_t
from .base import _FFI_MODE, _LIB, c_array, c_str, check_call, string_types
from .runtime_ctypes import (
DGLArray,
DGLArrayHandle,
DGLContext,
DGLDataType,
TypeCode,
dgl_shape_index_t,
)
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
...@@ -17,15 +24,31 @@ try: ...@@ -17,15 +24,31 @@ try:
if _FFI_MODE == "ctypes": if _FFI_MODE == "ctypes":
raise ImportError() raise ImportError()
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
from ._cy3.core import NDArrayBase as _NDArrayBase from ._cy3.core import NDArrayBase as _NDArrayBase
from ._cy3.core import (
_from_dlpack,
_make_array,
_reg_extension,
_set_class_ndarray,
)
else: else:
from ._cy2.core import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
from ._cy2.core import NDArrayBase as _NDArrayBase from ._cy2.core import NDArrayBase as _NDArrayBase
from ._cy2.core import (
_from_dlpack,
_make_array,
_reg_extension,
_set_class_ndarray,
)
except IMPORT_EXCEPT: except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
from ._ctypes.ndarray import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
from ._ctypes.ndarray import (
_from_dlpack,
_make_array,
_reg_extension,
_set_class_ndarray,
)
def context(dev_type, dev_id=0): def context(dev_type, dev_id=0):
"""Construct a DGL context with given device type and id. """Construct a DGL context with given device type and id.
...@@ -63,10 +86,9 @@ def context(dev_type, dev_id=0): ...@@ -63,10 +86,9 @@ def context(dev_type, dev_id=0):
def numpyasarray(np_data): def numpyasarray(np_data):
"""Return a DGLArray representation of a numpy array. """Return a DGLArray representation of a numpy array."""
"""
data = np_data data = np_data
assert data.flags['C_CONTIGUOUS'] assert data.flags["C_CONTIGUOUS"]
arr = DGLArray() arr = DGLArray()
shape = c_array(dgl_shape_index_t, data.shape) shape = c_array(dgl_shape_index_t, data.shape)
arr.data = data.ctypes.data_as(ctypes.c_void_p) arr.data = data.ctypes.data_as(ctypes.c_void_p)
...@@ -102,14 +124,18 @@ def empty(shape, dtype="float32", ctx=context(1, 0)): ...@@ -102,14 +124,18 @@ def empty(shape, dtype="float32", ctx=context(1, 0)):
ndim = ctypes.c_int(len(shape)) ndim = ctypes.c_int(len(shape))
handle = DGLArrayHandle() handle = DGLArrayHandle()
dtype = DGLDataType(dtype) dtype = DGLDataType(dtype)
check_call(_LIB.DGLArrayAlloc( check_call(
shape, ndim, _LIB.DGLArrayAlloc(
ctypes.c_int(dtype.type_code), shape,
ctypes.c_int(dtype.bits), ndim,
ctypes.c_int(dtype.lanes), ctypes.c_int(dtype.type_code),
ctx.device_type, ctypes.c_int(dtype.bits),
ctx.device_id, ctypes.c_int(dtype.lanes),
ctypes.byref(handle))) ctx.device_type,
ctx.device_id,
ctypes.byref(handle),
)
)
return _make_array(handle, False) return _make_array(handle, False)
...@@ -135,18 +161,23 @@ def empty_shared_mem(name, is_create, shape, dtype="float32"): ...@@ -135,18 +161,23 @@ def empty_shared_mem(name, is_create, shape, dtype="float32"):
arr : dgl.nd.NDArray arr : dgl.nd.NDArray
The array dgl supported. The array dgl supported.
""" """
name = ctypes.c_char_p(name.encode('utf-8')) name = ctypes.c_char_p(name.encode("utf-8"))
shape = c_array(dgl_shape_index_t, shape) shape = c_array(dgl_shape_index_t, shape)
ndim = ctypes.c_int(len(shape)) ndim = ctypes.c_int(len(shape))
handle = DGLArrayHandle() handle = DGLArrayHandle()
dtype = DGLDataType(dtype) dtype = DGLDataType(dtype)
check_call(_LIB.DGLArrayAllocSharedMem( check_call(
name, shape, ndim, _LIB.DGLArrayAllocSharedMem(
ctypes.c_int(dtype.type_code), name,
ctypes.c_int(dtype.bits), shape,
ctypes.c_int(dtype.lanes), ndim,
is_create, ctypes.c_int(dtype.type_code),
ctypes.byref(handle))) ctypes.c_int(dtype.bits),
ctypes.c_int(dtype.lanes),
is_create,
ctypes.byref(handle),
)
)
return _make_array(handle, False) return _make_array(handle, False)
...@@ -171,10 +202,14 @@ def from_dlpack(dltensor): ...@@ -171,10 +202,14 @@ def from_dlpack(dltensor):
class NDArrayBase(_NDArrayBase): class NDArrayBase(_NDArrayBase):
"""A simple Device/CPU Array object in runtime.""" """A simple Device/CPU Array object in runtime."""
@property @property
def shape(self): def shape(self):
"""Shape of this array""" """Shape of this array"""
return tuple(self.handle.contents.shape[i] for i in range(self.handle.contents.ndim)) return tuple(
self.handle.contents.shape[i]
for i in range(self.handle.contents.ndim)
)
@property @property
def dtype(self): def dtype(self):
...@@ -219,17 +254,19 @@ class NDArrayBase(_NDArrayBase): ...@@ -219,17 +254,19 @@ class NDArrayBase(_NDArrayBase):
def __setitem__(self, in_slice, value): def __setitem__(self, in_slice, value):
"""Set ndarray value""" """Set ndarray value"""
if (not isinstance(in_slice, slice) or if (
in_slice.start is not None not isinstance(in_slice, slice)
or in_slice.stop is not None): or in_slice.start is not None
raise ValueError('Array only support set from numpy array') or in_slice.stop is not None
):
raise ValueError("Array only support set from numpy array")
if isinstance(value, NDArrayBase): if isinstance(value, NDArrayBase):
if value.handle is not self.handle: if value.handle is not self.handle:
value.copyto(self) value.copyto(self)
elif isinstance(value, (np.ndarray, np.generic)): elif isinstance(value, (np.ndarray, np.generic)):
self.copyfrom(value) self.copyfrom(value)
else: else:
raise TypeError('type %s not supported' % str(type(value))) raise TypeError("type %s not supported" % str(type(value)))
def copyfrom(self, source_array): def copyfrom(self, source_array):
"""Perform a synchronized copy from the array. """Perform a synchronized copy from the array.
...@@ -252,8 +289,10 @@ class NDArrayBase(_NDArrayBase): ...@@ -252,8 +289,10 @@ class NDArrayBase(_NDArrayBase):
try: try:
source_array = np.asarray(source_array, dtype=self.dtype) source_array = np.asarray(source_array, dtype=self.dtype)
except: except:
raise TypeError('array must be an array_like data,' + raise TypeError(
'type %s is not supported' % str(type(source_array))) "array must be an array_like data,"
+ "type %s is not supported" % str(type(source_array))
)
t = DGLDataType(self.dtype) t = DGLDataType(self.dtype)
shape, dtype = self.shape, self.dtype shape, dtype = self.shape, self.dtype
if t.lanes > 1: if t.lanes > 1:
...@@ -262,12 +301,17 @@ class NDArrayBase(_NDArrayBase): ...@@ -262,12 +301,17 @@ class NDArrayBase(_NDArrayBase):
dtype = str(t) dtype = str(t)
if source_array.shape != shape: if source_array.shape != shape:
raise ValueError("array shape do not match the shape of NDArray {0} vs {1}".format( raise ValueError(
source_array.shape, shape)) "array shape do not match the shape of NDArray {0} vs {1}".format(
source_array.shape, shape
)
)
source_array = np.ascontiguousarray(source_array, dtype=dtype) source_array = np.ascontiguousarray(source_array, dtype=dtype)
assert source_array.flags['C_CONTIGUOUS'] assert source_array.flags["C_CONTIGUOUS"]
data = source_array.ctypes.data_as(ctypes.c_void_p) data = source_array.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize) nbytes = ctypes.c_size_t(
source_array.size * source_array.dtype.itemsize
)
check_call(_LIB.DGLArrayCopyFromBytes(self.handle, data, nbytes)) check_call(_LIB.DGLArrayCopyFromBytes(self.handle, data, nbytes))
return self return self
...@@ -293,7 +337,7 @@ class NDArrayBase(_NDArrayBase): ...@@ -293,7 +337,7 @@ class NDArrayBase(_NDArrayBase):
t.lanes = 1 t.lanes = 1
dtype = str(t) dtype = str(t)
np_arr = np.empty(shape, dtype=dtype) np_arr = np.empty(shape, dtype=dtype)
assert np_arr.flags['C_CONTIGUOUS'] assert np_arr.flags["C_CONTIGUOUS"]
data = np_arr.ctypes.data_as(ctypes.c_void_p) data = np_arr.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize) nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize)
check_call(_LIB.DGLArrayCopyToBytes(self.handle, data, nbytes)) check_call(_LIB.DGLArrayCopyToBytes(self.handle, data, nbytes))
...@@ -310,20 +354,17 @@ class NDArrayBase(_NDArrayBase): ...@@ -310,20 +354,17 @@ class NDArrayBase(_NDArrayBase):
if isinstance(target, DGLContext): if isinstance(target, DGLContext):
target = empty(self.shape, self.dtype, target) target = empty(self.shape, self.dtype, target)
if isinstance(target, NDArrayBase): if isinstance(target, NDArrayBase):
check_call(_LIB.DGLArrayCopyFromTo( check_call(_LIB.DGLArrayCopyFromTo(self.handle, target.handle))
self.handle, target.handle))
else: else:
raise ValueError("Unsupported target type %s" % str(type(target))) raise ValueError("Unsupported target type %s" % str(type(target)))
return target return target
def pin_memory_(self): def pin_memory_(self):
"""Pin host memory and map into GPU address space (in-place) """Pin host memory and map into GPU address space (in-place)"""
"""
check_call(_LIB.DGLArrayPinData(self.handle)) check_call(_LIB.DGLArrayPinData(self.handle))
def unpin_memory_(self): def unpin_memory_(self):
"""Unpin host memory pinned by pin_memory_() """Unpin host memory pinned by pin_memory_()"""
"""
check_call(_LIB.DGLArrayUnpinData(self.handle)) check_call(_LIB.DGLArrayUnpinData(self.handle))
def record_stream(self, stream): def record_stream(self, stream):
...@@ -340,6 +381,7 @@ class NDArrayBase(_NDArrayBase): ...@@ -340,6 +381,7 @@ class NDArrayBase(_NDArrayBase):
""" """
check_call(_LIB.DGLArrayRecordStream(self.handle, stream)) check_call(_LIB.DGLArrayRecordStream(self.handle, stream))
def free_extension_handle(handle, type_code): def free_extension_handle(handle, type_code):
"""Free c++ extension type handle """Free c++ extension type handle
...@@ -353,6 +395,7 @@ def free_extension_handle(handle, type_code): ...@@ -353,6 +395,7 @@ def free_extension_handle(handle, type_code):
""" """
check_call(_LIB.DGLExtTypeFree(handle, ctypes.c_int(type_code))) check_call(_LIB.DGLExtTypeFree(handle, ctypes.c_int(type_code)))
def register_extension(cls, fcreate=None): def register_extension(cls, fcreate=None):
"""Register a extension class to DGL. """Register a extension class to DGL.
...@@ -398,6 +441,8 @@ def register_extension(cls, fcreate=None): ...@@ -398,6 +441,8 @@ def register_extension(cls, fcreate=None):
return self.handle.value return self.handle.value
""" """
if fcreate and cls._dgl_tcode < TypeCode.EXT_BEGIN: if fcreate and cls._dgl_tcode < TypeCode.EXT_BEGIN:
raise ValueError("Cannot register create when extension tcode is same as buildin") raise ValueError(
"Cannot register create when extension tcode is same as buildin"
)
_reg_extension(cls, fcreate) _reg_extension(cls, fcreate)
return cls return cls
...@@ -4,23 +4,27 @@ from __future__ import absolute_import ...@@ -4,23 +4,27 @@ from __future__ import absolute_import
import ctypes import ctypes
import sys import sys
from .. import _api_internal from .. import _api_internal
from .base import _FFI_MODE, _LIB, c_str, check_call, py_str
from .object_generic import ObjectGeneric, convert_to_object from .object_generic import ObjectGeneric, convert_to_object
from .base import _LIB, check_call, c_str, py_str, _FFI_MODE
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" \ # pylint: disable=invalid-name
else ImportError # pylint: disable=invalid-name IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
try: try:
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes": if _FFI_MODE == "ctypes":
raise ImportError() raise ImportError()
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
from ._cy3.core import _register_object, ObjectBase as _ObjectBase from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import _register_object
else: else:
from ._cy2.core import _register_object, ObjectBase as _ObjectBase from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import _register_object
except IMPORT_EXCEPT: except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
from ._ctypes.object import _register_object, ObjectBase as _ObjectBase from ._ctypes.object import ObjectBase as _ObjectBase
from ._ctypes.object import _register_object
def _new_object(cls): def _new_object(cls):
...@@ -36,11 +40,15 @@ class ObjectBase(_ObjectBase): ...@@ -36,11 +40,15 @@ class ObjectBase(_ObjectBase):
Note that the same handle **CANNOT** be shared across multiple ObjectBase instances. Note that the same handle **CANNOT** be shared across multiple ObjectBase instances.
""" """
def __dir__(self): def __dir__(self):
plist = ctypes.POINTER(ctypes.c_char_p)() plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint() size = ctypes.c_uint()
check_call(_LIB.DGLObjectListAttrNames( check_call(
self.handle, ctypes.byref(size), ctypes.byref(plist))) _LIB.DGLObjectListAttrNames(
self.handle, ctypes.byref(size), ctypes.byref(plist)
)
)
names = [] names = []
for i in range(size.value): for i in range(size.value):
names.append(py_str(plist[i])) names.append(py_str(plist[i]))
...@@ -57,7 +65,7 @@ class ObjectBase(_ObjectBase): ...@@ -57,7 +65,7 @@ class ObjectBase(_ObjectBase):
def __reduce__(self): def __reduce__(self):
cls = type(self) cls = type(self)
return (_new_object, (cls, ), self.__getstate__()) return (_new_object, (cls,), self.__getstate__())
def __getstate__(self): def __getstate__(self):
# TODO(minjie): TVM assumes that a Node (Object in DGL) can be serialized # TODO(minjie): TVM assumes that a Node (Object in DGL) can be serialized
...@@ -100,7 +108,9 @@ def register_object(type_key=None): ...@@ -100,7 +108,9 @@ def register_object(type_key=None):
def register(cls): def register(cls):
"""internal register function""" """internal register function"""
tindex = ctypes.c_int() tindex = ctypes.c_int()
ret = _LIB.DGLObjectTypeKey2Index(c_str(object_name), ctypes.byref(tindex)) ret = _LIB.DGLObjectTypeKey2Index(
c_str(object_name), ctypes.byref(tindex)
)
if ret == 0: if ret == 0:
_register_object(tindex.value, cls) _register_object(tindex.value, cls)
return cls return cls
......
...@@ -2,23 +2,28 @@ ...@@ -2,23 +2,28 @@
# pylint: disable=unused-import # pylint: disable=unused-import
from __future__ import absolute_import from __future__ import absolute_import
from numbers import Number, Integral from numbers import Integral, Number
from .. import _api_internal from .. import _api_internal
from .base import string_types from .base import string_types
# Object base class # Object base class
_CLASS_OBJECT_BASE = None _CLASS_OBJECT_BASE = None
def _set_class_object_base(cls): def _set_class_object_base(cls):
global _CLASS_OBJECT_BASE global _CLASS_OBJECT_BASE
_CLASS_OBJECT_BASE = cls _CLASS_OBJECT_BASE = cls
class ObjectGeneric(object): class ObjectGeneric(object):
"""Base class for all classes that can be converted to object.""" """Base class for all classes that can be converted to object."""
def asobject(self): def asobject(self):
"""Convert value to object""" """Convert value to object"""
raise NotImplementedError() raise NotImplementedError()
def convert_to_object(value): def convert_to_object(value):
"""Convert a python value to corresponding object type. """Convert a python value to corresponding object type.
...@@ -40,9 +45,12 @@ def convert_to_object(value): ...@@ -40,9 +45,12 @@ def convert_to_object(value):
if isinstance(value, dict): if isinstance(value, dict):
vlist = [] vlist = []
for item in value.items(): for item in value.items():
if (not isinstance(item[0], _CLASS_OBJECT_BASE) and if not isinstance(item[0], _CLASS_OBJECT_BASE) and not isinstance(
not isinstance(item[0], string_types)): item[0], string_types
raise ValueError("key of map must already been a container type") ):
raise ValueError(
"key of map must already been a container type"
)
vlist.append(item[0]) vlist.append(item[0])
vlist.append(convert_to_object(item[1])) vlist.append(convert_to_object(item[1]))
return _api_internal._Map(*vlist) return _api_internal._Map(*vlist)
......
...@@ -4,14 +4,18 @@ from __future__ import absolute_import ...@@ -4,14 +4,18 @@ from __future__ import absolute_import
import ctypes import ctypes
import json import json
import numpy as np import numpy as np
from .base import _LIB, check_call
from .. import _api_internal from .. import _api_internal
from .base import _LIB, check_call
dgl_shape_index_t = ctypes.c_int64 dgl_shape_index_t = ctypes.c_int64
class TypeCode(object): class TypeCode(object):
"""Type code used in API calls""" """Type code used in API calls"""
INT = 0 INT = 0
UINT = 1 UINT = 1
FLOAT = 2 FLOAT = 2
...@@ -28,22 +32,25 @@ class TypeCode(object): ...@@ -28,22 +32,25 @@ class TypeCode(object):
NDARRAY_CONTAINER = 13 NDARRAY_CONTAINER = 13
EXT_BEGIN = 15 EXT_BEGIN = 15
class DGLByteArray(ctypes.Structure): class DGLByteArray(ctypes.Structure):
"""Temp data structure for byte array.""" """Temp data structure for byte array."""
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)] _fields_ = [
("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t),
]
class DGLDataType(ctypes.Structure): class DGLDataType(ctypes.Structure):
"""DGL datatype structure""" """DGL datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8), _fields_ = [
("lanes", ctypes.c_uint16)] ("type_code", ctypes.c_uint8),
CODE2STR = { ("bits", ctypes.c_uint8),
0 : 'int', ("lanes", ctypes.c_uint16),
1 : 'uint', ]
2 : 'float', CODE2STR = {0: "int", 1: "uint", 2: "float", 4: "handle"}
4 : 'handle'
}
_cache = {} _cache = {}
def __new__(cls, type_str): def __new__(cls, type_str):
...@@ -90,50 +97,54 @@ class DGLDataType(ctypes.Structure): ...@@ -90,50 +97,54 @@ class DGLDataType(ctypes.Structure):
return x return x
def __eq__(self, other): def __eq__(self, other):
return (self.bits == other.bits and return (
self.type_code == other.type_code and self.bits == other.bits
self.lanes == other.lanes) and self.type_code == other.type_code
and self.lanes == other.lanes
)
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
RPC_SESS_MASK = 128 RPC_SESS_MASK = 128
class DGLContext(ctypes.Structure): class DGLContext(ctypes.Structure):
"""DGL context strucure.""" """DGL context strucure."""
_fields_ = [("device_type", ctypes.c_int),
("device_id", ctypes.c_int)] _fields_ = [("device_type", ctypes.c_int), ("device_id", ctypes.c_int)]
MASK2STR = { MASK2STR = {
1 : 'cpu', 1: "cpu",
2 : 'gpu', 2: "gpu",
4 : 'opencl', 4: "opencl",
5 : 'aocl', 5: "aocl",
6 : 'sdaccel', 6: "sdaccel",
7 : 'vulkan', 7: "vulkan",
8 : 'metal', 8: "metal",
9 : 'vpi', 9: "vpi",
10: 'rocm', 10: "rocm",
11: 'opengl', 11: "opengl",
12: 'ext_dev', 12: "ext_dev",
} }
STR2MASK = { STR2MASK = {
'llvm': 1, "llvm": 1,
'stackvm': 1, "stackvm": 1,
'cpu': 1, "cpu": 1,
'gpu': 2, "gpu": 2,
'cuda': 2, "cuda": 2,
'nvptx': 2, "nvptx": 2,
'cl': 4, "cl": 4,
'opencl': 4, "opencl": 4,
'aocl' : 5, "aocl": 5,
'aocl_sw_emu' : 5, "aocl_sw_emu": 5,
'sdaccel': 6, "sdaccel": 6,
'vulkan': 7, "vulkan": 7,
'metal': 8, "metal": 8,
'vpi': 9, "vpi": 9,
'rocm': 10, "rocm": 10,
'opengl': 11, "opengl": 11,
'ext_dev': 12, "ext_dev": 12,
} }
_cache = {} _cache = {}
...@@ -155,26 +166,25 @@ class DGLContext(ctypes.Structure): ...@@ -155,26 +166,25 @@ class DGLContext(ctypes.Structure):
@property @property
def exist(self): def exist(self):
"""Whether this device exist.""" """Whether this device exist."""
return _api_internal._GetDeviceAttr( return (
self.device_type, self.device_id, 0) != 0 _api_internal._GetDeviceAttr(self.device_type, self.device_id, 0)
!= 0
)
@property @property
def max_threads_per_block(self): def max_threads_per_block(self):
"""Maximum number of threads on each block.""" """Maximum number of threads on each block."""
return _api_internal._GetDeviceAttr( return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 1)
self.device_type, self.device_id, 1)
@property @property
def warp_size(self): def warp_size(self):
"""Number of threads that executes in concurrent.""" """Number of threads that executes in concurrent."""
return _api_internal._GetDeviceAttr( return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 2)
self.device_type, self.device_id, 2)
@property @property
def max_shared_memory_per_block(self): def max_shared_memory_per_block(self):
"""Total amount of shared memory per block in bytes.""" """Total amount of shared memory per block in bytes."""
return _api_internal._GetDeviceAttr( return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 3)
self.device_type, self.device_id, 3)
@property @property
def compute_version(self): def compute_version(self):
...@@ -187,26 +197,22 @@ class DGLContext(ctypes.Structure): ...@@ -187,26 +197,22 @@ class DGLContext(ctypes.Structure):
version : str version : str
The version string in `major.minor` format. The version string in `major.minor` format.
""" """
return _api_internal._GetDeviceAttr( return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 4)
self.device_type, self.device_id, 4)
@property @property
def device_name(self): def device_name(self):
"""Return the string name of device.""" """Return the string name of device."""
return _api_internal._GetDeviceAttr( return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 5)
self.device_type, self.device_id, 5)
@property @property
def max_clock_rate(self): def max_clock_rate(self):
"""Return the max clock frequency of device.""" """Return the max clock frequency of device."""
return _api_internal._GetDeviceAttr( return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 6)
self.device_type, self.device_id, 6)
@property @property
def multi_processor_count(self): def multi_processor_count(self):
"""Return the number of compute units of device.""" """Return the number of compute units of device."""
return _api_internal._GetDeviceAttr( return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 7)
self.device_type, self.device_id, 7)
@property @property
def max_thread_dimensions(self): def max_thread_dimensions(self):
...@@ -217,17 +223,20 @@ class DGLContext(ctypes.Structure): ...@@ -217,17 +223,20 @@ class DGLContext(ctypes.Structure):
dims: List of int dims: List of int
The maximum length of threadIdx.x, threadIdx.y, threadIdx.z The maximum length of threadIdx.x, threadIdx.y, threadIdx.z
""" """
return json.loads(_api_internal._GetDeviceAttr( return json.loads(
self.device_type, self.device_id, 8)) _api_internal._GetDeviceAttr(self.device_type, self.device_id, 8)
)
def sync(self): def sync(self):
"""Synchronize until jobs finished at the context.""" """Synchronize until jobs finished at the context."""
check_call(_LIB.DGLSynchronize(self.device_type, self.device_id, None)) check_call(_LIB.DGLSynchronize(self.device_type, self.device_id, None))
def __eq__(self, other): def __eq__(self, other):
return (isinstance(other, DGLContext) and return (
self.device_id == other.device_id and isinstance(other, DGLContext)
self.device_type == other.device_type) and self.device_id == other.device_id
and self.device_type == other.device_type
)
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
...@@ -237,9 +246,14 @@ class DGLContext(ctypes.Structure): ...@@ -237,9 +246,14 @@ class DGLContext(ctypes.Structure):
tbl_id = self.device_type / RPC_SESS_MASK - 1 tbl_id = self.device_type / RPC_SESS_MASK - 1
dev_type = self.device_type % RPC_SESS_MASK dev_type = self.device_type % RPC_SESS_MASK
return "remote[%d]:%s(%d)" % ( return "remote[%d]:%s(%d)" % (
tbl_id, DGLContext.MASK2STR[dev_type], self.device_id) tbl_id,
DGLContext.MASK2STR[dev_type],
self.device_id,
)
return "%s(%d)" % ( return "%s(%d)" % (
DGLContext.MASK2STR[self.device_type], self.device_id) DGLContext.MASK2STR[self.device_type],
self.device_id,
)
def __hash__(self): def __hash__(self):
return hash((self.device_type, self.device_id)) return hash((self.device_type, self.device_id))
...@@ -247,13 +261,17 @@ class DGLContext(ctypes.Structure): ...@@ -247,13 +261,17 @@ class DGLContext(ctypes.Structure):
class DGLArray(ctypes.Structure): class DGLArray(ctypes.Structure):
"""DGLValue in C API""" """DGLValue in C API"""
_fields_ = [("data", ctypes.c_void_p),
("ctx", DGLContext), _fields_ = [
("ndim", ctypes.c_int), ("data", ctypes.c_void_p),
("dtype", DGLDataType), ("ctx", DGLContext),
("shape", ctypes.POINTER(dgl_shape_index_t)), ("ndim", ctypes.c_int),
("strides", ctypes.POINTER(dgl_shape_index_t)), ("dtype", DGLDataType),
("byte_offset", ctypes.c_uint64)] ("shape", ctypes.POINTER(dgl_shape_index_t)),
("strides", ctypes.POINTER(dgl_shape_index_t)),
("byte_offset", ctypes.c_uint64),
]
DGLArrayHandle = ctypes.POINTER(DGLArray) DGLArrayHandle = ctypes.POINTER(DGLArray)
......
...@@ -5,13 +5,13 @@ For applications, please use PyTorch's stream management, of which DGL is aware. ...@@ -5,13 +5,13 @@ For applications, please use PyTorch's stream management, of which DGL is aware.
from __future__ import absolute_import from __future__ import absolute_import
import ctypes import ctypes
from .base import _LIB, check_call, _FFI_MODE
from .base import _FFI_MODE, _LIB, check_call
from .runtime_ctypes import DGLStreamHandle from .runtime_ctypes import DGLStreamHandle
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
def to_dgl_stream_handle(cuda_stream): def to_dgl_stream_handle(cuda_stream):
""" Convert torch.cuda.Stream to DGL stream handle """Convert torch.cuda.Stream to DGL stream handle
Parameters Parameters
---------- ----------
...@@ -24,6 +24,7 @@ def to_dgl_stream_handle(cuda_stream): ...@@ -24,6 +24,7 @@ def to_dgl_stream_handle(cuda_stream):
""" """
return ctypes.c_void_p(cuda_stream.cuda_stream) return ctypes.c_void_p(cuda_stream.cuda_stream)
def _dgl_get_stream(ctx): def _dgl_get_stream(ctx):
"""Get the current CUDA stream of the given DGL context. """Get the current CUDA stream of the given DGL context.
...@@ -37,6 +38,9 @@ def _dgl_get_stream(ctx): ...@@ -37,6 +38,9 @@ def _dgl_get_stream(ctx):
DGLStreamHandle of the current CUDA stream. DGLStreamHandle of the current CUDA stream.
""" """
current_cuda_stream = DGLStreamHandle() current_cuda_stream = DGLStreamHandle()
check_call(_LIB.DGLGetStream( check_call(
ctx.device_type, ctx.device_id, ctypes.byref(current_cuda_stream))) _LIB.DGLGetStream(
ctx.device_type, ctx.device_id, ctypes.byref(current_cuda_stream)
)
)
return current_cuda_stream return current_cuda_stream
from __future__ import absolute_import from __future__ import absolute_import
import sys
import os
import json
import importlib import importlib
import json
import logging import logging
import os
import sys
from . import backend from . import backend
from .set_default_backend import set_default_backend from .set_default_backend import set_default_backend
...@@ -13,13 +13,18 @@ _enabled_apis = set() ...@@ -13,13 +13,18 @@ _enabled_apis = set()
logger = logging.getLogger("dgl-core") logger = logging.getLogger("dgl-core")
def _gen_missing_api(api, mod_name): def _gen_missing_api(api, mod_name):
def _missing_api(*args, **kwargs): def _missing_api(*args, **kwargs):
raise ImportError('API "%s" is not supported by backend "%s".' raise ImportError(
' You can switch to other backends by setting' 'API "%s" is not supported by backend "%s".'
' the DGLBACKEND environment.' % (api, mod_name)) " You can switch to other backends by setting"
" the DGLBACKEND environment." % (api, mod_name)
)
return _missing_api return _missing_api
def load_backend(mod_name): def load_backend(mod_name):
# Load backend does four things: # Load backend does four things:
# (1) Import backend framework (PyTorch, MXNet, Tensorflow, etc.) # (1) Import backend framework (PyTorch, MXNet, Tensorflow, etc.)
...@@ -28,40 +33,46 @@ def load_backend(mod_name): ...@@ -28,40 +33,46 @@ def load_backend(mod_name):
# (3) Sets up the tensoradapter library path. # (3) Sets up the tensoradapter library path.
# (4) Import the Python wrappers of the backend framework. DGL does this last because # (4) Import the Python wrappers of the backend framework. DGL does this last because
# it already depends on both the backend framework and the DGL C library. # it already depends on both the backend framework and the DGL C library.
if mod_name == 'pytorch': if mod_name == "pytorch":
import torch import torch
mod = torch mod = torch
elif mod_name == 'mxnet': elif mod_name == "mxnet":
import mxnet import mxnet
mod = mxnet mod = mxnet
elif mod_name == 'tensorflow': elif mod_name == "tensorflow":
import tensorflow import tensorflow
mod = tensorflow mod = tensorflow
else: else:
raise NotImplementedError('Unsupported backend: %s' % mod_name) raise NotImplementedError("Unsupported backend: %s" % mod_name)
from .._ffi.base import load_tensor_adapter # imports DGL C library
from .._ffi.base import load_tensor_adapter # imports DGL C library
version = mod.__version__ version = mod.__version__
load_tensor_adapter(mod_name, version) load_tensor_adapter(mod_name, version)
logger.debug('Using backend: %s' % mod_name) logger.debug("Using backend: %s" % mod_name)
mod = importlib.import_module('.%s' % mod_name, __name__) mod = importlib.import_module(".%s" % mod_name, __name__)
thismod = sys.modules[__name__] thismod = sys.modules[__name__]
for api in backend.__dict__.keys(): for api in backend.__dict__.keys():
if api.startswith('__'): if api.startswith("__"):
# ignore python builtin attributes # ignore python builtin attributes
continue continue
if api == 'data_type_dict': if api == "data_type_dict":
# load data type # load data type
if api not in mod.__dict__: if api not in mod.__dict__:
raise ImportError('API "data_type_dict" is required but missing for' raise ImportError(
' backend "%s".' % (mod_name)) 'API "data_type_dict" is required but missing for'
' backend "%s".' % (mod_name)
)
data_type_dict = mod.__dict__[api]() data_type_dict = mod.__dict__[api]()
for name, dtype in data_type_dict.items(): for name, dtype in data_type_dict.items():
setattr(thismod, name, dtype) setattr(thismod, name, dtype)
# override data type dict function # override data type dict function
setattr(thismod, 'data_type_dict', data_type_dict) setattr(thismod, "data_type_dict", data_type_dict)
# for data types with aliases, treat the first listed type as # for data types with aliases, treat the first listed type as
# the true one # the true one
...@@ -69,11 +80,9 @@ def load_backend(mod_name): ...@@ -69,11 +80,9 @@ def load_backend(mod_name):
for k, v in data_type_dict.items(): for k, v in data_type_dict.items():
if not v in rev_data_type_dict.keys(): if not v in rev_data_type_dict.keys():
rev_data_type_dict[v] = k rev_data_type_dict[v] = k
setattr(thismod, setattr(thismod, "reverse_data_type_dict", rev_data_type_dict)
'reverse_data_type_dict',
rev_data_type_dict)
# log backend name # log backend name
setattr(thismod, 'backend_name', mod_name) setattr(thismod, "backend_name", mod_name)
else: else:
# load functions # load functions
if api in mod.__dict__: if api in mod.__dict__:
...@@ -82,28 +91,32 @@ def load_backend(mod_name): ...@@ -82,28 +91,32 @@ def load_backend(mod_name):
else: else:
setattr(thismod, api, _gen_missing_api(api, mod_name)) setattr(thismod, api, _gen_missing_api(api, mod_name))
def get_preferred_backend(): def get_preferred_backend():
default_dir = None default_dir = None
if "DGLDEFAULTDIR" in os.environ: if "DGLDEFAULTDIR" in os.environ:
default_dir = os.getenv('DGLDEFAULTDIR') default_dir = os.getenv("DGLDEFAULTDIR")
else: else:
default_dir = os.path.join(os.path.expanduser('~'), '.dgl') default_dir = os.path.join(os.path.expanduser("~"), ".dgl")
config_path = os.path.join(default_dir, 'config.json') config_path = os.path.join(default_dir, "config.json")
backend_name = None backend_name = None
if "DGLBACKEND" in os.environ: if "DGLBACKEND" in os.environ:
backend_name = os.getenv('DGLBACKEND') backend_name = os.getenv("DGLBACKEND")
elif os.path.exists(config_path): elif os.path.exists(config_path):
with open(config_path, "r") as config_file: with open(config_path, "r") as config_file:
config_dict = json.load(config_file) config_dict = json.load(config_file)
backend_name = config_dict.get('backend', '').lower() backend_name = config_dict.get("backend", "").lower()
if (backend_name in ['tensorflow', 'mxnet', 'pytorch']): if backend_name in ["tensorflow", "mxnet", "pytorch"]:
return backend_name return backend_name
else: else:
print("DGL backend not selected or invalid. " print(
"Assuming PyTorch for now.", file=sys.stderr) "DGL backend not selected or invalid. "
set_default_backend(default_dir, 'pytorch') "Assuming PyTorch for now.",
return 'pytorch' file=sys.stderr,
)
set_default_backend(default_dir, "pytorch")
return "pytorch"
load_backend(get_preferred_backend()) load_backend(get_preferred_backend())
...@@ -124,8 +137,10 @@ def is_enabled(api): ...@@ -124,8 +137,10 @@ def is_enabled(api):
""" """
return api in _enabled_apis return api in _enabled_apis
def to_dgl_nd(data): def to_dgl_nd(data):
return zerocopy_to_dgl_ndarray(data) return zerocopy_to_dgl_ndarray(data)
def from_dgl_nd(data): def from_dgl_nd(data):
return zerocopy_from_dgl_ndarray(data) return zerocopy_from_dgl_ndarray(data)
...@@ -16,6 +16,7 @@ that returns whether the interface is supported by the framework or not. ...@@ -16,6 +16,7 @@ that returns whether the interface is supported by the framework or not.
############################################################################### ###############################################################################
# Tensor, data type and context interfaces # Tensor, data type and context interfaces
def data_type_dict(): def data_type_dict():
"""Returns a dictionary from data type string to the data type. """Returns a dictionary from data type string to the data type.
...@@ -52,10 +53,12 @@ def data_type_dict(): ...@@ -52,10 +53,12 @@ def data_type_dict():
""" """
pass pass
def cpu(): def cpu():
"""Return a context object for CPU device.""" """Return a context object for CPU device."""
pass pass
def tensor(data, dtype=None): def tensor(data, dtype=None):
"""Create a tensor given the data and data type. """Create a tensor given the data and data type.
...@@ -81,6 +84,7 @@ def tensor(data, dtype=None): ...@@ -81,6 +84,7 @@ def tensor(data, dtype=None):
""" """
pass pass
def as_scalar(data): def as_scalar(data):
"""Returns a scalar whose value is copied from this array. """Returns a scalar whose value is copied from this array.
...@@ -96,6 +100,7 @@ def as_scalar(data): ...@@ -96,6 +100,7 @@ def as_scalar(data):
""" """
pass pass
def get_preferred_sparse_format(): def get_preferred_sparse_format():
"""Get the preferred sparse matrix format supported by the backend. """Get the preferred sparse matrix format supported by the backend.
...@@ -109,6 +114,7 @@ def get_preferred_sparse_format(): ...@@ -109,6 +114,7 @@ def get_preferred_sparse_format():
""" """
pass pass
def sparse_matrix(data, index, shape, force_format=False): def sparse_matrix(data, index, shape, force_format=False):
"""Create a sparse matrix. """Create a sparse matrix.
...@@ -146,6 +152,7 @@ def sparse_matrix(data, index, shape, force_format=False): ...@@ -146,6 +152,7 @@ def sparse_matrix(data, index, shape, force_format=False):
""" """
pass pass
def sparse_matrix_indices(spmat): def sparse_matrix_indices(spmat):
"""Return the indices of the given sparse matrix. """Return the indices of the given sparse matrix.
...@@ -169,10 +176,12 @@ def sparse_matrix_indices(spmat): ...@@ -169,10 +176,12 @@ def sparse_matrix_indices(spmat):
""" """
pass pass
def is_tensor(obj): def is_tensor(obj):
"""Returns true if the given object is a framework-specific tensor.""" """Returns true if the given object is a framework-specific tensor."""
pass pass
def shape(input): def shape(input):
"""Return the shape of the tensor. """Return the shape of the tensor.
...@@ -188,6 +197,7 @@ def shape(input): ...@@ -188,6 +197,7 @@ def shape(input):
""" """
pass pass
def dtype(input): def dtype(input):
"""Return the data type of the tensor. """Return the data type of the tensor.
...@@ -203,6 +213,7 @@ def dtype(input): ...@@ -203,6 +213,7 @@ def dtype(input):
""" """
pass pass
def ndim(input): def ndim(input):
"""Return the number of dimensions of the tensor. """Return the number of dimensions of the tensor.
...@@ -218,6 +229,7 @@ def ndim(input): ...@@ -218,6 +229,7 @@ def ndim(input):
""" """
pass pass
def context(input): def context(input):
"""Return the context/device of the input tensor. """Return the context/device of the input tensor.
...@@ -233,6 +245,7 @@ def context(input): ...@@ -233,6 +245,7 @@ def context(input):
""" """
pass pass
def device_type(ctx): def device_type(ctx):
"""Return a str representing device type. """Return a str representing device type.
...@@ -247,6 +260,7 @@ def device_type(ctx): ...@@ -247,6 +260,7 @@ def device_type(ctx):
""" """
pass pass
def device_id(ctx): def device_id(ctx):
"""Return device index. """Return device index.
...@@ -265,6 +279,7 @@ def device_id(ctx): ...@@ -265,6 +279,7 @@ def device_id(ctx):
""" """
pass pass
def to_backend_ctx(dglctx): def to_backend_ctx(dglctx):
"""Convert a DGL context object to a backend context. """Convert a DGL context object to a backend context.
...@@ -279,6 +294,7 @@ def to_backend_ctx(dglctx): ...@@ -279,6 +294,7 @@ def to_backend_ctx(dglctx):
""" """
pass pass
def astype(input, ty): def astype(input, ty):
"""Convert the input tensor to the given data type. """Convert the input tensor to the given data type.
...@@ -296,6 +312,7 @@ def astype(input, ty): ...@@ -296,6 +312,7 @@ def astype(input, ty):
""" """
pass pass
def asnumpy(input): def asnumpy(input):
"""Convert the input tensor to numpy array. """Convert the input tensor to numpy array.
...@@ -313,6 +330,7 @@ def asnumpy(input): ...@@ -313,6 +330,7 @@ def asnumpy(input):
""" """
pass pass
def copy_to(input, ctx, **kwargs): def copy_to(input, ctx, **kwargs):
"""Copy the given tensor to the context. """Copy the given tensor to the context.
...@@ -330,6 +348,7 @@ def copy_to(input, ctx, **kwargs): ...@@ -330,6 +348,7 @@ def copy_to(input, ctx, **kwargs):
""" """
pass pass
def is_pinned(input): def is_pinned(input):
"""Check whether the tensor is in pinned memory. """Check whether the tensor is in pinned memory.
...@@ -345,12 +364,14 @@ def is_pinned(input): ...@@ -345,12 +364,14 @@ def is_pinned(input):
""" """
pass pass
############################################################################### ###############################################################################
# Tensor functions on feature data # Tensor functions on feature data
# -------------------------------- # --------------------------------
# These functions are performance critical, so it's better to have efficient # These functions are performance critical, so it's better to have efficient
# implementation in each framework. # implementation in each framework.
def sum(input, dim, keepdims=False): def sum(input, dim, keepdims=False):
"""Reduce sum the input tensor along the given dim. """Reduce sum the input tensor along the given dim.
...@@ -370,6 +391,7 @@ def sum(input, dim, keepdims=False): ...@@ -370,6 +391,7 @@ def sum(input, dim, keepdims=False):
""" """
pass pass
def floor_div(in1, in2): def floor_div(in1, in2):
"""Element-wise integer division and rounds each quotient towards zero. """Element-wise integer division and rounds each quotient towards zero.
...@@ -386,6 +408,7 @@ def floor_div(in1, in2): ...@@ -386,6 +408,7 @@ def floor_div(in1, in2):
A framework-specific tensor. A framework-specific tensor.
""" """
def reduce_sum(input): def reduce_sum(input):
"""Returns the sum of all elements in the input tensor. """Returns the sum of all elements in the input tensor.
...@@ -401,6 +424,7 @@ def reduce_sum(input): ...@@ -401,6 +424,7 @@ def reduce_sum(input):
""" """
pass pass
def cumsum(input, dim): def cumsum(input, dim):
"""Return the cumulative sum of the elements along a given axis. """Return the cumulative sum of the elements along a given axis.
...@@ -418,6 +442,7 @@ def cumsum(input, dim): ...@@ -418,6 +442,7 @@ def cumsum(input, dim):
""" """
pass pass
def mean(input, dim): def mean(input, dim):
"""Reduce average the input tensor along the given dim. """Reduce average the input tensor along the given dim.
...@@ -435,6 +460,7 @@ def mean(input, dim): ...@@ -435,6 +460,7 @@ def mean(input, dim):
""" """
pass pass
def reduce_mean(input): def reduce_mean(input):
"""Returns the average of all elements in the input tensor. """Returns the average of all elements in the input tensor.
...@@ -450,6 +476,7 @@ def reduce_mean(input): ...@@ -450,6 +476,7 @@ def reduce_mean(input):
""" """
pass pass
def max(input, dim): def max(input, dim):
"""Reduce max the input tensor along the given dim. """Reduce max the input tensor along the given dim.
...@@ -467,6 +494,7 @@ def max(input, dim): ...@@ -467,6 +494,7 @@ def max(input, dim):
""" """
pass pass
def reduce_max(input): def reduce_max(input):
"""Returns the max of all elements in the input tensor. """Returns the max of all elements in the input tensor.
...@@ -482,6 +510,7 @@ def reduce_max(input): ...@@ -482,6 +510,7 @@ def reduce_max(input):
""" """
pass pass
def min(input, dim): def min(input, dim):
"""Reduce min the input tensor along the given dim. """Reduce min the input tensor along the given dim.
...@@ -499,6 +528,7 @@ def min(input, dim): ...@@ -499,6 +528,7 @@ def min(input, dim):
""" """
pass pass
def reduce_min(input): def reduce_min(input):
"""Returns the min of all elements in the input tensor. """Returns the min of all elements in the input tensor.
...@@ -533,6 +563,7 @@ def argsort(input, dim, descending): ...@@ -533,6 +563,7 @@ def argsort(input, dim, descending):
A framework-specific tensor. A framework-specific tensor.
""" """
def topk(input, k, dim, descending=True): def topk(input, k, dim, descending=True):
"""Return the k largest elements of the given input tensor along the given dimension. """Return the k largest elements of the given input tensor along the given dimension.
...@@ -551,6 +582,7 @@ def topk(input, k, dim, descending=True): ...@@ -551,6 +582,7 @@ def topk(input, k, dim, descending=True):
""" """
pass pass
def argtopk(input, k, dim, descending=True): def argtopk(input, k, dim, descending=True):
"""Return the indices of the k largest elements of the given input tensor """Return the indices of the k largest elements of the given input tensor
along the given dimension. along the given dimension.
...@@ -570,6 +602,7 @@ def argtopk(input, k, dim, descending=True): ...@@ -570,6 +602,7 @@ def argtopk(input, k, dim, descending=True):
""" """
pass pass
def exp(input): def exp(input):
"""Returns a new tensor with the exponential of the elements of the input tensor `input`. """Returns a new tensor with the exponential of the elements of the input tensor `input`.
...@@ -585,6 +618,7 @@ def exp(input): ...@@ -585,6 +618,7 @@ def exp(input):
""" """
pass pass
def inverse(input): def inverse(input):
"""Returns the inverse matrix of a square matrix if it exists. """Returns the inverse matrix of a square matrix if it exists.
...@@ -600,6 +634,7 @@ def inverse(input): ...@@ -600,6 +634,7 @@ def inverse(input):
""" """
pass pass
def sqrt(input): def sqrt(input):
"""Returns a new tensor with the square root of the elements of the input tensor `input`. """Returns a new tensor with the square root of the elements of the input tensor `input`.
...@@ -615,6 +650,7 @@ def sqrt(input): ...@@ -615,6 +650,7 @@ def sqrt(input):
""" """
pass pass
def softmax(input, dim=-1): def softmax(input, dim=-1):
"""Apply the softmax function on given dimension. """Apply the softmax function on given dimension.
...@@ -650,6 +686,7 @@ def cat(seq, dim): ...@@ -650,6 +686,7 @@ def cat(seq, dim):
""" """
pass pass
def stack(seq, dim): def stack(seq, dim):
"""Stack the sequence of tensors along the given dimension. """Stack the sequence of tensors along the given dimension.
...@@ -667,6 +704,7 @@ def stack(seq, dim): ...@@ -667,6 +704,7 @@ def stack(seq, dim):
""" """
pass pass
def split(input, sizes_or_sections, dim): def split(input, sizes_or_sections, dim):
"""Split the input tensor into chunks. """Split the input tensor into chunks.
...@@ -692,6 +730,7 @@ def split(input, sizes_or_sections, dim): ...@@ -692,6 +730,7 @@ def split(input, sizes_or_sections, dim):
""" """
pass pass
def repeat(input, repeats, dim): def repeat(input, repeats, dim):
"""Repeats elements of an array. """Repeats elements of an array.
...@@ -711,6 +750,7 @@ def repeat(input, repeats, dim): ...@@ -711,6 +750,7 @@ def repeat(input, repeats, dim):
""" """
pass pass
def gather_row(data, row_index): def gather_row(data, row_index):
"""Slice out the data given the row index. """Slice out the data given the row index.
...@@ -728,6 +768,7 @@ def gather_row(data, row_index): ...@@ -728,6 +768,7 @@ def gather_row(data, row_index):
""" """
pass pass
def slice_axis(data, axis, begin, end): def slice_axis(data, axis, begin, end):
"""Slice along a given axis. """Slice along a given axis.
Returns an array slice along a given axis starting from :attr:`begin` index to :attr:`end` index. Returns an array slice along a given axis starting from :attr:`begin` index to :attr:`end` index.
...@@ -749,6 +790,7 @@ def slice_axis(data, axis, begin, end): ...@@ -749,6 +790,7 @@ def slice_axis(data, axis, begin, end):
""" """
pass pass
def take(data, indices, dim): def take(data, indices, dim):
"""Takes elements from an input array along the given dim. """Takes elements from an input array along the given dim.
...@@ -763,6 +805,7 @@ def take(data, indices, dim): ...@@ -763,6 +805,7 @@ def take(data, indices, dim):
""" """
pass pass
def narrow_row(x, start, stop): def narrow_row(x, start, stop):
"""Narrow down the tensor along the first dimension. """Narrow down the tensor along the first dimension.
...@@ -786,6 +829,7 @@ def narrow_row(x, start, stop): ...@@ -786,6 +829,7 @@ def narrow_row(x, start, stop):
""" """
pass pass
def scatter_row(data, row_index, value): def scatter_row(data, row_index, value):
"""Write the value into the data tensor using the row index. """Write the value into the data tensor using the row index.
...@@ -807,6 +851,7 @@ def scatter_row(data, row_index, value): ...@@ -807,6 +851,7 @@ def scatter_row(data, row_index, value):
""" """
pass pass
def index_add_inplace(data, row_idx, value): def index_add_inplace(data, row_idx, value):
"""Add the values into the data tensor using the row index inplace. """Add the values into the data tensor using the row index inplace.
...@@ -832,6 +877,7 @@ def index_add_inplace(data, row_idx, value): ...@@ -832,6 +877,7 @@ def index_add_inplace(data, row_idx, value):
""" """
pass pass
def scatter_row_inplace(data, row_index, value): def scatter_row_inplace(data, row_index, value):
"""Write the value into the data tensor using the row index inplace. """Write the value into the data tensor using the row index inplace.
...@@ -848,6 +894,7 @@ def scatter_row_inplace(data, row_index, value): ...@@ -848,6 +894,7 @@ def scatter_row_inplace(data, row_index, value):
""" """
pass pass
def squeeze(input, dim): def squeeze(input, dim):
"""Remove the given dimension of size 1. """Remove the given dimension of size 1.
...@@ -865,6 +912,7 @@ def squeeze(input, dim): ...@@ -865,6 +912,7 @@ def squeeze(input, dim):
""" """
pass pass
def unsqueeze(input, dim): def unsqueeze(input, dim):
"""Add the given dimension of size 1. """Add the given dimension of size 1.
...@@ -882,6 +930,7 @@ def unsqueeze(input, dim): ...@@ -882,6 +930,7 @@ def unsqueeze(input, dim):
""" """
pass pass
def reshape(input, shape): def reshape(input, shape):
"""Reshape the tensor. """Reshape the tensor.
...@@ -899,6 +948,7 @@ def reshape(input, shape): ...@@ -899,6 +948,7 @@ def reshape(input, shape):
""" """
pass pass
def swapaxes(input, axis1, axis2): def swapaxes(input, axis1, axis2):
"""Interchange the two given axes of a tensor. """Interchange the two given axes of a tensor.
...@@ -916,6 +966,7 @@ def swapaxes(input, axis1, axis2): ...@@ -916,6 +966,7 @@ def swapaxes(input, axis1, axis2):
""" """
pass pass
def zeros(shape, dtype, ctx): def zeros(shape, dtype, ctx):
"""Create a zero tensor. """Create a zero tensor.
...@@ -935,6 +986,7 @@ def zeros(shape, dtype, ctx): ...@@ -935,6 +986,7 @@ def zeros(shape, dtype, ctx):
""" """
pass pass
def zeros_like(input): def zeros_like(input):
"""Create a zero tensor with the same shape, dtype and context of the """Create a zero tensor with the same shape, dtype and context of the
given tensor. given tensor.
...@@ -951,6 +1003,7 @@ def zeros_like(input): ...@@ -951,6 +1003,7 @@ def zeros_like(input):
""" """
pass pass
def ones(shape, dtype, ctx): def ones(shape, dtype, ctx):
"""Create a one tensor. """Create a one tensor.
...@@ -970,6 +1023,7 @@ def ones(shape, dtype, ctx): ...@@ -970,6 +1023,7 @@ def ones(shape, dtype, ctx):
""" """
pass pass
def uniform(shape, dtype, ctx, low, high): def uniform(shape, dtype, ctx, low, high):
"""Create a tensor with random value in a uniform """Create a tensor with random value in a uniform
distribution between low (inclusive) and high (exclusive). distribution between low (inclusive) and high (exclusive).
...@@ -990,6 +1044,7 @@ def uniform(shape, dtype, ctx, low, high): ...@@ -990,6 +1044,7 @@ def uniform(shape, dtype, ctx, low, high):
""" """
pass pass
def randint(shape, dtype, ctx, low, high): def randint(shape, dtype, ctx, low, high):
"""Create a tensor with random value in a uniform integer """Create a tensor with random value in a uniform integer
distribution between low (inclusive) and high (exclusive) distribution between low (inclusive) and high (exclusive)
...@@ -1010,6 +1065,7 @@ def randint(shape, dtype, ctx, low, high): ...@@ -1010,6 +1065,7 @@ def randint(shape, dtype, ctx, low, high):
""" """
pass pass
def pad_packed_tensor(input, lengths, value, l_min=None): def pad_packed_tensor(input, lengths, value, l_min=None):
r"""Pads a packed batch of variable length tensors with given value. r"""Pads a packed batch of variable length tensors with given value.
...@@ -1034,6 +1090,7 @@ def pad_packed_tensor(input, lengths, value, l_min=None): ...@@ -1034,6 +1090,7 @@ def pad_packed_tensor(input, lengths, value, l_min=None):
""" """
pass pass
def pack_padded_tensor(input, lengths): def pack_padded_tensor(input, lengths):
r"""Packs a tensor containing padded sequence of variable length. r"""Packs a tensor containing padded sequence of variable length.
...@@ -1054,6 +1111,7 @@ def pack_padded_tensor(input, lengths): ...@@ -1054,6 +1111,7 @@ def pack_padded_tensor(input, lengths):
""" """
pass pass
def boolean_mask(input, mask): def boolean_mask(input, mask):
"""Selects elements in x according to the given mask from the first """Selects elements in x according to the given mask from the first
dimension. dimension.
...@@ -1072,6 +1130,7 @@ def boolean_mask(input, mask): ...@@ -1072,6 +1130,7 @@ def boolean_mask(input, mask):
""" """
pass pass
def equal(x, y): def equal(x, y):
"""Compares whether the elements are equal. """Compares whether the elements are equal.
...@@ -1087,6 +1146,7 @@ def equal(x, y): ...@@ -1087,6 +1146,7 @@ def equal(x, y):
""" """
pass pass
def allclose(x, y, rtol=1e-4, atol=1e-4): def allclose(x, y, rtol=1e-4, atol=1e-4):
"""Compares whether all elements are close. """Compares whether all elements are close.
...@@ -1102,6 +1162,7 @@ def allclose(x, y, rtol=1e-4, atol=1e-4): ...@@ -1102,6 +1162,7 @@ def allclose(x, y, rtol=1e-4, atol=1e-4):
Absolute tolerance Absolute tolerance
""" """
def logical_not(input): def logical_not(input):
"""Perform a logical not operation. Equivalent to np.logical_not """Perform a logical not operation. Equivalent to np.logical_not
...@@ -1117,9 +1178,11 @@ def logical_not(input): ...@@ -1117,9 +1178,11 @@ def logical_not(input):
""" """
pass pass
def logical_and(input1, input2): def logical_and(input1, input2):
pass pass
def clone(input): def clone(input):
"""Return a clone of the input tensor. """Return a clone of the input tensor.
...@@ -1135,6 +1198,7 @@ def clone(input): ...@@ -1135,6 +1198,7 @@ def clone(input):
""" """
pass pass
def clamp(data, min_val, max_val): def clamp(data, min_val, max_val):
"""Clamp all elements in :attr:`input` into the range [min_val, max_val] """Clamp all elements in :attr:`input` into the range [min_val, max_val]
and return a resulting tensor. and return a resulting tensor.
...@@ -1155,6 +1219,7 @@ def clamp(data, min_val, max_val): ...@@ -1155,6 +1219,7 @@ def clamp(data, min_val, max_val):
""" """
pass pass
def replace_inf_with_zero(x): def replace_inf_with_zero(x):
"""Returns a new tensor replacing infinity and negative infinity with zeros. """Returns a new tensor replacing infinity and negative infinity with zeros.
...@@ -1170,6 +1235,7 @@ def replace_inf_with_zero(x): ...@@ -1170,6 +1235,7 @@ def replace_inf_with_zero(x):
""" """
pass pass
def count_nonzero(input): def count_nonzero(input):
"""Return the count of non-zero values in the tensor input. """Return the count of non-zero values in the tensor input.
...@@ -1185,6 +1251,7 @@ def count_nonzero(input): ...@@ -1185,6 +1251,7 @@ def count_nonzero(input):
""" """
pass pass
############################################################################### ###############################################################################
# Tensor functions used *only* on index tensor # Tensor functions used *only* on index tensor
# ---------------- # ----------------
...@@ -1193,6 +1260,7 @@ def count_nonzero(input): ...@@ -1193,6 +1260,7 @@ def count_nonzero(input):
# DGL should contain all the operations on index, so this set of operators # DGL should contain all the operations on index, so this set of operators
# should be gradually removed. # should be gradually removed.
def unique(input, return_inverse=False, return_counts=False): def unique(input, return_inverse=False, return_counts=False):
"""Returns the unique scalar elements in a tensor. """Returns the unique scalar elements in a tensor.
...@@ -1219,6 +1287,7 @@ def unique(input, return_inverse=False, return_counts=False): ...@@ -1219,6 +1287,7 @@ def unique(input, return_inverse=False, return_counts=False):
""" """
pass pass
def full_1d(length, fill_value, dtype, ctx): def full_1d(length, fill_value, dtype, ctx):
"""Create a 1D tensor full of the fill_value. """Create a 1D tensor full of the fill_value.
...@@ -1240,6 +1309,7 @@ def full_1d(length, fill_value, dtype, ctx): ...@@ -1240,6 +1309,7 @@ def full_1d(length, fill_value, dtype, ctx):
""" """
pass pass
def nonzero_1d(input): def nonzero_1d(input):
"""Return the nonzero index of the given 1D input. """Return the nonzero index of the given 1D input.
...@@ -1255,6 +1325,7 @@ def nonzero_1d(input): ...@@ -1255,6 +1325,7 @@ def nonzero_1d(input):
""" """
pass pass
def sort_1d(input): def sort_1d(input):
"""Sort a 1D tensor (in ascending order) and also return the original index. """Sort a 1D tensor (in ascending order) and also return the original index.
...@@ -1272,6 +1343,7 @@ def sort_1d(input): ...@@ -1272,6 +1343,7 @@ def sort_1d(input):
""" """
pass pass
def arange(start, stop, dtype, ctx): def arange(start, stop, dtype, ctx):
"""Create a 1D range int64 tensor. """Create a 1D range int64 tensor.
...@@ -1293,6 +1365,7 @@ def arange(start, stop, dtype, ctx): ...@@ -1293,6 +1365,7 @@ def arange(start, stop, dtype, ctx):
""" """
pass pass
def rand_shuffle(arr): def rand_shuffle(arr):
"""Random shuffle the data in the first dimension of the array. """Random shuffle the data in the first dimension of the array.
...@@ -1310,6 +1383,7 @@ def rand_shuffle(arr): ...@@ -1310,6 +1383,7 @@ def rand_shuffle(arr):
""" """
pass pass
def zerocopy_to_dlpack(input): def zerocopy_to_dlpack(input):
"""Create a dlpack tensor that shares the input memory. """Create a dlpack tensor that shares the input memory.
...@@ -1325,6 +1399,7 @@ def zerocopy_to_dlpack(input): ...@@ -1325,6 +1399,7 @@ def zerocopy_to_dlpack(input):
""" """
pass pass
def zerocopy_from_dlpack(dlpack_tensor): def zerocopy_from_dlpack(dlpack_tensor):
"""Create a tensor that shares the dlpack_tensor. """Create a tensor that shares the dlpack_tensor.
...@@ -1340,6 +1415,7 @@ def zerocopy_from_dlpack(dlpack_tensor): ...@@ -1340,6 +1415,7 @@ def zerocopy_from_dlpack(dlpack_tensor):
""" """
pass pass
def zerocopy_to_numpy(input): def zerocopy_to_numpy(input):
"""Create a numpy ndarray that shares the input memory. """Create a numpy ndarray that shares the input memory.
...@@ -1355,6 +1431,7 @@ def zerocopy_to_numpy(input): ...@@ -1355,6 +1431,7 @@ def zerocopy_to_numpy(input):
""" """
pass pass
def zerocopy_from_numpy(np_array): def zerocopy_from_numpy(np_array):
"""Create a tensor that shares the numpy array. """Create a tensor that shares the numpy array.
...@@ -1370,6 +1447,7 @@ def zerocopy_from_numpy(np_array): ...@@ -1370,6 +1447,7 @@ def zerocopy_from_numpy(np_array):
""" """
pass pass
def zerocopy_to_dgl_ndarray(input): def zerocopy_to_dgl_ndarray(input):
"""Zerocopy a framework-specific Tensor to dgl.ndarray.NDArray """Zerocopy a framework-specific Tensor to dgl.ndarray.NDArray
...@@ -1383,6 +1461,7 @@ def zerocopy_to_dgl_ndarray(input): ...@@ -1383,6 +1461,7 @@ def zerocopy_to_dgl_ndarray(input):
""" """
pass pass
def zerocopy_to_dgl_ndarray_for_write(input): def zerocopy_to_dgl_ndarray_for_write(input):
"""Zerocopy a framework-specific Tensor to dgl.ndarray.NDArray """Zerocopy a framework-specific Tensor to dgl.ndarray.NDArray
that is ready for write (required in MXNet). that is ready for write (required in MXNet).
...@@ -1412,7 +1491,6 @@ def zerocopy_from_dgl_ndarray(input): ...@@ -1412,7 +1491,6 @@ def zerocopy_from_dgl_ndarray(input):
pass pass
############################################################################### ###############################################################################
# Custom Operators for graph level computations. # Custom Operators for graph level computations.
...@@ -1420,8 +1498,20 @@ def zerocopy_from_dgl_ndarray(input): ...@@ -1420,8 +1498,20 @@ def zerocopy_from_dgl_ndarray(input):
# kernels (see kernel.py), and plug into tensor framework using custom op # kernels (see kernel.py), and plug into tensor framework using custom op
# extensions. # extensions.
def binary_reduce(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data,
out_size, lhs_map, rhs_map, out_map): def binary_reduce(
reducer,
binary_op,
graph,
lhs,
rhs,
lhs_data,
rhs_data,
out_size,
lhs_map,
rhs_map,
out_map,
):
"""Perform binary operation between given data and reduce based on graph """Perform binary operation between given data and reduce based on graph
structure. structure.
...@@ -1458,6 +1548,7 @@ def binary_reduce(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data, ...@@ -1458,6 +1548,7 @@ def binary_reduce(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data,
""" """
pass pass
def copy_reduce(reducer, graph, target, in_data, out_size, in_map, out_map): def copy_reduce(reducer, graph, target, in_data, out_size, in_map, out_map):
"""Copy target data and perform reduce based on graph structure. """Copy target data and perform reduce based on graph structure.
...@@ -1486,8 +1577,9 @@ def copy_reduce(reducer, graph, target, in_data, out_size, in_map, out_map): ...@@ -1486,8 +1577,9 @@ def copy_reduce(reducer, graph, target, in_data, out_size, in_map, out_map):
""" """
pass pass
def gspmm(gidx, op, reduce_op, lhs_data, rhs_data): def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
r""" Generalized Sparse Matrix Multiplication interface. r"""Generalized Sparse Matrix Multiplication interface.
It fuses two steps into one kernel. It fuses two steps into one kernel.
(1) Computes messages by :attr:`op` source node and edge features. (1) Computes messages by :attr:`op` source node and edge features.
(2) Aggregate the messages by :attr:`reduce_op` as the features on destination nodes. (2) Aggregate the messages by :attr:`reduce_op` as the features on destination nodes.
...@@ -1523,8 +1615,9 @@ def gspmm(gidx, op, reduce_op, lhs_data, rhs_data): ...@@ -1523,8 +1615,9 @@ def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
""" """
pass pass
def gspmm_hetero(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple): def gspmm_hetero(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple):
r""" Generalized Sparse Matrix Multiplication interface on heterogenenous graph. r"""Generalized Sparse Matrix Multiplication interface on heterogenenous graph.
All the relation types of the heterogeneous graph will be processed together. All the relation types of the heterogeneous graph will be processed together.
It fuses two steps into one kernel. It fuses two steps into one kernel.
(1) Computes messages by :attr:`op` source node and edge features. (1) Computes messages by :attr:`op` source node and edge features.
...@@ -1564,8 +1657,9 @@ def gspmm_hetero(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple): ...@@ -1564,8 +1657,9 @@ def gspmm_hetero(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple):
""" """
pass pass
def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface. def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target="u", rhs_target="v"):
r"""Generalized Sampled-Dense-Dense Matrix Multiplication interface.
It computes edge features by :attr:`op` lhs features and rhs features. It computes edge features by :attr:`op` lhs features and rhs features.
.. math:: .. math::
...@@ -1599,8 +1693,11 @@ def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'): ...@@ -1599,8 +1693,11 @@ def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
""" """
pass pass
def gsddmm_hetero(g, op, lhs_len, lhs_target='u', rhs_target='v', *lhs_and_rhs_tuple):
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface on def gsddmm_hetero(
g, op, lhs_len, lhs_target="u", rhs_target="v", *lhs_and_rhs_tuple
):
r"""Generalized Sampled-Dense-Dense Matrix Multiplication interface on
heterogenenous graph. All the relation types of the heterogeneous graph heterogenenous graph. All the relation types of the heterogeneous graph
will be processed together. will be processed together.
It computes edge features by :attr:`op` lhs features and rhs features. It computes edge features by :attr:`op` lhs features and rhs features.
...@@ -1639,6 +1736,7 @@ def gsddmm_hetero(g, op, lhs_len, lhs_target='u', rhs_target='v', *lhs_and_rhs_t ...@@ -1639,6 +1736,7 @@ def gsddmm_hetero(g, op, lhs_len, lhs_target='u', rhs_target='v', *lhs_and_rhs_t
""" """
pass pass
def edge_softmax(gidx, logits, eids, norm_by): def edge_softmax(gidx, logits, eids, norm_by):
r"""Compute edge softmax. r"""Compute edge softmax.
...@@ -1676,6 +1774,7 @@ def edge_softmax(gidx, logits, eids, norm_by): ...@@ -1676,6 +1774,7 @@ def edge_softmax(gidx, logits, eids, norm_by):
""" """
pass pass
def edge_softmax_hetero(gidx, eids, norm_by, *logits): def edge_softmax_hetero(gidx, eids, norm_by, *logits):
r"""Compute edge softmax. r"""Compute edge softmax.
...@@ -1713,6 +1812,7 @@ def edge_softmax_hetero(gidx, eids, norm_by, *logits): ...@@ -1713,6 +1812,7 @@ def edge_softmax_hetero(gidx, eids, norm_by, *logits):
""" """
pass pass
def segment_reduce(op, x, offsets): def segment_reduce(op, x, offsets):
"""Segment reduction operator. """Segment reduction operator.
...@@ -1741,6 +1841,7 @@ def segment_reduce(op, x, offsets): ...@@ -1741,6 +1841,7 @@ def segment_reduce(op, x, offsets):
""" """
pass pass
def scatter_add(x, idx, m): def scatter_add(x, idx, m):
"""Scatter add (on first dimension) operator. """Scatter add (on first dimension) operator.
...@@ -1763,6 +1864,7 @@ def scatter_add(x, idx, m): ...@@ -1763,6 +1864,7 @@ def scatter_add(x, idx, m):
""" """
pass pass
def csrmm(A, A_weights, B, B_weights, num_vtypes): def csrmm(A, A_weights, B, B_weights, num_vtypes):
"""Compute weighted adjacency matrix multiplication. """Compute weighted adjacency matrix multiplication.
...@@ -1795,6 +1897,7 @@ def csrmm(A, A_weights, B, B_weights, num_vtypes): ...@@ -1795,6 +1897,7 @@ def csrmm(A, A_weights, B, B_weights, num_vtypes):
""" """
pass pass
def csrsum(gidxs, weights): def csrsum(gidxs, weights):
"""Compute weighted adjacency matrix summation. """Compute weighted adjacency matrix summation.
...@@ -1821,6 +1924,7 @@ def csrsum(gidxs, weights): ...@@ -1821,6 +1924,7 @@ def csrsum(gidxs, weights):
""" """
pass pass
def csrmask(A, A_weights, B): def csrmask(A, A_weights, B):
"""Retrieve the values in the weighted adjacency matrix of graph :attr:`A` at the """Retrieve the values in the weighted adjacency matrix of graph :attr:`A` at the
non-zero positions of graph :attr:`B`'s adjacency matrix. non-zero positions of graph :attr:`B`'s adjacency matrix.
...@@ -1848,8 +1952,9 @@ def csrmask(A, A_weights, B): ...@@ -1848,8 +1952,9 @@ def csrmask(A, A_weights, B):
""" """
pass pass
def gather_mm(A, B, idx_a, idx_b): def gather_mm(A, B, idx_a, idx_b):
r""" Dense Matrix Multiplication interface. It multiplies 2D dense tensor A r"""Dense Matrix Multiplication interface. It multiplies 2D dense tensor A
and 3D dense tensor B according to their relation types. A is unsorted and and 3D dense tensor B according to their relation types. A is unsorted and
the relation type is fetched from idx_b. the relation type is fetched from idx_b.
...@@ -1871,8 +1976,9 @@ def gather_mm(A, B, idx_a, idx_b): ...@@ -1871,8 +1976,9 @@ def gather_mm(A, B, idx_a, idx_b):
""" """
pass pass
def segment_mm(A, B, seglen_A): def segment_mm(A, B, seglen_A):
r""" Dense Matrix Multiplication interface. It multiplies dense tensor A r"""Dense Matrix Multiplication interface. It multiplies dense tensor A
and dense tensor B according to relation types. A is sorted and concatenated and dense tensor B according to relation types. A is sorted and concatenated
according to relation types. according to relation types.
...@@ -1900,6 +2006,7 @@ def segment_mm(A, B, seglen_A): ...@@ -1900,6 +2006,7 @@ def segment_mm(A, B, seglen_A):
# These are not related to tensors. Some of them are temporary workarounds that # These are not related to tensors. Some of them are temporary workarounds that
# should be included in DGL in the future. # should be included in DGL in the future.
def sync(): def sync():
"""Synchronize computation. """Synchronize computation.
...@@ -1909,33 +2016,35 @@ def sync(): ...@@ -1909,33 +2016,35 @@ def sync():
""" """
pass pass
def attach_grad(tensor): def attach_grad(tensor):
""" Attach gradients to the input tensor """Attach gradients to the input tensor"""
"""
pass pass
def backward(x, head_gradient=None): def backward(x, head_gradient=None):
"""Invoke backward computation with an optional head gradient. """Invoke backward computation with an optional head gradient."""
"""
pass pass
def grad(x): def grad(x):
"""Fetches the gradient from the tensor after backward computation. """Fetches the gradient from the tensor after backward computation."""
"""
pass pass
def is_no_grad(x): def is_no_grad(x):
""" Test if the input tensor has gradient """Test if the input tensor has gradient"""
"""
pass pass
def is_recording(): def is_recording():
""" Test if the execution is recording gradients. """Test if the execution is recording gradients."""
"""
pass pass
class record_grad(object): class record_grad(object):
"""Context manager that records the gradients""" """Context manager that records the gradients"""
def __init__(self): def __init__(self):
pass pass
...@@ -1948,6 +2057,7 @@ class record_grad(object): ...@@ -1948,6 +2057,7 @@ class record_grad(object):
class no_grad(object): class no_grad(object):
"""Context manager that explicitly disables gradient computation""" """Context manager that explicitly disables gradient computation"""
def __init__(self): def __init__(self):
pass pass
...@@ -1957,8 +2067,10 @@ class no_grad(object): ...@@ -1957,8 +2067,10 @@ class no_grad(object):
def __exit__(self, exc_type, exc_value, exc_traceback): def __exit__(self, exc_type, exc_value, exc_traceback):
pass pass
class NodeEmbedding(object): class NodeEmbedding(object):
"""Sparse node embeddings""" """Sparse node embeddings"""
def __init__(self): def __init__(self):
pass pass
......
from .tensor import *
from .sparse import * from .sparse import *
from .tensor import *
import mxnet as mx import mxnet as mx
import numpy as np import numpy as np
from mxnet import nd from mxnet import nd
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _scatter_add
from ...sparse import _csrmm, _csrsum, _csrmask
from ...base import dgl_warning, is_all, ALL
from .tensor import asnumpy, copy_to, zerocopy_from_numpy, context, to_backend_ctx
from ...heterograph_index import create_unitgraph_from_csr
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce', 'scatter_add', from ...base import ALL, dgl_warning, is_all
'csrmm', 'csrsum', 'csrmask'] from ...heterograph_index import create_unitgraph_from_csr
from ...sparse import (
_bwd_segment_cmp,
_csrmask,
_csrmm,
_csrsum,
_gsddmm,
_gspmm,
_scatter_add,
_segment_reduce,
)
from .tensor import (
asnumpy,
context,
copy_to,
to_backend_ctx,
zerocopy_from_numpy,
)
__all__ = [
"gspmm",
"gsddmm",
"edge_softmax",
"segment_reduce",
"scatter_add",
"csrmm",
"csrsum",
"csrmask",
]
def _scatter_nd(index, src, n_rows): def _scatter_nd(index, src, n_rows):
...@@ -26,7 +49,10 @@ def _scatter_nd(index, src, n_rows): ...@@ -26,7 +49,10 @@ def _scatter_nd(index, src, n_rows):
di = shp[i] di = shp[i]
offset_i = np.arange(di, dtype=index.dtype) offset_i = np.arange(di, dtype=index.dtype)
offsets.append( offsets.append(
(stride * offset_i).reshape((1,) * i + (di,) + (1,) * (ndim - 1 - i))) (stride * offset_i).reshape(
(1,) * i + (di,) + (1,) * (ndim - 1 - i)
)
)
stride *= di stride *= di
if ndim > 1: if ndim > 1:
new_idx = index * stride + sum(offsets) new_idx = index * stride + sum(offsets)
...@@ -52,7 +78,10 @@ def _gather_nd(index, src): ...@@ -52,7 +78,10 @@ def _gather_nd(index, src):
di = shp[i] di = shp[i]
offset_i = nd.arange(di, dtype=index.dtype) offset_i = nd.arange(di, dtype=index.dtype)
offsets.append( offsets.append(
(stride * offset_i).reshape((1,) * i + (di,) + (1,) * (ndim - 1 - i))) (stride * offset_i).reshape(
(1,) * i + (di,) + (1,) * (ndim - 1 - i)
)
)
stride *= di stride *= di
if ndim > 1: if ndim > 1:
new_idx = index * stride + copy_to(sum(offsets), ctx) new_idx = index * stride + copy_to(sum(offsets), ctx)
...@@ -107,11 +136,11 @@ def _need_reduce_last_dim(ufeat, efeat): ...@@ -107,11 +136,11 @@ def _need_reduce_last_dim(ufeat, efeat):
def _muldiv(op, x): def _muldiv(op, x):
return 1. / x if op == 'div' else x return 1.0 / x if op == "div" else x
def _addsub(op, x): def _addsub(op, x):
return -x if op == 'sub' else x return -x if op == "sub" else x
def _expand(x, shape): def _expand(x, shape):
...@@ -134,45 +163,48 @@ class GSpMM(mx.autograd.Function): ...@@ -134,45 +163,48 @@ class GSpMM(mx.autograd.Function):
ctx = context(dZ) ctx = context(dZ)
X, Y, argX, argY = self.saved_tensors X, Y, argX, argY = self.saved_tensors
gidx, op, reduce_op = self.gidx, self.op, self.reduce_op gidx, op, reduce_op = self.gidx, self.op, self.reduce_op
if op != 'copy_rhs': if op != "copy_rhs":
g_rev = gidx.reverse() g_rev = gidx.reverse()
if reduce_op == 'sum': if reduce_op == "sum":
if op in ['mul', 'div']: if op in ["mul", "div"]:
dX = _gspmm(g_rev, 'mul', 'sum', dZ, _muldiv(op, Y))[0] dX = _gspmm(g_rev, "mul", "sum", dZ, _muldiv(op, Y))[0]
elif op in ['add', 'sub']: elif op in ["add", "sub"]:
dX = _gspmm(g_rev, 'copy_lhs', 'sum', dZ, Y)[0] dX = _gspmm(g_rev, "copy_lhs", "sum", dZ, Y)[0]
elif op == 'copy_lhs': elif op == "copy_lhs":
dX = _gspmm(g_rev, 'copy_lhs', 'sum', dZ, None)[0] dX = _gspmm(g_rev, "copy_lhs", "sum", dZ, None)[0]
else: else:
if op in ['mul', 'div']: if op in ["mul", "div"]:
dX = _scatter_nd( dX = _scatter_nd(
argX, argX,
_muldiv(op, _gather_nd(argY, _expand(Y, dZ.shape[1:]))) * dZ, _muldiv(op, _gather_nd(argY, _expand(Y, dZ.shape[1:])))
X.shape[0]) * dZ,
elif op in ['add', 'sub', 'copy_lhs']: X.shape[0],
)
elif op in ["add", "sub", "copy_lhs"]:
dX = _scatter_nd(argX, dZ, X.shape[0]) dX = _scatter_nd(argX, dZ, X.shape[0])
dX = _reduce_grad(dX, X.shape) dX = _reduce_grad(dX, X.shape)
else: else:
dX = nd.zeros_like(X) dX = nd.zeros_like(X)
if op != 'copy_lhs': if op != "copy_lhs":
if reduce_op == 'sum': if reduce_op == "sum":
if op == 'mul' and _need_reduce_last_dim(X, Y): if op == "mul" and _need_reduce_last_dim(X, Y):
dY = _gsddmm(gidx, 'dot', X, dZ) dY = _gsddmm(gidx, "dot", X, dZ)
elif op in ['mul', 'div']: elif op in ["mul", "div"]:
dY = _gsddmm(gidx, 'mul', X, dZ) dY = _gsddmm(gidx, "mul", X, dZ)
if op == 'div': if op == "div":
dY = -dY / (Y ** 2) dY = -dY / (Y**2)
elif op in ['add', 'sub', 'copy_rhs']: elif op in ["add", "sub", "copy_rhs"]:
dY = _gsddmm(gidx, 'copy_rhs', X, _addsub(op, dZ)) dY = _gsddmm(gidx, "copy_rhs", X, _addsub(op, dZ))
else: else:
if op in ['mul', 'div']: if op in ["mul", "div"]:
dY = _scatter_nd( dY = _scatter_nd(
argY, argY,
_gather_nd(argX, _expand(X, dZ.shape[1:])) * dZ, _gather_nd(argX, _expand(X, dZ.shape[1:])) * dZ,
Y.shape[0]) Y.shape[0],
if op == 'div': )
dY = -dY / (Y ** 2) if op == "div":
elif op in ['add', 'sub', 'copy_rhs']: dY = -dY / (Y**2)
elif op in ["add", "sub", "copy_rhs"]:
dY = _scatter_nd(argY, _addsub(op, dZ), Y.shape[0]) dY = _scatter_nd(argY, _addsub(op, dZ), Y.shape[0])
dY = _reduce_grad(dY, Y.shape) dY = _reduce_grad(dY, Y.shape)
else: else:
...@@ -207,7 +239,9 @@ class GSDDMM(mx.autograd.Function): ...@@ -207,7 +239,9 @@ class GSDDMM(mx.autograd.Function):
self.rhs_target = rhs_target self.rhs_target = rhs_target
def forward(self, X, Y): def forward(self, X, Y):
out = _gsddmm(self.gidx, self.op, X, Y, self.lhs_target, self.rhs_target) out = _gsddmm(
self.gidx, self.op, X, Y, self.lhs_target, self.rhs_target
)
self.save_for_backward(X, Y) self.save_for_backward(X, Y)
return out return out
...@@ -216,47 +250,55 @@ class GSDDMM(mx.autograd.Function): ...@@ -216,47 +250,55 @@ class GSDDMM(mx.autograd.Function):
X, Y = self.saved_tensors X, Y = self.saved_tensors
gidx, op = self.gidx, self.op gidx, op = self.gidx, self.op
lhs_target, rhs_target = self.lhs_target, self.rhs_target lhs_target, rhs_target = self.lhs_target, self.rhs_target
if op != 'copy_rhs': if op != "copy_rhs":
if lhs_target in ['u', 'v']: if lhs_target in ["u", "v"]:
_gidx = gidx if self.lhs_target == 'v' else gidx.reverse() _gidx = gidx if self.lhs_target == "v" else gidx.reverse()
if op in ['add', 'sub', 'copy_lhs']: if op in ["add", "sub", "copy_lhs"]:
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] dX = _gspmm(_gidx, "copy_rhs", "sum", None, dZ)[0]
else: # mul, div, dot else: # mul, div, dot
if rhs_target == lhs_target: if rhs_target == lhs_target:
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] * _muldiv(op, Y) dX = _gspmm(_gidx, "copy_rhs", "sum", None, dZ)[
elif self.rhs_target == 'e': 0
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * _muldiv(op, Y))[0] ] * _muldiv(op, Y)
elif self.rhs_target == "e":
dX = _gspmm(
_gidx, "copy_rhs", "sum", None, dZ * _muldiv(op, Y)
)[0]
else: # rhs_target = !lhs_target else: # rhs_target = !lhs_target
dX = _gspmm(_gidx, 'mul', 'sum', _muldiv(op, Y), dZ)[0] dX = _gspmm(_gidx, "mul", "sum", _muldiv(op, Y), dZ)[0]
else: # lhs_target == 'e' else: # lhs_target == 'e'
if op in ['add', 'sub', 'copy_lhs']: if op in ["add", "sub", "copy_lhs"]:
dX = dZ dX = dZ
else: # mul, div, dot else: # mul, div, dot
dX = _gsddmm(gidx, 'mul', dZ, _muldiv(op, Y), 'e', rhs_target) dX = _gsddmm(
gidx, "mul", dZ, _muldiv(op, Y), "e", rhs_target
)
dX = _reduce_grad(dX, X.shape) dX = _reduce_grad(dX, X.shape)
else: else:
dX = nd.zeros_like(X) dX = nd.zeros_like(X)
if op != 'copy_lhs': if op != "copy_lhs":
if self.rhs_target in ['u', 'v']: if self.rhs_target in ["u", "v"]:
_gidx = gidx if rhs_target == 'v' else gidx.reverse() _gidx = gidx if rhs_target == "v" else gidx.reverse()
if op in ['add', 'sub', 'copy_rhs']: if op in ["add", "sub", "copy_rhs"]:
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, _addsub(op, dZ))[0] dY = _gspmm(
_gidx, "copy_rhs", "sum", None, _addsub(op, dZ)
)[0]
else: # mul, div, dot else: # mul, div, dot
if lhs_target == rhs_target: if lhs_target == rhs_target:
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] * X dY = _gspmm(_gidx, "copy_rhs", "sum", None, dZ)[0] * X
elif self.lhs_target == 'e': elif self.lhs_target == "e":
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * X)[0] dY = _gspmm(_gidx, "copy_rhs", "sum", None, dZ * X)[0]
else: # rhs_target = !lhs_target else: # rhs_target = !lhs_target
dY = _gspmm(_gidx, 'mul', 'sum', X, dZ)[0] dY = _gspmm(_gidx, "mul", "sum", X, dZ)[0]
if op == 'div': if op == "div":
dY = -dY / (Y ** 2) dY = -dY / (Y**2)
else: else:
if op in ['add', 'sub', 'copy_rhs']: if op in ["add", "sub", "copy_rhs"]:
dY = _addsub(op, dZ) dY = _addsub(op, dZ)
else: # mul, div, dot else: # mul, div, dot
dY = _gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target) dY = _gsddmm(gidx, "mul", dZ, X, "e", lhs_target)
if op == 'div': if op == "div":
dY = -dY / (Y ** 2) dY = -dY / (Y**2)
dY = _reduce_grad(dY, Y.shape) dY = _reduce_grad(dY, Y.shape)
else: else:
dY = nd.zeros_like(Y) dY = nd.zeros_like(Y)
...@@ -264,7 +306,7 @@ class GSDDMM(mx.autograd.Function): ...@@ -264,7 +306,7 @@ class GSDDMM(mx.autograd.Function):
return dX, dY return dX, dY
def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'): def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target="u", rhs_target="v"):
func = GSDDMM(gidx, op, lhs_target, rhs_target) func = GSDDMM(gidx, op, lhs_target, rhs_target)
ctx = to_backend_ctx(gidx.ctx) ctx = to_backend_ctx(gidx.ctx)
if lhs_data is None: if lhs_data is None:
...@@ -279,7 +321,7 @@ class EdgeSoftmax(mx.autograd.Function): ...@@ -279,7 +321,7 @@ class EdgeSoftmax(mx.autograd.Function):
super(EdgeSoftmax, self).__init__() super(EdgeSoftmax, self).__init__()
if not is_all(eids): if not is_all(eids):
gidx = gidx.edge_subgraph([eids], True).graph gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == 'src': if norm_by == "src":
gidx = gidx.reverse() gidx = gidx.reverse()
self.gidx = gidx self.gidx = gidx
...@@ -298,10 +340,10 @@ class EdgeSoftmax(mx.autograd.Function): ...@@ -298,10 +340,10 @@ class EdgeSoftmax(mx.autograd.Function):
return out.data return out.data
""" """
gidx = self.gidx gidx = self.gidx
score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0] score_max = _gspmm(gidx, "copy_rhs", "max", None, score)[0]
score = mx.nd.exp(_gsddmm(gidx, 'sub', score, score_max, 'e', 'v')) score = mx.nd.exp(_gsddmm(gidx, "sub", score, score_max, "e", "v"))
score_sum = _gspmm(gidx, 'copy_rhs', 'sum', None, score)[0] score_sum = _gspmm(gidx, "copy_rhs", "sum", None, score)[0]
out = _gsddmm(gidx, 'div', score, score_sum, 'e', 'v') out = _gsddmm(gidx, "div", score, score_sum, "e", "v")
self.save_for_backward(out) self.save_for_backward(out)
return out return out
...@@ -319,16 +361,16 @@ class EdgeSoftmax(mx.autograd.Function): ...@@ -319,16 +361,16 @@ class EdgeSoftmax(mx.autograd.Function):
sds_sum = sds.dst_sum() # type dgl.NData sds_sum = sds.dst_sum() # type dgl.NData
grad_score = sds - sds * sds_sum # multiple expressions grad_score = sds - sds * sds_sum # multiple expressions
""" """
out, = self.saved_tensors (out,) = self.saved_tensors
gidx = self.gidx gidx = self.gidx
sds = out * grad_out sds = out * grad_out
accum = gspmm(gidx, 'copy_rhs', 'sum', None, sds) accum = gspmm(gidx, "copy_rhs", "sum", None, sds)
grad_score = sds - gsddmm(gidx, 'mul', out, accum, 'e', 'v') grad_score = sds - gsddmm(gidx, "mul", out, accum, "e", "v")
self.save_tensors = None self.save_tensors = None
return grad_score return grad_score
def edge_softmax(gidx, logits, eids=ALL, norm_by='dst'): def edge_softmax(gidx, logits, eids=ALL, norm_by="dst"):
softmax_op = EdgeSoftmax(gidx, eids, norm_by) softmax_op = EdgeSoftmax(gidx, eids, norm_by)
return softmax_op(logits) return softmax_op(logits)
...@@ -345,10 +387,10 @@ class SegmentReduce(mx.autograd.Function): ...@@ -345,10 +387,10 @@ class SegmentReduce(mx.autograd.Function):
return y return y
def backward(self, dy): def backward(self, dy):
arg, = self.saved_tensors (arg,) = self.saved_tensors
offsets = self.offsets offsets = self.offsets
m = offsets[-1].asscalar() m = offsets[-1].asscalar()
if self.op == 'sum': if self.op == "sum":
offsets_np = asnumpy(offsets[1:]) offsets_np = asnumpy(offsets[1:])
indices_np = np.zeros((m + 1,), dtype=offsets_np.dtype) indices_np = np.zeros((m + 1,), dtype=offsets_np.dtype)
np.add.at(indices_np, offsets_np, np.ones_like(offsets_np)) np.add.at(indices_np, offsets_np, np.ones_like(offsets_np))
...@@ -374,7 +416,7 @@ class ScatterAdd(mx.autograd.Function): ...@@ -374,7 +416,7 @@ class ScatterAdd(mx.autograd.Function):
def forward(self, x): def forward(self, x):
y = _scatter_add(x, self.idx, self.m) y = _scatter_add(x, self.idx, self.m)
return y return y
def backward(self, dy): def backward(self, dy):
return dy[self.idx] return dy[self.idx]
...@@ -392,36 +434,66 @@ class CSRMM(mx.autograd.Function): ...@@ -392,36 +434,66 @@ class CSRMM(mx.autograd.Function):
self.num_vtypes = num_vtypes self.num_vtypes = num_vtypes
def forward(self, A_weights, B_weights): def forward(self, A_weights, B_weights):
gidxC, C_weights = _csrmm(self.gidxA, A_weights, self.gidxB, B_weights, self.num_vtypes) gidxC, C_weights = _csrmm(
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(0, False, 'csr') self.gidxA, A_weights, self.gidxB, B_weights, self.num_vtypes
)
(
nrows,
ncols,
C_indptr,
C_indices,
C_eids,
) = gidxC.adjacency_matrix_tensors(0, False, "csr")
# Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same # Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same
# as the underlying tensors of the created graph gidxC. # as the underlying tensors of the created graph gidxC.
self.backward_cache = gidxC self.backward_cache = gidxC
self.save_for_backward(A_weights, B_weights) self.save_for_backward(A_weights, B_weights)
nrows = nd.array([nrows], dtype='int64') nrows = nd.array([nrows], dtype="int64")
ncols = nd.array([ncols], dtype='int64') ncols = nd.array([ncols], dtype="int64")
return nrows, ncols, C_indptr, C_indices, C_eids, C_weights return nrows, ncols, C_indptr, C_indices, C_eids, C_weights
def backward(self, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights): def backward(
self, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights
):
# Only the last argument is meaningful. # Only the last argument is meaningful.
gidxC = self.backward_cache gidxC = self.backward_cache
A_weights, B_weights = self.saved_tensors A_weights, B_weights = self.saved_tensors
dgidxA, dA_weights = _csrmm( dgidxA, dA_weights = _csrmm(
gidxC, dC_weights, self.gidxB.reverse(), B_weights, self.gidxA.number_of_ntypes()) gidxC,
dC_weights,
self.gidxB.reverse(),
B_weights,
self.gidxA.number_of_ntypes(),
)
dgidxB, dB_weights = _csrmm( dgidxB, dB_weights = _csrmm(
self.gidxA.reverse(), A_weights, gidxC, dC_weights, self.gidxB.number_of_ntypes()) self.gidxA.reverse(),
A_weights,
gidxC,
dC_weights,
self.gidxB.number_of_ntypes(),
)
dA_weights = _csrmask(dgidxA, dA_weights, self.gidxA) dA_weights = _csrmask(dgidxA, dA_weights, self.gidxA)
dB_weights = _csrmask(dgidxB, dB_weights, self.gidxB) dB_weights = _csrmask(dgidxB, dB_weights, self.gidxB)
return dA_weights, dB_weights return dA_weights, dB_weights
def csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes): def csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes):
op = CSRMM(gidxA, gidxB, num_vtypes) op = CSRMM(gidxA, gidxB, num_vtypes)
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = op(A_weights, B_weights) nrows, ncols, C_indptr, C_indices, C_eids, C_weights = op(
A_weights, B_weights
)
gidxC = create_unitgraph_from_csr( gidxC = create_unitgraph_from_csr(
num_vtypes, nrows.asscalar(), ncols.asscalar(), C_indptr, C_indices, C_eids, num_vtypes,
["coo", "csr", "csc"]) nrows.asscalar(),
ncols.asscalar(),
C_indptr,
C_indices,
C_eids,
["coo", "csr", "csc"],
)
return gidxC, C_weights return gidxC, C_weights
class CSRSum(mx.autograd.Function): class CSRSum(mx.autograd.Function):
def __init__(self, gidxs): def __init__(self, gidxs):
super().__init__() super().__init__()
...@@ -429,29 +501,44 @@ class CSRSum(mx.autograd.Function): ...@@ -429,29 +501,44 @@ class CSRSum(mx.autograd.Function):
def forward(self, *weights): def forward(self, *weights):
gidxC, C_weights = _csrsum(self.gidxs, weights) gidxC, C_weights = _csrsum(self.gidxs, weights)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors( (
0, False, 'csr') nrows,
ncols,
C_indptr,
C_indices,
C_eids,
) = gidxC.adjacency_matrix_tensors(0, False, "csr")
# Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same # Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same
# as the underlying tensors of the created graph gidxC. # as the underlying tensors of the created graph gidxC.
self.backward_cache = gidxC self.backward_cache = gidxC
nrows = nd.array([nrows], dtype='int64') nrows = nd.array([nrows], dtype="int64")
ncols = nd.array([ncols], dtype='int64') ncols = nd.array([ncols], dtype="int64")
return nrows, ncols, C_indptr, C_indices, C_eids, C_weights return nrows, ncols, C_indptr, C_indices, C_eids, C_weights
def backward(self, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights): def backward(
self, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights
):
# Only the last argument is meaningful. # Only the last argument is meaningful.
gidxC = self.backward_cache gidxC = self.backward_cache
return tuple(csrmask(gidxC, dC_weights, gidx) for gidx in self.gidxs) return tuple(csrmask(gidxC, dC_weights, gidx) for gidx in self.gidxs)
def csrsum(gidxs, weights): def csrsum(gidxs, weights):
op = CSRSum(gidxs) op = CSRSum(gidxs)
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = op(*weights) nrows, ncols, C_indptr, C_indices, C_eids, C_weights = op(*weights)
num_vtypes = gidxs[0].number_of_ntypes() num_vtypes = gidxs[0].number_of_ntypes()
gidxC = create_unitgraph_from_csr( gidxC = create_unitgraph_from_csr(
num_vtypes, nrows.asscalar(), ncols.asscalar(), C_indptr, C_indices, C_eids, num_vtypes,
["coo", "csr", "csc"]) nrows.asscalar(),
ncols.asscalar(),
C_indptr,
C_indices,
C_eids,
["coo", "csr", "csc"],
)
return gidxC, C_weights return gidxC, C_weights
class CSRMask(mx.autograd.Function): class CSRMask(mx.autograd.Function):
def __init__(self, gidxA, gidxB): def __init__(self, gidxA, gidxB):
super().__init__() super().__init__()
...@@ -464,6 +551,7 @@ class CSRMask(mx.autograd.Function): ...@@ -464,6 +551,7 @@ class CSRMask(mx.autograd.Function):
def backward(self, dB_weights): def backward(self, dB_weights):
return _csrmask(self.gidxB, dB_weights, self.gidxA) return _csrmask(self.gidxB, dB_weights, self.gidxA)
def csrmask(gidxA, A_weights, gidxB): def csrmask(gidxA, A_weights, gidxB):
op = CSRMask(gidxA, gidxB) op = CSRMask(gidxA, gidxB)
return op(A_weights) return op(A_weights)
"""Sparse optimizer is not supported for mxnet""" """Sparse optimizer is not supported for mxnet"""
\ No newline at end of file
from __future__ import absolute_import from __future__ import absolute_import
import builtins
import numbers
import os
from distutils.version import LooseVersion from distutils.version import LooseVersion
import os
import numpy as np
import mxnet as mx import mxnet as mx
import mxnet.ndarray as nd import mxnet.ndarray as nd
import numbers import numpy as np
import builtins
from ... import ndarray as dglnd from ... import ndarray as dglnd
from ..._deprecate import kernel as K from ..._deprecate import kernel as K
from ...function.base import TargetCode from ...function.base import TargetCode
...@@ -17,26 +18,31 @@ if LooseVersion(mx.__version__) < LooseVersion("1.6.0"): ...@@ -17,26 +18,31 @@ if LooseVersion(mx.__version__) < LooseVersion("1.6.0"):
# After MXNet 1.5, empty tensors aren't supprted by default. # After MXNet 1.5, empty tensors aren't supprted by default.
# After we turn on the numpy compatible flag, MXNet supports empty NDArray. # After we turn on the numpy compatible flag, MXNet supports empty NDArray.
mx.set_np_shape(bool(os.environ.get('DGL_MXNET_SET_NP_SHAPE', True))) mx.set_np_shape(bool(os.environ.get("DGL_MXNET_SET_NP_SHAPE", True)))
def data_type_dict(): def data_type_dict():
return {'float16' : np.float16, return {
'float32' : np.float32, "float16": np.float16,
'float64' : np.float64, "float32": np.float32,
'uint8' : np.uint8, "float64": np.float64,
'int8' : np.int8, "uint8": np.uint8,
'int16' : np.int16, "int8": np.int8,
'int32' : np.int32, "int16": np.int16,
'int64' : np.int64, "int32": np.int32,
'bool' : np.bool} # mxnet does not support bool "int64": np.int64,
"bool": np.bool,
} # mxnet does not support bool
def cpu(): def cpu():
return mx.cpu() return mx.cpu()
def tensor(data, dtype=None): def tensor(data, dtype=None):
if dtype == np.bool: if dtype == np.bool:
# mxnet doesn't support bool # mxnet doesn't support bool
dtype = np.int32 dtype = np.int32
if isinstance(data, nd.NDArray): if isinstance(data, nd.NDArray):
if dtype is None or data.dtype == dtype: if dtype is None or data.dtype == dtype:
return data return data
...@@ -51,9 +57,14 @@ def tensor(data, dtype=None): ...@@ -51,9 +57,14 @@ def tensor(data, dtype=None):
elif len(data) == 0: elif len(data) == 0:
dtype = np.int64 dtype = np.int64
else: else:
dtype = np.int64 if isinstance(data[0], numbers.Integral) else np.float32 dtype = (
np.int64
if isinstance(data[0], numbers.Integral)
else np.float32
)
return nd.array(data, dtype=dtype) return nd.array(data, dtype=dtype)
def as_scalar(data): def as_scalar(data):
if data.size != 1: if data.size != 1:
raise ValueError("The current array is not a scalar") raise ValueError("The current array is not a scalar")
...@@ -61,6 +72,7 @@ def as_scalar(data): ...@@ -61,6 +72,7 @@ def as_scalar(data):
data = data.expand_dims(axis=0) data = data.expand_dims(axis=0)
return data.asscalar() return data.asscalar()
def get_preferred_sparse_format(): def get_preferred_sparse_format():
"""Get the preferred sparse matrix format supported by the backend. """Get the preferred sparse matrix format supported by the backend.
...@@ -69,60 +81,79 @@ def get_preferred_sparse_format(): ...@@ -69,60 +81,79 @@ def get_preferred_sparse_format():
""" """
return "csr" return "csr"
def sparse_matrix(data, index, shape, force_format=False): def sparse_matrix(data, index, shape, force_format=False):
fmt = index[0] fmt = index[0]
if fmt == 'coo': if fmt == "coo":
if force_format: if force_format:
raise TypeError('MXNet backend only supports CSR format,' raise TypeError(
' but COO format is forced.') "MXNet backend only supports CSR format,"
" but COO format is forced."
)
coord = index[1] coord = index[1]
# generate convert idx # generate convert idx
# FIXME: cannot use int64 # FIXME: cannot use int64
tmp_data = nd.arange(len(coord[0]), dtype=data.dtype, ctx=coord[0].context) tmp_data = nd.arange(
tmp_spmat = nd.sparse.csr_matrix((tmp_data, (coord[0], coord[1])), len(coord[0]), dtype=data.dtype, ctx=coord[0].context
tuple(shape), ctx=data.context) )
convert_idx = nd.cast(tmp_spmat.data, dtype='int64') tmp_spmat = nd.sparse.csr_matrix(
(tmp_data, (coord[0], coord[1])), tuple(shape), ctx=data.context
)
convert_idx = nd.cast(tmp_spmat.data, dtype="int64")
# shuffle the data # shuffle the data
data = data[convert_idx] data = data[convert_idx]
spmat = nd.sparse.csr_matrix((data, tmp_spmat.indices, tmp_spmat.indptr), spmat = nd.sparse.csr_matrix(
tuple(shape), ctx=data.context) (data, tmp_spmat.indices, tmp_spmat.indptr),
tuple(shape),
ctx=data.context,
)
return spmat, convert_idx return spmat, convert_idx
elif fmt == 'csr': elif fmt == "csr":
indices = index[1] indices = index[1]
indptr = index[2] indptr = index[2]
spmat = nd.sparse.csr_matrix((data, indices, indptr), spmat = nd.sparse.csr_matrix(
tuple(shape), ctx=data.context) (data, indices, indptr), tuple(shape), ctx=data.context
)
# No conversion is required. # No conversion is required.
return spmat, None return spmat, None
else: else:
raise TypeError('Invalid format: %s.' % fmt) raise TypeError("Invalid format: %s." % fmt)
def sparse_matrix_indices(spmat): def sparse_matrix_indices(spmat):
return ('csr', spmat.indices, spmat.indptr) return ("csr", spmat.indices, spmat.indptr)
def is_tensor(obj): def is_tensor(obj):
return isinstance(obj, nd.NDArray) return isinstance(obj, nd.NDArray)
def shape(input): def shape(input):
# NOTE: the input cannot be a symbol # NOTE: the input cannot be a symbol
return input.shape return input.shape
def dtype(input): def dtype(input):
# NOTE: the input cannot be a symbol # NOTE: the input cannot be a symbol
return input.dtype return input.dtype
def ndim(input): def ndim(input):
return input.ndim return input.ndim
def context(input): def context(input):
return input.context return input.context
def device_type(ctx): def device_type(ctx):
return ctx.device_type return ctx.device_type
def device_id(ctx): def device_id(ctx):
return ctx.device_id return ctx.device_id
def to_backend_ctx(dglctx): def to_backend_ctx(dglctx):
dev_type = dglctx.device_type dev_type = dglctx.device_type
if dev_type == 1: if dev_type == 1:
...@@ -130,84 +161,110 @@ def to_backend_ctx(dglctx): ...@@ -130,84 +161,110 @@ def to_backend_ctx(dglctx):
elif dev_type == 2: elif dev_type == 2:
return mx.gpu(dglctx.device_id) return mx.gpu(dglctx.device_id)
else: else:
raise ValueError('Unsupported DGL device context:', dglctx) raise ValueError("Unsupported DGL device context:", dglctx)
def astype(input, ty): def astype(input, ty):
if ty == np.bool: if ty == np.bool:
ty = np.int32 ty = np.int32
return input.astype(ty) return input.astype(ty)
def asnumpy(input): def asnumpy(input):
return input.asnumpy() return input.asnumpy()
def copy_to(input, ctx, **kwargs): def copy_to(input, ctx, **kwargs):
return input.as_in_context(ctx) return input.as_in_context(ctx)
def is_pinned(input): def is_pinned(input):
return input.context == mx.cpu_pinned() return input.context == mx.cpu_pinned()
def sum(input, dim, keepdims=False): def sum(input, dim, keepdims=False):
if len(input) == 0: if len(input) == 0:
return nd.array([0.], dtype=input.dtype, ctx=input.context) return nd.array([0.0], dtype=input.dtype, ctx=input.context)
return nd.sum(input, axis=dim, keepdims=keepdims) return nd.sum(input, axis=dim, keepdims=keepdims)
def floor_div(in1, in2): def floor_div(in1, in2):
return in1 / in2 return in1 / in2
def reduce_sum(input): def reduce_sum(input):
return input.sum() return input.sum()
def cumsum(input, dim): def cumsum(input, dim):
return nd.cumsum(input, axis=dim) return nd.cumsum(input, axis=dim)
def mean(input, dim): def mean(input, dim):
return nd.mean(input, axis=dim) return nd.mean(input, axis=dim)
def reduce_mean(input): def reduce_mean(input):
return input.mean() return input.mean()
def max(input, dim): def max(input, dim):
return nd.max(input, axis=dim) return nd.max(input, axis=dim)
def reduce_max(input): def reduce_max(input):
return input.max() return input.max()
def min(input, dim): def min(input, dim):
return nd.min(input, axis=dim) return nd.min(input, axis=dim)
def reduce_min(input): def reduce_min(input):
return input.min() return input.min()
def topk(input, k, dim, descending=True): def topk(input, k, dim, descending=True):
return nd.topk(input, axis=dim, k=k, ret_typ='value', is_ascend=not descending) return nd.topk(
input, axis=dim, k=k, ret_typ="value", is_ascend=not descending
)
def argtopk(input, k, dim, descending=True): def argtopk(input, k, dim, descending=True):
idx = nd.argsort(input, dim, is_ascend=not descending) idx = nd.argsort(input, dim, is_ascend=not descending)
return nd.slice_axis(input, dim, 0, k) return nd.slice_axis(input, dim, 0, k)
def argsort(input, dim, descending): def argsort(input, dim, descending):
idx = nd.argsort(input, dim, is_ascend=not descending) idx = nd.argsort(input, dim, is_ascend=not descending)
idx = nd.cast(idx, dtype='int64') idx = nd.cast(idx, dtype="int64")
return idx return idx
def exp(input): def exp(input):
return nd.exp(input) return nd.exp(input)
def inverse(input): def inverse(input):
return nd.linalg_inverse(input) return nd.linalg_inverse(input)
def sqrt(input): def sqrt(input):
return nd.sqrt(input) return nd.sqrt(input)
def softmax(input, dim=-1): def softmax(input, dim=-1):
return nd.softmax(input, axis=dim) return nd.softmax(input, axis=dim)
def cat(seq, dim): def cat(seq, dim):
return nd.concat(*seq, dim=dim) return nd.concat(*seq, dim=dim)
def stack(seq, dim): def stack(seq, dim):
return nd.stack(*seq, axis=dim) return nd.stack(*seq, axis=dim)
def split(x, sizes_or_sections, dim): def split(x, sizes_or_sections, dim):
if isinstance(sizes_or_sections, list) and len(sizes_or_sections) == 1: if isinstance(sizes_or_sections, list) and len(sizes_or_sections) == 1:
assert len(x) == sizes_or_sections[0] assert len(x) == sizes_or_sections[0]
...@@ -217,13 +274,18 @@ def split(x, sizes_or_sections, dim): ...@@ -217,13 +274,18 @@ def split(x, sizes_or_sections, dim):
sizes_or_sections1 = tuple(np.cumsum(sizes_or_sections)[:-1]) sizes_or_sections1 = tuple(np.cumsum(sizes_or_sections)[:-1])
return nd.split_v2(x, sizes_or_sections1, axis=dim) return nd.split_v2(x, sizes_or_sections1, axis=dim)
def repeat(input, repeats, dim): def repeat(input, repeats, dim):
if isinstance(repeats, nd.NDArray): if isinstance(repeats, nd.NDArray):
return nd.array(np.repeat(input.asnumpy(), repeats.asnumpy(), axis=dim), return nd.array(
ctx=input.context, dtype=input.dtype) np.repeat(input.asnumpy(), repeats.asnumpy(), axis=dim),
ctx=input.context,
dtype=input.dtype,
)
else: else:
return nd.repeat(input, repeats, axis=dim) return nd.repeat(input, repeats, axis=dim)
def gather_row(data, row_index): def gather_row(data, row_index):
# MXNet workaround for empty row index # MXNet workaround for empty row index
if len(row_index) == 0: if len(row_index) == 0:
...@@ -235,7 +297,10 @@ def gather_row(data, row_index): ...@@ -235,7 +297,10 @@ def gather_row(data, row_index):
if isinstance(row_index, nd.NDArray): if isinstance(row_index, nd.NDArray):
return nd.take(data, row_index) return nd.take(data, row_index)
else: else:
return data[row_index,] return data[
row_index,
]
def slice_axis(data, axis, begin, end): def slice_axis(data, axis, begin, end):
dim = data.shape[axis] dim = data.shape[axis]
...@@ -245,49 +310,64 @@ def slice_axis(data, axis, begin, end): ...@@ -245,49 +310,64 @@ def slice_axis(data, axis, begin, end):
end += dim end += dim
return nd.slice_axis(data, axis, begin, end) return nd.slice_axis(data, axis, begin, end)
def take(data, indices, dim): def take(data, indices, dim):
return nd.take(data, indices, dim) return nd.take(data, indices, dim)
def narrow_row(data, start, stop): def narrow_row(data, start, stop):
return data[start:stop] return data[start:stop]
def index_add_inplace(data, row_idx, value): def index_add_inplace(data, row_idx, value):
raise NotImplementedError("MXNet doesn't support inplace index_add") raise NotImplementedError("MXNet doesn't support inplace index_add")
def scatter_row(data, row_index, value): def scatter_row(data, row_index, value):
return mx.nd.contrib.index_copy(data, row_index, value) return mx.nd.contrib.index_copy(data, row_index, value)
def scatter_row_inplace(data, row_index, value): def scatter_row_inplace(data, row_index, value):
data[row_index] = value data[row_index] = value
def squeeze(input, dim): def squeeze(input, dim):
return nd.squeeze(input, axis=dim) return nd.squeeze(input, axis=dim)
def unsqueeze(input, dim): def unsqueeze(input, dim):
return nd.expand_dims(input, axis=dim) return nd.expand_dims(input, axis=dim)
def reshape(input, shape): def reshape(input, shape):
# NOTE: the input cannot be a symbol # NOTE: the input cannot be a symbol
return nd.reshape(input ,shape) return nd.reshape(input, shape)
def swapaxes(input, axis1, axis2): def swapaxes(input, axis1, axis2):
return nd.swapaxes(input, axis1, axis2) return nd.swapaxes(input, axis1, axis2)
def zeros(shape, dtype, ctx): def zeros(shape, dtype, ctx):
return nd.zeros(shape, dtype=dtype, ctx=ctx) return nd.zeros(shape, dtype=dtype, ctx=ctx)
def zeros_like(input): def zeros_like(input):
return nd.zeros_like(input) return nd.zeros_like(input)
def ones(shape, dtype, ctx): def ones(shape, dtype, ctx):
return nd.ones(shape, dtype=dtype, ctx=ctx) return nd.ones(shape, dtype=dtype, ctx=ctx)
def uniform(shape, dtype, ctx, low, high): def uniform(shape, dtype, ctx, low, high):
return nd.random.uniform(low, high, ctx=ctx, dtype=dtype, shape=shape) return nd.random.uniform(low, high, ctx=ctx, dtype=dtype, shape=shape)
def randint(shape, dtype, ctx, low, high): def randint(shape, dtype, ctx, low, high):
return nd.random.randint(low, high, ctx=ctx, dtype=dtype, shape=shape) return nd.random.randint(low, high, ctx=ctx, dtype=dtype, shape=shape)
def pad_packed_tensor(input, lengths, value, l_min=None): def pad_packed_tensor(input, lengths, value, l_min=None):
old_shape = input.shape old_shape = input.shape
if isinstance(lengths, nd.NDArray): if isinstance(lengths, nd.NDArray):
...@@ -300,12 +380,17 @@ def pad_packed_tensor(input, lengths, value, l_min=None): ...@@ -300,12 +380,17 @@ def pad_packed_tensor(input, lengths, value, l_min=None):
batch_size = len(lengths) batch_size = len(lengths)
ctx = input.context ctx = input.context
dtype = input.dtype dtype = input.dtype
x = nd.full((batch_size * max_len, *old_shape[1:]), value, ctx=ctx, dtype=dtype) x = nd.full(
(batch_size * max_len, *old_shape[1:]), value, ctx=ctx, dtype=dtype
)
index = [] index = []
for i, l in enumerate(lengths): for i, l in enumerate(lengths):
index.extend(range(i * max_len, i * max_len + l)) index.extend(range(i * max_len, i * max_len + l))
index = nd.array(index, ctx=ctx) index = nd.array(index, ctx=ctx)
return scatter_row(x, index, input).reshape(batch_size, max_len, *old_shape[1:]) return scatter_row(x, index, input).reshape(
batch_size, max_len, *old_shape[1:]
)
def pack_padded_tensor(input, lengths): def pack_padded_tensor(input, lengths):
batch_size, max_len = input.shape[:2] batch_size, max_len = input.shape[:2]
...@@ -316,46 +401,60 @@ def pack_padded_tensor(input, lengths): ...@@ -316,46 +401,60 @@ def pack_padded_tensor(input, lengths):
index = nd.array(index, ctx=ctx) index = nd.array(index, ctx=ctx)
return gather_row(input.reshape(batch_size * max_len, -1), index) return gather_row(input.reshape(batch_size * max_len, -1), index)
def boolean_mask(input, mask): def boolean_mask(input, mask):
return mx.contrib.nd.boolean_mask(input, mask) return mx.contrib.nd.boolean_mask(input, mask)
def equal(x, y): def equal(x, y):
return x == y return x == y
def allclose(x, y, rtol=1e-4, atol=1e-4): def allclose(x, y, rtol=1e-4, atol=1e-4):
return np.allclose(x.asnumpy(), y.asnumpy(), rtol=rtol, atol=atol) return np.allclose(x.asnumpy(), y.asnumpy(), rtol=rtol, atol=atol)
def logical_not(input): def logical_not(input):
return nd.logical_not(input) return nd.logical_not(input)
def logical_and(input1, input2): def logical_and(input1, input2):
return nd.logical_and(input1, input2) return nd.logical_and(input1, input2)
def clone(input): def clone(input):
return input.copy() return input.copy()
def clamp(data, min_val, max_val): def clamp(data, min_val, max_val):
return nd.clip(data, min_val, max_val) return nd.clip(data, min_val, max_val)
def replace_inf_with_zero(x): def replace_inf_with_zero(x):
return nd.where(nd.abs(x) == np.inf, nd.zeros_like(x), x) return nd.where(nd.abs(x) == np.inf, nd.zeros_like(x), x)
def count_nonzero(input): def count_nonzero(input):
# TODO: fallback to numpy is unfortunate # TODO: fallback to numpy is unfortunate
tmp = input.asnumpy() tmp = input.asnumpy()
return np.count_nonzero(tmp) return np.count_nonzero(tmp)
def unique(input, return_inverse=False, return_counts=False): def unique(input, return_inverse=False, return_counts=False):
# TODO: fallback to numpy is unfortunate # TODO: fallback to numpy is unfortunate
tmp = input.asnumpy() tmp = input.asnumpy()
if return_inverse and return_counts: if return_inverse and return_counts:
tmp, inv, count = np.unique(tmp, return_inverse=True, return_counts=True) tmp, inv, count = np.unique(
tmp, return_inverse=True, return_counts=True
)
tmp = nd.array(tmp, ctx=input.context, dtype=input.dtype) tmp = nd.array(tmp, ctx=input.context, dtype=input.dtype)
inv = nd.array(inv, ctx=input.context) inv = nd.array(inv, ctx=input.context)
count = nd.array(count, ctx=input.context) count = nd.array(count, ctx=input.context)
return tmp, inv, count return tmp, inv, count
elif return_inverse or return_counts: elif return_inverse or return_counts:
tmp, tmp2 = np.unique(tmp, return_inverse=return_inverse, return_counts=return_counts) tmp, tmp2 = np.unique(
tmp, return_inverse=return_inverse, return_counts=return_counts
)
tmp = nd.array(tmp, ctx=input.context, dtype=input.dtype) tmp = nd.array(tmp, ctx=input.context, dtype=input.dtype)
tmp2 = nd.array(tmp2, ctx=input.context) tmp2 = nd.array(tmp2, ctx=input.context)
return tmp, tmp2 return tmp, tmp2
...@@ -363,9 +462,11 @@ def unique(input, return_inverse=False, return_counts=False): ...@@ -363,9 +462,11 @@ def unique(input, return_inverse=False, return_counts=False):
tmp = np.unique(tmp) tmp = np.unique(tmp)
return nd.array(tmp, ctx=input.context, dtype=input.dtype) return nd.array(tmp, ctx=input.context, dtype=input.dtype)
def full_1d(length, fill_value, dtype, ctx): def full_1d(length, fill_value, dtype, ctx):
return nd.full((length,), fill_value, dtype=dtype, ctx=ctx) return nd.full((length,), fill_value, dtype=dtype, ctx=ctx)
def nonzero_1d(input): def nonzero_1d(input):
# TODO: fallback to numpy is unfortunate # TODO: fallback to numpy is unfortunate
tmp = input.asnumpy() tmp = input.asnumpy()
...@@ -373,50 +474,70 @@ def nonzero_1d(input): ...@@ -373,50 +474,70 @@ def nonzero_1d(input):
r = nd.array(tmp, ctx=input.context, dtype=tmp.dtype) r = nd.array(tmp, ctx=input.context, dtype=tmp.dtype)
return r return r
def sort_1d(input): def sort_1d(input):
# TODO: this isn't an ideal implementation. # TODO: this isn't an ideal implementation.
val = nd.sort(input, axis=None, is_ascend=True) val = nd.sort(input, axis=None, is_ascend=True)
idx = nd.argsort(input, is_ascend=True) idx = nd.argsort(input, is_ascend=True)
idx = nd.cast(idx, dtype='int64') idx = nd.cast(idx, dtype="int64")
return val, idx return val, idx
def arange(start, stop, dtype=np.int64, ctx=None): def arange(start, stop, dtype=np.int64, ctx=None):
if start >= stop: if start >= stop:
return nd.array([], dtype=dtype, ctx=ctx) return nd.array([], dtype=dtype, ctx=ctx)
else: else:
return nd.arange(start, stop, dtype=dtype, ctx=ctx) return nd.arange(start, stop, dtype=dtype, ctx=ctx)
def rand_shuffle(arr): def rand_shuffle(arr):
return mx.nd.random.shuffle(arr) return mx.nd.random.shuffle(arr)
def zerocopy_to_dlpack(arr): def zerocopy_to_dlpack(arr):
return arr.to_dlpack_for_read() return arr.to_dlpack_for_read()
def zerocopy_from_dlpack(dlpack_arr): def zerocopy_from_dlpack(dlpack_arr):
return nd.from_dlpack(dlpack_arr) return nd.from_dlpack(dlpack_arr)
def zerocopy_to_numpy(arr): def zerocopy_to_numpy(arr):
# NOTE: not zerocopy # NOTE: not zerocopy
return arr.asnumpy() return arr.asnumpy()
def zerocopy_from_numpy(np_data): def zerocopy_from_numpy(np_data):
np_data = np.asarray(np_data, order='C') np_data = np.asarray(np_data, order="C")
return mx.nd.from_numpy(np_data, zero_copy=True) return mx.nd.from_numpy(np_data, zero_copy=True)
def zerocopy_to_dgl_ndarray(arr): def zerocopy_to_dgl_ndarray(arr):
arr.to_dlpack_for_read() arr.to_dlpack_for_read()
return dglnd.from_dlpack(arr.to_dlpack_for_read()) return dglnd.from_dlpack(arr.to_dlpack_for_read())
def zerocopy_to_dgl_ndarray_for_write(arr): def zerocopy_to_dgl_ndarray_for_write(arr):
return dglnd.from_dlpack(arr.to_dlpack_for_write()) return dglnd.from_dlpack(arr.to_dlpack_for_write())
def zerocopy_from_dgl_ndarray(arr): def zerocopy_from_dgl_ndarray(arr):
return nd.from_dlpack(arr.to_dlpack()) return nd.from_dlpack(arr.to_dlpack())
class BinaryReduce(mx.autograd.Function): class BinaryReduce(mx.autograd.Function):
def __init__(self, reducer, binary_op, graph, lhs, rhs, out_size, lhs_map, def __init__(
rhs_map, out_map): self,
reducer,
binary_op,
graph,
lhs,
rhs,
out_size,
lhs_map,
rhs_map,
out_map,
):
super(BinaryReduce, self).__init__() super(BinaryReduce, self).__init__()
self.reducer = reducer self.reducer = reducer
self.binary_op = binary_op self.binary_op = binary_op
...@@ -431,23 +552,37 @@ class BinaryReduce(mx.autograd.Function): ...@@ -431,23 +552,37 @@ class BinaryReduce(mx.autograd.Function):
def forward(self, lhs_data, rhs_data): def forward(self, lhs_data, rhs_data):
lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data) lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data)
rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data) rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data)
feat_shape = K.infer_binary_feature_shape(self.binary_op, lhs_data_nd, rhs_data_nd) feat_shape = K.infer_binary_feature_shape(
self.binary_op, lhs_data_nd, rhs_data_nd
)
out_shape = feat_shape out_shape = feat_shape
if self.binary_op == 'dot': if self.binary_op == "dot":
out_shape = feat_shape[:-1] out_shape = feat_shape[:-1]
out_data = nd.empty((self.out_size,) + out_shape, out_data = nd.empty(
ctx=lhs_data.context, dtype=lhs_data.dtype) (self.out_size,) + out_shape,
ctx=lhs_data.context,
dtype=lhs_data.dtype,
)
out_data_nd = zerocopy_to_dgl_ndarray_for_write(out_data) out_data_nd = zerocopy_to_dgl_ndarray_for_write(out_data)
K.binary_op_reduce( K.binary_op_reduce(
self.reducer if self.reducer != 'mean' else 'sum', self.reducer if self.reducer != "mean" else "sum",
self.binary_op, self.graph, self.lhs, self.rhs, self.binary_op,
lhs_data_nd, rhs_data_nd, out_data_nd, self.lhs_map[0], self.graph,
self.rhs_map[0], self.out_map[0]) self.lhs,
self.rhs,
lhs_data_nd,
rhs_data_nd,
out_data_nd,
self.lhs_map[0],
self.rhs_map[0],
self.out_map[0],
)
# normalize if mean reducer # normalize if mean reducer
# NOTE(zihao): this is a temporary hack and we should have better solution in the future. # NOTE(zihao): this is a temporary hack and we should have better solution in the future.
if self.reducer == 'mean': if self.reducer == "mean":
degs = nd.empty((out_data.shape[0],), degs = nd.empty(
ctx=out_data.context, dtype=out_data.dtype) (out_data.shape[0],), ctx=out_data.context, dtype=out_data.dtype
)
degs_nd = zerocopy_to_dgl_ndarray(degs) degs_nd = zerocopy_to_dgl_ndarray(degs)
if self.lhs != TargetCode.DST: if self.lhs != TargetCode.DST:
target = self.lhs target = self.lhs
...@@ -460,49 +595,100 @@ class BinaryReduce(mx.autograd.Function): ...@@ -460,49 +595,100 @@ class BinaryReduce(mx.autograd.Function):
in_ones = nd.ones((n,), ctx=lhs_data.context, dtype=lhs_data.dtype) in_ones = nd.ones((n,), ctx=lhs_data.context, dtype=lhs_data.dtype)
in_ones_nd = zerocopy_to_dgl_ndarray(in_ones) in_ones_nd = zerocopy_to_dgl_ndarray(in_ones)
K.copy_reduce( K.copy_reduce(
'sum', self.graph, target, in_ones_nd, degs_nd, "sum",
in_map, self.out_map[0]) self.graph,
target,
in_ones_nd,
degs_nd,
in_map,
self.out_map[0],
)
# reshape # reshape
degs = degs.reshape((out_data.shape[0],) + (1,) * (out_data.ndim - 1)).clip(1, float('inf')) degs = degs.reshape(
(out_data.shape[0],) + (1,) * (out_data.ndim - 1)
).clip(1, float("inf"))
out_data = out_data / degs out_data = out_data / degs
else: else:
degs = None degs = None
self.save_for_backward(lhs_data_nd, rhs_data_nd, out_data_nd, self.save_for_backward(
feat_shape, degs) lhs_data_nd, rhs_data_nd, out_data_nd, feat_shape, degs
)
return out_data return out_data
def backward(self, grad_out): def backward(self, grad_out):
lhs_data_nd, rhs_data_nd, out_data_nd, feat_shape, degs = self.saved_tensors (
if self.reducer == 'mean': lhs_data_nd,
rhs_data_nd,
out_data_nd,
feat_shape,
degs,
) = self.saved_tensors
if self.reducer == "mean":
grad_out = grad_out / degs grad_out = grad_out / degs
grad_out_nd = zerocopy_to_dgl_ndarray(grad_out) grad_out_nd = zerocopy_to_dgl_ndarray(grad_out)
grad_lhs = nd.empty((lhs_data_nd.shape[0],) + feat_shape, grad_lhs = nd.empty(
ctx=grad_out.context, dtype=grad_out.dtype) (lhs_data_nd.shape[0],) + feat_shape,
ctx=grad_out.context,
dtype=grad_out.dtype,
)
K.backward_lhs_binary_op_reduce( K.backward_lhs_binary_op_reduce(
self.reducer if self.reducer != 'mean' else 'sum', self.reducer if self.reducer != "mean" else "sum",
self.binary_op, self.graph, self.lhs, self.rhs, self.binary_op,
lhs_data_nd, rhs_data_nd, out_data_nd, grad_out_nd, self.graph,
zerocopy_to_dgl_ndarray_for_write(grad_lhs), self.lhs_map[1], self.lhs,
self.rhs_map[1], self.out_map[1]) self.rhs,
lhs_data_nd,
rhs_data_nd,
out_data_nd,
grad_out_nd,
zerocopy_to_dgl_ndarray_for_write(grad_lhs),
self.lhs_map[1],
self.rhs_map[1],
self.out_map[1],
)
grad_lhs = _reduce_grad(grad_lhs, lhs_data_nd.shape) grad_lhs = _reduce_grad(grad_lhs, lhs_data_nd.shape)
grad_rhs = nd.empty((rhs_data_nd.shape[0],) + feat_shape, grad_rhs = nd.empty(
ctx=grad_out.context, dtype=grad_out.dtype) (rhs_data_nd.shape[0],) + feat_shape,
ctx=grad_out.context,
dtype=grad_out.dtype,
)
K.backward_rhs_binary_op_reduce( K.backward_rhs_binary_op_reduce(
self.reducer if self.reducer != 'mean' else 'sum', self.reducer if self.reducer != "mean" else "sum",
self.binary_op, self.graph, self.lhs, self.rhs, self.binary_op,
lhs_data_nd, rhs_data_nd, out_data_nd, grad_out_nd, self.graph,
zerocopy_to_dgl_ndarray_for_write(grad_rhs), self.lhs_map[1], self.lhs,
self.rhs_map[1], self.out_map[1]) self.rhs,
lhs_data_nd,
rhs_data_nd,
out_data_nd,
grad_out_nd,
zerocopy_to_dgl_ndarray_for_write(grad_rhs),
self.lhs_map[1],
self.rhs_map[1],
self.out_map[1],
)
grad_rhs = _reduce_grad(grad_rhs, rhs_data_nd.shape) grad_rhs = _reduce_grad(grad_rhs, rhs_data_nd.shape)
# clear saved tensors explicitly # clear saved tensors explicitly
self.saved_tensors = None self.saved_tensors = None
return grad_lhs, grad_rhs return grad_lhs, grad_rhs
def binary_reduce(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data, def binary_reduce(
out_size, lhs_map=(None, None), rhs_map=(None, None), out_map=(None, None)): reducer,
func = BinaryReduce(reducer, binary_op, graph, lhs, rhs, out_size, lhs_map, binary_op,
rhs_map, out_map) graph,
lhs,
rhs,
lhs_data,
rhs_data,
out_size,
lhs_map=(None, None),
rhs_map=(None, None),
out_map=(None, None),
):
func = BinaryReduce(
reducer, binary_op, graph, lhs, rhs, out_size, lhs_map, rhs_map, out_map
)
return func(lhs_data, rhs_data) return func(lhs_data, rhs_data)
...@@ -518,28 +704,46 @@ class CopyReduce(mx.autograd.Function): ...@@ -518,28 +704,46 @@ class CopyReduce(mx.autograd.Function):
def forward(self, in_data): def forward(self, in_data):
feat_shape = in_data.shape[1:] feat_shape = in_data.shape[1:]
out_data = nd.empty((self.out_size,) + feat_shape, out_data = nd.empty(
ctx=in_data.context, dtype=in_data.dtype) (self.out_size,) + feat_shape,
ctx=in_data.context,
dtype=in_data.dtype,
)
in_data_nd = zerocopy_to_dgl_ndarray(in_data) in_data_nd = zerocopy_to_dgl_ndarray(in_data)
out_data_nd = zerocopy_to_dgl_ndarray_for_write(out_data) out_data_nd = zerocopy_to_dgl_ndarray_for_write(out_data)
K.copy_reduce( K.copy_reduce(
self.reducer if self.reducer != 'mean' else 'sum', self.reducer if self.reducer != "mean" else "sum",
self.graph, self.target, in_data_nd, out_data_nd, self.graph,
self.in_map[0], self.out_map[0]) self.target,
in_data_nd,
out_data_nd,
self.in_map[0],
self.out_map[0],
)
# normalize if mean reducer # normalize if mean reducer
# NOTE(zihao): this is a temporary hack and we should have better solution in the future. # NOTE(zihao): this is a temporary hack and we should have better solution in the future.
if self.reducer == 'mean': if self.reducer == "mean":
in_ones = nd.ones((in_data.shape[0],), in_ones = nd.ones(
ctx=in_data.context, dtype=in_data.dtype) (in_data.shape[0],), ctx=in_data.context, dtype=in_data.dtype
degs = nd.empty((out_data.shape[0],), )
ctx=out_data.context, dtype=out_data.dtype) degs = nd.empty(
(out_data.shape[0],), ctx=out_data.context, dtype=out_data.dtype
)
in_ones_nd = zerocopy_to_dgl_ndarray(in_ones) in_ones_nd = zerocopy_to_dgl_ndarray(in_ones)
degs_nd = zerocopy_to_dgl_ndarray(degs) degs_nd = zerocopy_to_dgl_ndarray(degs)
K.copy_reduce( K.copy_reduce(
'sum', self.graph, self.target, in_ones_nd, degs_nd, "sum",
self.in_map[0], self.out_map[0]) self.graph,
self.target,
in_ones_nd,
degs_nd,
self.in_map[0],
self.out_map[0],
)
# reshape # reshape
degs = degs.reshape((out_data.shape[0],) + (1,) * (out_data.ndim - 1)).clip(1, float('inf')) degs = degs.reshape(
(out_data.shape[0],) + (1,) * (out_data.ndim - 1)
).clip(1, float("inf"))
out_data = out_data / degs out_data = out_data / degs
else: else:
degs = None degs = None
...@@ -548,23 +752,37 @@ class CopyReduce(mx.autograd.Function): ...@@ -548,23 +752,37 @@ class CopyReduce(mx.autograd.Function):
def backward(self, grad_out): def backward(self, grad_out):
in_data_nd, out_data_nd, degs = self.saved_tensors in_data_nd, out_data_nd, degs = self.saved_tensors
grad_in = nd.empty(in_data_nd.shape, ctx=grad_out.context, grad_in = nd.empty(
dtype=grad_out.dtype) in_data_nd.shape, ctx=grad_out.context, dtype=grad_out.dtype
if self.reducer == 'mean': )
if self.reducer == "mean":
grad_out = grad_out / degs grad_out = grad_out / degs
grad_out_nd = zerocopy_to_dgl_ndarray(grad_out) grad_out_nd = zerocopy_to_dgl_ndarray(grad_out)
K.backward_copy_reduce( K.backward_copy_reduce(
self.reducer if self.reducer != 'mean' else 'sum', self.reducer if self.reducer != "mean" else "sum",
self.graph, self.target, in_data_nd, out_data_nd, self.graph,
grad_out_nd, zerocopy_to_dgl_ndarray_for_write(grad_in), self.target,
self.in_map[1], self.out_map[1]) in_data_nd,
out_data_nd,
grad_out_nd,
zerocopy_to_dgl_ndarray_for_write(grad_in),
self.in_map[1],
self.out_map[1],
)
# clear saved tensors explicitly # clear saved tensors explicitly
self.saved_tensors = None self.saved_tensors = None
return grad_in return grad_in
def copy_reduce(reducer, graph, target, in_data, out_size, in_map=(None, None), def copy_reduce(
out_map=(None, None)): reducer,
graph,
target,
in_data,
out_size,
in_map=(None, None),
out_map=(None, None),
):
func = CopyReduce(reducer, graph, target, out_size, in_map, out_map) func = CopyReduce(reducer, graph, target, out_size, in_map, out_map)
return func(in_data) return func(in_data)
...@@ -600,6 +818,7 @@ def _reduce_grad(grad, shape): ...@@ -600,6 +818,7 @@ def _reduce_grad(grad, shape):
grad = grad.sum(axis=tuple(reduce_idx), keepdims=True) grad = grad.sum(axis=tuple(reduce_idx), keepdims=True)
return grad.reshape(shape) return grad.reshape(shape)
def sync(): def sync():
"""Synchronize computation. """Synchronize computation.
...@@ -609,24 +828,31 @@ def sync(): ...@@ -609,24 +828,31 @@ def sync():
""" """
mx.nd.waitall() mx.nd.waitall()
def attach_grad(tensor): def attach_grad(tensor):
tensor.attach_grad() tensor.attach_grad()
return tensor return tensor
def backward(x, head_gradient=None): def backward(x, head_gradient=None):
x.backward(head_gradient) x.backward(head_gradient)
def grad(x): def grad(x):
return x.grad return x.grad
def is_no_grad(x): def is_no_grad(x):
return (x != 0).sum() == 0 return (x != 0).sum() == 0
def is_recording(): def is_recording():
return mx.autograd.is_recording() return mx.autograd.is_recording()
record_grad = mx.autograd.record record_grad = mx.autograd.record
class no_grad(object): class no_grad(object):
def __init__(self): def __init__(self):
pass pass
......
from .tensor import *
from .sparse import * from .sparse import *
from .tensor import *
import torch as th import torch as th
from torch.cuda.amp import custom_fwd, custom_bwd from torch.cuda.amp import custom_bwd, custom_fwd
from ...base import is_all, ALL
from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp
from ...sparse import _csrmm, _csrsum, _csrmask, _scatter_add, _update_grad_minmax_hetero
from ...sparse import _gather_mm, _gather_mm_scatter, _segment_mm, _segment_mm_backward_B
from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp, _edge_softmax_forward, _edge_softmax_backward
from ...sparse import _csrmm, _csrsum, _csrmask, _scatter_add, _update_grad_minmax_hetero
from ...heterograph_index import create_unitgraph_from_csr
__all__ = ['gspmm', 'gsddmm', 'gspmm_hetero', 'gsddmm_hetero', 'edge_softmax', 'edge_softmax_hetero', from ...base import ALL, is_all
'segment_reduce', 'scatter_add', 'csrmm', 'csrsum', 'csrmask', 'gather_mm', 'segment_mm'] from ...heterograph_index import create_unitgraph_from_csr
from ...sparse import (
_bwd_segment_cmp,
_csrmask,
_csrmm,
_csrsum,
_edge_softmax_backward,
_edge_softmax_forward,
_gather_mm,
_gather_mm_scatter,
_gsddmm,
_gsddmm_hetero,
_gspmm,
_gspmm_hetero,
_scatter_add,
_segment_mm,
_segment_mm_backward_B,
_segment_reduce,
_update_grad_minmax_hetero,
)
__all__ = [
"gspmm",
"gsddmm",
"gspmm_hetero",
"gsddmm_hetero",
"edge_softmax",
"edge_softmax_hetero",
"segment_reduce",
"scatter_add",
"csrmm",
"csrsum",
"csrmask",
"gather_mm",
"segment_mm",
]
def _reduce_grad(grad, shape): def _reduce_grad(grad, shape):
...@@ -38,7 +65,9 @@ def _reduce_grad(grad, shape): ...@@ -38,7 +65,9 @@ def _reduce_grad(grad, shape):
num_to_squeeze = len(grad_shape) - len(in_shape) num_to_squeeze = len(grad_shape) - len(in_shape)
# pad inshape # pad inshape
in_shape = (1,) * num_to_squeeze + in_shape in_shape = (1,) * num_to_squeeze + in_shape
reduce_idx = th.nonzero(th.tensor(grad_shape) - th.tensor(in_shape), as_tuple=False) reduce_idx = th.nonzero(
th.tensor(grad_shape) - th.tensor(in_shape), as_tuple=False
)
reduce_idx += 1 # skip batch dim reduce_idx += 1 # skip batch dim
if len(reduce_idx) > 0: if len(reduce_idx) > 0:
grad = grad.sum(dim=tuple(reduce_idx), keepdim=True) grad = grad.sum(dim=tuple(reduce_idx), keepdim=True)
...@@ -62,23 +91,23 @@ def _expand(x, shape): ...@@ -62,23 +91,23 @@ def _expand(x, shape):
def spmm_cache_X(binary_op, reduce_op, req_grad_X, req_grad_Y): def spmm_cache_X(binary_op, reduce_op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache X in SpMM forward stage.""" """Rules to identify whether to cache X in SpMM forward stage."""
if binary_op != 'copy_lhs' and req_grad_Y: if binary_op != "copy_lhs" and req_grad_Y:
if reduce_op == 'sum': if reduce_op == "sum":
return True return True
else: else:
if binary_op == 'mul': if binary_op == "mul":
return True return True
return False return False
def spmm_cache_Y(binary_op, reduce_op, req_grad_X, req_grad_Y): def spmm_cache_Y(binary_op, reduce_op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache Y in SpMM forward stage.""" """Rules to identify whether to cache Y in SpMM forward stage."""
if binary_op != 'copy_rhs' and req_grad_X: if binary_op != "copy_rhs" and req_grad_X:
if reduce_op == 'sum': if reduce_op == "sum":
if binary_op in ['mul', 'add']: if binary_op in ["mul", "add"]:
return True return True
else: else:
if binary_op == 'mul': if binary_op == "mul":
return True return True
return False return False
...@@ -86,7 +115,7 @@ def spmm_cache_Y(binary_op, reduce_op, req_grad_X, req_grad_Y): ...@@ -86,7 +115,7 @@ def spmm_cache_Y(binary_op, reduce_op, req_grad_X, req_grad_Y):
def spmm_cache_argX(binary_op, reduce_op, req_grad_X, req_grad_Y): def spmm_cache_argX(binary_op, reduce_op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache argX in SpMM forward stage.""" """Rules to identify whether to cache argX in SpMM forward stage."""
if req_grad_X or req_grad_Y: if req_grad_X or req_grad_Y:
if reduce_op in ['min', 'max']: if reduce_op in ["min", "max"]:
return True return True
return False return False
...@@ -94,7 +123,7 @@ def spmm_cache_argX(binary_op, reduce_op, req_grad_X, req_grad_Y): ...@@ -94,7 +123,7 @@ def spmm_cache_argX(binary_op, reduce_op, req_grad_X, req_grad_Y):
def spmm_cache_argY(binary_op, reduce_op, req_grad_X, req_grad_Y): def spmm_cache_argY(binary_op, reduce_op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache argY in SpMM forward stage.""" """Rules to identify whether to cache argY in SpMM forward stage."""
if req_grad_X or req_grad_Y: if req_grad_X or req_grad_Y:
if reduce_op in ['min', 'max']: if reduce_op in ["min", "max"]:
return True return True
return False return False
...@@ -109,7 +138,16 @@ class GSpMM(th.autograd.Function): ...@@ -109,7 +138,16 @@ class GSpMM(th.autograd.Function):
Y_shape = Y.shape if Y is not None else None Y_shape = Y.shape if Y is not None else None
dtype = X.dtype if X is not None else Y.dtype dtype = X.dtype if X is not None else Y.dtype
device = X.device if X is not None else Y.device device = X.device if X is not None else Y.device
ctx.backward_cache = gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last ctx.backward_cache = (
gidx,
op,
reduce_op,
X_shape,
Y_shape,
dtype,
device,
reduce_last,
)
req_grad_X = X.requires_grad if X is not None else False req_grad_X = X.requires_grad if X is not None else False
req_grad_Y = Y.requires_grad if Y is not None else False req_grad_Y = Y.requires_grad if Y is not None else False
if not spmm_cache_X(op, reduce_op, req_grad_X, req_grad_Y): if not spmm_cache_X(op, reduce_op, req_grad_X, req_grad_Y):
...@@ -126,45 +164,54 @@ class GSpMM(th.autograd.Function): ...@@ -126,45 +164,54 @@ class GSpMM(th.autograd.Function):
@staticmethod @staticmethod
@custom_bwd @custom_bwd
def backward(ctx, dZ): def backward(ctx, dZ):
gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last = ctx.backward_cache (
gidx,
op,
reduce_op,
X_shape,
Y_shape,
dtype,
device,
reduce_last,
) = ctx.backward_cache
X, Y, argX, argY = ctx.saved_tensors X, Y, argX, argY = ctx.saved_tensors
if op != 'copy_rhs' and ctx.needs_input_grad[3]: if op != "copy_rhs" and ctx.needs_input_grad[3]:
g_rev = gidx.reverse() g_rev = gidx.reverse()
if reduce_op == 'sum': if reduce_op == "sum":
if op == 'mul': if op == "mul":
dX = gspmm(g_rev, 'mul', 'sum', dZ, Y) dX = gspmm(g_rev, "mul", "sum", dZ, Y)
elif op == 'add': elif op == "add":
dX = gspmm(g_rev, 'copy_lhs', 'sum', dZ, Y) dX = gspmm(g_rev, "copy_lhs", "sum", dZ, Y)
elif op == 'copy_lhs': elif op == "copy_lhs":
dX = gspmm(g_rev, 'copy_lhs', 'sum', dZ, None) dX = gspmm(g_rev, "copy_lhs", "sum", dZ, None)
else: # max/min else: # max/min
dX = th.zeros((X_shape[0],) + dZ.shape[1:], dX = th.zeros(
dtype=dtype, device=device) (X_shape[0],) + dZ.shape[1:], dtype=dtype, device=device
if op == 'mul': )
grad = _expand(Y, dZ.shape[1:]).gather( if op == "mul":
0, argY.long()) * dZ grad = _expand(Y, dZ.shape[1:]).gather(0, argY.long()) * dZ
dX.scatter_add_(0, argX.long(), grad) dX.scatter_add_(0, argX.long(), grad)
elif op in ['add', 'copy_lhs']: elif op in ["add", "copy_lhs"]:
dX.scatter_add_(0, argX.long(), dZ) dX.scatter_add_(0, argX.long(), dZ)
dX = _reduce_grad(dX, X_shape) dX = _reduce_grad(dX, X_shape)
else: # X has not gradient else: # X has not gradient
dX = None dX = None
if op != 'copy_lhs' and ctx.needs_input_grad[4]: if op != "copy_lhs" and ctx.needs_input_grad[4]:
if reduce_op == 'sum': if reduce_op == "sum":
if op == 'mul' and reduce_last: if op == "mul" and reduce_last:
dY = gsddmm(gidx, 'dot', X, dZ) dY = gsddmm(gidx, "dot", X, dZ)
elif op == 'mul': elif op == "mul":
dY = gsddmm(gidx, 'mul', X, dZ) dY = gsddmm(gidx, "mul", X, dZ)
elif op in ['add', 'copy_rhs']: elif op in ["add", "copy_rhs"]:
dY = gsddmm(gidx, 'copy_rhs', X, dZ) dY = gsddmm(gidx, "copy_rhs", X, dZ)
else: # max/min else: # max/min
dY = th.zeros((Y_shape[0],) + dZ.shape[1:], dY = th.zeros(
dtype=dtype, device=device) (Y_shape[0],) + dZ.shape[1:], dtype=dtype, device=device
if op == 'mul': )
grad = _expand(X, dZ.shape[1:]).gather( if op == "mul":
0, argX.long()) * dZ grad = _expand(X, dZ.shape[1:]).gather(0, argX.long()) * dZ
dY.scatter_add_(0, argY.long(), grad) dY.scatter_add_(0, argY.long(), grad)
elif op in ['add', 'copy_rhs']: elif op in ["add", "copy_rhs"]:
dY.scatter_add_(0, argY.long(), dZ) dY.scatter_add_(0, argY.long(), dZ)
dY = _reduce_grad(dY, Y_shape) dY = _reduce_grad(dY, Y_shape)
else: # Y has no gradient else: # Y has no gradient
...@@ -175,94 +222,178 @@ class GSpMM(th.autograd.Function): ...@@ -175,94 +222,178 @@ class GSpMM(th.autograd.Function):
class GSpMM_hetero(th.autograd.Function): class GSpMM_hetero(th.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=th.float16) @custom_fwd(cast_inputs=th.float16)
def forward(ctx, gidx, op, reduce_op, X_len, *feats): # feats = lhs_data + rhs_data def forward(
out, (argX, argY, argX_ntype, argY_etype) = _gspmm_hetero(gidx, op, reduce_op, X_len, feats) ctx, gidx, op, reduce_op, X_len, *feats
): # feats = lhs_data + rhs_data
out, (argX, argY, argX_ntype, argY_etype) = _gspmm_hetero(
gidx, op, reduce_op, X_len, feats
)
X, Y = feats[:X_len], feats[X_len:] X, Y = feats[:X_len], feats[X_len:]
# TODO (Israt): check target to decide src_id/dst_id? # TODO (Israt): check target to decide src_id/dst_id?
src_id, dst_id = gidx.metagraph.find_edge(0) src_id, dst_id = gidx.metagraph.find_edge(0)
reduce_last = _need_reduce_last_dim(X[src_id], Y[dst_id]) reduce_last = _need_reduce_last_dim(X[src_id], Y[dst_id])
X_shape = tuple([X[i].shape if X[i] is not None else None X_shape = tuple(
for i in range(X_len)]) [X[i].shape if X[i] is not None else None for i in range(X_len)]
Y_shape = tuple([Y[i].shape if Y[i] is not None else None )
for i in range(len(Y))]) Y_shape = tuple(
[Y[i].shape if Y[i] is not None else None for i in range(len(Y))]
)
dtype = X[src_id].dtype if X[src_id] is not None else Y[dst_id].dtype dtype = X[src_id].dtype if X[src_id] is not None else Y[dst_id].dtype
device = X[src_id].device if X[src_id] is not None else Y[dst_id].device device = X[src_id].device if X[src_id] is not None else Y[dst_id].device
ctx.backward_cache = gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last, X_len ctx.backward_cache = (
req_grad_X = tuple([X[i].requires_grad if X[i] is not None else False gidx,
for i in range(X_len)]) op,
req_grad_Y = tuple([Y[i].requires_grad if Y[i] is not None else False reduce_op,
for i in range(len(Y))]) X_shape,
Y_shape,
dtype,
device,
reduce_last,
X_len,
)
req_grad_X = tuple(
[
X[i].requires_grad if X[i] is not None else False
for i in range(X_len)
]
)
req_grad_Y = tuple(
[
Y[i].requires_grad if Y[i] is not None else False
for i in range(len(Y))
]
)
# checking the first relation to decide for all the relations # checking the first relation to decide for all the relations
if not spmm_cache_argX(op, reduce_op, req_grad_X[src_id], req_grad_Y[dst_id]): if not spmm_cache_argX(
op, reduce_op, req_grad_X[src_id], req_grad_Y[dst_id]
):
argX = tuple([None] * len(X)) argX = tuple([None] * len(X))
if not spmm_cache_argY(op, reduce_op, req_grad_X[src_id], req_grad_Y[dst_id]): if not spmm_cache_argY(
op, reduce_op, req_grad_X[src_id], req_grad_Y[dst_id]
):
argY = tuple([None] * len(X)) argY = tuple([None] * len(X))
ctx.save_for_backward(*feats, *argX, *argX_ntype, *argY, *argY_etype ) ctx.save_for_backward(*feats, *argX, *argX_ntype, *argY, *argY_etype)
return out return out
@staticmethod @staticmethod
@custom_bwd @custom_bwd
def backward(ctx, *dZ): def backward(ctx, *dZ):
gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last, X_len = ctx.backward_cache (
gidx,
op,
reduce_op,
X_shape,
Y_shape,
dtype,
device,
reduce_last,
X_len,
) = ctx.backward_cache
num_ntypes = gidx.number_of_ntypes() num_ntypes = gidx.number_of_ntypes()
feats = ctx.saved_tensors[:-(4 * num_ntypes)] feats = ctx.saved_tensors[: -(4 * num_ntypes)]
argX = ctx.saved_tensors[-(4 * num_ntypes):-(3 * num_ntypes)] argX = ctx.saved_tensors[-(4 * num_ntypes) : -(3 * num_ntypes)]
argX_ntype = ctx.saved_tensors[-(3 * num_ntypes):-(2 * num_ntypes)] argX_ntype = ctx.saved_tensors[-(3 * num_ntypes) : -(2 * num_ntypes)]
argY = ctx.saved_tensors[-(2 * num_ntypes):- num_ntypes] argY = ctx.saved_tensors[-(2 * num_ntypes) : -num_ntypes]
argY_etype = ctx.saved_tensors[-num_ntypes:] argY_etype = ctx.saved_tensors[-num_ntypes:]
X, Y = feats[:X_len], feats[X_len:] X, Y = feats[:X_len], feats[X_len:]
if op != 'copy_rhs' and any([x is not None for x in X]): if op != "copy_rhs" and any([x is not None for x in X]):
g_rev = gidx.reverse() g_rev = gidx.reverse()
if reduce_op == 'sum': if reduce_op == "sum":
if op == 'mul': if op == "mul":
dX = gspmm_hetero(g_rev, 'mul', 'sum', len(X), *tuple(dZ + Y)) dX = gspmm_hetero(
elif op == 'add': g_rev, "mul", "sum", len(X), *tuple(dZ + Y)
dX = gspmm_hetero(g_rev, 'copy_lhs', 'sum', len(X), *tuple(dZ + Y)) )
elif op == 'copy_lhs': elif op == "add":
dX = gspmm_hetero(
g_rev, "copy_lhs", "sum", len(X), *tuple(dZ + Y)
)
elif op == "copy_lhs":
tpl_None = tuple([None] * len(Y)) tpl_None = tuple([None] * len(Y))
dX = gspmm_hetero(g_rev, 'copy_lhs', 'sum', len(X), *tuple(dZ + tpl_None)) dX = gspmm_hetero(
g_rev, "copy_lhs", "sum", len(X), *tuple(dZ + tpl_None)
)
else: # max/min else: # max/min
# Assuming that the features are of the same dimension (enforced by the forward function) # Assuming that the features are of the same dimension (enforced by the forward function)
src_id, dst_id = gidx.metagraph.find_edge(0) src_id, dst_id = gidx.metagraph.find_edge(0)
dX = tuple([th.zeros((X_shape[i][0],) + dZ[dst_id].shape[1:], dtype=dtype, device=device) dX = tuple(
if X[i] is not None else None for i in range(len(X))]) [
if op == 'mul': th.zeros(
grad = _expand(Y, dZ.shape[1:]).gather( (X_shape[i][0],) + dZ[dst_id].shape[1:],
0, argY.long()) * dZ dtype=dtype,
device=device,
)
if X[i] is not None
else None
for i in range(len(X))
]
)
if op == "mul":
grad = _expand(Y, dZ.shape[1:]).gather(0, argY.long()) * dZ
dX.scatter_add_(0, argX.long(), grad) dX.scatter_add_(0, argX.long(), grad)
elif op in ['add', 'copy_lhs']: elif op in ["add", "copy_lhs"]:
dX = _update_grad_minmax_hetero(g_rev, op, dZ, argX, argX_ntype, dX) dX = _update_grad_minmax_hetero(
dX = tuple([_reduce_grad(dX[i], X_shape[i]) if X[i] is not None else None g_rev, op, dZ, argX, argX_ntype, dX
for i in range(len(X))]) )
dX = tuple(
[
_reduce_grad(dX[i], X_shape[i])
if X[i] is not None
else None
for i in range(len(X))
]
)
else: # X has not gradient else: # X has not gradient
dX = tuple([None] * len(X)) dX = tuple([None] * len(X))
if op != 'copy_lhs' and any([y is not None for y in Y]): if op != "copy_lhs" and any([y is not None for y in Y]):
# TODO(Israt): implement other combinations of reduce functions # TODO(Israt): implement other combinations of reduce functions
if reduce_op == 'sum': if reduce_op == "sum":
tpl_dZ = tuple([dZ[i] if dZ[i] is not None else None tpl_dZ = tuple(
for i in range(len(dZ))]) [
dZ[i] if dZ[i] is not None else None
for i in range(len(dZ))
]
)
tpl_X_dZ = tuple(X + tpl_dZ) tpl_X_dZ = tuple(X + tpl_dZ)
if op == 'mul' and reduce_last: if op == "mul" and reduce_last:
dY = gsddmm_hetero(gidx, 'dot', X_len, 'u', 'v', *tpl_X_dZ) dY = gsddmm_hetero(gidx, "dot", X_len, "u", "v", *tpl_X_dZ)
elif op == 'mul': elif op == "mul":
dY = gsddmm_hetero(gidx, 'mul', X_len, 'u', 'v', *tpl_X_dZ) dY = gsddmm_hetero(gidx, "mul", X_len, "u", "v", *tpl_X_dZ)
elif op in ['add', 'copy_rhs']: elif op in ["add", "copy_rhs"]:
dY = gsddmm_hetero(gidx, 'copy_rhs', X_len, 'u', 'v', *tpl_X_dZ) dY = gsddmm_hetero(
gidx, "copy_rhs", X_len, "u", "v", *tpl_X_dZ
)
else: # max/min else: # max/min
src_id, dst_id = gidx.metagraph.find_edge(0) src_id, dst_id = gidx.metagraph.find_edge(0)
dY = tuple([th.zeros((Y_shape[i][0],) + dZ[dst_id].shape[1:], dtype=dtype, device=device) dY = tuple(
if Y[i] is not None else None for i in range(len(Y))]) [
if op == 'mul': th.zeros(
grad = _expand(X, dZ.shape[1:]).gather( (Y_shape[i][0],) + dZ[dst_id].shape[1:],
0, argX.long()) * dZ dtype=dtype,
device=device,
)
if Y[i] is not None
else None
for i in range(len(Y))
]
)
if op == "mul":
grad = _expand(X, dZ.shape[1:]).gather(0, argX.long()) * dZ
dY.scatter_add_(0, argY.long(), grad) dY.scatter_add_(0, argY.long(), grad)
elif op in ['add', 'copy_rhs']: elif op in ["add", "copy_rhs"]:
dY = _update_grad_minmax_hetero(gidx.reverse(), op, dZ, argY, argY_etype, dY) dY = _update_grad_minmax_hetero(
dY = tuple([_reduce_grad(dY[i], Y_shape[i]) if dY[i] is not None else None gidx.reverse(), op, dZ, argY, argY_etype, dY
for i in range(len(dY))]) )
dY = tuple(
[
_reduce_grad(dY[i], Y_shape[i])
if dY[i] is not None
else None
for i in range(len(dY))
]
)
else: # Y has no gradient else: # Y has no gradient
dY = tuple([None] * len(Y)) dY = tuple([None] * len(Y))
return (None, None, None, None) + dX + dY return (None, None, None, None) + dX + dY
...@@ -270,14 +401,14 @@ class GSpMM_hetero(th.autograd.Function): ...@@ -270,14 +401,14 @@ class GSpMM_hetero(th.autograd.Function):
def sddmm_cache_X(op, req_grad_X, req_grad_Y): def sddmm_cache_X(op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache X in SDDMM forward stage.""" """Rules to identify whether to cache X in SDDMM forward stage."""
if op in ['mul', 'dot'] and req_grad_Y: if op in ["mul", "dot"] and req_grad_Y:
return True return True
return False return False
def sddmm_cache_Y(op, req_grad_X, req_grad_Y): def sddmm_cache_Y(op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache Y in SDDMM forward stage.""" """Rules to identify whether to cache Y in SDDMM forward stage."""
if op in ['mul', 'dot'] and req_grad_X: if op in ["mul", "dot"] and req_grad_X:
return True return True
return False return False
...@@ -304,43 +435,43 @@ class GSDDMM(th.autograd.Function): ...@@ -304,43 +435,43 @@ class GSDDMM(th.autograd.Function):
def backward(ctx, dZ): def backward(ctx, dZ):
gidx, op, lhs_target, rhs_target, X_shape, Y_shape = ctx.backward_cache gidx, op, lhs_target, rhs_target, X_shape, Y_shape = ctx.backward_cache
X, Y = ctx.saved_tensors X, Y = ctx.saved_tensors
if op != 'copy_rhs' and ctx.needs_input_grad[2]: if op != "copy_rhs" and ctx.needs_input_grad[2]:
if lhs_target in ['u', 'v']: if lhs_target in ["u", "v"]:
_gidx = gidx if lhs_target == 'v' else gidx.reverse() _gidx = gidx if lhs_target == "v" else gidx.reverse()
if op in ['add', 'copy_lhs']: if op in ["add", "copy_lhs"]:
dX = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ) dX = gspmm(_gidx, "copy_rhs", "sum", None, dZ)
else: # mul, dot else: # mul, dot
if rhs_target == lhs_target: if rhs_target == lhs_target:
dX = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ) * Y dX = gspmm(_gidx, "copy_rhs", "sum", None, dZ) * Y
elif rhs_target == 'e': elif rhs_target == "e":
dX = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * Y) dX = gspmm(_gidx, "copy_rhs", "sum", None, dZ * Y)
else: # rhs_target = !lhs_target else: # rhs_target = !lhs_target
dX = gspmm(_gidx, 'mul', 'sum', Y, dZ) dX = gspmm(_gidx, "mul", "sum", Y, dZ)
else: # lhs_target == 'e' else: # lhs_target == 'e'
if op in ['add', 'copy_lhs']: if op in ["add", "copy_lhs"]:
dX = dZ dX = dZ
else: # mul, dot else: # mul, dot
dX = gsddmm(gidx, 'mul', dZ, Y, 'e', rhs_target) dX = gsddmm(gidx, "mul", dZ, Y, "e", rhs_target)
dX = _reduce_grad(dX, X_shape) dX = _reduce_grad(dX, X_shape)
else: else:
dX = None dX = None
if op != 'copy_lhs' and ctx.needs_input_grad[3]: if op != "copy_lhs" and ctx.needs_input_grad[3]:
if rhs_target in ['u', 'v']: if rhs_target in ["u", "v"]:
_gidx = gidx if rhs_target == 'v' else gidx.reverse() _gidx = gidx if rhs_target == "v" else gidx.reverse()
if op in ['add', 'copy_rhs']: if op in ["add", "copy_rhs"]:
dY = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ) dY = gspmm(_gidx, "copy_rhs", "sum", None, dZ)
else: # mul, dot else: # mul, dot
if lhs_target == rhs_target: if lhs_target == rhs_target:
dY = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ) * X dY = gspmm(_gidx, "copy_rhs", "sum", None, dZ) * X
elif lhs_target == 'e': elif lhs_target == "e":
dY = gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * X) dY = gspmm(_gidx, "copy_rhs", "sum", None, dZ * X)
else: # rhs_target = !lhs_target else: # rhs_target = !lhs_target
dY = gspmm(_gidx, 'mul', 'sum', X, dZ) dY = gspmm(_gidx, "mul", "sum", X, dZ)
else: else:
if op in ['add', 'copy_rhs']: if op in ["add", "copy_rhs"]:
dY = dZ dY = dZ
else: # mul, dot else: # mul, dot
dY = gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target) dY = gsddmm(gidx, "mul", dZ, X, "e", lhs_target)
dY = _reduce_grad(dY, Y_shape) dY = _reduce_grad(dY, Y_shape)
else: else:
dY = None dY = None
...@@ -350,18 +481,38 @@ class GSDDMM(th.autograd.Function): ...@@ -350,18 +481,38 @@ class GSDDMM(th.autograd.Function):
class GSDDMM_hetero(th.autograd.Function): class GSDDMM_hetero(th.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=th.float16) @custom_fwd(cast_inputs=th.float16)
def forward(ctx, gidx, op, X_len, lhs_target, rhs_target, *feats): # feats = X+Y def forward(
ctx, gidx, op, X_len, lhs_target, rhs_target, *feats
): # feats = X+Y
out = _gsddmm_hetero(gidx, op, X_len, lhs_target, rhs_target, feats) out = _gsddmm_hetero(gidx, op, X_len, lhs_target, rhs_target, feats)
X, Y = feats[:X_len], feats[X_len:] X, Y = feats[:X_len], feats[X_len:]
X_shape = tuple([X[i].shape if X[i] is not None else None X_shape = tuple(
for i in range(len(X))]) [X[i].shape if X[i] is not None else None for i in range(len(X))]
Y_shape = tuple([Y[i].shape if Y[i] is not None else None )
for i in range(len(Y))]) Y_shape = tuple(
ctx.backward_cache = gidx, op, lhs_target, rhs_target, X_shape, Y_shape, X_len [Y[i].shape if Y[i] is not None else None for i in range(len(Y))]
req_grad_X = tuple([X[i].requires_grad if X[i] is not None else False )
for i in range(len(X))]) ctx.backward_cache = (
req_grad_Y = tuple([Y[i].requires_grad if Y[i] is not None else False gidx,
for i in range(len(Y))]) op,
lhs_target,
rhs_target,
X_shape,
Y_shape,
X_len,
)
req_grad_X = tuple(
[
X[i].requires_grad if X[i] is not None else False
for i in range(len(X))
]
)
req_grad_Y = tuple(
[
Y[i].requires_grad if Y[i] is not None else False
for i in range(len(Y))
]
)
ctx.save_for_backward(*feats) ctx.save_for_backward(*feats)
return out return out
...@@ -369,58 +520,140 @@ class GSDDMM_hetero(th.autograd.Function): ...@@ -369,58 +520,140 @@ class GSDDMM_hetero(th.autograd.Function):
@custom_bwd @custom_bwd
# TODO(Israt): Implement the complete backward operator # TODO(Israt): Implement the complete backward operator
def backward(ctx, *dZ): def backward(ctx, *dZ):
gidx, op, lhs_target, rhs_target, X_shape, Y_shape, X_len = ctx.backward_cache (
gidx,
op,
lhs_target,
rhs_target,
X_shape,
Y_shape,
X_len,
) = ctx.backward_cache
feats = ctx.saved_tensors feats = ctx.saved_tensors
X, Y = feats[:X_len], feats[X_len:] X, Y = feats[:X_len], feats[X_len:]
if op != 'copy_rhs' and any([x is not None for x in X]): if op != "copy_rhs" and any([x is not None for x in X]):
if lhs_target in ['u', 'v']: if lhs_target in ["u", "v"]:
_gidx = gidx if lhs_target == 'v' else gidx.reverse() _gidx = gidx if lhs_target == "v" else gidx.reverse()
tpl_of_None = tuple([None] * len(X)) tpl_of_None = tuple([None] * len(X))
if op in ['add', 'copy_lhs']: if op in ["add", "copy_lhs"]:
dX = gspmm_hetero(_gidx, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ))) dX = gspmm_hetero(
_gidx,
"copy_rhs",
"sum",
len(X),
*(tuple(tpl_of_None + dZ))
)
else: # mul, dot else: # mul, dot
if rhs_target == lhs_target: if rhs_target == lhs_target:
dX = gspmm_hetero(_gidx, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ))) * Y dX = (
elif rhs_target == 'e': gspmm_hetero(
dZ_mul_Y = tuple([dZ[i] * Y[i] if dZ[i] is not None else None _gidx,
for i in range(len(Y))]) "copy_rhs",
dX = gspmm_hetero(_gidx, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ_mul_Y))) "sum",
len(X),
*(tuple(tpl_of_None + dZ))
)
* Y
)
elif rhs_target == "e":
dZ_mul_Y = tuple(
[
dZ[i] * Y[i] if dZ[i] is not None else None
for i in range(len(Y))
]
)
dX = gspmm_hetero(
_gidx,
"copy_rhs",
"sum",
len(X),
*(tuple(tpl_of_None + dZ_mul_Y))
)
else: # rhs_target = !lhs_target else: # rhs_target = !lhs_target
dX = gspmm_hetero(_gidx, 'mul', 'sum', len(X), *tuple(Y + dZ)) dX = gspmm_hetero(
_gidx, "mul", "sum", len(X), *tuple(Y + dZ)
)
else: # lhs_target == 'e' else: # lhs_target == 'e'
if op in ['add', 'copy_lhs']: if op in ["add", "copy_lhs"]:
dX = dZ dX = dZ
else: # mul, dot else: # mul, dot
num_etype = gidx.number_of_etypes() num_etype = gidx.number_of_etypes()
dX = gsddmm_hetero(gidx, 'mul', num_etype, 'e', rhs_target, *tuple(dZ + Y)) dX = gsddmm_hetero(
dX = tuple([_reduce_grad(dX[i], X_shape[i]) if X[i] is not None else None gidx, "mul", num_etype, "e", rhs_target, *tuple(dZ + Y)
for i in range(len(X))]) )
dX = tuple(
[
_reduce_grad(dX[i], X_shape[i])
if X[i] is not None
else None
for i in range(len(X))
]
)
else: else:
dX = tuple([None] * len(X)) dX = tuple([None] * len(X))
if op != 'copy_lhs' and any([y is not None for y in Y]): if op != "copy_lhs" and any([y is not None for y in Y]):
if rhs_target in ['u', 'v']: if rhs_target in ["u", "v"]:
_gidx = gidx if rhs_target == 'v' else gidx.reverse() _gidx = gidx if rhs_target == "v" else gidx.reverse()
tpl_of_None = tuple([None] * len(X)) tpl_of_None = tuple([None] * len(X))
if op in ['add', 'copy_rhs']: if op in ["add", "copy_rhs"]:
dY = gspmm_hetero(_gidx, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ))) dY = gspmm_hetero(
_gidx,
"copy_rhs",
"sum",
len(X),
*(tuple(tpl_of_None + dZ))
)
else: # mul, dot else: # mul, dot
if lhs_target == rhs_target: if lhs_target == rhs_target:
dY = gspmm_hetero(_gidx, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ))) * X dY = (
elif lhs_target == 'e': gspmm_hetero(
dZ_mul_X = tuple([dZ[i] * X[i] if dZ[i] is not None else None _gidx,
for i in range(len(X))]) "copy_rhs",
dY = gspmm_hetero(_gidx, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ_mul_X))) "sum",
len(X),
*(tuple(tpl_of_None + dZ))
)
* X
)
elif lhs_target == "e":
dZ_mul_X = tuple(
[
dZ[i] * X[i] if dZ[i] is not None else None
for i in range(len(X))
]
)
dY = gspmm_hetero(
_gidx,
"copy_rhs",
"sum",
len(X),
*(tuple(tpl_of_None + dZ_mul_X))
)
else: # rhs_target = !lhs_target else: # rhs_target = !lhs_target
dY = gspmm_hetero(_gidx, 'mul', 'sum', len(X), *tuple(X + dZ)) dY = gspmm_hetero(
_gidx, "mul", "sum", len(X), *tuple(X + dZ)
)
else: else:
if op in ['add', 'copy_rhs']: if op in ["add", "copy_rhs"]:
dY = tuple([dZ[i] if dZ[i] is not None else None dY = tuple(
for i in range(len(dZ))]) [
dZ[i] if dZ[i] is not None else None
for i in range(len(dZ))
]
)
else: # mul, dot else: # mul, dot
num_etype = gidx.number_of_etypes() num_etype = gidx.number_of_etypes()
dY = gsddmm_hetero(gidx, 'mul', num_etype, 'e', lhs_target, *tuple(dZ + X)) dY = gsddmm_hetero(
dY = tuple([_reduce_grad(dY[i], Y_shape[i]) if Y[i] is not None else None gidx, "mul", num_etype, "e", lhs_target, *tuple(dZ + X)
for i in range(len(Y))]) )
dY = tuple(
[
_reduce_grad(dY[i], Y_shape[i])
if Y[i] is not None
else None
for i in range(len(Y))
]
)
else: else:
dY = tuple([None] * len(Y)) dY = tuple([None] * len(Y))
return (None, None, None, None, None) + dX + dY return (None, None, None, None, None) + dX + dY
...@@ -447,17 +680,17 @@ class EdgeSoftmax(th.autograd.Function): ...@@ -447,17 +680,17 @@ class EdgeSoftmax(th.autograd.Function):
# a local variable # a local variable
if not is_all(eids): if not is_all(eids):
gidx = gidx.edge_subgraph([eids], True).graph gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == 'src': if norm_by == "src":
gidx = gidx.reverse() gidx = gidx.reverse()
#Note: Now _edge_softmax_forward op only supports CPU # Note: Now _edge_softmax_forward op only supports CPU
#TODO(Zhejiang): We will support GPU in the future # TODO(Zhejiang): We will support GPU in the future
if(score.is_cuda): if score.is_cuda:
score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0] score_max = _gspmm(gidx, "copy_rhs", "max", None, score)[0]
score = th.exp(_gsddmm(gidx, 'sub', score, score_max, 'e', 'v')) score = th.exp(_gsddmm(gidx, "sub", score, score_max, "e", "v"))
score_sum = _gspmm(gidx, 'copy_rhs', 'sum', None, score)[0] score_sum = _gspmm(gidx, "copy_rhs", "sum", None, score)[0]
out = _gsddmm(gidx, 'div', score, score_sum, 'e', 'v') out = _gsddmm(gidx, "div", score, score_sum, "e", "v")
else: else:
out = _edge_softmax_forward(gidx, score, 'copy_rhs') out = _edge_softmax_forward(gidx, score, "copy_rhs")
ctx.backward_cache = gidx ctx.backward_cache = gidx
ctx.save_for_backward(out) ctx.save_for_backward(out)
return out return out
...@@ -480,14 +713,14 @@ class EdgeSoftmax(th.autograd.Function): ...@@ -480,14 +713,14 @@ class EdgeSoftmax(th.autograd.Function):
return grad_score.data return grad_score.data
""" """
gidx = ctx.backward_cache gidx = ctx.backward_cache
out, = ctx.saved_tensors (out,) = ctx.saved_tensors
sds = out * grad_out sds = out * grad_out
#Note: Now _edge_softmax_backward op only supports CPU # Note: Now _edge_softmax_backward op only supports CPU
#TODO(Zhejiang): We will support GPU in the future # TODO(Zhejiang): We will support GPU in the future
if(out.is_cuda): if out.is_cuda:
accum = gspmm(gidx, 'copy_rhs', 'sum', None, sds) accum = gspmm(gidx, "copy_rhs", "sum", None, sds)
grad_score = sds - gsddmm(gidx, 'mul', out, accum, 'e', 'v') grad_score = sds - gsddmm(gidx, "mul", out, accum, "e", "v")
else: else:
grad_score = _edge_softmax_backward(gidx, out, sds) grad_score = _edge_softmax_backward(gidx, out, sds)
return None, grad_score, None, None return None, grad_score, None, None
...@@ -514,18 +747,28 @@ class EdgeSoftmax_hetero(th.autograd.Function): ...@@ -514,18 +747,28 @@ class EdgeSoftmax_hetero(th.autograd.Function):
# a local variable # a local variable
if not is_all(eids): if not is_all(eids):
gidx = gidx.edge_subgraph([eids], True).graph gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == 'src': if norm_by == "src":
gidx = gidx.reverse() gidx = gidx.reverse()
u_len = gidx.number_of_ntypes() u_len = gidx.number_of_ntypes()
e_len = gidx.number_of_etypes() e_len = gidx.number_of_etypes()
lhs = [None] * u_len lhs = [None] * u_len
feats = tuple(lhs + list(score)) feats = tuple(lhs + list(score))
score_max = _gspmm_hetero(gidx, 'copy_rhs', 'max', u_len, feats)[0] score_max = _gspmm_hetero(gidx, "copy_rhs", "max", u_len, feats)[0]
out_tmp = _gsddmm_hetero(gidx, 'sub', e_len, 'e', 'v', tuple(list(score) + list(score_max))) out_tmp = _gsddmm_hetero(
score = tuple([th.exp(out_tmp[i]) if out_tmp[i] is not None else None gidx, "sub", e_len, "e", "v", tuple(list(score) + list(score_max))
for i in range(len(out_tmp))]) )
score_sum = _gspmm_hetero(gidx, 'copy_rhs', 'sum', u_len, tuple(lhs + list(score)))[0] score = tuple(
out = _gsddmm_hetero(gidx, 'div', e_len, 'e', 'v', tuple(list(score) + list(score_sum))) [
th.exp(out_tmp[i]) if out_tmp[i] is not None else None
for i in range(len(out_tmp))
]
)
score_sum = _gspmm_hetero(
gidx, "copy_rhs", "sum", u_len, tuple(lhs + list(score))
)[0]
out = _gsddmm_hetero(
gidx, "div", e_len, "e", "v", tuple(list(score) + list(score_sum))
)
ctx.backward_cache = gidx ctx.backward_cache = gidx
ctx.save_for_backward(*out) ctx.save_for_backward(*out)
return out return out
...@@ -552,12 +795,14 @@ class EdgeSoftmax_hetero(th.autograd.Function): ...@@ -552,12 +795,14 @@ class EdgeSoftmax_hetero(th.autograd.Function):
e_len = gidx.number_of_etypes() e_len = gidx.number_of_etypes()
lhs = [None] * u_len lhs = [None] * u_len
out = ctx.saved_tensors out = ctx.saved_tensors
sds = tuple([out[i] * grad_out[i] sds = tuple([out[i] * grad_out[i] for i in range(len(out))])
for i in range(len(out))]) accum = _gspmm_hetero(
accum = _gspmm_hetero(gidx, 'copy_rhs', 'sum', u_len, tuple(lhs + list(sds)))[0] gidx, "copy_rhs", "sum", u_len, tuple(lhs + list(sds))
out_sddmm = _gsddmm_hetero(gidx, 'mul', e_len, 'e', 'v', tuple(list(out) + list(accum))) )[0]
grad_score = tuple([sds[i] - out_sddmm[i] out_sddmm = _gsddmm_hetero(
for i in range(len(sds))]) gidx, "mul", e_len, "e", "v", tuple(list(out) + list(accum))
)
grad_score = tuple([sds[i] - out_sddmm[i] for i in range(len(sds))])
return (None, None, None) + grad_score return (None, None, None) + grad_score
...@@ -576,12 +821,13 @@ class SegmentReduce(th.autograd.Function): ...@@ -576,12 +821,13 @@ class SegmentReduce(th.autograd.Function):
op = ctx.backward_cache op = ctx.backward_cache
arg, offsets = ctx.saved_tensors arg, offsets = ctx.saved_tensors
m = offsets[-1].item() m = offsets[-1].item()
if op == 'sum': if op == "sum":
offsets = offsets[1:] offsets = offsets[1:]
# To address the issue of trailing zeros, related issue: # To address the issue of trailing zeros, related issue:
# https://github.com/dmlc/dgl/pull/2610 # https://github.com/dmlc/dgl/pull/2610
indices = th.zeros( indices = th.zeros(
(m + 1,), device=offsets.device, dtype=offsets.dtype) (m + 1,), device=offsets.device, dtype=offsets.dtype
)
indices.scatter_add_(0, offsets, th.ones_like(offsets)) indices.scatter_add_(0, offsets, th.ones_like(offsets))
indices = th.cumsum(indices, -1)[:-1] indices = th.cumsum(indices, -1)[:-1]
dx = dy[indices] dx = dy[indices]
...@@ -608,23 +854,50 @@ class ScatterAdd(th.autograd.Function): ...@@ -608,23 +854,50 @@ class ScatterAdd(th.autograd.Function):
class CSRMM(th.autograd.Function): class CSRMM(th.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, gidxA, A_weights, gidxB, B_weights, num_vtypes): def forward(ctx, gidxA, A_weights, gidxB, B_weights, num_vtypes):
gidxC, C_weights = _csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes) gidxC, C_weights = _csrmm(
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(0, False, 'csr') gidxA, A_weights, gidxB, B_weights, num_vtypes
)
(
nrows,
ncols,
C_indptr,
C_indices,
C_eids,
) = gidxC.adjacency_matrix_tensors(0, False, "csr")
# Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same # Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same
# as the underlying tensors of the created graph gidxC. # as the underlying tensors of the created graph gidxC.
ctx.backward_cache = gidxA, gidxB, gidxC ctx.backward_cache = gidxA, gidxB, gidxC
ctx.save_for_backward(A_weights, B_weights) ctx.save_for_backward(A_weights, B_weights)
return th.tensor(nrows), th.tensor(ncols), C_indptr, C_indices, C_eids, C_weights return (
th.tensor(nrows),
th.tensor(ncols),
C_indptr,
C_indices,
C_eids,
C_weights,
)
@staticmethod @staticmethod
def backward(ctx, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights): def backward(
ctx, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights
):
# Only the last argument is meaningful. # Only the last argument is meaningful.
gidxA, gidxB, gidxC = ctx.backward_cache gidxA, gidxB, gidxC = ctx.backward_cache
A_weights, B_weights = ctx.saved_tensors A_weights, B_weights = ctx.saved_tensors
dgidxA, dA_weights = csrmm( dgidxA, dA_weights = csrmm(
gidxC, dC_weights, gidxB.reverse(), B_weights, gidxA.number_of_ntypes()) gidxC,
dC_weights,
gidxB.reverse(),
B_weights,
gidxA.number_of_ntypes(),
)
dgidxB, dB_weights = csrmm( dgidxB, dB_weights = csrmm(
gidxA.reverse(), A_weights, gidxC, dC_weights, gidxB.number_of_ntypes()) gidxA.reverse(),
A_weights,
gidxC,
dC_weights,
gidxB.number_of_ntypes(),
)
dA_weights = csrmask(dgidxA, dA_weights, gidxA) dA_weights = csrmask(dgidxA, dA_weights, gidxA)
dB_weights = csrmask(dgidxB, dB_weights, gidxB) dB_weights = csrmask(dgidxB, dB_weights, gidxB)
return None, dA_weights, None, dB_weights, None return None, dA_weights, None, dB_weights, None
...@@ -635,18 +908,34 @@ class CSRSum(th.autograd.Function): ...@@ -635,18 +908,34 @@ class CSRSum(th.autograd.Function):
def forward(ctx, gidxs, *weights): def forward(ctx, gidxs, *weights):
# PyTorch tensors must be explicit arguments of the forward function # PyTorch tensors must be explicit arguments of the forward function
gidxC, C_weights = _csrsum(gidxs, weights) gidxC, C_weights = _csrsum(gidxs, weights)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors( (
0, False, 'csr') nrows,
ncols,
C_indptr,
C_indices,
C_eids,
) = gidxC.adjacency_matrix_tensors(0, False, "csr")
# Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same # Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same
# as the underlying tensors of the created graph gidxC. # as the underlying tensors of the created graph gidxC.
ctx.backward_cache = gidxs, gidxC ctx.backward_cache = gidxs, gidxC
return th.tensor(nrows), th.tensor(ncols), C_indptr, C_indices, C_eids, C_weights return (
th.tensor(nrows),
th.tensor(ncols),
C_indptr,
C_indices,
C_eids,
C_weights,
)
@staticmethod @staticmethod
def backward(ctx, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights): def backward(
ctx, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights
):
# Only the last argument is meaningful. # Only the last argument is meaningful.
gidxs, gidxC = ctx.backward_cache gidxs, gidxC = ctx.backward_cache
return (None,) + tuple(csrmask(gidxC, dC_weights, gidx) for gidx in gidxs) return (None,) + tuple(
csrmask(gidxC, dC_weights, gidx) for gidx in gidxs
)
class CSRMask(th.autograd.Function): class CSRMask(th.autograd.Function):
...@@ -692,7 +981,9 @@ class GATHERMM(th.autograd.Function): ...@@ -692,7 +981,9 @@ class GATHERMM(th.autograd.Function):
@custom_fwd(cast_inputs=th.float16) @custom_fwd(cast_inputs=th.float16)
def forward(ctx, A, B, idx_a, idx_b): def forward(ctx, A, B, idx_a, idx_b):
if B.dim() != 3: if B.dim() != 3:
raise ValueError("Expected dimension of B is 3. Got " + str(B.dim())) raise ValueError(
"Expected dimension of B is 3. Got " + str(B.dim())
)
N = len(idx_b) if idx_a is None else len(idx_a) N = len(idx_b) if idx_a is None else len(idx_a)
C = th.zeros((N, B.shape[2]), device=A.device, dtype=A.dtype) C = th.zeros((N, B.shape[2]), device=A.device, dtype=A.dtype)
C = _gather_mm(A, B, C, idx_a, idx_b) C = _gather_mm(A, B, C, idx_a, idx_b)
...@@ -706,103 +997,158 @@ class GATHERMM(th.autograd.Function): ...@@ -706,103 +997,158 @@ class GATHERMM(th.autograd.Function):
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
# Compute A_grad = Out_grad * B^T # Compute A_grad = Out_grad * B^T
A_grad = th.zeros(A.shape, device=A.device, dtype=A.dtype) A_grad = th.zeros(A.shape, device=A.device, dtype=A.dtype)
A_grad = _gather_mm_scatter(dZ, B.transpose(1, 2), A_grad, A_grad = _gather_mm_scatter(
idx_b=idx_b, idx_c=idx_a) dZ, B.transpose(1, 2), A_grad, idx_b=idx_b, idx_c=idx_a
)
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
# Compute B_grad = A^T * Out_grad # Compute B_grad = A^T * Out_grad
B_grad = th.zeros(B.shape, device=B.device, dtype=B.dtype) B_grad = th.zeros(B.shape, device=B.device, dtype=B.dtype)
B_grad = _gather_mm_scatter(A, dZ, B_grad, idx_a=idx_a, idx_c=idx_b) B_grad = _gather_mm_scatter(A, dZ, B_grad, idx_a=idx_a, idx_c=idx_b)
return A_grad, B_grad, None, None return A_grad, B_grad, None, None
def gspmm(gidx, op, reduce_op, lhs_data, rhs_data): def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
if op == 'sub': if op == "sub":
op = 'add' op = "add"
rhs_data = -rhs_data rhs_data = -rhs_data
if op == 'div': if op == "div":
op = 'mul' op = "mul"
rhs_data = 1. / rhs_data rhs_data = 1.0 / rhs_data
return GSpMM.apply(gidx, op, reduce_op, lhs_data, rhs_data) return GSpMM.apply(gidx, op, reduce_op, lhs_data, rhs_data)
def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
if op == 'sub': def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target="u", rhs_target="v"):
op = 'add' if op == "sub":
op = "add"
rhs_data = -rhs_data rhs_data = -rhs_data
if op == 'div': if op == "div":
op = 'mul' op = "mul"
rhs_data = 1. / rhs_data rhs_data = 1.0 / rhs_data
return GSDDMM.apply(gidx, op, lhs_data, rhs_data, lhs_target, rhs_target) return GSDDMM.apply(gidx, op, lhs_data, rhs_data, lhs_target, rhs_target)
def gspmm_hetero(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple): def gspmm_hetero(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple):
lhs_tuple, rhs_tuple = lhs_and_rhs_tuple[:lhs_len], lhs_and_rhs_tuple[lhs_len:] lhs_tuple, rhs_tuple = (
if op == 'sub': lhs_and_rhs_tuple[:lhs_len],
op = 'add' lhs_and_rhs_tuple[lhs_len:],
rhs_tuple = tuple([-rhs_tuple[i] if rhs_tuple[i] is not None else None )
for i in range(len(rhs_tuple))]) if op == "sub":
if op == 'div': op = "add"
op = 'mul' rhs_tuple = tuple(
rhs_tuple = tuple([(1. / rhs_tuple[i]) if rhs_tuple[i] is not None else None [
for i in range(len(rhs_tuple))]) -rhs_tuple[i] if rhs_tuple[i] is not None else None
if op in ['add', 'mul']: for i in range(len(rhs_tuple))
]
)
if op == "div":
op = "mul"
rhs_tuple = tuple(
[
(1.0 / rhs_tuple[i]) if rhs_tuple[i] is not None else None
for i in range(len(rhs_tuple))
]
)
if op in ["add", "mul"]:
lhs_and_rhs_tuple = tuple(list(lhs_tuple) + list(rhs_tuple)) lhs_and_rhs_tuple = tuple(list(lhs_tuple) + list(rhs_tuple))
return GSpMM_hetero.apply(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple) return GSpMM_hetero.apply(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple)
def gsddmm_hetero(g, op, lhs_len, lhs_target='u', rhs_target='v', *lhs_and_rhs_tuple):
lhs_tuple, rhs_tuple = lhs_and_rhs_tuple[:lhs_len], lhs_and_rhs_tuple[lhs_len:] def gsddmm_hetero(
if op == 'sub': g, op, lhs_len, lhs_target="u", rhs_target="v", *lhs_and_rhs_tuple
op = 'add' ):
rhs_tuple = tuple([-rhs_tuple[i] if rhs_tuple[i] is not None else None lhs_tuple, rhs_tuple = (
for i in range(len(rhs_tuple))]) lhs_and_rhs_tuple[:lhs_len],
if op == 'div': lhs_and_rhs_tuple[lhs_len:],
op = 'mul' )
rhs_tuple = tuple([(1. / rhs_tuple[i]) if rhs_tuple[i] is not None else None if op == "sub":
for i in range(len(rhs_tuple))]) op = "add"
if op in ['add', 'mul']: rhs_tuple = tuple(
[
-rhs_tuple[i] if rhs_tuple[i] is not None else None
for i in range(len(rhs_tuple))
]
)
if op == "div":
op = "mul"
rhs_tuple = tuple(
[
(1.0 / rhs_tuple[i]) if rhs_tuple[i] is not None else None
for i in range(len(rhs_tuple))
]
)
if op in ["add", "mul"]:
lhs_and_rhs_tuple = tuple(list(lhs_tuple) + list(rhs_tuple)) lhs_and_rhs_tuple = tuple(list(lhs_tuple) + list(rhs_tuple))
return GSDDMM_hetero.apply(g, op, lhs_len, lhs_target, rhs_target, *lhs_and_rhs_tuple) return GSDDMM_hetero.apply(
g, op, lhs_len, lhs_target, rhs_target, *lhs_and_rhs_tuple
)
def edge_softmax(gidx, logits, eids=ALL, norm_by='dst'): def edge_softmax(gidx, logits, eids=ALL, norm_by="dst"):
return EdgeSoftmax.apply(gidx, logits, eids, norm_by) return EdgeSoftmax.apply(gidx, logits, eids, norm_by)
def edge_softmax_hetero(gidx, eids=ALL, norm_by='dst', *logits):
def edge_softmax_hetero(gidx, eids=ALL, norm_by="dst", *logits):
return EdgeSoftmax_hetero.apply(gidx, eids, norm_by, *logits) return EdgeSoftmax_hetero.apply(gidx, eids, norm_by, *logits)
def segment_reduce(op, x, offsets): def segment_reduce(op, x, offsets):
return SegmentReduce.apply(op, x, offsets) return SegmentReduce.apply(op, x, offsets)
def scatter_add(x, idx, m): def scatter_add(x, idx, m):
return ScatterAdd.apply(x, idx, m) return ScatterAdd.apply(x, idx, m)
def csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes): def csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes):
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = \ nrows, ncols, C_indptr, C_indices, C_eids, C_weights = CSRMM.apply(
CSRMM.apply(gidxA, A_weights, gidxB, B_weights, num_vtypes) gidxA, A_weights, gidxB, B_weights, num_vtypes
)
gidxC = create_unitgraph_from_csr( gidxC = create_unitgraph_from_csr(
num_vtypes, nrows.item(), ncols.item(), C_indptr, C_indices, C_eids, num_vtypes,
["coo", "csr", "csc"]) nrows.item(),
ncols.item(),
C_indptr,
C_indices,
C_eids,
["coo", "csr", "csc"],
)
return gidxC, C_weights return gidxC, C_weights
def csrsum(gidxs, weights): def csrsum(gidxs, weights):
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = CSRSum.apply(gidxs, *weights) nrows, ncols, C_indptr, C_indices, C_eids, C_weights = CSRSum.apply(
gidxs, *weights
)
gidxC = create_unitgraph_from_csr( gidxC = create_unitgraph_from_csr(
gidxs[0].number_of_ntypes(), nrows.item(), ncols.item(), C_indptr, C_indices, C_eids, gidxs[0].number_of_ntypes(),
["coo", "csr", "csc"]) nrows.item(),
ncols.item(),
C_indptr,
C_indices,
C_eids,
["coo", "csr", "csc"],
)
return gidxC, C_weights return gidxC, C_weights
def csrmask(gidxA, A_weights, gidxB): def csrmask(gidxA, A_weights, gidxB):
return CSRMask.apply(gidxA, A_weights, gidxB) return CSRMask.apply(gidxA, A_weights, gidxB)
def segment_mm(A, B, seglen_A): def segment_mm(A, B, seglen_A):
if A.device.type == 'cpu': if A.device.type == "cpu":
C = [] C = []
off = 0 off = 0
for i in range(B.shape[0]): for i in range(B.shape[0]):
C.append(A[off:off+seglen_A[i]] @ B[i]) C.append(A[off : off + seglen_A[i]] @ B[i])
off += seglen_A[i] off += seglen_A[i]
return th.cat(C) return th.cat(C)
else: else:
return SEGMENTMM.apply(A, B, seglen_A) return SEGMENTMM.apply(A, B, seglen_A)
def gather_mm(A, B, idx_A=None, idx_B=None): def gather_mm(A, B, idx_A=None, idx_B=None):
if A.device.type == 'cpu': if A.device.type == "cpu":
A = A[idx_A] if idx_A is not None else A A = A[idx_A] if idx_A is not None else A
B = B[idx_B] if idx_B is not None else B B = B[idx_B] if idx_B is not None else B
return th.bmm(A.unsqueeze(1), B).squeeze(1) return th.bmm(A.unsqueeze(1), B).squeeze(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment