Unverified Commit a208e886 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4680)



* [Misc] Black auto fix.

* fix pylint disable
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 29434e65
...@@ -5,33 +5,41 @@ from __future__ import absolute_import ...@@ -5,33 +5,41 @@ from __future__ import absolute_import
import ctypes import ctypes
import traceback import traceback
from numbers import Number, Integral from numbers import Integral, Number
from ..base import _LIB, check_call from ..base import _LIB, c_str, check_call, string_types
from ..base import c_str, string_types from ..object_generic import ObjectGeneric, convert_to_object
from ..object_generic import convert_to_object, ObjectGeneric from ..runtime_ctypes import DGLByteArray, DGLContext, DGLDataType
from ..runtime_ctypes import DGLDataType, DGLByteArray, DGLContext
from . import ndarray as _nd from . import ndarray as _nd
from . import object as _object
from .ndarray import NDArrayBase, _make_array from .ndarray import NDArrayBase, _make_array
from .types import DGLValue, TypeCode
from .types import DGLPackedCFunc, DGLCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
from .object import ObjectBase from .object import ObjectBase
from . import object as _object from .types import (
C_TO_PY_ARG_SWITCH,
RETURN_SWITCH,
DGLCFuncFinalizer,
DGLPackedCFunc,
DGLValue,
TypeCode,
_wrap_arg_func,
)
FunctionHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p
ModuleHandle = ctypes.c_void_p ModuleHandle = ctypes.c_void_p
DGLRetValueHandle = ctypes.c_void_p DGLRetValueHandle = ctypes.c_void_p
def _ctypes_free_resource(rhandle): def _ctypes_free_resource(rhandle):
"""callback to free resources when it it not needed.""" """callback to free resources when it it not needed."""
pyobj = ctypes.cast(rhandle, ctypes.py_object) pyobj = ctypes.cast(rhandle, ctypes.py_object)
ctypes.pythonapi.Py_DecRef(pyobj) ctypes.pythonapi.Py_DecRef(pyobj)
# Global callback that is always alive # Global callback that is always alive
DGL_FREE_PYOBJ = DGLCFuncFinalizer(_ctypes_free_resource) DGL_FREE_PYOBJ = DGLCFuncFinalizer(_ctypes_free_resource)
ctypes.pythonapi.Py_IncRef(ctypes.py_object(DGL_FREE_PYOBJ)) ctypes.pythonapi.Py_IncRef(ctypes.py_object(DGL_FREE_PYOBJ))
def convert_to_dgl_func(pyfunc): def convert_to_dgl_func(pyfunc):
"""Convert a python function to DGL function """Convert a python function to DGL function
...@@ -46,10 +54,15 @@ def convert_to_dgl_func(pyfunc): ...@@ -46,10 +54,15 @@ def convert_to_dgl_func(pyfunc):
The converted dgl function. The converted dgl function.
""" """
local_pyfunc = pyfunc local_pyfunc = pyfunc
def cfun(args, type_codes, num_args, ret, _): def cfun(args, type_codes, num_args, ret, _):
""" ctypes function """ """ctypes function"""
num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args num_args = (
pyargs = (C_TO_PY_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args)) num_args.value if isinstance(num_args, ctypes.c_int) else num_args
)
pyargs = (
C_TO_PY_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args)
)
# pylint: disable=broad-except # pylint: disable=broad-except
try: try:
rv = local_pyfunc(*pyargs) rv = local_pyfunc(*pyargs)
...@@ -60,12 +73,16 @@ def convert_to_dgl_func(pyfunc): ...@@ -60,12 +73,16 @@ def convert_to_dgl_func(pyfunc):
if rv is not None: if rv is not None:
if isinstance(rv, tuple): if isinstance(rv, tuple):
raise ValueError("PackedFunction can only support one return value") raise ValueError(
"PackedFunction can only support one return value"
)
temp_args = [] temp_args = []
values, tcodes, _ = _make_dgl_args((rv,), temp_args) values, tcodes, _ = _make_dgl_args((rv,), temp_args)
if not isinstance(ret, DGLRetValueHandle): if not isinstance(ret, DGLRetValueHandle):
ret = DGLRetValueHandle(ret) ret = DGLRetValueHandle(ret)
check_call(_LIB.DGLCFuncSetReturn(ret, values, tcodes, ctypes.c_int(1))) check_call(
_LIB.DGLCFuncSetReturn(ret, values, tcodes, ctypes.c_int(1))
)
_ = temp_args _ = temp_args
_ = rv _ = rv
return 0 return 0
...@@ -76,8 +93,11 @@ def convert_to_dgl_func(pyfunc): ...@@ -76,8 +93,11 @@ def convert_to_dgl_func(pyfunc):
# DGL_FREE_PYOBJ will be called after it is no longer needed. # DGL_FREE_PYOBJ will be called after it is no longer needed.
pyobj = ctypes.py_object(f) pyobj = ctypes.py_object(f)
ctypes.pythonapi.Py_IncRef(pyobj) ctypes.pythonapi.Py_IncRef(pyobj)
check_call(_LIB.DGLFuncCreateFromCFunc( check_call(
f, pyobj, DGL_FREE_PYOBJ, ctypes.byref(handle))) _LIB.DGLFuncCreateFromCFunc(
f, pyobj, DGL_FREE_PYOBJ, ctypes.byref(handle)
)
)
return _CLASS_FUNCTION(handle, False) return _CLASS_FUNCTION(handle, False)
...@@ -104,8 +124,11 @@ def _make_dgl_args(args, temp_args): ...@@ -104,8 +124,11 @@ def _make_dgl_args(args, temp_args):
temp_args.append(arg) temp_args.append(arg)
elif isinstance(arg, NDArrayBase): elif isinstance(arg, NDArrayBase):
values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p) values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
type_codes[i] = (TypeCode.NDARRAY_CONTAINER type_codes[i] = (
if not arg.is_view else TypeCode.ARRAY_HANDLE) TypeCode.NDARRAY_CONTAINER
if not arg.is_view
else TypeCode.ARRAY_HANDLE
)
elif isinstance(arg, _nd._DGL_COMPATS): elif isinstance(arg, _nd._DGL_COMPATS):
values[i].v_handle = ctypes.c_void_p(arg._dgl_handle) values[i].v_handle = ctypes.c_void_p(arg._dgl_handle)
type_codes[i] = arg.__class__._dgl_tcode type_codes[i] = arg.__class__._dgl_tcode
...@@ -125,7 +148,8 @@ def _make_dgl_args(args, temp_args): ...@@ -125,7 +148,8 @@ def _make_dgl_args(args, temp_args):
arr = DGLByteArray() arr = DGLByteArray()
arr.data = ctypes.cast( arr.data = ctypes.cast(
(ctypes.c_byte * len(arg)).from_buffer(arg), (ctypes.c_byte * len(arg)).from_buffer(arg),
ctypes.POINTER(ctypes.c_byte)) ctypes.POINTER(ctypes.c_byte),
)
arr.size = len(arg) arr.size = len(arg)
values[i].v_handle = ctypes.c_void_p(ctypes.addressof(arr)) values[i].v_handle = ctypes.c_void_p(ctypes.addressof(arr))
temp_args.append(arr) temp_args.append(arr)
...@@ -134,7 +158,7 @@ def _make_dgl_args(args, temp_args): ...@@ -134,7 +158,7 @@ def _make_dgl_args(args, temp_args):
values[i].v_str = c_str(arg) values[i].v_str = c_str(arg)
type_codes[i] = TypeCode.STR type_codes[i] = TypeCode.STR
# NOTE(minjie): module is not used in DGL # NOTE(minjie): module is not used in DGL
#elif isinstance(arg, _CLASS_MODULE): # elif isinstance(arg, _CLASS_MODULE):
# values[i].v_handle = arg.handle # values[i].v_handle = arg.handle
# type_codes[i] = TypeCode.MODULE_HANDLE # type_codes[i] = TypeCode.MODULE_HANDLE
elif isinstance(arg, FunctionBase): elif isinstance(arg, FunctionBase):
...@@ -155,6 +179,7 @@ def _make_dgl_args(args, temp_args): ...@@ -155,6 +179,7 @@ def _make_dgl_args(args, temp_args):
class FunctionBase(object): class FunctionBase(object):
"""Function base.""" """Function base."""
__slots__ = ["handle", "is_global"] __slots__ = ["handle", "is_global"]
# pylint: disable=no-member # pylint: disable=no-member
def __init__(self, handle, is_global): def __init__(self, handle, is_global):
...@@ -185,9 +210,16 @@ class FunctionBase(object): ...@@ -185,9 +210,16 @@ class FunctionBase(object):
values, tcodes, num_args = _make_dgl_args(args, temp_args) values, tcodes, num_args = _make_dgl_args(args, temp_args)
ret_val = DGLValue() ret_val = DGLValue()
ret_tcode = ctypes.c_int() ret_tcode = ctypes.c_int()
check_call(_LIB.DGLFuncCall( check_call(
self.handle, values, tcodes, ctypes.c_int(num_args), _LIB.DGLFuncCall(
ctypes.byref(ret_val), ctypes.byref(ret_tcode))) self.handle,
values,
tcodes,
ctypes.c_int(num_args),
ctypes.byref(ret_val),
ctypes.byref(ret_tcode),
)
)
_ = temp_args _ = temp_args
_ = args _ = args
return RETURN_SWITCH[ret_tcode.value](ret_val) return RETURN_SWITCH[ret_tcode.value](ret_val)
...@@ -199,9 +231,16 @@ def __init_handle_by_constructor__(fconstructor, args): ...@@ -199,9 +231,16 @@ def __init_handle_by_constructor__(fconstructor, args):
values, tcodes, num_args = _make_dgl_args(args, temp_args) values, tcodes, num_args = _make_dgl_args(args, temp_args)
ret_val = DGLValue() ret_val = DGLValue()
ret_tcode = ctypes.c_int() ret_tcode = ctypes.c_int()
check_call(_LIB.DGLFuncCall( check_call(
fconstructor.handle, values, tcodes, ctypes.c_int(num_args), _LIB.DGLFuncCall(
ctypes.byref(ret_val), ctypes.byref(ret_tcode))) fconstructor.handle,
values,
tcodes,
ctypes.c_int(num_args),
ctypes.byref(ret_val),
ctypes.byref(ret_tcode),
)
)
_ = temp_args _ = temp_args
_ = args _ = args
assert ret_tcode.value == TypeCode.OBJECT_HANDLE assert ret_tcode.value == TypeCode.OBJECT_HANDLE
...@@ -216,6 +255,7 @@ def _return_module(x): ...@@ -216,6 +255,7 @@ def _return_module(x):
handle = ModuleHandle(handle) handle = ModuleHandle(handle)
return _CLASS_MODULE(handle) return _CLASS_MODULE(handle)
def _handle_return_func(x): def _handle_return_func(x):
"""Return function""" """Return function"""
handle = x.v_handle handle = x.v_handle
...@@ -228,22 +268,32 @@ def _handle_return_func(x): ...@@ -228,22 +268,32 @@ def _handle_return_func(x):
_object.__init_by_constructor__ = __init_handle_by_constructor__ _object.__init_by_constructor__ = __init_handle_by_constructor__
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False) RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(
x.v_handle, False
)
C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func( C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func(
_handle_return_func, TypeCode.FUNC_HANDLE) _handle_return_func, TypeCode.FUNC_HANDLE
)
C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func( C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
_return_module, TypeCode.MODULE_HANDLE) _return_module, TypeCode.MODULE_HANDLE
C_TO_PY_ARG_SWITCH[TypeCode.ARRAY_HANDLE] = lambda x: _make_array(x.v_handle, True) )
C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False) C_TO_PY_ARG_SWITCH[TypeCode.ARRAY_HANDLE] = lambda x: _make_array(
x.v_handle, True
)
C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(
x.v_handle, False
)
_CLASS_MODULE = None _CLASS_MODULE = None
_CLASS_FUNCTION = None _CLASS_FUNCTION = None
def _set_class_module(module_class): def _set_class_module(module_class):
"""Initialize the module.""" """Initialize the module."""
global _CLASS_MODULE global _CLASS_MODULE
_CLASS_MODULE = module_class _CLASS_MODULE = module_class
def _set_class_function(func_class): def _set_class_function(func_class):
global _CLASS_FUNCTION global _CLASS_FUNCTION
_CLASS_FUNCTION = func_class _CLASS_FUNCTION = func_class
...@@ -3,18 +3,23 @@ ...@@ -3,18 +3,23 @@
from __future__ import absolute_import from __future__ import absolute_import
import ctypes import ctypes
from ..base import _LIB, check_call, c_str
from ..runtime_ctypes import DGLArrayHandle
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _return_handle
from ..base import _LIB, c_str, check_call
from ..runtime_ctypes import DGLArrayHandle
from .types import (
C_TO_PY_ARG_SWITCH,
RETURN_SWITCH,
_return_handle,
_wrap_arg_func,
)
DGLPyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p) DGLPyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
_c_str_dltensor = c_str('dltensor') _c_str_dltensor = c_str("dltensor")
_c_str_used_dltensor = c_str('used_dltensor') _c_str_used_dltensor = c_str("used_dltensor")
# used for PyCapsule manipulation # used for PyCapsule manipulation
if hasattr(ctypes, 'pythonapi'): if hasattr(ctypes, "pythonapi"):
ctypes.pythonapi.PyCapsule_GetName.restype = ctypes.c_char_p ctypes.pythonapi.PyCapsule_GetName.restype = ctypes.c_char_p
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
...@@ -31,9 +36,13 @@ def _from_dlpack(dltensor): ...@@ -31,9 +36,13 @@ def _from_dlpack(dltensor):
handle = DGLArrayHandle() handle = DGLArrayHandle()
check_call(_LIB.DGLArrayFromDLPack(ptr, ctypes.byref(handle))) check_call(_LIB.DGLArrayFromDLPack(ptr, ctypes.byref(handle)))
ctypes.pythonapi.PyCapsule_SetName(dltensor, _c_str_used_dltensor) ctypes.pythonapi.PyCapsule_SetName(dltensor, _c_str_used_dltensor)
ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, DGLPyCapsuleDestructor(0)) ctypes.pythonapi.PyCapsule_SetDestructor(
dltensor, DGLPyCapsuleDestructor(0)
)
return _make_array(handle, False) return _make_array(handle, False)
raise ValueError("Expect a dltensor field, PyCapsule can only be consumed once") raise ValueError(
"Expect a dltensor field, PyCapsule can only be consumed once"
)
def _dlpack_deleter(pycapsule): def _dlpack_deleter(pycapsule):
...@@ -45,13 +54,17 @@ def _dlpack_deleter(pycapsule): ...@@ -45,13 +54,17 @@ def _dlpack_deleter(pycapsule):
# work out always. # work out always.
ptr = ctypes.cast(ptr, ctypes.c_void_p) ptr = ctypes.cast(ptr, ctypes.c_void_p)
_LIB.DGLDLManagedTensorCallDeleter(ptr) _LIB.DGLDLManagedTensorCallDeleter(ptr)
ctypes.pythonapi.PyCapsule_SetDestructor(pycapsule, DGLPyCapsuleDestructor(0)) ctypes.pythonapi.PyCapsule_SetDestructor(
pycapsule, DGLPyCapsuleDestructor(0)
)
_c_dlpack_deleter = DGLPyCapsuleDestructor(_dlpack_deleter) _c_dlpack_deleter = DGLPyCapsuleDestructor(_dlpack_deleter)
class NDArrayBase(object): class NDArrayBase(object):
"""A simple Device/CPU Array object in runtime.""" """A simple Device/CPU Array object in runtime."""
__slots__ = ["handle", "is_view"] __slots__ = ["handle", "is_view"]
# pylint: disable=no-member # pylint: disable=no-member
def __init__(self, handle, is_view=False): def __init__(self, handle, is_view=False):
...@@ -89,27 +102,36 @@ class NDArrayBase(object): ...@@ -89,27 +102,36 @@ class NDArrayBase(object):
dlpack : DLPack tensor view of the array data dlpack : DLPack tensor view of the array data
""" """
ptr = ctypes.c_void_p() ptr = ctypes.c_void_p()
check_call(_LIB.DGLArrayToDLPack(self.handle, ctypes.byref(ptr), alignment)) check_call(
return ctypes.pythonapi.PyCapsule_New(ptr, _c_str_dltensor, _c_dlpack_deleter) _LIB.DGLArrayToDLPack(self.handle, ctypes.byref(ptr), alignment)
)
return ctypes.pythonapi.PyCapsule_New(
ptr, _c_str_dltensor, _c_dlpack_deleter
)
def _make_array(handle, is_view): def _make_array(handle, is_view):
handle = ctypes.cast(handle, DGLArrayHandle) handle = ctypes.cast(handle, DGLArrayHandle)
return _CLASS_NDARRAY(handle, is_view) return _CLASS_NDARRAY(handle, is_view)
_DGL_COMPATS = () _DGL_COMPATS = ()
def _reg_extension(cls, fcreate): def _reg_extension(cls, fcreate):
global _DGL_COMPATS global _DGL_COMPATS
_DGL_COMPATS += (cls,) _DGL_COMPATS += (cls,)
if fcreate: if fcreate:
fret = lambda x: fcreate(_return_handle(x)) fret = lambda x: fcreate(_return_handle(x))
RETURN_SWITCH[cls._dgl_tcode] = fret RETURN_SWITCH[cls._dgl_tcode] = fret
C_TO_PY_ARG_SWITCH[cls._dgl_tcode] = _wrap_arg_func(fret, cls._dgl_tcode) C_TO_PY_ARG_SWITCH[cls._dgl_tcode] = _wrap_arg_func(
fret, cls._dgl_tcode
)
_CLASS_NDARRAY = None _CLASS_NDARRAY = None
def _set_class_ndarray(cls): def _set_class_ndarray(cls):
global _CLASS_NDARRAY global _CLASS_NDARRAY
_CLASS_NDARRAY = cls _CLASS_NDARRAY = cls
...@@ -2,10 +2,16 @@ ...@@ -2,10 +2,16 @@
from __future__ import absolute_import from __future__ import absolute_import
import ctypes import ctypes
from ..base import _LIB, check_call, c_str
from ..base import _LIB, c_str, check_call
from ..object_generic import _set_class_object_base from ..object_generic import _set_class_object_base
from .types import DGLValue, TypeCode from .types import (
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func C_TO_PY_ARG_SWITCH,
RETURN_SWITCH,
DGLValue,
TypeCode,
_wrap_arg_func,
)
ObjectHandle = ctypes.c_void_p ObjectHandle = ctypes.c_void_p
__init_by_constructor__ = None __init_by_constructor__ = None
...@@ -13,10 +19,12 @@ __init_by_constructor__ = None ...@@ -13,10 +19,12 @@ __init_by_constructor__ = None
"""Maps object type to its constructor""" """Maps object type to its constructor"""
OBJECT_TYPE = {} OBJECT_TYPE = {}
def _register_object(index, cls): def _register_object(index, cls):
"""register object class in python""" """register object class in python"""
OBJECT_TYPE[index] = cls OBJECT_TYPE[index] = cls
def _return_object(x): def _return_object(x):
"""Construct a object object from the given DGLValue object""" """Construct a object object from the given DGLValue object"""
handle = x.v_handle handle = x.v_handle
...@@ -34,32 +42,41 @@ def _return_object(x): ...@@ -34,32 +42,41 @@ def _return_object(x):
RETURN_SWITCH[TypeCode.OBJECT_HANDLE] = _return_object RETURN_SWITCH[TypeCode.OBJECT_HANDLE] = _return_object
C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_HANDLE] = _wrap_arg_func( C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_HANDLE] = _wrap_arg_func(
_return_object, TypeCode.OBJECT_HANDLE) _return_object, TypeCode.OBJECT_HANDLE
)
class ObjectBase(object): class ObjectBase(object):
"""Object base class""" """Object base class"""
__slots__ = ["handle"] __slots__ = ["handle"]
# pylint: disable=no-member # pylint: disable=no-member
def __del__(self): def __del__(self):
if _LIB is not None and hasattr(self, 'handle'): if _LIB is not None and hasattr(self, "handle"):
check_call(_LIB.DGLObjectFree(self.handle)) check_call(_LIB.DGLObjectFree(self.handle))
def __getattr__(self, name): def __getattr__(self, name):
if name == 'handle': if name == "handle":
raise AttributeError("'handle' is a reserved attribute name that should not be used") raise AttributeError(
"'handle' is a reserved attribute name that should not be used"
)
ret_val = DGLValue() ret_val = DGLValue()
ret_type_code = ctypes.c_int() ret_type_code = ctypes.c_int()
ret_success = ctypes.c_int() ret_success = ctypes.c_int()
check_call(_LIB.DGLObjectGetAttr( check_call(
self.handle, c_str(name), _LIB.DGLObjectGetAttr(
ctypes.byref(ret_val), self.handle,
ctypes.byref(ret_type_code), c_str(name),
ctypes.byref(ret_success))) ctypes.byref(ret_val),
ctypes.byref(ret_type_code),
ctypes.byref(ret_success),
)
)
if not ret_success.value: if not ret_success.value:
raise AttributeError( raise AttributeError(
"'%s' object has no attribute '%s'" % (str(type(self)), name)) "'%s' object has no attribute '%s'" % (str(type(self)), name)
)
return RETURN_SWITCH[ret_type_code.value](ret_val) return RETURN_SWITCH[ret_type_code.value](ret_val)
def __init_handle_by_constructor__(self, fconstructor, *args): def __init_handle_by_constructor__(self, fconstructor, *args):
...@@ -81,9 +98,12 @@ class ObjectBase(object): ...@@ -81,9 +98,12 @@ class ObjectBase(object):
""" """
# assign handle first to avoid error raising # assign handle first to avoid error raising
self.handle = None self.handle = None
handle = __init_by_constructor__(fconstructor, args) # pylint: disable=not-callable handle = __init_by_constructor__(
fconstructor, args
) # pylint: disable=not-callable
if not isinstance(handle, ObjectHandle): if not isinstance(handle, ObjectHandle):
handle = ObjectHandle(handle) handle = ObjectHandle(handle)
self.handle = handle self.handle = handle
_set_class_object_base(ObjectBase) _set_class_object_base(ObjectBase)
...@@ -3,17 +3,22 @@ ...@@ -3,17 +3,22 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import ctypes import ctypes
from ..base import py_str, check_call, _LIB
from ..runtime_ctypes import DGLByteArray, TypeCode, DGLDataType, DGLContext from ..base import _LIB, check_call, py_str
from ..runtime_ctypes import DGLByteArray, DGLContext, DGLDataType, TypeCode
class DGLValue(ctypes.Union): class DGLValue(ctypes.Union):
"""DGLValue in C API""" """DGLValue in C API"""
_fields_ = [("v_int64", ctypes.c_int64),
("v_float64", ctypes.c_double), _fields_ = [
("v_handle", ctypes.c_void_p), ("v_int64", ctypes.c_int64),
("v_str", ctypes.c_char_p), ("v_float64", ctypes.c_double),
("v_type", DGLDataType), ("v_handle", ctypes.c_void_p),
("v_ctx", DGLContext)] ("v_str", ctypes.c_char_p),
("v_type", DGLDataType),
("v_ctx", DGLContext),
]
DGLPackedCFunc = ctypes.CFUNCTYPE( DGLPackedCFunc = ctypes.CFUNCTYPE(
...@@ -22,12 +27,11 @@ DGLPackedCFunc = ctypes.CFUNCTYPE( ...@@ -22,12 +27,11 @@ DGLPackedCFunc = ctypes.CFUNCTYPE(
ctypes.POINTER(ctypes.c_int), ctypes.POINTER(ctypes.c_int),
ctypes.c_int, ctypes.c_int,
ctypes.c_void_p, ctypes.c_void_p,
ctypes.c_void_p) ctypes.c_void_p,
)
DGLCFuncFinalizer = ctypes.CFUNCTYPE( DGLCFuncFinalizer = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
None,
ctypes.c_void_p)
def _return_handle(x): def _return_handle(x):
...@@ -37,6 +41,7 @@ def _return_handle(x): ...@@ -37,6 +41,7 @@ def _return_handle(x):
handle = ctypes.c_void_p(handle) handle = ctypes.c_void_p(handle)
return handle return handle
def _return_bytes(x): def _return_bytes(x):
"""return handle""" """return handle"""
handle = x.v_handle handle = x.v_handle
...@@ -47,16 +52,20 @@ def _return_bytes(x): ...@@ -47,16 +52,20 @@ def _return_bytes(x):
res = bytearray(size) res = bytearray(size)
rptr = (ctypes.c_byte * size).from_buffer(res) rptr = (ctypes.c_byte * size).from_buffer(res)
if not ctypes.memmove(rptr, arr.data, size): if not ctypes.memmove(rptr, arr.data, size):
raise RuntimeError('memmove failed') raise RuntimeError("memmove failed")
return res return res
def _wrap_arg_func(return_f, type_code): def _wrap_arg_func(return_f, type_code):
tcode = ctypes.c_int(type_code) tcode = ctypes.c_int(type_code)
def _wrap_func(x): def _wrap_func(x):
check_call(_LIB.DGLCbArgToReturn(ctypes.byref(x), tcode)) check_call(_LIB.DGLCbArgToReturn(ctypes.byref(x), tcode))
return return_f(x) return return_f(x)
return _wrap_func return _wrap_func
RETURN_SWITCH = { RETURN_SWITCH = {
TypeCode.INT: lambda x: x.v_int64, TypeCode.INT: lambda x: x.v_int64,
TypeCode.FLOAT: lambda x: x.v_float64, TypeCode.FLOAT: lambda x: x.v_float64,
...@@ -64,7 +73,9 @@ RETURN_SWITCH = { ...@@ -64,7 +73,9 @@ RETURN_SWITCH = {
TypeCode.NULL: lambda x: None, TypeCode.NULL: lambda x: None,
TypeCode.STR: lambda x: py_str(x.v_str), TypeCode.STR: lambda x: py_str(x.v_str),
TypeCode.BYTES: _return_bytes, TypeCode.BYTES: _return_bytes,
TypeCode.DGL_CONTEXT: lambda x: DGLContext(x.v_ctx.device_type, x.v_ctx.device_id), TypeCode.DGL_CONTEXT: lambda x: DGLContext(
x.v_ctx.device_type, x.v_ctx.device_id
),
} }
C_TO_PY_ARG_SWITCH = { C_TO_PY_ARG_SWITCH = {
...@@ -74,5 +85,7 @@ C_TO_PY_ARG_SWITCH = { ...@@ -74,5 +85,7 @@ C_TO_PY_ARG_SWITCH = {
TypeCode.NULL: lambda x: None, TypeCode.NULL: lambda x: None,
TypeCode.STR: lambda x: py_str(x.v_str), TypeCode.STR: lambda x: py_str(x.v_str),
TypeCode.BYTES: _return_bytes, TypeCode.BYTES: _return_bytes,
TypeCode.DGL_CONTEXT: lambda x: DGLContext(x.v_ctx.device_type, x.v_ctx.device_id), TypeCode.DGL_CONTEXT: lambda x: DGLContext(
x.v_ctx.device_type, x.v_ctx.device_id
),
} }
...@@ -3,22 +3,24 @@ ...@@ -3,22 +3,24 @@
"""ctypes library and helper functions """ """ctypes library and helper functions """
from __future__ import absolute_import from __future__ import absolute_import
import sys
import os
import ctypes import ctypes
import logging import logging
import os
import sys
import numpy as np import numpy as np
from . import libinfo from . import libinfo
#---------------------------- # ----------------------------
# library loading # library loading
#---------------------------- # ----------------------------
if sys.version_info[0] == 3: if sys.version_info[0] == 3:
string_types = (str,) string_types = (str,)
numeric_types = (float, int, np.float32, np.int32) numeric_types = (float, int, np.float32, np.int32)
# this function is needed for python3 # this function is needed for python3
# to convert ctypes.char_p .value back to python str # to convert ctypes.char_p .value back to python str
py_str = lambda x: x.decode('utf-8') py_str = lambda x: x.decode("utf-8")
else: else:
string_types = (basestring,) string_types = (basestring,)
numeric_types = (float, int, long, np.float32, np.int32) numeric_types = (float, int, long, np.float32, np.int32)
...@@ -27,8 +29,10 @@ else: ...@@ -27,8 +29,10 @@ else:
class DGLError(Exception): class DGLError(Exception):
"""Error thrown by DGL function""" """Error thrown by DGL function"""
pass # pylint: disable=unnecessary-pass pass # pylint: disable=unnecessary-pass
def _load_lib(): def _load_lib():
"""Load libary by searching possible path.""" """Load libary by searching possible path."""
lib_path = libinfo.find_lib_path() lib_path = libinfo.find_lib_path()
...@@ -39,6 +43,7 @@ def _load_lib(): ...@@ -39,6 +43,7 @@ def _load_lib():
lib.DGLGetLastError.restype = ctypes.c_char_p lib.DGLGetLastError.restype = ctypes.c_char_p
return lib, basename, dirname return lib, basename, dirname
# version number # version number
__version__ = libinfo.__version__ __version__ = libinfo.__version__
# library instance of nnvm # library instance of nnvm
...@@ -47,9 +52,9 @@ _LIB, _LIB_NAME, _DIR_NAME = _load_lib() ...@@ -47,9 +52,9 @@ _LIB, _LIB_NAME, _DIR_NAME = _load_lib()
# The FFI mode of DGL # The FFI mode of DGL
_FFI_MODE = os.environ.get("DGL_FFI", "auto") _FFI_MODE = os.environ.get("DGL_FFI", "auto")
#---------------------------- # ----------------------------
# helper function in ctypes. # helper function in ctypes.
#---------------------------- # ----------------------------
def check_call(ret): def check_call(ret):
"""Check the return value of C API call """Check the return value of C API call
...@@ -77,7 +82,7 @@ def c_str(string): ...@@ -77,7 +82,7 @@ def c_str(string):
str : c_char_p str : c_char_p
A char pointer that can be passed to C API A char pointer that can be passed to C API
""" """
return ctypes.c_char_p(string.encode('utf-8')) return ctypes.c_char_p(string.encode("utf-8"))
def c_array(ctype, values): def c_array(ctype, values):
...@@ -111,10 +116,13 @@ def decorate(func, fwrapped): ...@@ -111,10 +116,13 @@ def decorate(func, fwrapped):
The wrapped function The wrapped function
""" """
import decorator import decorator
return decorator.decorate(func, fwrapped) return decorator.decorate(func, fwrapped)
tensor_adapter_loaded = False tensor_adapter_loaded = False
def load_tensor_adapter(backend, version): def load_tensor_adapter(backend, version):
"""Tell DGL to load a tensoradapter library for given backend and version. """Tell DGL to load a tensoradapter library for given backend and version.
...@@ -126,17 +134,17 @@ def load_tensor_adapter(backend, version): ...@@ -126,17 +134,17 @@ def load_tensor_adapter(backend, version):
The version number of the backend. The version number of the backend.
""" """
global tensor_adapter_loaded global tensor_adapter_loaded
version = version.split('+')[0] version = version.split("+")[0]
if sys.platform.startswith('linux'): if sys.platform.startswith("linux"):
basename = 'libtensoradapter_%s_%s.so' % (backend, version) basename = "libtensoradapter_%s_%s.so" % (backend, version)
elif sys.platform.startswith('darwin'): elif sys.platform.startswith("darwin"):
basename = 'libtensoradapter_%s_%s.dylib' % (backend, version) basename = "libtensoradapter_%s_%s.dylib" % (backend, version)
elif sys.platform.startswith('win'): elif sys.platform.startswith("win"):
basename = 'tensoradapter_%s_%s.dll' % (backend, version) basename = "tensoradapter_%s_%s.dll" % (backend, version)
else: else:
raise NotImplementedError('Unsupported system: %s' % sys.platform) raise NotImplementedError("Unsupported system: %s" % sys.platform)
path = os.path.join(_DIR_NAME, 'tensoradapter', backend, basename) path = os.path.join(_DIR_NAME, "tensoradapter", backend, basename)
tensor_adapter_loaded = (_LIB.DGLLoadTensorAdapter(path.encode('utf-8')) == 0) tensor_adapter_loaded = _LIB.DGLLoadTensorAdapter(path.encode("utf-8")) == 0
if not tensor_adapter_loaded: if not tensor_adapter_loaded:
logger = logging.getLogger("dgl-core") logger = logging.getLogger("dgl-core")
logger.debug("Memory optimization with PyTorch is not enabled.") logger.debug("Memory optimization with PyTorch is not enabled.")
...@@ -2,9 +2,10 @@ ...@@ -2,9 +2,10 @@
"""Function namespace.""" """Function namespace."""
from __future__ import absolute_import from __future__ import absolute_import
import sys
import ctypes import ctypes
from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE import sys
from .base import _FFI_MODE, _LIB, c_str, check_call, py_str, string_types
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
...@@ -13,21 +14,31 @@ try: ...@@ -13,21 +14,31 @@ try:
if _FFI_MODE == "ctypes": if _FFI_MODE == "ctypes":
raise ImportError() raise ImportError()
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_function, _set_class_module
from ._cy3.core import FunctionBase as _FunctionBase from ._cy3.core import FunctionBase as _FunctionBase
from ._cy3.core import convert_to_dgl_func from ._cy3.core import (
_set_class_function,
_set_class_module,
convert_to_dgl_func,
)
else: else:
from ._cy2.core import _set_class_function, _set_class_module
from ._cy2.core import FunctionBase as _FunctionBase from ._cy2.core import FunctionBase as _FunctionBase
from ._cy2.core import convert_to_dgl_func from ._cy2.core import (
_set_class_function,
_set_class_module,
convert_to_dgl_func,
)
except IMPORT_EXCEPT: except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_function, _set_class_module
from ._ctypes.function import FunctionBase as _FunctionBase from ._ctypes.function import FunctionBase as _FunctionBase
from ._ctypes.function import convert_to_dgl_func from ._ctypes.function import (
_set_class_function,
_set_class_module,
convert_to_dgl_func,
)
FunctionHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p
class Function(_FunctionBase): class Function(_FunctionBase):
"""The PackedFunc object. """The PackedFunc object.
...@@ -51,11 +62,13 @@ class Function(_FunctionBase): ...@@ -51,11 +62,13 @@ class Function(_FunctionBase):
dgl.register_func: How to register global function. dgl.register_func: How to register global function.
dgl.get_global_func: How to get global function. dgl.get_global_func: How to get global function.
""" """
pass # pylint: disable=unnecessary-pass pass # pylint: disable=unnecessary-pass
class ModuleBase(object): class ModuleBase(object):
"""Base class for module""" """Base class for module"""
__slots__ = ["handle", "_entry", "entry_name"] __slots__ = ["handle", "_entry", "entry_name"]
def __init__(self, handle): def __init__(self, handle):
...@@ -97,13 +110,16 @@ class ModuleBase(object): ...@@ -97,13 +110,16 @@ class ModuleBase(object):
The result function. The result function.
""" """
ret_handle = FunctionHandle() ret_handle = FunctionHandle()
check_call(_LIB.DGLModGetFunction( check_call(
self.handle, c_str(name), _LIB.DGLModGetFunction(
ctypes.c_int(query_imports), self.handle,
ctypes.byref(ret_handle))) c_str(name),
ctypes.c_int(query_imports),
ctypes.byref(ret_handle),
)
)
if not ret_handle.value: if not ret_handle.value:
raise AttributeError( raise AttributeError("Module has no function '%s'" % name)
"Module has no function '%s'" % name)
return Function(ret_handle, False) return Function(ret_handle, False)
def import_module(self, module): def import_module(self, module):
...@@ -175,13 +191,16 @@ def register_func(func_name, f=None, override=False): ...@@ -175,13 +191,16 @@ def register_func(func_name, f=None, override=False):
raise ValueError("expect string function name") raise ValueError("expect string function name")
ioverride = ctypes.c_int(override) ioverride = ctypes.c_int(override)
def register(myf): def register(myf):
"""internal register function""" """internal register function"""
if not isinstance(myf, Function): if not isinstance(myf, Function):
myf = convert_to_dgl_func(myf) myf = convert_to_dgl_func(myf)
check_call(_LIB.DGLFuncRegisterGlobal( check_call(
c_str(func_name), myf.handle, ioverride)) _LIB.DGLFuncRegisterGlobal(c_str(func_name), myf.handle, ioverride)
)
return myf return myf
if f: if f:
return register(f) return register(f)
return register return register
...@@ -214,7 +233,6 @@ def get_global_func(name, allow_missing=False): ...@@ -214,7 +233,6 @@ def get_global_func(name, allow_missing=False):
raise ValueError("Cannot find global function %s" % name) raise ValueError("Cannot find global function %s" % name)
def list_global_func_names(): def list_global_func_names():
"""Get list of global functions registered. """Get list of global functions registered.
...@@ -226,8 +244,9 @@ def list_global_func_names(): ...@@ -226,8 +244,9 @@ def list_global_func_names():
plist = ctypes.POINTER(ctypes.c_char_p)() plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint() size = ctypes.c_uint()
check_call(_LIB.DGLFuncListGlobalNames(ctypes.byref(size), check_call(
ctypes.byref(plist))) _LIB.DGLFuncListGlobalNames(ctypes.byref(size), ctypes.byref(plist))
)
fnames = [] fnames = []
for i in range(size.value): for i in range(size.value):
fnames.append(py_str(plist[i])) fnames.append(py_str(plist[i]))
...@@ -249,8 +268,10 @@ def extract_ext_funcs(finit): ...@@ -249,8 +268,10 @@ def extract_ext_funcs(finit):
The extracted functions The extracted functions
""" """
fdict = {} fdict = {}
def _list(name, func): def _list(name, func):
fdict[name] = func fdict[name] = func
myf = convert_to_dgl_func(_list) myf = convert_to_dgl_func(_list)
ret = finit(myf.handle) ret = finit(myf.handle)
_ = myf _ = myf
...@@ -258,11 +279,13 @@ def extract_ext_funcs(finit): ...@@ -258,11 +279,13 @@ def extract_ext_funcs(finit):
raise RuntimeError("cannot initialize with %s" % finit) raise RuntimeError("cannot initialize with %s" % finit)
return fdict return fdict
def _get_api(f): def _get_api(f):
flocal = f flocal = f
flocal.is_global = True flocal.is_global = True
return flocal return flocal
def _init_api(namespace, target_module_name=None): def _init_api(namespace, target_module_name=None):
"""Initialize api for a given module name """Initialize api for a given module name
...@@ -272,8 +295,7 @@ def _init_api(namespace, target_module_name=None): ...@@ -272,8 +295,7 @@ def _init_api(namespace, target_module_name=None):
target_module_name : str target_module_name : str
The target module name if different from namespace The target module name if different from namespace
""" """
target_module_name = ( target_module_name = target_module_name if target_module_name else namespace
target_module_name if target_module_name else namespace)
if namespace.startswith("dgl."): if namespace.startswith("dgl."):
_init_api_prefix(target_module_name, namespace[4:]) _init_api_prefix(target_module_name, namespace[4:])
else: else:
...@@ -284,10 +306,10 @@ def _init_api_prefix(module_name, prefix): ...@@ -284,10 +306,10 @@ def _init_api_prefix(module_name, prefix):
module = sys.modules[module_name] module = sys.modules[module_name]
for name in list_global_func_names(): for name in list_global_func_names():
if name.startswith("_") and not name.startswith('_deprecate'): if name.startswith("_") and not name.startswith("_deprecate"):
# internal APIs are ignored # internal APIs are ignored
continue continue
name_split = name.rsplit('.', 1) name_split = name.rsplit(".", 1)
if name_split[0] != prefix: if name_split[0] != prefix:
continue continue
...@@ -300,12 +322,13 @@ def _init_api_prefix(module_name, prefix): ...@@ -300,12 +322,13 @@ def _init_api_prefix(module_name, prefix):
f = get_global_func(name) f = get_global_func(name)
ff = _get_api(f) ff = _get_api(f)
ff.__name__ = fname ff.__name__ = fname
ff.__doc__ = ("DGL PackedFunc %s. " % fname) ff.__doc__ = "DGL PackedFunc %s. " % fname
setattr(target_module, ff.__name__, ff) setattr(target_module, ff.__name__, ff)
def _init_internal_api(): def _init_internal_api():
for name in list_global_func_names(): for name in list_global_func_names():
if not name.startswith("_") or name.startswith('_deprecate'): if not name.startswith("_") or name.startswith("_deprecate"):
# normal APIs are ignored # normal APIs are ignored
continue continue
target_module = sys.modules["dgl._api_internal"] target_module = sys.modules["dgl._api_internal"]
...@@ -316,7 +339,8 @@ def _init_internal_api(): ...@@ -316,7 +339,8 @@ def _init_internal_api():
f = get_global_func(name) f = get_global_func(name)
ff = _get_api(f) ff = _get_api(f)
ff.__name__ = fname ff.__name__ = fname
ff.__doc__ = ("DGL PackedFunc %s. " % fname) ff.__doc__ = "DGL PackedFunc %s. " % fname
setattr(target_module, ff.__name__, ff) setattr(target_module, ff.__name__, ff)
_set_class_function(Function) _set_class_function(Function)
"""Library information.""" """Library information."""
from __future__ import absolute_import from __future__ import absolute_import
import sys
import os import os
import pathlib import pathlib
import sys
def find_lib_path(name=None, search_path=None, optional=False): def find_lib_path(name=None, search_path=None, optional=False):
...@@ -30,13 +31,21 @@ def find_lib_path(name=None, search_path=None, optional=False): ...@@ -30,13 +31,21 @@ def find_lib_path(name=None, search_path=None, optional=False):
dll_path = [] dll_path = []
if os.environ.get('DGL_LIBRARY_PATH', None): if os.environ.get("DGL_LIBRARY_PATH", None):
dll_path.append(os.environ['DGL_LIBRARY_PATH']) dll_path.append(os.environ["DGL_LIBRARY_PATH"])
if sys.platform.startswith('linux') and os.environ.get('LD_LIBRARY_PATH', None): if sys.platform.startswith("linux") and os.environ.get(
dll_path.extend([p.strip() for p in os.environ['LD_LIBRARY_PATH'].split(":")]) "LD_LIBRARY_PATH", None
elif sys.platform.startswith('darwin') and os.environ.get('DYLD_LIBRARY_PATH', None): ):
dll_path.extend([p.strip() for p in os.environ['DYLD_LIBRARY_PATH'].split(":")]) dll_path.extend(
[p.strip() for p in os.environ["LD_LIBRARY_PATH"].split(":")]
)
elif sys.platform.startswith("darwin") and os.environ.get(
"DYLD_LIBRARY_PATH", None
):
dll_path.extend(
[p.strip() for p in os.environ["DYLD_LIBRARY_PATH"].split(":")]
)
# Pip lib directory # Pip lib directory
dll_path.append(os.path.join(ffi_dir, "..")) dll_path.append(os.path.join(ffi_dir, ".."))
...@@ -54,17 +63,21 @@ def find_lib_path(name=None, search_path=None, optional=False): ...@@ -54,17 +63,21 @@ def find_lib_path(name=None, search_path=None, optional=False):
elif isinstance(search_path, str): elif isinstance(search_path, str):
dll_path.append(search_path) dll_path.append(search_path)
else: else:
raise ValueError("type(search_path)={} is invalid".format(type(search_path))) raise ValueError(
dll_path = [str(x.absolute()) if isinstance(x, pathlib.Path) "type(search_path)={} is invalid".format(type(search_path))
else os.path.abspath(x) for x in dll_path] )
dll_path = [
str(x.absolute()) if isinstance(x, pathlib.Path) else os.path.abspath(x)
for x in dll_path
]
if name is None: if name is None:
if sys.platform.startswith('win32'): if sys.platform.startswith("win32"):
name = ['libdgl.dll', 'dgl.dll'] name = ["libdgl.dll", "dgl.dll"]
elif sys.platform.startswith('darwin'): elif sys.platform.startswith("darwin"):
name = 'libdgl.dylib' name = "libdgl.dylib"
else: else:
name = 'libdgl.so' name = "libdgl.so"
if isinstance(name, str): if isinstance(name, str):
name = [name] name = [name]
...@@ -76,9 +89,11 @@ def find_lib_path(name=None, search_path=None, optional=False): ...@@ -76,9 +89,11 @@ def find_lib_path(name=None, search_path=None, optional=False):
lib_found = [p for p in lib_dll_path if os.path.isfile(p)] lib_found = [p for p in lib_dll_path if os.path.isfile(p)]
if not lib_found: if not lib_found:
message = ('Cannot find the files.\n' + message = (
'List of candidates:\n' + "Cannot find the files.\n"
str('\n'.join(lib_dll_path))) + "List of candidates:\n"
+ str("\n".join(lib_dll_path))
)
if not optional: if not optional:
raise RuntimeError(message) raise RuntimeError(message)
return None return None
......
...@@ -2,13 +2,20 @@ ...@@ -2,13 +2,20 @@
"""Runtime NDArray api""" """Runtime NDArray api"""
from __future__ import absolute_import from __future__ import absolute_import
import sys
import ctypes import ctypes
import sys
import numpy as np import numpy as np
from .base import _LIB, check_call, c_array, string_types, _FFI_MODE, c_str
from .runtime_ctypes import DGLDataType, DGLContext, DGLArray, DGLArrayHandle
from .runtime_ctypes import TypeCode, dgl_shape_index_t
from .base import _FFI_MODE, _LIB, c_array, c_str, check_call, string_types
from .runtime_ctypes import (
DGLArray,
DGLArrayHandle,
DGLContext,
DGLDataType,
TypeCode,
dgl_shape_index_t,
)
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
...@@ -17,15 +24,31 @@ try: ...@@ -17,15 +24,31 @@ try:
if _FFI_MODE == "ctypes": if _FFI_MODE == "ctypes":
raise ImportError() raise ImportError()
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
from ._cy3.core import NDArrayBase as _NDArrayBase from ._cy3.core import NDArrayBase as _NDArrayBase
from ._cy3.core import (
_from_dlpack,
_make_array,
_reg_extension,
_set_class_ndarray,
)
else: else:
from ._cy2.core import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
from ._cy2.core import NDArrayBase as _NDArrayBase from ._cy2.core import NDArrayBase as _NDArrayBase
from ._cy2.core import (
_from_dlpack,
_make_array,
_reg_extension,
_set_class_ndarray,
)
except IMPORT_EXCEPT: except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
from ._ctypes.ndarray import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
from ._ctypes.ndarray import (
_from_dlpack,
_make_array,
_reg_extension,
_set_class_ndarray,
)
def context(dev_type, dev_id=0): def context(dev_type, dev_id=0):
"""Construct a DGL context with given device type and id. """Construct a DGL context with given device type and id.
...@@ -63,10 +86,9 @@ def context(dev_type, dev_id=0): ...@@ -63,10 +86,9 @@ def context(dev_type, dev_id=0):
def numpyasarray(np_data): def numpyasarray(np_data):
"""Return a DGLArray representation of a numpy array. """Return a DGLArray representation of a numpy array."""
"""
data = np_data data = np_data
assert data.flags['C_CONTIGUOUS'] assert data.flags["C_CONTIGUOUS"]
arr = DGLArray() arr = DGLArray()
shape = c_array(dgl_shape_index_t, data.shape) shape = c_array(dgl_shape_index_t, data.shape)
arr.data = data.ctypes.data_as(ctypes.c_void_p) arr.data = data.ctypes.data_as(ctypes.c_void_p)
...@@ -102,14 +124,18 @@ def empty(shape, dtype="float32", ctx=context(1, 0)): ...@@ -102,14 +124,18 @@ def empty(shape, dtype="float32", ctx=context(1, 0)):
ndim = ctypes.c_int(len(shape)) ndim = ctypes.c_int(len(shape))
handle = DGLArrayHandle() handle = DGLArrayHandle()
dtype = DGLDataType(dtype) dtype = DGLDataType(dtype)
check_call(_LIB.DGLArrayAlloc( check_call(
shape, ndim, _LIB.DGLArrayAlloc(
ctypes.c_int(dtype.type_code), shape,
ctypes.c_int(dtype.bits), ndim,
ctypes.c_int(dtype.lanes), ctypes.c_int(dtype.type_code),
ctx.device_type, ctypes.c_int(dtype.bits),
ctx.device_id, ctypes.c_int(dtype.lanes),
ctypes.byref(handle))) ctx.device_type,
ctx.device_id,
ctypes.byref(handle),
)
)
return _make_array(handle, False) return _make_array(handle, False)
...@@ -135,18 +161,23 @@ def empty_shared_mem(name, is_create, shape, dtype="float32"): ...@@ -135,18 +161,23 @@ def empty_shared_mem(name, is_create, shape, dtype="float32"):
arr : dgl.nd.NDArray arr : dgl.nd.NDArray
The array dgl supported. The array dgl supported.
""" """
name = ctypes.c_char_p(name.encode('utf-8')) name = ctypes.c_char_p(name.encode("utf-8"))
shape = c_array(dgl_shape_index_t, shape) shape = c_array(dgl_shape_index_t, shape)
ndim = ctypes.c_int(len(shape)) ndim = ctypes.c_int(len(shape))
handle = DGLArrayHandle() handle = DGLArrayHandle()
dtype = DGLDataType(dtype) dtype = DGLDataType(dtype)
check_call(_LIB.DGLArrayAllocSharedMem( check_call(
name, shape, ndim, _LIB.DGLArrayAllocSharedMem(
ctypes.c_int(dtype.type_code), name,
ctypes.c_int(dtype.bits), shape,
ctypes.c_int(dtype.lanes), ndim,
is_create, ctypes.c_int(dtype.type_code),
ctypes.byref(handle))) ctypes.c_int(dtype.bits),
ctypes.c_int(dtype.lanes),
is_create,
ctypes.byref(handle),
)
)
return _make_array(handle, False) return _make_array(handle, False)
...@@ -171,10 +202,14 @@ def from_dlpack(dltensor): ...@@ -171,10 +202,14 @@ def from_dlpack(dltensor):
class NDArrayBase(_NDArrayBase): class NDArrayBase(_NDArrayBase):
"""A simple Device/CPU Array object in runtime.""" """A simple Device/CPU Array object in runtime."""
@property @property
def shape(self): def shape(self):
"""Shape of this array""" """Shape of this array"""
return tuple(self.handle.contents.shape[i] for i in range(self.handle.contents.ndim)) return tuple(
self.handle.contents.shape[i]
for i in range(self.handle.contents.ndim)
)
@property @property
def dtype(self): def dtype(self):
...@@ -219,17 +254,19 @@ class NDArrayBase(_NDArrayBase): ...@@ -219,17 +254,19 @@ class NDArrayBase(_NDArrayBase):
def __setitem__(self, in_slice, value): def __setitem__(self, in_slice, value):
"""Set ndarray value""" """Set ndarray value"""
if (not isinstance(in_slice, slice) or if (
in_slice.start is not None not isinstance(in_slice, slice)
or in_slice.stop is not None): or in_slice.start is not None
raise ValueError('Array only support set from numpy array') or in_slice.stop is not None
):
raise ValueError("Array only support set from numpy array")
if isinstance(value, NDArrayBase): if isinstance(value, NDArrayBase):
if value.handle is not self.handle: if value.handle is not self.handle:
value.copyto(self) value.copyto(self)
elif isinstance(value, (np.ndarray, np.generic)): elif isinstance(value, (np.ndarray, np.generic)):
self.copyfrom(value) self.copyfrom(value)
else: else:
raise TypeError('type %s not supported' % str(type(value))) raise TypeError("type %s not supported" % str(type(value)))
def copyfrom(self, source_array): def copyfrom(self, source_array):
"""Perform a synchronized copy from the array. """Perform a synchronized copy from the array.
...@@ -252,8 +289,10 @@ class NDArrayBase(_NDArrayBase): ...@@ -252,8 +289,10 @@ class NDArrayBase(_NDArrayBase):
try: try:
source_array = np.asarray(source_array, dtype=self.dtype) source_array = np.asarray(source_array, dtype=self.dtype)
except: except:
raise TypeError('array must be an array_like data,' + raise TypeError(
'type %s is not supported' % str(type(source_array))) "array must be an array_like data,"
+ "type %s is not supported" % str(type(source_array))
)
t = DGLDataType(self.dtype) t = DGLDataType(self.dtype)
shape, dtype = self.shape, self.dtype shape, dtype = self.shape, self.dtype
if t.lanes > 1: if t.lanes > 1:
...@@ -262,12 +301,17 @@ class NDArrayBase(_NDArrayBase): ...@@ -262,12 +301,17 @@ class NDArrayBase(_NDArrayBase):
dtype = str(t) dtype = str(t)
if source_array.shape != shape: if source_array.shape != shape:
raise ValueError("array shape do not match the shape of NDArray {0} vs {1}".format( raise ValueError(
source_array.shape, shape)) "array shape do not match the shape of NDArray {0} vs {1}".format(
source_array.shape, shape
)
)
source_array = np.ascontiguousarray(source_array, dtype=dtype) source_array = np.ascontiguousarray(source_array, dtype=dtype)
assert source_array.flags['C_CONTIGUOUS'] assert source_array.flags["C_CONTIGUOUS"]
data = source_array.ctypes.data_as(ctypes.c_void_p) data = source_array.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize) nbytes = ctypes.c_size_t(
source_array.size * source_array.dtype.itemsize
)
check_call(_LIB.DGLArrayCopyFromBytes(self.handle, data, nbytes)) check_call(_LIB.DGLArrayCopyFromBytes(self.handle, data, nbytes))
return self return self
...@@ -293,7 +337,7 @@ class NDArrayBase(_NDArrayBase): ...@@ -293,7 +337,7 @@ class NDArrayBase(_NDArrayBase):
t.lanes = 1 t.lanes = 1
dtype = str(t) dtype = str(t)
np_arr = np.empty(shape, dtype=dtype) np_arr = np.empty(shape, dtype=dtype)
assert np_arr.flags['C_CONTIGUOUS'] assert np_arr.flags["C_CONTIGUOUS"]
data = np_arr.ctypes.data_as(ctypes.c_void_p) data = np_arr.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize) nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize)
check_call(_LIB.DGLArrayCopyToBytes(self.handle, data, nbytes)) check_call(_LIB.DGLArrayCopyToBytes(self.handle, data, nbytes))
...@@ -310,20 +354,17 @@ class NDArrayBase(_NDArrayBase): ...@@ -310,20 +354,17 @@ class NDArrayBase(_NDArrayBase):
if isinstance(target, DGLContext): if isinstance(target, DGLContext):
target = empty(self.shape, self.dtype, target) target = empty(self.shape, self.dtype, target)
if isinstance(target, NDArrayBase): if isinstance(target, NDArrayBase):
check_call(_LIB.DGLArrayCopyFromTo( check_call(_LIB.DGLArrayCopyFromTo(self.handle, target.handle))
self.handle, target.handle))
else: else:
raise ValueError("Unsupported target type %s" % str(type(target))) raise ValueError("Unsupported target type %s" % str(type(target)))
return target return target
def pin_memory_(self): def pin_memory_(self):
"""Pin host memory and map into GPU address space (in-place) """Pin host memory and map into GPU address space (in-place)"""
"""
check_call(_LIB.DGLArrayPinData(self.handle)) check_call(_LIB.DGLArrayPinData(self.handle))
def unpin_memory_(self): def unpin_memory_(self):
"""Unpin host memory pinned by pin_memory_() """Unpin host memory pinned by pin_memory_()"""
"""
check_call(_LIB.DGLArrayUnpinData(self.handle)) check_call(_LIB.DGLArrayUnpinData(self.handle))
def record_stream(self, stream): def record_stream(self, stream):
...@@ -340,6 +381,7 @@ class NDArrayBase(_NDArrayBase): ...@@ -340,6 +381,7 @@ class NDArrayBase(_NDArrayBase):
""" """
check_call(_LIB.DGLArrayRecordStream(self.handle, stream)) check_call(_LIB.DGLArrayRecordStream(self.handle, stream))
def free_extension_handle(handle, type_code): def free_extension_handle(handle, type_code):
"""Free c++ extension type handle """Free c++ extension type handle
...@@ -353,6 +395,7 @@ def free_extension_handle(handle, type_code): ...@@ -353,6 +395,7 @@ def free_extension_handle(handle, type_code):
""" """
check_call(_LIB.DGLExtTypeFree(handle, ctypes.c_int(type_code))) check_call(_LIB.DGLExtTypeFree(handle, ctypes.c_int(type_code)))
def register_extension(cls, fcreate=None): def register_extension(cls, fcreate=None):
"""Register a extension class to DGL. """Register a extension class to DGL.
...@@ -398,6 +441,8 @@ def register_extension(cls, fcreate=None): ...@@ -398,6 +441,8 @@ def register_extension(cls, fcreate=None):
return self.handle.value return self.handle.value
""" """
if fcreate and cls._dgl_tcode < TypeCode.EXT_BEGIN: if fcreate and cls._dgl_tcode < TypeCode.EXT_BEGIN:
raise ValueError("Cannot register create when extension tcode is same as buildin") raise ValueError(
"Cannot register create when extension tcode is same as buildin"
)
_reg_extension(cls, fcreate) _reg_extension(cls, fcreate)
return cls return cls
...@@ -4,23 +4,27 @@ from __future__ import absolute_import ...@@ -4,23 +4,27 @@ from __future__ import absolute_import
import ctypes import ctypes
import sys import sys
from .. import _api_internal from .. import _api_internal
from .base import _FFI_MODE, _LIB, c_str, check_call, py_str
from .object_generic import ObjectGeneric, convert_to_object from .object_generic import ObjectGeneric, convert_to_object
from .base import _LIB, check_call, c_str, py_str, _FFI_MODE
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" \ # pylint: disable=invalid-name
else ImportError # pylint: disable=invalid-name IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
try: try:
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes": if _FFI_MODE == "ctypes":
raise ImportError() raise ImportError()
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
from ._cy3.core import _register_object, ObjectBase as _ObjectBase from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import _register_object
else: else:
from ._cy2.core import _register_object, ObjectBase as _ObjectBase from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import _register_object
except IMPORT_EXCEPT: except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
from ._ctypes.object import _register_object, ObjectBase as _ObjectBase from ._ctypes.object import ObjectBase as _ObjectBase
from ._ctypes.object import _register_object
def _new_object(cls): def _new_object(cls):
...@@ -36,11 +40,15 @@ class ObjectBase(_ObjectBase): ...@@ -36,11 +40,15 @@ class ObjectBase(_ObjectBase):
Note that the same handle **CANNOT** be shared across multiple ObjectBase instances. Note that the same handle **CANNOT** be shared across multiple ObjectBase instances.
""" """
def __dir__(self): def __dir__(self):
plist = ctypes.POINTER(ctypes.c_char_p)() plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint() size = ctypes.c_uint()
check_call(_LIB.DGLObjectListAttrNames( check_call(
self.handle, ctypes.byref(size), ctypes.byref(plist))) _LIB.DGLObjectListAttrNames(
self.handle, ctypes.byref(size), ctypes.byref(plist)
)
)
names = [] names = []
for i in range(size.value): for i in range(size.value):
names.append(py_str(plist[i])) names.append(py_str(plist[i]))
...@@ -57,7 +65,7 @@ class ObjectBase(_ObjectBase): ...@@ -57,7 +65,7 @@ class ObjectBase(_ObjectBase):
def __reduce__(self): def __reduce__(self):
cls = type(self) cls = type(self)
return (_new_object, (cls, ), self.__getstate__()) return (_new_object, (cls,), self.__getstate__())
def __getstate__(self): def __getstate__(self):
# TODO(minjie): TVM assumes that a Node (Object in DGL) can be serialized # TODO(minjie): TVM assumes that a Node (Object in DGL) can be serialized
...@@ -100,7 +108,9 @@ def register_object(type_key=None): ...@@ -100,7 +108,9 @@ def register_object(type_key=None):
def register(cls): def register(cls):
"""internal register function""" """internal register function"""
tindex = ctypes.c_int() tindex = ctypes.c_int()
ret = _LIB.DGLObjectTypeKey2Index(c_str(object_name), ctypes.byref(tindex)) ret = _LIB.DGLObjectTypeKey2Index(
c_str(object_name), ctypes.byref(tindex)
)
if ret == 0: if ret == 0:
_register_object(tindex.value, cls) _register_object(tindex.value, cls)
return cls return cls
......
...@@ -2,23 +2,28 @@ ...@@ -2,23 +2,28 @@
# pylint: disable=unused-import # pylint: disable=unused-import
from __future__ import absolute_import from __future__ import absolute_import
from numbers import Number, Integral from numbers import Integral, Number
from .. import _api_internal from .. import _api_internal
from .base import string_types from .base import string_types
# Object base class # Object base class
_CLASS_OBJECT_BASE = None _CLASS_OBJECT_BASE = None
def _set_class_object_base(cls): def _set_class_object_base(cls):
global _CLASS_OBJECT_BASE global _CLASS_OBJECT_BASE
_CLASS_OBJECT_BASE = cls _CLASS_OBJECT_BASE = cls
class ObjectGeneric(object): class ObjectGeneric(object):
"""Base class for all classes that can be converted to object.""" """Base class for all classes that can be converted to object."""
def asobject(self): def asobject(self):
"""Convert value to object""" """Convert value to object"""
raise NotImplementedError() raise NotImplementedError()
def convert_to_object(value): def convert_to_object(value):
"""Convert a python value to corresponding object type. """Convert a python value to corresponding object type.
...@@ -40,9 +45,12 @@ def convert_to_object(value): ...@@ -40,9 +45,12 @@ def convert_to_object(value):
if isinstance(value, dict): if isinstance(value, dict):
vlist = [] vlist = []
for item in value.items(): for item in value.items():
if (not isinstance(item[0], _CLASS_OBJECT_BASE) and if not isinstance(item[0], _CLASS_OBJECT_BASE) and not isinstance(
not isinstance(item[0], string_types)): item[0], string_types
raise ValueError("key of map must already been a container type") ):
raise ValueError(
"key of map must already been a container type"
)
vlist.append(item[0]) vlist.append(item[0])
vlist.append(convert_to_object(item[1])) vlist.append(convert_to_object(item[1]))
return _api_internal._Map(*vlist) return _api_internal._Map(*vlist)
......
...@@ -4,14 +4,18 @@ from __future__ import absolute_import ...@@ -4,14 +4,18 @@ from __future__ import absolute_import
import ctypes import ctypes
import json import json
import numpy as np import numpy as np
from .base import _LIB, check_call
from .. import _api_internal from .. import _api_internal
from .base import _LIB, check_call
dgl_shape_index_t = ctypes.c_int64 dgl_shape_index_t = ctypes.c_int64
class TypeCode(object): class TypeCode(object):
"""Type code used in API calls""" """Type code used in API calls"""
INT = 0 INT = 0
UINT = 1 UINT = 1
FLOAT = 2 FLOAT = 2
...@@ -28,22 +32,25 @@ class TypeCode(object): ...@@ -28,22 +32,25 @@ class TypeCode(object):
NDARRAY_CONTAINER = 13 NDARRAY_CONTAINER = 13
EXT_BEGIN = 15 EXT_BEGIN = 15
class DGLByteArray(ctypes.Structure): class DGLByteArray(ctypes.Structure):
"""Temp data structure for byte array.""" """Temp data structure for byte array."""
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)] _fields_ = [
("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t),
]
class DGLDataType(ctypes.Structure): class DGLDataType(ctypes.Structure):
"""DGL datatype structure""" """DGL datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8), _fields_ = [
("lanes", ctypes.c_uint16)] ("type_code", ctypes.c_uint8),
CODE2STR = { ("bits", ctypes.c_uint8),
0 : 'int', ("lanes", ctypes.c_uint16),
1 : 'uint', ]
2 : 'float', CODE2STR = {0: "int", 1: "uint", 2: "float", 4: "handle"}
4 : 'handle'
}
_cache = {} _cache = {}
def __new__(cls, type_str): def __new__(cls, type_str):
...@@ -90,50 +97,54 @@ class DGLDataType(ctypes.Structure): ...@@ -90,50 +97,54 @@ class DGLDataType(ctypes.Structure):
return x return x
def __eq__(self, other): def __eq__(self, other):
return (self.bits == other.bits and return (
self.type_code == other.type_code and self.bits == other.bits
self.lanes == other.lanes) and self.type_code == other.type_code
and self.lanes == other.lanes
)
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
RPC_SESS_MASK = 128 RPC_SESS_MASK = 128
class DGLContext(ctypes.Structure): class DGLContext(ctypes.Structure):
"""DGL context strucure.""" """DGL context strucure."""
_fields_ = [("device_type", ctypes.c_int),
("device_id", ctypes.c_int)] _fields_ = [("device_type", ctypes.c_int), ("device_id", ctypes.c_int)]
MASK2STR = { MASK2STR = {
1 : 'cpu', 1: "cpu",
2 : 'gpu', 2: "gpu",
4 : 'opencl', 4: "opencl",
5 : 'aocl', 5: "aocl",
6 : 'sdaccel', 6: "sdaccel",
7 : 'vulkan', 7: "vulkan",
8 : 'metal', 8: "metal",
9 : 'vpi', 9: "vpi",
10: 'rocm', 10: "rocm",
11: 'opengl', 11: "opengl",
12: 'ext_dev', 12: "ext_dev",
} }
STR2MASK = { STR2MASK = {
'llvm': 1, "llvm": 1,
'stackvm': 1, "stackvm": 1,
'cpu': 1, "cpu": 1,
'gpu': 2, "gpu": 2,
'cuda': 2, "cuda": 2,
'nvptx': 2, "nvptx": 2,
'cl': 4, "cl": 4,
'opencl': 4, "opencl": 4,
'aocl' : 5, "aocl": 5,
'aocl_sw_emu' : 5, "aocl_sw_emu": 5,
'sdaccel': 6, "sdaccel": 6,
'vulkan': 7, "vulkan": 7,
'metal': 8, "metal": 8,
'vpi': 9, "vpi": 9,
'rocm': 10, "rocm": 10,
'opengl': 11, "opengl": 11,
'ext_dev': 12, "ext_dev": 12,
} }
_cache = {} _cache = {}
...@@ -155,26 +166,25 @@ class DGLContext(ctypes.Structure): ...@@ -155,26 +166,25 @@ class DGLContext(ctypes.Structure):
@property @property
def exist(self): def exist(self):
"""Whether this device exist.""" """Whether this device exist."""
return _api_internal._GetDeviceAttr( return (
self.device_type, self.device_id, 0) != 0 _api_internal._GetDeviceAttr(self.device_type, self.device_id, 0)
!= 0
)
@property @property
def max_threads_per_block(self): def max_threads_per_block(self):
"""Maximum number of threads on each block.""" """Maximum number of threads on each block."""
return _api_internal._GetDeviceAttr( return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 1)
self.device_type, self.device_id, 1)
@property @property
def warp_size(self): def warp_size(self):
"""Number of threads that executes in concurrent.""" """Number of threads that executes in concurrent."""
return _api_internal._GetDeviceAttr( return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 2)
self.device_type, self.device_id, 2)
@property @property
def max_shared_memory_per_block(self): def max_shared_memory_per_block(self):
"""Total amount of shared memory per block in bytes.""" """Total amount of shared memory per block in bytes."""
return _api_internal._GetDeviceAttr( return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 3)
self.device_type, self.device_id, 3)
@property @property
def compute_version(self): def compute_version(self):
...@@ -187,26 +197,22 @@ class DGLContext(ctypes.Structure): ...@@ -187,26 +197,22 @@ class DGLContext(ctypes.Structure):
version : str version : str
The version string in `major.minor` format. The version string in `major.minor` format.
""" """
return _api_internal._GetDeviceAttr( return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 4)
self.device_type, self.device_id, 4)
@property @property
def device_name(self): def device_name(self):
"""Return the string name of device.""" """Return the string name of device."""
return _api_internal._GetDeviceAttr( return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 5)
self.device_type, self.device_id, 5)
@property @property
def max_clock_rate(self): def max_clock_rate(self):
"""Return the max clock frequency of device.""" """Return the max clock frequency of device."""
return _api_internal._GetDeviceAttr( return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 6)
self.device_type, self.device_id, 6)
@property @property
def multi_processor_count(self): def multi_processor_count(self):
"""Return the number of compute units of device.""" """Return the number of compute units of device."""
return _api_internal._GetDeviceAttr( return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 7)
self.device_type, self.device_id, 7)
@property @property
def max_thread_dimensions(self): def max_thread_dimensions(self):
...@@ -217,17 +223,20 @@ class DGLContext(ctypes.Structure): ...@@ -217,17 +223,20 @@ class DGLContext(ctypes.Structure):
dims: List of int dims: List of int
The maximum length of threadIdx.x, threadIdx.y, threadIdx.z The maximum length of threadIdx.x, threadIdx.y, threadIdx.z
""" """
return json.loads(_api_internal._GetDeviceAttr( return json.loads(
self.device_type, self.device_id, 8)) _api_internal._GetDeviceAttr(self.device_type, self.device_id, 8)
)
def sync(self): def sync(self):
"""Synchronize until jobs finished at the context.""" """Synchronize until jobs finished at the context."""
check_call(_LIB.DGLSynchronize(self.device_type, self.device_id, None)) check_call(_LIB.DGLSynchronize(self.device_type, self.device_id, None))
def __eq__(self, other): def __eq__(self, other):
return (isinstance(other, DGLContext) and return (
self.device_id == other.device_id and isinstance(other, DGLContext)
self.device_type == other.device_type) and self.device_id == other.device_id
and self.device_type == other.device_type
)
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
...@@ -237,9 +246,14 @@ class DGLContext(ctypes.Structure): ...@@ -237,9 +246,14 @@ class DGLContext(ctypes.Structure):
tbl_id = self.device_type / RPC_SESS_MASK - 1 tbl_id = self.device_type / RPC_SESS_MASK - 1
dev_type = self.device_type % RPC_SESS_MASK dev_type = self.device_type % RPC_SESS_MASK
return "remote[%d]:%s(%d)" % ( return "remote[%d]:%s(%d)" % (
tbl_id, DGLContext.MASK2STR[dev_type], self.device_id) tbl_id,
DGLContext.MASK2STR[dev_type],
self.device_id,
)
return "%s(%d)" % ( return "%s(%d)" % (
DGLContext.MASK2STR[self.device_type], self.device_id) DGLContext.MASK2STR[self.device_type],
self.device_id,
)
def __hash__(self): def __hash__(self):
return hash((self.device_type, self.device_id)) return hash((self.device_type, self.device_id))
...@@ -247,13 +261,17 @@ class DGLContext(ctypes.Structure): ...@@ -247,13 +261,17 @@ class DGLContext(ctypes.Structure):
class DGLArray(ctypes.Structure): class DGLArray(ctypes.Structure):
"""DGLValue in C API""" """DGLValue in C API"""
_fields_ = [("data", ctypes.c_void_p),
("ctx", DGLContext), _fields_ = [
("ndim", ctypes.c_int), ("data", ctypes.c_void_p),
("dtype", DGLDataType), ("ctx", DGLContext),
("shape", ctypes.POINTER(dgl_shape_index_t)), ("ndim", ctypes.c_int),
("strides", ctypes.POINTER(dgl_shape_index_t)), ("dtype", DGLDataType),
("byte_offset", ctypes.c_uint64)] ("shape", ctypes.POINTER(dgl_shape_index_t)),
("strides", ctypes.POINTER(dgl_shape_index_t)),
("byte_offset", ctypes.c_uint64),
]
DGLArrayHandle = ctypes.POINTER(DGLArray) DGLArrayHandle = ctypes.POINTER(DGLArray)
......
...@@ -5,13 +5,13 @@ For applications, please use PyTorch's stream management, of which DGL is aware. ...@@ -5,13 +5,13 @@ For applications, please use PyTorch's stream management, of which DGL is aware.
from __future__ import absolute_import from __future__ import absolute_import
import ctypes import ctypes
from .base import _LIB, check_call, _FFI_MODE
from .base import _FFI_MODE, _LIB, check_call
from .runtime_ctypes import DGLStreamHandle from .runtime_ctypes import DGLStreamHandle
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
def to_dgl_stream_handle(cuda_stream): def to_dgl_stream_handle(cuda_stream):
""" Convert torch.cuda.Stream to DGL stream handle """Convert torch.cuda.Stream to DGL stream handle
Parameters Parameters
---------- ----------
...@@ -24,6 +24,7 @@ def to_dgl_stream_handle(cuda_stream): ...@@ -24,6 +24,7 @@ def to_dgl_stream_handle(cuda_stream):
""" """
return ctypes.c_void_p(cuda_stream.cuda_stream) return ctypes.c_void_p(cuda_stream.cuda_stream)
def _dgl_get_stream(ctx): def _dgl_get_stream(ctx):
"""Get the current CUDA stream of the given DGL context. """Get the current CUDA stream of the given DGL context.
...@@ -37,6 +38,9 @@ def _dgl_get_stream(ctx): ...@@ -37,6 +38,9 @@ def _dgl_get_stream(ctx):
DGLStreamHandle of the current CUDA stream. DGLStreamHandle of the current CUDA stream.
""" """
current_cuda_stream = DGLStreamHandle() current_cuda_stream = DGLStreamHandle()
check_call(_LIB.DGLGetStream( check_call(
ctx.device_type, ctx.device_id, ctypes.byref(current_cuda_stream))) _LIB.DGLGetStream(
ctx.device_type, ctx.device_id, ctypes.byref(current_cuda_stream)
)
)
return current_cuda_stream return current_cuda_stream
from __future__ import absolute_import from __future__ import absolute_import
import sys
import os
import json
import importlib import importlib
import json
import logging import logging
import os
import sys
from . import backend from . import backend
from .set_default_backend import set_default_backend from .set_default_backend import set_default_backend
...@@ -13,13 +13,18 @@ _enabled_apis = set() ...@@ -13,13 +13,18 @@ _enabled_apis = set()
logger = logging.getLogger("dgl-core") logger = logging.getLogger("dgl-core")
def _gen_missing_api(api, mod_name): def _gen_missing_api(api, mod_name):
def _missing_api(*args, **kwargs): def _missing_api(*args, **kwargs):
raise ImportError('API "%s" is not supported by backend "%s".' raise ImportError(
' You can switch to other backends by setting' 'API "%s" is not supported by backend "%s".'
' the DGLBACKEND environment.' % (api, mod_name)) " You can switch to other backends by setting"
" the DGLBACKEND environment." % (api, mod_name)
)
return _missing_api return _missing_api
def load_backend(mod_name): def load_backend(mod_name):
# Load backend does four things: # Load backend does four things:
# (1) Import backend framework (PyTorch, MXNet, Tensorflow, etc.) # (1) Import backend framework (PyTorch, MXNet, Tensorflow, etc.)
...@@ -28,40 +33,46 @@ def load_backend(mod_name): ...@@ -28,40 +33,46 @@ def load_backend(mod_name):
# (3) Sets up the tensoradapter library path. # (3) Sets up the tensoradapter library path.
# (4) Import the Python wrappers of the backend framework. DGL does this last because # (4) Import the Python wrappers of the backend framework. DGL does this last because
# it already depends on both the backend framework and the DGL C library. # it already depends on both the backend framework and the DGL C library.
if mod_name == 'pytorch': if mod_name == "pytorch":
import torch import torch
mod = torch mod = torch
elif mod_name == 'mxnet': elif mod_name == "mxnet":
import mxnet import mxnet
mod = mxnet mod = mxnet
elif mod_name == 'tensorflow': elif mod_name == "tensorflow":
import tensorflow import tensorflow
mod = tensorflow mod = tensorflow
else: else:
raise NotImplementedError('Unsupported backend: %s' % mod_name) raise NotImplementedError("Unsupported backend: %s" % mod_name)
from .._ffi.base import load_tensor_adapter # imports DGL C library
from .._ffi.base import load_tensor_adapter # imports DGL C library
version = mod.__version__ version = mod.__version__
load_tensor_adapter(mod_name, version) load_tensor_adapter(mod_name, version)
logger.debug('Using backend: %s' % mod_name) logger.debug("Using backend: %s" % mod_name)
mod = importlib.import_module('.%s' % mod_name, __name__) mod = importlib.import_module(".%s" % mod_name, __name__)
thismod = sys.modules[__name__] thismod = sys.modules[__name__]
for api in backend.__dict__.keys(): for api in backend.__dict__.keys():
if api.startswith('__'): if api.startswith("__"):
# ignore python builtin attributes # ignore python builtin attributes
continue continue
if api == 'data_type_dict': if api == "data_type_dict":
# load data type # load data type
if api not in mod.__dict__: if api not in mod.__dict__:
raise ImportError('API "data_type_dict" is required but missing for' raise ImportError(
' backend "%s".' % (mod_name)) 'API "data_type_dict" is required but missing for'
' backend "%s".' % (mod_name)
)
data_type_dict = mod.__dict__[api]() data_type_dict = mod.__dict__[api]()
for name, dtype in data_type_dict.items(): for name, dtype in data_type_dict.items():
setattr(thismod, name, dtype) setattr(thismod, name, dtype)
# override data type dict function # override data type dict function
setattr(thismod, 'data_type_dict', data_type_dict) setattr(thismod, "data_type_dict", data_type_dict)
# for data types with aliases, treat the first listed type as # for data types with aliases, treat the first listed type as
# the true one # the true one
...@@ -69,11 +80,9 @@ def load_backend(mod_name): ...@@ -69,11 +80,9 @@ def load_backend(mod_name):
for k, v in data_type_dict.items(): for k, v in data_type_dict.items():
if not v in rev_data_type_dict.keys(): if not v in rev_data_type_dict.keys():
rev_data_type_dict[v] = k rev_data_type_dict[v] = k
setattr(thismod, setattr(thismod, "reverse_data_type_dict", rev_data_type_dict)
'reverse_data_type_dict',
rev_data_type_dict)
# log backend name # log backend name
setattr(thismod, 'backend_name', mod_name) setattr(thismod, "backend_name", mod_name)
else: else:
# load functions # load functions
if api in mod.__dict__: if api in mod.__dict__:
...@@ -82,28 +91,32 @@ def load_backend(mod_name): ...@@ -82,28 +91,32 @@ def load_backend(mod_name):
else: else:
setattr(thismod, api, _gen_missing_api(api, mod_name)) setattr(thismod, api, _gen_missing_api(api, mod_name))
def get_preferred_backend(): def get_preferred_backend():
default_dir = None default_dir = None
if "DGLDEFAULTDIR" in os.environ: if "DGLDEFAULTDIR" in os.environ:
default_dir = os.getenv('DGLDEFAULTDIR') default_dir = os.getenv("DGLDEFAULTDIR")
else: else:
default_dir = os.path.join(os.path.expanduser('~'), '.dgl') default_dir = os.path.join(os.path.expanduser("~"), ".dgl")
config_path = os.path.join(default_dir, 'config.json') config_path = os.path.join(default_dir, "config.json")
backend_name = None backend_name = None
if "DGLBACKEND" in os.environ: if "DGLBACKEND" in os.environ:
backend_name = os.getenv('DGLBACKEND') backend_name = os.getenv("DGLBACKEND")
elif os.path.exists(config_path): elif os.path.exists(config_path):
with open(config_path, "r") as config_file: with open(config_path, "r") as config_file:
config_dict = json.load(config_file) config_dict = json.load(config_file)
backend_name = config_dict.get('backend', '').lower() backend_name = config_dict.get("backend", "").lower()
if (backend_name in ['tensorflow', 'mxnet', 'pytorch']): if backend_name in ["tensorflow", "mxnet", "pytorch"]:
return backend_name return backend_name
else: else:
print("DGL backend not selected or invalid. " print(
"Assuming PyTorch for now.", file=sys.stderr) "DGL backend not selected or invalid. "
set_default_backend(default_dir, 'pytorch') "Assuming PyTorch for now.",
return 'pytorch' file=sys.stderr,
)
set_default_backend(default_dir, "pytorch")
return "pytorch"
load_backend(get_preferred_backend()) load_backend(get_preferred_backend())
...@@ -124,8 +137,10 @@ def is_enabled(api): ...@@ -124,8 +137,10 @@ def is_enabled(api):
""" """
return api in _enabled_apis return api in _enabled_apis
def to_dgl_nd(data): def to_dgl_nd(data):
return zerocopy_to_dgl_ndarray(data) return zerocopy_to_dgl_ndarray(data)
def from_dgl_nd(data): def from_dgl_nd(data):
return zerocopy_from_dgl_ndarray(data) return zerocopy_from_dgl_ndarray(data)
This diff is collapsed.
from .tensor import *
from .sparse import * from .sparse import *
from .tensor import *
import mxnet as mx import mxnet as mx
import numpy as np import numpy as np
from mxnet import nd from mxnet import nd
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _scatter_add
from ...sparse import _csrmm, _csrsum, _csrmask
from ...base import dgl_warning, is_all, ALL
from .tensor import asnumpy, copy_to, zerocopy_from_numpy, context, to_backend_ctx
from ...heterograph_index import create_unitgraph_from_csr
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce', 'scatter_add', from ...base import ALL, dgl_warning, is_all
'csrmm', 'csrsum', 'csrmask'] from ...heterograph_index import create_unitgraph_from_csr
from ...sparse import (
_bwd_segment_cmp,
_csrmask,
_csrmm,
_csrsum,
_gsddmm,
_gspmm,
_scatter_add,
_segment_reduce,
)
from .tensor import (
asnumpy,
context,
copy_to,
to_backend_ctx,
zerocopy_from_numpy,
)
__all__ = [
"gspmm",
"gsddmm",
"edge_softmax",
"segment_reduce",
"scatter_add",
"csrmm",
"csrsum",
"csrmask",
]
def _scatter_nd(index, src, n_rows): def _scatter_nd(index, src, n_rows):
...@@ -26,7 +49,10 @@ def _scatter_nd(index, src, n_rows): ...@@ -26,7 +49,10 @@ def _scatter_nd(index, src, n_rows):
di = shp[i] di = shp[i]
offset_i = np.arange(di, dtype=index.dtype) offset_i = np.arange(di, dtype=index.dtype)
offsets.append( offsets.append(
(stride * offset_i).reshape((1,) * i + (di,) + (1,) * (ndim - 1 - i))) (stride * offset_i).reshape(
(1,) * i + (di,) + (1,) * (ndim - 1 - i)
)
)
stride *= di stride *= di
if ndim > 1: if ndim > 1:
new_idx = index * stride + sum(offsets) new_idx = index * stride + sum(offsets)
...@@ -52,7 +78,10 @@ def _gather_nd(index, src): ...@@ -52,7 +78,10 @@ def _gather_nd(index, src):
di = shp[i] di = shp[i]
offset_i = nd.arange(di, dtype=index.dtype) offset_i = nd.arange(di, dtype=index.dtype)
offsets.append( offsets.append(
(stride * offset_i).reshape((1,) * i + (di,) + (1,) * (ndim - 1 - i))) (stride * offset_i).reshape(
(1,) * i + (di,) + (1,) * (ndim - 1 - i)
)
)
stride *= di stride *= di
if ndim > 1: if ndim > 1:
new_idx = index * stride + copy_to(sum(offsets), ctx) new_idx = index * stride + copy_to(sum(offsets), ctx)
...@@ -107,11 +136,11 @@ def _need_reduce_last_dim(ufeat, efeat): ...@@ -107,11 +136,11 @@ def _need_reduce_last_dim(ufeat, efeat):
def _muldiv(op, x): def _muldiv(op, x):
return 1. / x if op == 'div' else x return 1.0 / x if op == "div" else x
def _addsub(op, x): def _addsub(op, x):
return -x if op == 'sub' else x return -x if op == "sub" else x
def _expand(x, shape): def _expand(x, shape):
...@@ -134,45 +163,48 @@ class GSpMM(mx.autograd.Function): ...@@ -134,45 +163,48 @@ class GSpMM(mx.autograd.Function):
ctx = context(dZ) ctx = context(dZ)
X, Y, argX, argY = self.saved_tensors X, Y, argX, argY = self.saved_tensors
gidx, op, reduce_op = self.gidx, self.op, self.reduce_op gidx, op, reduce_op = self.gidx, self.op, self.reduce_op
if op != 'copy_rhs': if op != "copy_rhs":
g_rev = gidx.reverse() g_rev = gidx.reverse()
if reduce_op == 'sum': if reduce_op == "sum":
if op in ['mul', 'div']: if op in ["mul", "div"]:
dX = _gspmm(g_rev, 'mul', 'sum', dZ, _muldiv(op, Y))[0] dX = _gspmm(g_rev, "mul", "sum", dZ, _muldiv(op, Y))[0]
elif op in ['add', 'sub']: elif op in ["add", "sub"]:
dX = _gspmm(g_rev, 'copy_lhs', 'sum', dZ, Y)[0] dX = _gspmm(g_rev, "copy_lhs", "sum", dZ, Y)[0]
elif op == 'copy_lhs': elif op == "copy_lhs":
dX = _gspmm(g_rev, 'copy_lhs', 'sum', dZ, None)[0] dX = _gspmm(g_rev, "copy_lhs", "sum", dZ, None)[0]
else: else:
if op in ['mul', 'div']: if op in ["mul", "div"]:
dX = _scatter_nd( dX = _scatter_nd(
argX, argX,
_muldiv(op, _gather_nd(argY, _expand(Y, dZ.shape[1:]))) * dZ, _muldiv(op, _gather_nd(argY, _expand(Y, dZ.shape[1:])))
X.shape[0]) * dZ,
elif op in ['add', 'sub', 'copy_lhs']: X.shape[0],
)
elif op in ["add", "sub", "copy_lhs"]:
dX = _scatter_nd(argX, dZ, X.shape[0]) dX = _scatter_nd(argX, dZ, X.shape[0])
dX = _reduce_grad(dX, X.shape) dX = _reduce_grad(dX, X.shape)
else: else:
dX = nd.zeros_like(X) dX = nd.zeros_like(X)
if op != 'copy_lhs': if op != "copy_lhs":
if reduce_op == 'sum': if reduce_op == "sum":
if op == 'mul' and _need_reduce_last_dim(X, Y): if op == "mul" and _need_reduce_last_dim(X, Y):
dY = _gsddmm(gidx, 'dot', X, dZ) dY = _gsddmm(gidx, "dot", X, dZ)
elif op in ['mul', 'div']: elif op in ["mul", "div"]:
dY = _gsddmm(gidx, 'mul', X, dZ) dY = _gsddmm(gidx, "mul", X, dZ)
if op == 'div': if op == "div":
dY = -dY / (Y ** 2) dY = -dY / (Y**2)
elif op in ['add', 'sub', 'copy_rhs']: elif op in ["add", "sub", "copy_rhs"]:
dY = _gsddmm(gidx, 'copy_rhs', X, _addsub(op, dZ)) dY = _gsddmm(gidx, "copy_rhs", X, _addsub(op, dZ))
else: else:
if op in ['mul', 'div']: if op in ["mul", "div"]:
dY = _scatter_nd( dY = _scatter_nd(
argY, argY,
_gather_nd(argX, _expand(X, dZ.shape[1:])) * dZ, _gather_nd(argX, _expand(X, dZ.shape[1:])) * dZ,
Y.shape[0]) Y.shape[0],
if op == 'div': )
dY = -dY / (Y ** 2) if op == "div":
elif op in ['add', 'sub', 'copy_rhs']: dY = -dY / (Y**2)
elif op in ["add", "sub", "copy_rhs"]:
dY = _scatter_nd(argY, _addsub(op, dZ), Y.shape[0]) dY = _scatter_nd(argY, _addsub(op, dZ), Y.shape[0])
dY = _reduce_grad(dY, Y.shape) dY = _reduce_grad(dY, Y.shape)
else: else:
...@@ -207,7 +239,9 @@ class GSDDMM(mx.autograd.Function): ...@@ -207,7 +239,9 @@ class GSDDMM(mx.autograd.Function):
self.rhs_target = rhs_target self.rhs_target = rhs_target
def forward(self, X, Y): def forward(self, X, Y):
out = _gsddmm(self.gidx, self.op, X, Y, self.lhs_target, self.rhs_target) out = _gsddmm(
self.gidx, self.op, X, Y, self.lhs_target, self.rhs_target
)
self.save_for_backward(X, Y) self.save_for_backward(X, Y)
return out return out
...@@ -216,47 +250,55 @@ class GSDDMM(mx.autograd.Function): ...@@ -216,47 +250,55 @@ class GSDDMM(mx.autograd.Function):
X, Y = self.saved_tensors X, Y = self.saved_tensors
gidx, op = self.gidx, self.op gidx, op = self.gidx, self.op
lhs_target, rhs_target = self.lhs_target, self.rhs_target lhs_target, rhs_target = self.lhs_target, self.rhs_target
if op != 'copy_rhs': if op != "copy_rhs":
if lhs_target in ['u', 'v']: if lhs_target in ["u", "v"]:
_gidx = gidx if self.lhs_target == 'v' else gidx.reverse() _gidx = gidx if self.lhs_target == "v" else gidx.reverse()
if op in ['add', 'sub', 'copy_lhs']: if op in ["add", "sub", "copy_lhs"]:
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] dX = _gspmm(_gidx, "copy_rhs", "sum", None, dZ)[0]
else: # mul, div, dot else: # mul, div, dot
if rhs_target == lhs_target: if rhs_target == lhs_target:
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] * _muldiv(op, Y) dX = _gspmm(_gidx, "copy_rhs", "sum", None, dZ)[
elif self.rhs_target == 'e': 0
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * _muldiv(op, Y))[0] ] * _muldiv(op, Y)
elif self.rhs_target == "e":
dX = _gspmm(
_gidx, "copy_rhs", "sum", None, dZ * _muldiv(op, Y)
)[0]
else: # rhs_target = !lhs_target else: # rhs_target = !lhs_target
dX = _gspmm(_gidx, 'mul', 'sum', _muldiv(op, Y), dZ)[0] dX = _gspmm(_gidx, "mul", "sum", _muldiv(op, Y), dZ)[0]
else: # lhs_target == 'e' else: # lhs_target == 'e'
if op in ['add', 'sub', 'copy_lhs']: if op in ["add", "sub", "copy_lhs"]:
dX = dZ dX = dZ
else: # mul, div, dot else: # mul, div, dot
dX = _gsddmm(gidx, 'mul', dZ, _muldiv(op, Y), 'e', rhs_target) dX = _gsddmm(
gidx, "mul", dZ, _muldiv(op, Y), "e", rhs_target
)
dX = _reduce_grad(dX, X.shape) dX = _reduce_grad(dX, X.shape)
else: else:
dX = nd.zeros_like(X) dX = nd.zeros_like(X)
if op != 'copy_lhs': if op != "copy_lhs":
if self.rhs_target in ['u', 'v']: if self.rhs_target in ["u", "v"]:
_gidx = gidx if rhs_target == 'v' else gidx.reverse() _gidx = gidx if rhs_target == "v" else gidx.reverse()
if op in ['add', 'sub', 'copy_rhs']: if op in ["add", "sub", "copy_rhs"]:
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, _addsub(op, dZ))[0] dY = _gspmm(
_gidx, "copy_rhs", "sum", None, _addsub(op, dZ)
)[0]
else: # mul, div, dot else: # mul, div, dot
if lhs_target == rhs_target: if lhs_target == rhs_target:
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] * X dY = _gspmm(_gidx, "copy_rhs", "sum", None, dZ)[0] * X
elif self.lhs_target == 'e': elif self.lhs_target == "e":
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * X)[0] dY = _gspmm(_gidx, "copy_rhs", "sum", None, dZ * X)[0]
else: # rhs_target = !lhs_target else: # rhs_target = !lhs_target
dY = _gspmm(_gidx, 'mul', 'sum', X, dZ)[0] dY = _gspmm(_gidx, "mul", "sum", X, dZ)[0]
if op == 'div': if op == "div":
dY = -dY / (Y ** 2) dY = -dY / (Y**2)
else: else:
if op in ['add', 'sub', 'copy_rhs']: if op in ["add", "sub", "copy_rhs"]:
dY = _addsub(op, dZ) dY = _addsub(op, dZ)
else: # mul, div, dot else: # mul, div, dot
dY = _gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target) dY = _gsddmm(gidx, "mul", dZ, X, "e", lhs_target)
if op == 'div': if op == "div":
dY = -dY / (Y ** 2) dY = -dY / (Y**2)
dY = _reduce_grad(dY, Y.shape) dY = _reduce_grad(dY, Y.shape)
else: else:
dY = nd.zeros_like(Y) dY = nd.zeros_like(Y)
...@@ -264,7 +306,7 @@ class GSDDMM(mx.autograd.Function): ...@@ -264,7 +306,7 @@ class GSDDMM(mx.autograd.Function):
return dX, dY return dX, dY
def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'): def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target="u", rhs_target="v"):
func = GSDDMM(gidx, op, lhs_target, rhs_target) func = GSDDMM(gidx, op, lhs_target, rhs_target)
ctx = to_backend_ctx(gidx.ctx) ctx = to_backend_ctx(gidx.ctx)
if lhs_data is None: if lhs_data is None:
...@@ -279,7 +321,7 @@ class EdgeSoftmax(mx.autograd.Function): ...@@ -279,7 +321,7 @@ class EdgeSoftmax(mx.autograd.Function):
super(EdgeSoftmax, self).__init__() super(EdgeSoftmax, self).__init__()
if not is_all(eids): if not is_all(eids):
gidx = gidx.edge_subgraph([eids], True).graph gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == 'src': if norm_by == "src":
gidx = gidx.reverse() gidx = gidx.reverse()
self.gidx = gidx self.gidx = gidx
...@@ -298,10 +340,10 @@ class EdgeSoftmax(mx.autograd.Function): ...@@ -298,10 +340,10 @@ class EdgeSoftmax(mx.autograd.Function):
return out.data return out.data
""" """
gidx = self.gidx gidx = self.gidx
score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0] score_max = _gspmm(gidx, "copy_rhs", "max", None, score)[0]
score = mx.nd.exp(_gsddmm(gidx, 'sub', score, score_max, 'e', 'v')) score = mx.nd.exp(_gsddmm(gidx, "sub", score, score_max, "e", "v"))
score_sum = _gspmm(gidx, 'copy_rhs', 'sum', None, score)[0] score_sum = _gspmm(gidx, "copy_rhs", "sum", None, score)[0]
out = _gsddmm(gidx, 'div', score, score_sum, 'e', 'v') out = _gsddmm(gidx, "div", score, score_sum, "e", "v")
self.save_for_backward(out) self.save_for_backward(out)
return out return out
...@@ -319,16 +361,16 @@ class EdgeSoftmax(mx.autograd.Function): ...@@ -319,16 +361,16 @@ class EdgeSoftmax(mx.autograd.Function):
sds_sum = sds.dst_sum() # type dgl.NData sds_sum = sds.dst_sum() # type dgl.NData
grad_score = sds - sds * sds_sum # multiple expressions grad_score = sds - sds * sds_sum # multiple expressions
""" """
out, = self.saved_tensors (out,) = self.saved_tensors
gidx = self.gidx gidx = self.gidx
sds = out * grad_out sds = out * grad_out
accum = gspmm(gidx, 'copy_rhs', 'sum', None, sds) accum = gspmm(gidx, "copy_rhs", "sum", None, sds)
grad_score = sds - gsddmm(gidx, 'mul', out, accum, 'e', 'v') grad_score = sds - gsddmm(gidx, "mul", out, accum, "e", "v")
self.save_tensors = None self.save_tensors = None
return grad_score return grad_score
def edge_softmax(gidx, logits, eids=ALL, norm_by='dst'): def edge_softmax(gidx, logits, eids=ALL, norm_by="dst"):
softmax_op = EdgeSoftmax(gidx, eids, norm_by) softmax_op = EdgeSoftmax(gidx, eids, norm_by)
return softmax_op(logits) return softmax_op(logits)
...@@ -345,10 +387,10 @@ class SegmentReduce(mx.autograd.Function): ...@@ -345,10 +387,10 @@ class SegmentReduce(mx.autograd.Function):
return y return y
def backward(self, dy): def backward(self, dy):
arg, = self.saved_tensors (arg,) = self.saved_tensors
offsets = self.offsets offsets = self.offsets
m = offsets[-1].asscalar() m = offsets[-1].asscalar()
if self.op == 'sum': if self.op == "sum":
offsets_np = asnumpy(offsets[1:]) offsets_np = asnumpy(offsets[1:])
indices_np = np.zeros((m + 1,), dtype=offsets_np.dtype) indices_np = np.zeros((m + 1,), dtype=offsets_np.dtype)
np.add.at(indices_np, offsets_np, np.ones_like(offsets_np)) np.add.at(indices_np, offsets_np, np.ones_like(offsets_np))
...@@ -374,7 +416,7 @@ class ScatterAdd(mx.autograd.Function): ...@@ -374,7 +416,7 @@ class ScatterAdd(mx.autograd.Function):
def forward(self, x): def forward(self, x):
y = _scatter_add(x, self.idx, self.m) y = _scatter_add(x, self.idx, self.m)
return y return y
def backward(self, dy): def backward(self, dy):
return dy[self.idx] return dy[self.idx]
...@@ -392,36 +434,66 @@ class CSRMM(mx.autograd.Function): ...@@ -392,36 +434,66 @@ class CSRMM(mx.autograd.Function):
self.num_vtypes = num_vtypes self.num_vtypes = num_vtypes
def forward(self, A_weights, B_weights): def forward(self, A_weights, B_weights):
gidxC, C_weights = _csrmm(self.gidxA, A_weights, self.gidxB, B_weights, self.num_vtypes) gidxC, C_weights = _csrmm(
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(0, False, 'csr') self.gidxA, A_weights, self.gidxB, B_weights, self.num_vtypes
)
(
nrows,
ncols,
C_indptr,
C_indices,
C_eids,
) = gidxC.adjacency_matrix_tensors(0, False, "csr")
# Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same # Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same
# as the underlying tensors of the created graph gidxC. # as the underlying tensors of the created graph gidxC.
self.backward_cache = gidxC self.backward_cache = gidxC
self.save_for_backward(A_weights, B_weights) self.save_for_backward(A_weights, B_weights)
nrows = nd.array([nrows], dtype='int64') nrows = nd.array([nrows], dtype="int64")
ncols = nd.array([ncols], dtype='int64') ncols = nd.array([ncols], dtype="int64")
return nrows, ncols, C_indptr, C_indices, C_eids, C_weights return nrows, ncols, C_indptr, C_indices, C_eids, C_weights
def backward(self, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights): def backward(
self, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights
):
# Only the last argument is meaningful. # Only the last argument is meaningful.
gidxC = self.backward_cache gidxC = self.backward_cache
A_weights, B_weights = self.saved_tensors A_weights, B_weights = self.saved_tensors
dgidxA, dA_weights = _csrmm( dgidxA, dA_weights = _csrmm(
gidxC, dC_weights, self.gidxB.reverse(), B_weights, self.gidxA.number_of_ntypes()) gidxC,
dC_weights,
self.gidxB.reverse(),
B_weights,
self.gidxA.number_of_ntypes(),
)
dgidxB, dB_weights = _csrmm( dgidxB, dB_weights = _csrmm(
self.gidxA.reverse(), A_weights, gidxC, dC_weights, self.gidxB.number_of_ntypes()) self.gidxA.reverse(),
A_weights,
gidxC,
dC_weights,
self.gidxB.number_of_ntypes(),
)
dA_weights = _csrmask(dgidxA, dA_weights, self.gidxA) dA_weights = _csrmask(dgidxA, dA_weights, self.gidxA)
dB_weights = _csrmask(dgidxB, dB_weights, self.gidxB) dB_weights = _csrmask(dgidxB, dB_weights, self.gidxB)
return dA_weights, dB_weights return dA_weights, dB_weights
def csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes): def csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes):
op = CSRMM(gidxA, gidxB, num_vtypes) op = CSRMM(gidxA, gidxB, num_vtypes)
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = op(A_weights, B_weights) nrows, ncols, C_indptr, C_indices, C_eids, C_weights = op(
A_weights, B_weights
)
gidxC = create_unitgraph_from_csr( gidxC = create_unitgraph_from_csr(
num_vtypes, nrows.asscalar(), ncols.asscalar(), C_indptr, C_indices, C_eids, num_vtypes,
["coo", "csr", "csc"]) nrows.asscalar(),
ncols.asscalar(),
C_indptr,
C_indices,
C_eids,
["coo", "csr", "csc"],
)
return gidxC, C_weights return gidxC, C_weights
class CSRSum(mx.autograd.Function): class CSRSum(mx.autograd.Function):
def __init__(self, gidxs): def __init__(self, gidxs):
super().__init__() super().__init__()
...@@ -429,29 +501,44 @@ class CSRSum(mx.autograd.Function): ...@@ -429,29 +501,44 @@ class CSRSum(mx.autograd.Function):
def forward(self, *weights): def forward(self, *weights):
gidxC, C_weights = _csrsum(self.gidxs, weights) gidxC, C_weights = _csrsum(self.gidxs, weights)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors( (
0, False, 'csr') nrows,
ncols,
C_indptr,
C_indices,
C_eids,
) = gidxC.adjacency_matrix_tensors(0, False, "csr")
# Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same # Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same
# as the underlying tensors of the created graph gidxC. # as the underlying tensors of the created graph gidxC.
self.backward_cache = gidxC self.backward_cache = gidxC
nrows = nd.array([nrows], dtype='int64') nrows = nd.array([nrows], dtype="int64")
ncols = nd.array([ncols], dtype='int64') ncols = nd.array([ncols], dtype="int64")
return nrows, ncols, C_indptr, C_indices, C_eids, C_weights return nrows, ncols, C_indptr, C_indices, C_eids, C_weights
def backward(self, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights): def backward(
self, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights
):
# Only the last argument is meaningful. # Only the last argument is meaningful.
gidxC = self.backward_cache gidxC = self.backward_cache
return tuple(csrmask(gidxC, dC_weights, gidx) for gidx in self.gidxs) return tuple(csrmask(gidxC, dC_weights, gidx) for gidx in self.gidxs)
def csrsum(gidxs, weights): def csrsum(gidxs, weights):
op = CSRSum(gidxs) op = CSRSum(gidxs)
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = op(*weights) nrows, ncols, C_indptr, C_indices, C_eids, C_weights = op(*weights)
num_vtypes = gidxs[0].number_of_ntypes() num_vtypes = gidxs[0].number_of_ntypes()
gidxC = create_unitgraph_from_csr( gidxC = create_unitgraph_from_csr(
num_vtypes, nrows.asscalar(), ncols.asscalar(), C_indptr, C_indices, C_eids, num_vtypes,
["coo", "csr", "csc"]) nrows.asscalar(),
ncols.asscalar(),
C_indptr,
C_indices,
C_eids,
["coo", "csr", "csc"],
)
return gidxC, C_weights return gidxC, C_weights
class CSRMask(mx.autograd.Function): class CSRMask(mx.autograd.Function):
def __init__(self, gidxA, gidxB): def __init__(self, gidxA, gidxB):
super().__init__() super().__init__()
...@@ -464,6 +551,7 @@ class CSRMask(mx.autograd.Function): ...@@ -464,6 +551,7 @@ class CSRMask(mx.autograd.Function):
def backward(self, dB_weights): def backward(self, dB_weights):
return _csrmask(self.gidxB, dB_weights, self.gidxA) return _csrmask(self.gidxB, dB_weights, self.gidxA)
def csrmask(gidxA, A_weights, gidxB): def csrmask(gidxA, A_weights, gidxB):
op = CSRMask(gidxA, gidxB) op = CSRMask(gidxA, gidxB)
return op(A_weights) return op(A_weights)
"""Sparse optimizer is not supported for mxnet""" """Sparse optimizer is not supported for mxnet"""
\ No newline at end of file
This diff is collapsed.
from .tensor import *
from .sparse import * from .sparse import *
from .tensor import *
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment