"docs/vscode:/vscode.git/clone" did not exist on "90cfb10dc49187842247d3bffb25a06af0b1e826"
Unverified Commit a208e886 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4680)



* [Misc] Black auto fix.

* fix pylint disable
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 29434e65
......@@ -5,33 +5,41 @@ from __future__ import absolute_import
import ctypes
import traceback
from numbers import Number, Integral
from numbers import Integral, Number
from ..base import _LIB, check_call
from ..base import c_str, string_types
from ..object_generic import convert_to_object, ObjectGeneric
from ..runtime_ctypes import DGLDataType, DGLByteArray, DGLContext
from ..base import _LIB, c_str, check_call, string_types
from ..object_generic import ObjectGeneric, convert_to_object
from ..runtime_ctypes import DGLByteArray, DGLContext, DGLDataType
from . import ndarray as _nd
from . import object as _object
from .ndarray import NDArrayBase, _make_array
from .types import DGLValue, TypeCode
from .types import DGLPackedCFunc, DGLCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
from .object import ObjectBase
from . import object as _object
from .types import (
C_TO_PY_ARG_SWITCH,
RETURN_SWITCH,
DGLCFuncFinalizer,
DGLPackedCFunc,
DGLValue,
TypeCode,
_wrap_arg_func,
)
FunctionHandle = ctypes.c_void_p
ModuleHandle = ctypes.c_void_p
DGLRetValueHandle = ctypes.c_void_p
def _ctypes_free_resource(rhandle):
"""callback to free resources when it it not needed."""
pyobj = ctypes.cast(rhandle, ctypes.py_object)
ctypes.pythonapi.Py_DecRef(pyobj)
# Global callback that is always alive
DGL_FREE_PYOBJ = DGLCFuncFinalizer(_ctypes_free_resource)
ctypes.pythonapi.Py_IncRef(ctypes.py_object(DGL_FREE_PYOBJ))
def convert_to_dgl_func(pyfunc):
"""Convert a python function to DGL function
......@@ -46,10 +54,15 @@ def convert_to_dgl_func(pyfunc):
The converted dgl function.
"""
local_pyfunc = pyfunc
def cfun(args, type_codes, num_args, ret, _):
""" ctypes function """
num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args
pyargs = (C_TO_PY_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args))
"""ctypes function"""
num_args = (
num_args.value if isinstance(num_args, ctypes.c_int) else num_args
)
pyargs = (
C_TO_PY_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args)
)
# pylint: disable=broad-except
try:
rv = local_pyfunc(*pyargs)
......@@ -60,12 +73,16 @@ def convert_to_dgl_func(pyfunc):
if rv is not None:
if isinstance(rv, tuple):
raise ValueError("PackedFunction can only support one return value")
raise ValueError(
"PackedFunction can only support one return value"
)
temp_args = []
values, tcodes, _ = _make_dgl_args((rv,), temp_args)
if not isinstance(ret, DGLRetValueHandle):
ret = DGLRetValueHandle(ret)
check_call(_LIB.DGLCFuncSetReturn(ret, values, tcodes, ctypes.c_int(1)))
check_call(
_LIB.DGLCFuncSetReturn(ret, values, tcodes, ctypes.c_int(1))
)
_ = temp_args
_ = rv
return 0
......@@ -76,8 +93,11 @@ def convert_to_dgl_func(pyfunc):
# DGL_FREE_PYOBJ will be called after it is no longer needed.
pyobj = ctypes.py_object(f)
ctypes.pythonapi.Py_IncRef(pyobj)
check_call(_LIB.DGLFuncCreateFromCFunc(
f, pyobj, DGL_FREE_PYOBJ, ctypes.byref(handle)))
check_call(
_LIB.DGLFuncCreateFromCFunc(
f, pyobj, DGL_FREE_PYOBJ, ctypes.byref(handle)
)
)
return _CLASS_FUNCTION(handle, False)
......@@ -104,8 +124,11 @@ def _make_dgl_args(args, temp_args):
temp_args.append(arg)
elif isinstance(arg, NDArrayBase):
values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
type_codes[i] = (TypeCode.NDARRAY_CONTAINER
if not arg.is_view else TypeCode.ARRAY_HANDLE)
type_codes[i] = (
TypeCode.NDARRAY_CONTAINER
if not arg.is_view
else TypeCode.ARRAY_HANDLE
)
elif isinstance(arg, _nd._DGL_COMPATS):
values[i].v_handle = ctypes.c_void_p(arg._dgl_handle)
type_codes[i] = arg.__class__._dgl_tcode
......@@ -125,7 +148,8 @@ def _make_dgl_args(args, temp_args):
arr = DGLByteArray()
arr.data = ctypes.cast(
(ctypes.c_byte * len(arg)).from_buffer(arg),
ctypes.POINTER(ctypes.c_byte))
ctypes.POINTER(ctypes.c_byte),
)
arr.size = len(arg)
values[i].v_handle = ctypes.c_void_p(ctypes.addressof(arr))
temp_args.append(arr)
......@@ -134,7 +158,7 @@ def _make_dgl_args(args, temp_args):
values[i].v_str = c_str(arg)
type_codes[i] = TypeCode.STR
# NOTE(minjie): module is not used in DGL
#elif isinstance(arg, _CLASS_MODULE):
# elif isinstance(arg, _CLASS_MODULE):
# values[i].v_handle = arg.handle
# type_codes[i] = TypeCode.MODULE_HANDLE
elif isinstance(arg, FunctionBase):
......@@ -155,6 +179,7 @@ def _make_dgl_args(args, temp_args):
class FunctionBase(object):
"""Function base."""
__slots__ = ["handle", "is_global"]
# pylint: disable=no-member
def __init__(self, handle, is_global):
......@@ -185,9 +210,16 @@ class FunctionBase(object):
values, tcodes, num_args = _make_dgl_args(args, temp_args)
ret_val = DGLValue()
ret_tcode = ctypes.c_int()
check_call(_LIB.DGLFuncCall(
self.handle, values, tcodes, ctypes.c_int(num_args),
ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
check_call(
_LIB.DGLFuncCall(
self.handle,
values,
tcodes,
ctypes.c_int(num_args),
ctypes.byref(ret_val),
ctypes.byref(ret_tcode),
)
)
_ = temp_args
_ = args
return RETURN_SWITCH[ret_tcode.value](ret_val)
......@@ -199,9 +231,16 @@ def __init_handle_by_constructor__(fconstructor, args):
values, tcodes, num_args = _make_dgl_args(args, temp_args)
ret_val = DGLValue()
ret_tcode = ctypes.c_int()
check_call(_LIB.DGLFuncCall(
fconstructor.handle, values, tcodes, ctypes.c_int(num_args),
ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
check_call(
_LIB.DGLFuncCall(
fconstructor.handle,
values,
tcodes,
ctypes.c_int(num_args),
ctypes.byref(ret_val),
ctypes.byref(ret_tcode),
)
)
_ = temp_args
_ = args
assert ret_tcode.value == TypeCode.OBJECT_HANDLE
......@@ -216,6 +255,7 @@ def _return_module(x):
handle = ModuleHandle(handle)
return _CLASS_MODULE(handle)
def _handle_return_func(x):
"""Return function"""
handle = x.v_handle
......@@ -228,22 +268,32 @@ def _handle_return_func(x):
_object.__init_by_constructor__ = __init_handle_by_constructor__
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False)
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(
x.v_handle, False
)
C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func(
_handle_return_func, TypeCode.FUNC_HANDLE)
_handle_return_func, TypeCode.FUNC_HANDLE
)
C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
_return_module, TypeCode.MODULE_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.ARRAY_HANDLE] = lambda x: _make_array(x.v_handle, True)
C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False)
_return_module, TypeCode.MODULE_HANDLE
)
C_TO_PY_ARG_SWITCH[TypeCode.ARRAY_HANDLE] = lambda x: _make_array(
x.v_handle, True
)
C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(
x.v_handle, False
)
_CLASS_MODULE = None
_CLASS_FUNCTION = None
def _set_class_module(module_class):
"""Initialize the module."""
global _CLASS_MODULE
_CLASS_MODULE = module_class
def _set_class_function(func_class):
global _CLASS_FUNCTION
_CLASS_FUNCTION = func_class
......@@ -3,18 +3,23 @@
from __future__ import absolute_import
import ctypes
from ..base import _LIB, check_call, c_str
from ..runtime_ctypes import DGLArrayHandle
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _return_handle
from ..base import _LIB, c_str, check_call
from ..runtime_ctypes import DGLArrayHandle
from .types import (
C_TO_PY_ARG_SWITCH,
RETURN_SWITCH,
_return_handle,
_wrap_arg_func,
)
DGLPyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
_c_str_dltensor = c_str('dltensor')
_c_str_used_dltensor = c_str('used_dltensor')
_c_str_dltensor = c_str("dltensor")
_c_str_used_dltensor = c_str("used_dltensor")
# used for PyCapsule manipulation
if hasattr(ctypes, 'pythonapi'):
if hasattr(ctypes, "pythonapi"):
ctypes.pythonapi.PyCapsule_GetName.restype = ctypes.c_char_p
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
......@@ -31,9 +36,13 @@ def _from_dlpack(dltensor):
handle = DGLArrayHandle()
check_call(_LIB.DGLArrayFromDLPack(ptr, ctypes.byref(handle)))
ctypes.pythonapi.PyCapsule_SetName(dltensor, _c_str_used_dltensor)
ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, DGLPyCapsuleDestructor(0))
ctypes.pythonapi.PyCapsule_SetDestructor(
dltensor, DGLPyCapsuleDestructor(0)
)
return _make_array(handle, False)
raise ValueError("Expect a dltensor field, PyCapsule can only be consumed once")
raise ValueError(
"Expect a dltensor field, PyCapsule can only be consumed once"
)
def _dlpack_deleter(pycapsule):
......@@ -45,13 +54,17 @@ def _dlpack_deleter(pycapsule):
# work out always.
ptr = ctypes.cast(ptr, ctypes.c_void_p)
_LIB.DGLDLManagedTensorCallDeleter(ptr)
ctypes.pythonapi.PyCapsule_SetDestructor(pycapsule, DGLPyCapsuleDestructor(0))
ctypes.pythonapi.PyCapsule_SetDestructor(
pycapsule, DGLPyCapsuleDestructor(0)
)
_c_dlpack_deleter = DGLPyCapsuleDestructor(_dlpack_deleter)
class NDArrayBase(object):
"""A simple Device/CPU Array object in runtime."""
__slots__ = ["handle", "is_view"]
# pylint: disable=no-member
def __init__(self, handle, is_view=False):
......@@ -89,27 +102,36 @@ class NDArrayBase(object):
dlpack : DLPack tensor view of the array data
"""
ptr = ctypes.c_void_p()
check_call(_LIB.DGLArrayToDLPack(self.handle, ctypes.byref(ptr), alignment))
return ctypes.pythonapi.PyCapsule_New(ptr, _c_str_dltensor, _c_dlpack_deleter)
check_call(
_LIB.DGLArrayToDLPack(self.handle, ctypes.byref(ptr), alignment)
)
return ctypes.pythonapi.PyCapsule_New(
ptr, _c_str_dltensor, _c_dlpack_deleter
)
def _make_array(handle, is_view):
handle = ctypes.cast(handle, DGLArrayHandle)
return _CLASS_NDARRAY(handle, is_view)
_DGL_COMPATS = ()
def _reg_extension(cls, fcreate):
global _DGL_COMPATS
_DGL_COMPATS += (cls,)
if fcreate:
fret = lambda x: fcreate(_return_handle(x))
RETURN_SWITCH[cls._dgl_tcode] = fret
C_TO_PY_ARG_SWITCH[cls._dgl_tcode] = _wrap_arg_func(fret, cls._dgl_tcode)
C_TO_PY_ARG_SWITCH[cls._dgl_tcode] = _wrap_arg_func(
fret, cls._dgl_tcode
)
_CLASS_NDARRAY = None
def _set_class_ndarray(cls):
global _CLASS_NDARRAY
_CLASS_NDARRAY = cls
......@@ -2,10 +2,16 @@
from __future__ import absolute_import
import ctypes
from ..base import _LIB, check_call, c_str
from ..base import _LIB, c_str, check_call
from ..object_generic import _set_class_object_base
from .types import DGLValue, TypeCode
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
from .types import (
C_TO_PY_ARG_SWITCH,
RETURN_SWITCH,
DGLValue,
TypeCode,
_wrap_arg_func,
)
ObjectHandle = ctypes.c_void_p
__init_by_constructor__ = None
......@@ -13,10 +19,12 @@ __init_by_constructor__ = None
"""Maps object type to its constructor"""
OBJECT_TYPE = {}
def _register_object(index, cls):
"""register object class in python"""
OBJECT_TYPE[index] = cls
def _return_object(x):
"""Construct a object object from the given DGLValue object"""
handle = x.v_handle
......@@ -34,32 +42,41 @@ def _return_object(x):
RETURN_SWITCH[TypeCode.OBJECT_HANDLE] = _return_object
C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_HANDLE] = _wrap_arg_func(
_return_object, TypeCode.OBJECT_HANDLE)
_return_object, TypeCode.OBJECT_HANDLE
)
class ObjectBase(object):
"""Object base class"""
__slots__ = ["handle"]
# pylint: disable=no-member
def __del__(self):
if _LIB is not None and hasattr(self, 'handle'):
if _LIB is not None and hasattr(self, "handle"):
check_call(_LIB.DGLObjectFree(self.handle))
def __getattr__(self, name):
if name == 'handle':
raise AttributeError("'handle' is a reserved attribute name that should not be used")
if name == "handle":
raise AttributeError(
"'handle' is a reserved attribute name that should not be used"
)
ret_val = DGLValue()
ret_type_code = ctypes.c_int()
ret_success = ctypes.c_int()
check_call(_LIB.DGLObjectGetAttr(
self.handle, c_str(name),
check_call(
_LIB.DGLObjectGetAttr(
self.handle,
c_str(name),
ctypes.byref(ret_val),
ctypes.byref(ret_type_code),
ctypes.byref(ret_success)))
ctypes.byref(ret_success),
)
)
if not ret_success.value:
raise AttributeError(
"'%s' object has no attribute '%s'" % (str(type(self)), name))
"'%s' object has no attribute '%s'" % (str(type(self)), name)
)
return RETURN_SWITCH[ret_type_code.value](ret_val)
def __init_handle_by_constructor__(self, fconstructor, *args):
......@@ -81,9 +98,12 @@ class ObjectBase(object):
"""
# assign handle first to avoid error raising
self.handle = None
handle = __init_by_constructor__(fconstructor, args) # pylint: disable=not-callable
handle = __init_by_constructor__(
fconstructor, args
) # pylint: disable=not-callable
if not isinstance(handle, ObjectHandle):
handle = ObjectHandle(handle)
self.handle = handle
_set_class_object_base(ObjectBase)
......@@ -3,17 +3,22 @@
from __future__ import absolute_import as _abs
import ctypes
from ..base import py_str, check_call, _LIB
from ..runtime_ctypes import DGLByteArray, TypeCode, DGLDataType, DGLContext
from ..base import _LIB, check_call, py_str
from ..runtime_ctypes import DGLByteArray, DGLContext, DGLDataType, TypeCode
class DGLValue(ctypes.Union):
"""DGLValue in C API"""
_fields_ = [("v_int64", ctypes.c_int64),
_fields_ = [
("v_int64", ctypes.c_int64),
("v_float64", ctypes.c_double),
("v_handle", ctypes.c_void_p),
("v_str", ctypes.c_char_p),
("v_type", DGLDataType),
("v_ctx", DGLContext)]
("v_ctx", DGLContext),
]
DGLPackedCFunc = ctypes.CFUNCTYPE(
......@@ -22,12 +27,11 @@ DGLPackedCFunc = ctypes.CFUNCTYPE(
ctypes.POINTER(ctypes.c_int),
ctypes.c_int,
ctypes.c_void_p,
ctypes.c_void_p)
ctypes.c_void_p,
)
DGLCFuncFinalizer = ctypes.CFUNCTYPE(
None,
ctypes.c_void_p)
DGLCFuncFinalizer = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
def _return_handle(x):
......@@ -37,6 +41,7 @@ def _return_handle(x):
handle = ctypes.c_void_p(handle)
return handle
def _return_bytes(x):
"""return handle"""
handle = x.v_handle
......@@ -47,16 +52,20 @@ def _return_bytes(x):
res = bytearray(size)
rptr = (ctypes.c_byte * size).from_buffer(res)
if not ctypes.memmove(rptr, arr.data, size):
raise RuntimeError('memmove failed')
raise RuntimeError("memmove failed")
return res
def _wrap_arg_func(return_f, type_code):
tcode = ctypes.c_int(type_code)
def _wrap_func(x):
check_call(_LIB.DGLCbArgToReturn(ctypes.byref(x), tcode))
return return_f(x)
return _wrap_func
RETURN_SWITCH = {
TypeCode.INT: lambda x: x.v_int64,
TypeCode.FLOAT: lambda x: x.v_float64,
......@@ -64,7 +73,9 @@ RETURN_SWITCH = {
TypeCode.NULL: lambda x: None,
TypeCode.STR: lambda x: py_str(x.v_str),
TypeCode.BYTES: _return_bytes,
TypeCode.DGL_CONTEXT: lambda x: DGLContext(x.v_ctx.device_type, x.v_ctx.device_id),
TypeCode.DGL_CONTEXT: lambda x: DGLContext(
x.v_ctx.device_type, x.v_ctx.device_id
),
}
C_TO_PY_ARG_SWITCH = {
......@@ -74,5 +85,7 @@ C_TO_PY_ARG_SWITCH = {
TypeCode.NULL: lambda x: None,
TypeCode.STR: lambda x: py_str(x.v_str),
TypeCode.BYTES: _return_bytes,
TypeCode.DGL_CONTEXT: lambda x: DGLContext(x.v_ctx.device_type, x.v_ctx.device_id),
TypeCode.DGL_CONTEXT: lambda x: DGLContext(
x.v_ctx.device_type, x.v_ctx.device_id
),
}
......@@ -3,22 +3,24 @@
"""ctypes library and helper functions """
from __future__ import absolute_import
import sys
import os
import ctypes
import logging
import os
import sys
import numpy as np
from . import libinfo
#----------------------------
# ----------------------------
# library loading
#----------------------------
# ----------------------------
if sys.version_info[0] == 3:
string_types = (str,)
numeric_types = (float, int, np.float32, np.int32)
# this function is needed for python3
# to convert ctypes.char_p .value back to python str
py_str = lambda x: x.decode('utf-8')
py_str = lambda x: x.decode("utf-8")
else:
string_types = (basestring,)
numeric_types = (float, int, long, np.float32, np.int32)
......@@ -27,8 +29,10 @@ else:
class DGLError(Exception):
"""Error thrown by DGL function"""
pass # pylint: disable=unnecessary-pass
def _load_lib():
"""Load libary by searching possible path."""
lib_path = libinfo.find_lib_path()
......@@ -39,6 +43,7 @@ def _load_lib():
lib.DGLGetLastError.restype = ctypes.c_char_p
return lib, basename, dirname
# version number
__version__ = libinfo.__version__
# library instance of nnvm
......@@ -47,9 +52,9 @@ _LIB, _LIB_NAME, _DIR_NAME = _load_lib()
# The FFI mode of DGL
_FFI_MODE = os.environ.get("DGL_FFI", "auto")
#----------------------------
# ----------------------------
# helper function in ctypes.
#----------------------------
# ----------------------------
def check_call(ret):
"""Check the return value of C API call
......@@ -77,7 +82,7 @@ def c_str(string):
str : c_char_p
A char pointer that can be passed to C API
"""
return ctypes.c_char_p(string.encode('utf-8'))
return ctypes.c_char_p(string.encode("utf-8"))
def c_array(ctype, values):
......@@ -111,10 +116,13 @@ def decorate(func, fwrapped):
The wrapped function
"""
import decorator
return decorator.decorate(func, fwrapped)
tensor_adapter_loaded = False
def load_tensor_adapter(backend, version):
"""Tell DGL to load a tensoradapter library for given backend and version.
......@@ -126,17 +134,17 @@ def load_tensor_adapter(backend, version):
The version number of the backend.
"""
global tensor_adapter_loaded
version = version.split('+')[0]
if sys.platform.startswith('linux'):
basename = 'libtensoradapter_%s_%s.so' % (backend, version)
elif sys.platform.startswith('darwin'):
basename = 'libtensoradapter_%s_%s.dylib' % (backend, version)
elif sys.platform.startswith('win'):
basename = 'tensoradapter_%s_%s.dll' % (backend, version)
version = version.split("+")[0]
if sys.platform.startswith("linux"):
basename = "libtensoradapter_%s_%s.so" % (backend, version)
elif sys.platform.startswith("darwin"):
basename = "libtensoradapter_%s_%s.dylib" % (backend, version)
elif sys.platform.startswith("win"):
basename = "tensoradapter_%s_%s.dll" % (backend, version)
else:
raise NotImplementedError('Unsupported system: %s' % sys.platform)
path = os.path.join(_DIR_NAME, 'tensoradapter', backend, basename)
tensor_adapter_loaded = (_LIB.DGLLoadTensorAdapter(path.encode('utf-8')) == 0)
raise NotImplementedError("Unsupported system: %s" % sys.platform)
path = os.path.join(_DIR_NAME, "tensoradapter", backend, basename)
tensor_adapter_loaded = _LIB.DGLLoadTensorAdapter(path.encode("utf-8")) == 0
if not tensor_adapter_loaded:
logger = logging.getLogger("dgl-core")
logger.debug("Memory optimization with PyTorch is not enabled.")
......@@ -2,9 +2,10 @@
"""Function namespace."""
from __future__ import absolute_import
import sys
import ctypes
from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE
import sys
from .base import _FFI_MODE, _LIB, c_str, check_call, py_str, string_types
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
......@@ -13,21 +14,31 @@ try:
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_function, _set_class_module
from ._cy3.core import FunctionBase as _FunctionBase
from ._cy3.core import convert_to_dgl_func
from ._cy3.core import (
_set_class_function,
_set_class_module,
convert_to_dgl_func,
)
else:
from ._cy2.core import _set_class_function, _set_class_module
from ._cy2.core import FunctionBase as _FunctionBase
from ._cy2.core import convert_to_dgl_func
from ._cy2.core import (
_set_class_function,
_set_class_module,
convert_to_dgl_func,
)
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_function, _set_class_module
from ._ctypes.function import FunctionBase as _FunctionBase
from ._ctypes.function import convert_to_dgl_func
from ._ctypes.function import (
_set_class_function,
_set_class_module,
convert_to_dgl_func,
)
FunctionHandle = ctypes.c_void_p
class Function(_FunctionBase):
"""The PackedFunc object.
......@@ -51,11 +62,13 @@ class Function(_FunctionBase):
dgl.register_func: How to register global function.
dgl.get_global_func: How to get global function.
"""
pass # pylint: disable=unnecessary-pass
class ModuleBase(object):
"""Base class for module"""
__slots__ = ["handle", "_entry", "entry_name"]
def __init__(self, handle):
......@@ -97,13 +110,16 @@ class ModuleBase(object):
The result function.
"""
ret_handle = FunctionHandle()
check_call(_LIB.DGLModGetFunction(
self.handle, c_str(name),
check_call(
_LIB.DGLModGetFunction(
self.handle,
c_str(name),
ctypes.c_int(query_imports),
ctypes.byref(ret_handle)))
ctypes.byref(ret_handle),
)
)
if not ret_handle.value:
raise AttributeError(
"Module has no function '%s'" % name)
raise AttributeError("Module has no function '%s'" % name)
return Function(ret_handle, False)
def import_module(self, module):
......@@ -175,13 +191,16 @@ def register_func(func_name, f=None, override=False):
raise ValueError("expect string function name")
ioverride = ctypes.c_int(override)
def register(myf):
"""internal register function"""
if not isinstance(myf, Function):
myf = convert_to_dgl_func(myf)
check_call(_LIB.DGLFuncRegisterGlobal(
c_str(func_name), myf.handle, ioverride))
check_call(
_LIB.DGLFuncRegisterGlobal(c_str(func_name), myf.handle, ioverride)
)
return myf
if f:
return register(f)
return register
......@@ -214,7 +233,6 @@ def get_global_func(name, allow_missing=False):
raise ValueError("Cannot find global function %s" % name)
def list_global_func_names():
"""Get list of global functions registered.
......@@ -226,8 +244,9 @@ def list_global_func_names():
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()
check_call(_LIB.DGLFuncListGlobalNames(ctypes.byref(size),
ctypes.byref(plist)))
check_call(
_LIB.DGLFuncListGlobalNames(ctypes.byref(size), ctypes.byref(plist))
)
fnames = []
for i in range(size.value):
fnames.append(py_str(plist[i]))
......@@ -249,8 +268,10 @@ def extract_ext_funcs(finit):
The extracted functions
"""
fdict = {}
def _list(name, func):
fdict[name] = func
myf = convert_to_dgl_func(_list)
ret = finit(myf.handle)
_ = myf
......@@ -258,11 +279,13 @@ def extract_ext_funcs(finit):
raise RuntimeError("cannot initialize with %s" % finit)
return fdict
def _get_api(f):
flocal = f
flocal.is_global = True
return flocal
def _init_api(namespace, target_module_name=None):
"""Initialize api for a given module name
......@@ -272,8 +295,7 @@ def _init_api(namespace, target_module_name=None):
target_module_name : str
The target module name if different from namespace
"""
target_module_name = (
target_module_name if target_module_name else namespace)
target_module_name = target_module_name if target_module_name else namespace
if namespace.startswith("dgl."):
_init_api_prefix(target_module_name, namespace[4:])
else:
......@@ -284,10 +306,10 @@ def _init_api_prefix(module_name, prefix):
module = sys.modules[module_name]
for name in list_global_func_names():
if name.startswith("_") and not name.startswith('_deprecate'):
if name.startswith("_") and not name.startswith("_deprecate"):
# internal APIs are ignored
continue
name_split = name.rsplit('.', 1)
name_split = name.rsplit(".", 1)
if name_split[0] != prefix:
continue
......@@ -300,12 +322,13 @@ def _init_api_prefix(module_name, prefix):
f = get_global_func(name)
ff = _get_api(f)
ff.__name__ = fname
ff.__doc__ = ("DGL PackedFunc %s. " % fname)
ff.__doc__ = "DGL PackedFunc %s. " % fname
setattr(target_module, ff.__name__, ff)
def _init_internal_api():
for name in list_global_func_names():
if not name.startswith("_") or name.startswith('_deprecate'):
if not name.startswith("_") or name.startswith("_deprecate"):
# normal APIs are ignored
continue
target_module = sys.modules["dgl._api_internal"]
......@@ -316,7 +339,8 @@ def _init_internal_api():
f = get_global_func(name)
ff = _get_api(f)
ff.__name__ = fname
ff.__doc__ = ("DGL PackedFunc %s. " % fname)
ff.__doc__ = "DGL PackedFunc %s. " % fname
setattr(target_module, ff.__name__, ff)
_set_class_function(Function)
"""Library information."""
from __future__ import absolute_import
import sys
import os
import pathlib
import sys
def find_lib_path(name=None, search_path=None, optional=False):
......@@ -30,13 +31,21 @@ def find_lib_path(name=None, search_path=None, optional=False):
dll_path = []
if os.environ.get('DGL_LIBRARY_PATH', None):
dll_path.append(os.environ['DGL_LIBRARY_PATH'])
if sys.platform.startswith('linux') and os.environ.get('LD_LIBRARY_PATH', None):
dll_path.extend([p.strip() for p in os.environ['LD_LIBRARY_PATH'].split(":")])
elif sys.platform.startswith('darwin') and os.environ.get('DYLD_LIBRARY_PATH', None):
dll_path.extend([p.strip() for p in os.environ['DYLD_LIBRARY_PATH'].split(":")])
if os.environ.get("DGL_LIBRARY_PATH", None):
dll_path.append(os.environ["DGL_LIBRARY_PATH"])
if sys.platform.startswith("linux") and os.environ.get(
"LD_LIBRARY_PATH", None
):
dll_path.extend(
[p.strip() for p in os.environ["LD_LIBRARY_PATH"].split(":")]
)
elif sys.platform.startswith("darwin") and os.environ.get(
"DYLD_LIBRARY_PATH", None
):
dll_path.extend(
[p.strip() for p in os.environ["DYLD_LIBRARY_PATH"].split(":")]
)
# Pip lib directory
dll_path.append(os.path.join(ffi_dir, ".."))
......@@ -54,17 +63,21 @@ def find_lib_path(name=None, search_path=None, optional=False):
elif isinstance(search_path, str):
dll_path.append(search_path)
else:
raise ValueError("type(search_path)={} is invalid".format(type(search_path)))
dll_path = [str(x.absolute()) if isinstance(x, pathlib.Path)
else os.path.abspath(x) for x in dll_path]
raise ValueError(
"type(search_path)={} is invalid".format(type(search_path))
)
dll_path = [
str(x.absolute()) if isinstance(x, pathlib.Path) else os.path.abspath(x)
for x in dll_path
]
if name is None:
if sys.platform.startswith('win32'):
name = ['libdgl.dll', 'dgl.dll']
elif sys.platform.startswith('darwin'):
name = 'libdgl.dylib'
if sys.platform.startswith("win32"):
name = ["libdgl.dll", "dgl.dll"]
elif sys.platform.startswith("darwin"):
name = "libdgl.dylib"
else:
name = 'libdgl.so'
name = "libdgl.so"
if isinstance(name, str):
name = [name]
......@@ -76,9 +89,11 @@ def find_lib_path(name=None, search_path=None, optional=False):
lib_found = [p for p in lib_dll_path if os.path.isfile(p)]
if not lib_found:
message = ('Cannot find the files.\n' +
'List of candidates:\n' +
str('\n'.join(lib_dll_path)))
message = (
"Cannot find the files.\n"
+ "List of candidates:\n"
+ str("\n".join(lib_dll_path))
)
if not optional:
raise RuntimeError(message)
return None
......
......@@ -2,13 +2,20 @@
"""Runtime NDArray api"""
from __future__ import absolute_import
import sys
import ctypes
import sys
import numpy as np
from .base import _LIB, check_call, c_array, string_types, _FFI_MODE, c_str
from .runtime_ctypes import DGLDataType, DGLContext, DGLArray, DGLArrayHandle
from .runtime_ctypes import TypeCode, dgl_shape_index_t
from .base import _FFI_MODE, _LIB, c_array, c_str, check_call, string_types
from .runtime_ctypes import (
DGLArray,
DGLArrayHandle,
DGLContext,
DGLDataType,
TypeCode,
dgl_shape_index_t,
)
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
......@@ -17,15 +24,31 @@ try:
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
from ._cy3.core import NDArrayBase as _NDArrayBase
from ._cy3.core import (
_from_dlpack,
_make_array,
_reg_extension,
_set_class_ndarray,
)
else:
from ._cy2.core import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
from ._cy2.core import NDArrayBase as _NDArrayBase
from ._cy2.core import (
_from_dlpack,
_make_array,
_reg_extension,
_set_class_ndarray,
)
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.ndarray import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
from ._ctypes.ndarray import (
_from_dlpack,
_make_array,
_reg_extension,
_set_class_ndarray,
)
def context(dev_type, dev_id=0):
"""Construct a DGL context with given device type and id.
......@@ -63,10 +86,9 @@ def context(dev_type, dev_id=0):
def numpyasarray(np_data):
"""Return a DGLArray representation of a numpy array.
"""
"""Return a DGLArray representation of a numpy array."""
data = np_data
assert data.flags['C_CONTIGUOUS']
assert data.flags["C_CONTIGUOUS"]
arr = DGLArray()
shape = c_array(dgl_shape_index_t, data.shape)
arr.data = data.ctypes.data_as(ctypes.c_void_p)
......@@ -102,14 +124,18 @@ def empty(shape, dtype="float32", ctx=context(1, 0)):
ndim = ctypes.c_int(len(shape))
handle = DGLArrayHandle()
dtype = DGLDataType(dtype)
check_call(_LIB.DGLArrayAlloc(
shape, ndim,
check_call(
_LIB.DGLArrayAlloc(
shape,
ndim,
ctypes.c_int(dtype.type_code),
ctypes.c_int(dtype.bits),
ctypes.c_int(dtype.lanes),
ctx.device_type,
ctx.device_id,
ctypes.byref(handle)))
ctypes.byref(handle),
)
)
return _make_array(handle, False)
......@@ -135,18 +161,23 @@ def empty_shared_mem(name, is_create, shape, dtype="float32"):
arr : dgl.nd.NDArray
The array dgl supported.
"""
name = ctypes.c_char_p(name.encode('utf-8'))
name = ctypes.c_char_p(name.encode("utf-8"))
shape = c_array(dgl_shape_index_t, shape)
ndim = ctypes.c_int(len(shape))
handle = DGLArrayHandle()
dtype = DGLDataType(dtype)
check_call(_LIB.DGLArrayAllocSharedMem(
name, shape, ndim,
check_call(
_LIB.DGLArrayAllocSharedMem(
name,
shape,
ndim,
ctypes.c_int(dtype.type_code),
ctypes.c_int(dtype.bits),
ctypes.c_int(dtype.lanes),
is_create,
ctypes.byref(handle)))
ctypes.byref(handle),
)
)
return _make_array(handle, False)
......@@ -171,10 +202,14 @@ def from_dlpack(dltensor):
class NDArrayBase(_NDArrayBase):
"""A simple Device/CPU Array object in runtime."""
@property
def shape(self):
"""Shape of this array"""
return tuple(self.handle.contents.shape[i] for i in range(self.handle.contents.ndim))
return tuple(
self.handle.contents.shape[i]
for i in range(self.handle.contents.ndim)
)
@property
def dtype(self):
......@@ -219,17 +254,19 @@ class NDArrayBase(_NDArrayBase):
def __setitem__(self, in_slice, value):
"""Set ndarray value"""
if (not isinstance(in_slice, slice) or
in_slice.start is not None
or in_slice.stop is not None):
raise ValueError('Array only support set from numpy array')
if (
not isinstance(in_slice, slice)
or in_slice.start is not None
or in_slice.stop is not None
):
raise ValueError("Array only support set from numpy array")
if isinstance(value, NDArrayBase):
if value.handle is not self.handle:
value.copyto(self)
elif isinstance(value, (np.ndarray, np.generic)):
self.copyfrom(value)
else:
raise TypeError('type %s not supported' % str(type(value)))
raise TypeError("type %s not supported" % str(type(value)))
def copyfrom(self, source_array):
"""Perform a synchronized copy from the array.
......@@ -252,8 +289,10 @@ class NDArrayBase(_NDArrayBase):
try:
source_array = np.asarray(source_array, dtype=self.dtype)
except:
raise TypeError('array must be an array_like data,' +
'type %s is not supported' % str(type(source_array)))
raise TypeError(
"array must be an array_like data,"
+ "type %s is not supported" % str(type(source_array))
)
t = DGLDataType(self.dtype)
shape, dtype = self.shape, self.dtype
if t.lanes > 1:
......@@ -262,12 +301,17 @@ class NDArrayBase(_NDArrayBase):
dtype = str(t)
if source_array.shape != shape:
raise ValueError("array shape do not match the shape of NDArray {0} vs {1}".format(
source_array.shape, shape))
raise ValueError(
"array shape do not match the shape of NDArray {0} vs {1}".format(
source_array.shape, shape
)
)
source_array = np.ascontiguousarray(source_array, dtype=dtype)
assert source_array.flags['C_CONTIGUOUS']
assert source_array.flags["C_CONTIGUOUS"]
data = source_array.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize)
nbytes = ctypes.c_size_t(
source_array.size * source_array.dtype.itemsize
)
check_call(_LIB.DGLArrayCopyFromBytes(self.handle, data, nbytes))
return self
......@@ -293,7 +337,7 @@ class NDArrayBase(_NDArrayBase):
t.lanes = 1
dtype = str(t)
np_arr = np.empty(shape, dtype=dtype)
assert np_arr.flags['C_CONTIGUOUS']
assert np_arr.flags["C_CONTIGUOUS"]
data = np_arr.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize)
check_call(_LIB.DGLArrayCopyToBytes(self.handle, data, nbytes))
......@@ -310,20 +354,17 @@ class NDArrayBase(_NDArrayBase):
if isinstance(target, DGLContext):
target = empty(self.shape, self.dtype, target)
if isinstance(target, NDArrayBase):
check_call(_LIB.DGLArrayCopyFromTo(
self.handle, target.handle))
check_call(_LIB.DGLArrayCopyFromTo(self.handle, target.handle))
else:
raise ValueError("Unsupported target type %s" % str(type(target)))
return target
def pin_memory_(self):
"""Pin host memory and map into GPU address space (in-place)
"""
"""Pin host memory and map into GPU address space (in-place)"""
check_call(_LIB.DGLArrayPinData(self.handle))
def unpin_memory_(self):
"""Unpin host memory pinned by pin_memory_()
"""
"""Unpin host memory pinned by pin_memory_()"""
check_call(_LIB.DGLArrayUnpinData(self.handle))
def record_stream(self, stream):
......@@ -340,6 +381,7 @@ class NDArrayBase(_NDArrayBase):
"""
check_call(_LIB.DGLArrayRecordStream(self.handle, stream))
def free_extension_handle(handle, type_code):
"""Free c++ extension type handle
......@@ -353,6 +395,7 @@ def free_extension_handle(handle, type_code):
"""
check_call(_LIB.DGLExtTypeFree(handle, ctypes.c_int(type_code)))
def register_extension(cls, fcreate=None):
"""Register a extension class to DGL.
......@@ -398,6 +441,8 @@ def register_extension(cls, fcreate=None):
return self.handle.value
"""
if fcreate and cls._dgl_tcode < TypeCode.EXT_BEGIN:
raise ValueError("Cannot register create when extension tcode is same as buildin")
raise ValueError(
"Cannot register create when extension tcode is same as buildin"
)
_reg_extension(cls, fcreate)
return cls
......@@ -4,23 +4,27 @@ from __future__ import absolute_import
import ctypes
import sys
from .. import _api_internal
from .base import _FFI_MODE, _LIB, c_str, check_call, py_str
from .object_generic import ObjectGeneric, convert_to_object
from .base import _LIB, check_call, c_str, py_str, _FFI_MODE
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" \
else ImportError # pylint: disable=invalid-name
# pylint: disable=invalid-name
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
try:
# pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _register_object, ObjectBase as _ObjectBase
from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import _register_object
else:
from ._cy2.core import _register_object, ObjectBase as _ObjectBase
from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import _register_object
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.object import _register_object, ObjectBase as _ObjectBase
from ._ctypes.object import ObjectBase as _ObjectBase
from ._ctypes.object import _register_object
def _new_object(cls):
......@@ -36,11 +40,15 @@ class ObjectBase(_ObjectBase):
Note that the same handle **CANNOT** be shared across multiple ObjectBase instances.
"""
def __dir__(self):
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()
check_call(_LIB.DGLObjectListAttrNames(
self.handle, ctypes.byref(size), ctypes.byref(plist)))
check_call(
_LIB.DGLObjectListAttrNames(
self.handle, ctypes.byref(size), ctypes.byref(plist)
)
)
names = []
for i in range(size.value):
names.append(py_str(plist[i]))
......@@ -57,7 +65,7 @@ class ObjectBase(_ObjectBase):
def __reduce__(self):
cls = type(self)
return (_new_object, (cls, ), self.__getstate__())
return (_new_object, (cls,), self.__getstate__())
def __getstate__(self):
# TODO(minjie): TVM assumes that a Node (Object in DGL) can be serialized
......@@ -100,7 +108,9 @@ def register_object(type_key=None):
def register(cls):
"""internal register function"""
tindex = ctypes.c_int()
ret = _LIB.DGLObjectTypeKey2Index(c_str(object_name), ctypes.byref(tindex))
ret = _LIB.DGLObjectTypeKey2Index(
c_str(object_name), ctypes.byref(tindex)
)
if ret == 0:
_register_object(tindex.value, cls)
return cls
......
......@@ -2,23 +2,28 @@
# pylint: disable=unused-import
from __future__ import absolute_import
from numbers import Number, Integral
from numbers import Integral, Number
from .. import _api_internal
from .base import string_types
# Object base class
_CLASS_OBJECT_BASE = None
def _set_class_object_base(cls):
global _CLASS_OBJECT_BASE
_CLASS_OBJECT_BASE = cls
class ObjectGeneric(object):
"""Base class for all classes that can be converted to object."""
def asobject(self):
"""Convert value to object"""
raise NotImplementedError()
def convert_to_object(value):
"""Convert a python value to corresponding object type.
......@@ -40,9 +45,12 @@ def convert_to_object(value):
if isinstance(value, dict):
vlist = []
for item in value.items():
if (not isinstance(item[0], _CLASS_OBJECT_BASE) and
not isinstance(item[0], string_types)):
raise ValueError("key of map must already been a container type")
if not isinstance(item[0], _CLASS_OBJECT_BASE) and not isinstance(
item[0], string_types
):
raise ValueError(
"key of map must already been a container type"
)
vlist.append(item[0])
vlist.append(convert_to_object(item[1]))
return _api_internal._Map(*vlist)
......
......@@ -4,14 +4,18 @@ from __future__ import absolute_import
import ctypes
import json
import numpy as np
from .base import _LIB, check_call
from .. import _api_internal
from .base import _LIB, check_call
dgl_shape_index_t = ctypes.c_int64
class TypeCode(object):
"""Type code used in API calls"""
INT = 0
UINT = 1
FLOAT = 2
......@@ -28,22 +32,25 @@ class TypeCode(object):
NDARRAY_CONTAINER = 13
EXT_BEGIN = 15
class DGLByteArray(ctypes.Structure):
"""Temp data structure for byte array."""
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)]
_fields_ = [
("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t),
]
class DGLDataType(ctypes.Structure):
"""DGL datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8),
_fields_ = [
("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8),
("lanes", ctypes.c_uint16)]
CODE2STR = {
0 : 'int',
1 : 'uint',
2 : 'float',
4 : 'handle'
}
("lanes", ctypes.c_uint16),
]
CODE2STR = {0: "int", 1: "uint", 2: "float", 4: "handle"}
_cache = {}
def __new__(cls, type_str):
......@@ -90,50 +97,54 @@ class DGLDataType(ctypes.Structure):
return x
def __eq__(self, other):
return (self.bits == other.bits and
self.type_code == other.type_code and
self.lanes == other.lanes)
return (
self.bits == other.bits
and self.type_code == other.type_code
and self.lanes == other.lanes
)
def __ne__(self, other):
return not self.__eq__(other)
RPC_SESS_MASK = 128
class DGLContext(ctypes.Structure):
"""DGL context strucure."""
_fields_ = [("device_type", ctypes.c_int),
("device_id", ctypes.c_int)]
_fields_ = [("device_type", ctypes.c_int), ("device_id", ctypes.c_int)]
MASK2STR = {
1 : 'cpu',
2 : 'gpu',
4 : 'opencl',
5 : 'aocl',
6 : 'sdaccel',
7 : 'vulkan',
8 : 'metal',
9 : 'vpi',
10: 'rocm',
11: 'opengl',
12: 'ext_dev',
1: "cpu",
2: "gpu",
4: "opencl",
5: "aocl",
6: "sdaccel",
7: "vulkan",
8: "metal",
9: "vpi",
10: "rocm",
11: "opengl",
12: "ext_dev",
}
STR2MASK = {
'llvm': 1,
'stackvm': 1,
'cpu': 1,
'gpu': 2,
'cuda': 2,
'nvptx': 2,
'cl': 4,
'opencl': 4,
'aocl' : 5,
'aocl_sw_emu' : 5,
'sdaccel': 6,
'vulkan': 7,
'metal': 8,
'vpi': 9,
'rocm': 10,
'opengl': 11,
'ext_dev': 12,
"llvm": 1,
"stackvm": 1,
"cpu": 1,
"gpu": 2,
"cuda": 2,
"nvptx": 2,
"cl": 4,
"opencl": 4,
"aocl": 5,
"aocl_sw_emu": 5,
"sdaccel": 6,
"vulkan": 7,
"metal": 8,
"vpi": 9,
"rocm": 10,
"opengl": 11,
"ext_dev": 12,
}
_cache = {}
......@@ -155,26 +166,25 @@ class DGLContext(ctypes.Structure):
@property
def exist(self):
"""Whether this device exist."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 0) != 0
return (
_api_internal._GetDeviceAttr(self.device_type, self.device_id, 0)
!= 0
)
@property
def max_threads_per_block(self):
"""Maximum number of threads on each block."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 1)
return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 1)
@property
def warp_size(self):
"""Number of threads that executes in concurrent."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 2)
return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 2)
@property
def max_shared_memory_per_block(self):
"""Total amount of shared memory per block in bytes."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 3)
return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 3)
@property
def compute_version(self):
......@@ -187,26 +197,22 @@ class DGLContext(ctypes.Structure):
version : str
The version string in `major.minor` format.
"""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 4)
return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 4)
@property
def device_name(self):
"""Return the string name of device."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 5)
return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 5)
@property
def max_clock_rate(self):
"""Return the max clock frequency of device."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 6)
return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 6)
@property
def multi_processor_count(self):
"""Return the number of compute units of device."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 7)
return _api_internal._GetDeviceAttr(self.device_type, self.device_id, 7)
@property
def max_thread_dimensions(self):
......@@ -217,17 +223,20 @@ class DGLContext(ctypes.Structure):
dims: List of int
The maximum length of threadIdx.x, threadIdx.y, threadIdx.z
"""
return json.loads(_api_internal._GetDeviceAttr(
self.device_type, self.device_id, 8))
return json.loads(
_api_internal._GetDeviceAttr(self.device_type, self.device_id, 8)
)
def sync(self):
"""Synchronize until jobs finished at the context."""
check_call(_LIB.DGLSynchronize(self.device_type, self.device_id, None))
def __eq__(self, other):
return (isinstance(other, DGLContext) and
self.device_id == other.device_id and
self.device_type == other.device_type)
return (
isinstance(other, DGLContext)
and self.device_id == other.device_id
and self.device_type == other.device_type
)
def __ne__(self, other):
return not self.__eq__(other)
......@@ -237,9 +246,14 @@ class DGLContext(ctypes.Structure):
tbl_id = self.device_type / RPC_SESS_MASK - 1
dev_type = self.device_type % RPC_SESS_MASK
return "remote[%d]:%s(%d)" % (
tbl_id, DGLContext.MASK2STR[dev_type], self.device_id)
tbl_id,
DGLContext.MASK2STR[dev_type],
self.device_id,
)
return "%s(%d)" % (
DGLContext.MASK2STR[self.device_type], self.device_id)
DGLContext.MASK2STR[self.device_type],
self.device_id,
)
def __hash__(self):
return hash((self.device_type, self.device_id))
......@@ -247,13 +261,17 @@ class DGLContext(ctypes.Structure):
class DGLArray(ctypes.Structure):
"""DGLValue in C API"""
_fields_ = [("data", ctypes.c_void_p),
_fields_ = [
("data", ctypes.c_void_p),
("ctx", DGLContext),
("ndim", ctypes.c_int),
("dtype", DGLDataType),
("shape", ctypes.POINTER(dgl_shape_index_t)),
("strides", ctypes.POINTER(dgl_shape_index_t)),
("byte_offset", ctypes.c_uint64)]
("byte_offset", ctypes.c_uint64),
]
DGLArrayHandle = ctypes.POINTER(DGLArray)
......
......@@ -5,13 +5,13 @@ For applications, please use PyTorch's stream management, of which DGL is aware.
from __future__ import absolute_import
import ctypes
from .base import _LIB, check_call, _FFI_MODE
from .base import _FFI_MODE, _LIB, check_call
from .runtime_ctypes import DGLStreamHandle
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
def to_dgl_stream_handle(cuda_stream):
""" Convert torch.cuda.Stream to DGL stream handle
"""Convert torch.cuda.Stream to DGL stream handle
Parameters
----------
......@@ -24,6 +24,7 @@ def to_dgl_stream_handle(cuda_stream):
"""
return ctypes.c_void_p(cuda_stream.cuda_stream)
def _dgl_get_stream(ctx):
"""Get the current CUDA stream of the given DGL context.
......@@ -37,6 +38,9 @@ def _dgl_get_stream(ctx):
DGLStreamHandle of the current CUDA stream.
"""
current_cuda_stream = DGLStreamHandle()
check_call(_LIB.DGLGetStream(
ctx.device_type, ctx.device_id, ctypes.byref(current_cuda_stream)))
check_call(
_LIB.DGLGetStream(
ctx.device_type, ctx.device_id, ctypes.byref(current_cuda_stream)
)
)
return current_cuda_stream
from __future__ import absolute_import
import sys
import os
import json
import importlib
import json
import logging
import os
import sys
from . import backend
from .set_default_backend import set_default_backend
......@@ -13,13 +13,18 @@ _enabled_apis = set()
logger = logging.getLogger("dgl-core")
def _gen_missing_api(api, mod_name):
def _missing_api(*args, **kwargs):
raise ImportError('API "%s" is not supported by backend "%s".'
' You can switch to other backends by setting'
' the DGLBACKEND environment.' % (api, mod_name))
raise ImportError(
'API "%s" is not supported by backend "%s".'
" You can switch to other backends by setting"
" the DGLBACKEND environment." % (api, mod_name)
)
return _missing_api
def load_backend(mod_name):
# Load backend does four things:
# (1) Import backend framework (PyTorch, MXNet, Tensorflow, etc.)
......@@ -28,40 +33,46 @@ def load_backend(mod_name):
# (3) Sets up the tensoradapter library path.
# (4) Import the Python wrappers of the backend framework. DGL does this last because
# it already depends on both the backend framework and the DGL C library.
if mod_name == 'pytorch':
if mod_name == "pytorch":
import torch
mod = torch
elif mod_name == 'mxnet':
elif mod_name == "mxnet":
import mxnet
mod = mxnet
elif mod_name == 'tensorflow':
elif mod_name == "tensorflow":
import tensorflow
mod = tensorflow
else:
raise NotImplementedError('Unsupported backend: %s' % mod_name)
raise NotImplementedError("Unsupported backend: %s" % mod_name)
from .._ffi.base import load_tensor_adapter # imports DGL C library
version = mod.__version__
load_tensor_adapter(mod_name, version)
logger.debug('Using backend: %s' % mod_name)
mod = importlib.import_module('.%s' % mod_name, __name__)
logger.debug("Using backend: %s" % mod_name)
mod = importlib.import_module(".%s" % mod_name, __name__)
thismod = sys.modules[__name__]
for api in backend.__dict__.keys():
if api.startswith('__'):
if api.startswith("__"):
# ignore python builtin attributes
continue
if api == 'data_type_dict':
if api == "data_type_dict":
# load data type
if api not in mod.__dict__:
raise ImportError('API "data_type_dict" is required but missing for'
' backend "%s".' % (mod_name))
raise ImportError(
'API "data_type_dict" is required but missing for'
' backend "%s".' % (mod_name)
)
data_type_dict = mod.__dict__[api]()
for name, dtype in data_type_dict.items():
setattr(thismod, name, dtype)
# override data type dict function
setattr(thismod, 'data_type_dict', data_type_dict)
setattr(thismod, "data_type_dict", data_type_dict)
# for data types with aliases, treat the first listed type as
# the true one
......@@ -69,11 +80,9 @@ def load_backend(mod_name):
for k, v in data_type_dict.items():
if not v in rev_data_type_dict.keys():
rev_data_type_dict[v] = k
setattr(thismod,
'reverse_data_type_dict',
rev_data_type_dict)
setattr(thismod, "reverse_data_type_dict", rev_data_type_dict)
# log backend name
setattr(thismod, 'backend_name', mod_name)
setattr(thismod, "backend_name", mod_name)
else:
# load functions
if api in mod.__dict__:
......@@ -82,28 +91,32 @@ def load_backend(mod_name):
else:
setattr(thismod, api, _gen_missing_api(api, mod_name))
def get_preferred_backend():
default_dir = None
if "DGLDEFAULTDIR" in os.environ:
default_dir = os.getenv('DGLDEFAULTDIR')
default_dir = os.getenv("DGLDEFAULTDIR")
else:
default_dir = os.path.join(os.path.expanduser('~'), '.dgl')
config_path = os.path.join(default_dir, 'config.json')
default_dir = os.path.join(os.path.expanduser("~"), ".dgl")
config_path = os.path.join(default_dir, "config.json")
backend_name = None
if "DGLBACKEND" in os.environ:
backend_name = os.getenv('DGLBACKEND')
backend_name = os.getenv("DGLBACKEND")
elif os.path.exists(config_path):
with open(config_path, "r") as config_file:
config_dict = json.load(config_file)
backend_name = config_dict.get('backend', '').lower()
backend_name = config_dict.get("backend", "").lower()
if (backend_name in ['tensorflow', 'mxnet', 'pytorch']):
if backend_name in ["tensorflow", "mxnet", "pytorch"]:
return backend_name
else:
print("DGL backend not selected or invalid. "
"Assuming PyTorch for now.", file=sys.stderr)
set_default_backend(default_dir, 'pytorch')
return 'pytorch'
print(
"DGL backend not selected or invalid. "
"Assuming PyTorch for now.",
file=sys.stderr,
)
set_default_backend(default_dir, "pytorch")
return "pytorch"
load_backend(get_preferred_backend())
......@@ -124,8 +137,10 @@ def is_enabled(api):
"""
return api in _enabled_apis
def to_dgl_nd(data):
return zerocopy_to_dgl_ndarray(data)
def from_dgl_nd(data):
return zerocopy_from_dgl_ndarray(data)
This diff is collapsed.
from .tensor import *
from .sparse import *
from .tensor import *
import mxnet as mx
import numpy as np
from mxnet import nd
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _scatter_add
from ...sparse import _csrmm, _csrsum, _csrmask
from ...base import dgl_warning, is_all, ALL
from .tensor import asnumpy, copy_to, zerocopy_from_numpy, context, to_backend_ctx
from ...heterograph_index import create_unitgraph_from_csr
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce', 'scatter_add',
'csrmm', 'csrsum', 'csrmask']
from ...base import ALL, dgl_warning, is_all
from ...heterograph_index import create_unitgraph_from_csr
from ...sparse import (
_bwd_segment_cmp,
_csrmask,
_csrmm,
_csrsum,
_gsddmm,
_gspmm,
_scatter_add,
_segment_reduce,
)
from .tensor import (
asnumpy,
context,
copy_to,
to_backend_ctx,
zerocopy_from_numpy,
)
__all__ = [
"gspmm",
"gsddmm",
"edge_softmax",
"segment_reduce",
"scatter_add",
"csrmm",
"csrsum",
"csrmask",
]
def _scatter_nd(index, src, n_rows):
......@@ -26,7 +49,10 @@ def _scatter_nd(index, src, n_rows):
di = shp[i]
offset_i = np.arange(di, dtype=index.dtype)
offsets.append(
(stride * offset_i).reshape((1,) * i + (di,) + (1,) * (ndim - 1 - i)))
(stride * offset_i).reshape(
(1,) * i + (di,) + (1,) * (ndim - 1 - i)
)
)
stride *= di
if ndim > 1:
new_idx = index * stride + sum(offsets)
......@@ -52,7 +78,10 @@ def _gather_nd(index, src):
di = shp[i]
offset_i = nd.arange(di, dtype=index.dtype)
offsets.append(
(stride * offset_i).reshape((1,) * i + (di,) + (1,) * (ndim - 1 - i)))
(stride * offset_i).reshape(
(1,) * i + (di,) + (1,) * (ndim - 1 - i)
)
)
stride *= di
if ndim > 1:
new_idx = index * stride + copy_to(sum(offsets), ctx)
......@@ -107,11 +136,11 @@ def _need_reduce_last_dim(ufeat, efeat):
def _muldiv(op, x):
return 1. / x if op == 'div' else x
return 1.0 / x if op == "div" else x
def _addsub(op, x):
return -x if op == 'sub' else x
return -x if op == "sub" else x
def _expand(x, shape):
......@@ -134,45 +163,48 @@ class GSpMM(mx.autograd.Function):
ctx = context(dZ)
X, Y, argX, argY = self.saved_tensors
gidx, op, reduce_op = self.gidx, self.op, self.reduce_op
if op != 'copy_rhs':
if op != "copy_rhs":
g_rev = gidx.reverse()
if reduce_op == 'sum':
if op in ['mul', 'div']:
dX = _gspmm(g_rev, 'mul', 'sum', dZ, _muldiv(op, Y))[0]
elif op in ['add', 'sub']:
dX = _gspmm(g_rev, 'copy_lhs', 'sum', dZ, Y)[0]
elif op == 'copy_lhs':
dX = _gspmm(g_rev, 'copy_lhs', 'sum', dZ, None)[0]
if reduce_op == "sum":
if op in ["mul", "div"]:
dX = _gspmm(g_rev, "mul", "sum", dZ, _muldiv(op, Y))[0]
elif op in ["add", "sub"]:
dX = _gspmm(g_rev, "copy_lhs", "sum", dZ, Y)[0]
elif op == "copy_lhs":
dX = _gspmm(g_rev, "copy_lhs", "sum", dZ, None)[0]
else:
if op in ['mul', 'div']:
if op in ["mul", "div"]:
dX = _scatter_nd(
argX,
_muldiv(op, _gather_nd(argY, _expand(Y, dZ.shape[1:]))) * dZ,
X.shape[0])
elif op in ['add', 'sub', 'copy_lhs']:
_muldiv(op, _gather_nd(argY, _expand(Y, dZ.shape[1:])))
* dZ,
X.shape[0],
)
elif op in ["add", "sub", "copy_lhs"]:
dX = _scatter_nd(argX, dZ, X.shape[0])
dX = _reduce_grad(dX, X.shape)
else:
dX = nd.zeros_like(X)
if op != 'copy_lhs':
if reduce_op == 'sum':
if op == 'mul' and _need_reduce_last_dim(X, Y):
dY = _gsddmm(gidx, 'dot', X, dZ)
elif op in ['mul', 'div']:
dY = _gsddmm(gidx, 'mul', X, dZ)
if op == 'div':
dY = -dY / (Y ** 2)
elif op in ['add', 'sub', 'copy_rhs']:
dY = _gsddmm(gidx, 'copy_rhs', X, _addsub(op, dZ))
if op != "copy_lhs":
if reduce_op == "sum":
if op == "mul" and _need_reduce_last_dim(X, Y):
dY = _gsddmm(gidx, "dot", X, dZ)
elif op in ["mul", "div"]:
dY = _gsddmm(gidx, "mul", X, dZ)
if op == "div":
dY = -dY / (Y**2)
elif op in ["add", "sub", "copy_rhs"]:
dY = _gsddmm(gidx, "copy_rhs", X, _addsub(op, dZ))
else:
if op in ['mul', 'div']:
if op in ["mul", "div"]:
dY = _scatter_nd(
argY,
_gather_nd(argX, _expand(X, dZ.shape[1:])) * dZ,
Y.shape[0])
if op == 'div':
dY = -dY / (Y ** 2)
elif op in ['add', 'sub', 'copy_rhs']:
Y.shape[0],
)
if op == "div":
dY = -dY / (Y**2)
elif op in ["add", "sub", "copy_rhs"]:
dY = _scatter_nd(argY, _addsub(op, dZ), Y.shape[0])
dY = _reduce_grad(dY, Y.shape)
else:
......@@ -207,7 +239,9 @@ class GSDDMM(mx.autograd.Function):
self.rhs_target = rhs_target
def forward(self, X, Y):
out = _gsddmm(self.gidx, self.op, X, Y, self.lhs_target, self.rhs_target)
out = _gsddmm(
self.gidx, self.op, X, Y, self.lhs_target, self.rhs_target
)
self.save_for_backward(X, Y)
return out
......@@ -216,47 +250,55 @@ class GSDDMM(mx.autograd.Function):
X, Y = self.saved_tensors
gidx, op = self.gidx, self.op
lhs_target, rhs_target = self.lhs_target, self.rhs_target
if op != 'copy_rhs':
if lhs_target in ['u', 'v']:
_gidx = gidx if self.lhs_target == 'v' else gidx.reverse()
if op in ['add', 'sub', 'copy_lhs']:
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0]
if op != "copy_rhs":
if lhs_target in ["u", "v"]:
_gidx = gidx if self.lhs_target == "v" else gidx.reverse()
if op in ["add", "sub", "copy_lhs"]:
dX = _gspmm(_gidx, "copy_rhs", "sum", None, dZ)[0]
else: # mul, div, dot
if rhs_target == lhs_target:
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] * _muldiv(op, Y)
elif self.rhs_target == 'e':
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * _muldiv(op, Y))[0]
dX = _gspmm(_gidx, "copy_rhs", "sum", None, dZ)[
0
] * _muldiv(op, Y)
elif self.rhs_target == "e":
dX = _gspmm(
_gidx, "copy_rhs", "sum", None, dZ * _muldiv(op, Y)
)[0]
else: # rhs_target = !lhs_target
dX = _gspmm(_gidx, 'mul', 'sum', _muldiv(op, Y), dZ)[0]
dX = _gspmm(_gidx, "mul", "sum", _muldiv(op, Y), dZ)[0]
else: # lhs_target == 'e'
if op in ['add', 'sub', 'copy_lhs']:
if op in ["add", "sub", "copy_lhs"]:
dX = dZ
else: # mul, div, dot
dX = _gsddmm(gidx, 'mul', dZ, _muldiv(op, Y), 'e', rhs_target)
dX = _gsddmm(
gidx, "mul", dZ, _muldiv(op, Y), "e", rhs_target
)
dX = _reduce_grad(dX, X.shape)
else:
dX = nd.zeros_like(X)
if op != 'copy_lhs':
if self.rhs_target in ['u', 'v']:
_gidx = gidx if rhs_target == 'v' else gidx.reverse()
if op in ['add', 'sub', 'copy_rhs']:
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, _addsub(op, dZ))[0]
if op != "copy_lhs":
if self.rhs_target in ["u", "v"]:
_gidx = gidx if rhs_target == "v" else gidx.reverse()
if op in ["add", "sub", "copy_rhs"]:
dY = _gspmm(
_gidx, "copy_rhs", "sum", None, _addsub(op, dZ)
)[0]
else: # mul, div, dot
if lhs_target == rhs_target:
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] * X
elif self.lhs_target == 'e':
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * X)[0]
dY = _gspmm(_gidx, "copy_rhs", "sum", None, dZ)[0] * X
elif self.lhs_target == "e":
dY = _gspmm(_gidx, "copy_rhs", "sum", None, dZ * X)[0]
else: # rhs_target = !lhs_target
dY = _gspmm(_gidx, 'mul', 'sum', X, dZ)[0]
if op == 'div':
dY = -dY / (Y ** 2)
dY = _gspmm(_gidx, "mul", "sum", X, dZ)[0]
if op == "div":
dY = -dY / (Y**2)
else:
if op in ['add', 'sub', 'copy_rhs']:
if op in ["add", "sub", "copy_rhs"]:
dY = _addsub(op, dZ)
else: # mul, div, dot
dY = _gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target)
if op == 'div':
dY = -dY / (Y ** 2)
dY = _gsddmm(gidx, "mul", dZ, X, "e", lhs_target)
if op == "div":
dY = -dY / (Y**2)
dY = _reduce_grad(dY, Y.shape)
else:
dY = nd.zeros_like(Y)
......@@ -264,7 +306,7 @@ class GSDDMM(mx.autograd.Function):
return dX, dY
def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target="u", rhs_target="v"):
func = GSDDMM(gidx, op, lhs_target, rhs_target)
ctx = to_backend_ctx(gidx.ctx)
if lhs_data is None:
......@@ -279,7 +321,7 @@ class EdgeSoftmax(mx.autograd.Function):
super(EdgeSoftmax, self).__init__()
if not is_all(eids):
gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == 'src':
if norm_by == "src":
gidx = gidx.reverse()
self.gidx = gidx
......@@ -298,10 +340,10 @@ class EdgeSoftmax(mx.autograd.Function):
return out.data
"""
gidx = self.gidx
score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0]
score = mx.nd.exp(_gsddmm(gidx, 'sub', score, score_max, 'e', 'v'))
score_sum = _gspmm(gidx, 'copy_rhs', 'sum', None, score)[0]
out = _gsddmm(gidx, 'div', score, score_sum, 'e', 'v')
score_max = _gspmm(gidx, "copy_rhs", "max", None, score)[0]
score = mx.nd.exp(_gsddmm(gidx, "sub", score, score_max, "e", "v"))
score_sum = _gspmm(gidx, "copy_rhs", "sum", None, score)[0]
out = _gsddmm(gidx, "div", score, score_sum, "e", "v")
self.save_for_backward(out)
return out
......@@ -319,16 +361,16 @@ class EdgeSoftmax(mx.autograd.Function):
sds_sum = sds.dst_sum() # type dgl.NData
grad_score = sds - sds * sds_sum # multiple expressions
"""
out, = self.saved_tensors
(out,) = self.saved_tensors
gidx = self.gidx
sds = out * grad_out
accum = gspmm(gidx, 'copy_rhs', 'sum', None, sds)
grad_score = sds - gsddmm(gidx, 'mul', out, accum, 'e', 'v')
accum = gspmm(gidx, "copy_rhs", "sum", None, sds)
grad_score = sds - gsddmm(gidx, "mul", out, accum, "e", "v")
self.save_tensors = None
return grad_score
def edge_softmax(gidx, logits, eids=ALL, norm_by='dst'):
def edge_softmax(gidx, logits, eids=ALL, norm_by="dst"):
softmax_op = EdgeSoftmax(gidx, eids, norm_by)
return softmax_op(logits)
......@@ -345,10 +387,10 @@ class SegmentReduce(mx.autograd.Function):
return y
def backward(self, dy):
arg, = self.saved_tensors
(arg,) = self.saved_tensors
offsets = self.offsets
m = offsets[-1].asscalar()
if self.op == 'sum':
if self.op == "sum":
offsets_np = asnumpy(offsets[1:])
indices_np = np.zeros((m + 1,), dtype=offsets_np.dtype)
np.add.at(indices_np, offsets_np, np.ones_like(offsets_np))
......@@ -392,36 +434,66 @@ class CSRMM(mx.autograd.Function):
self.num_vtypes = num_vtypes
def forward(self, A_weights, B_weights):
gidxC, C_weights = _csrmm(self.gidxA, A_weights, self.gidxB, B_weights, self.num_vtypes)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(0, False, 'csr')
gidxC, C_weights = _csrmm(
self.gidxA, A_weights, self.gidxB, B_weights, self.num_vtypes
)
(
nrows,
ncols,
C_indptr,
C_indices,
C_eids,
) = gidxC.adjacency_matrix_tensors(0, False, "csr")
# Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same
# as the underlying tensors of the created graph gidxC.
self.backward_cache = gidxC
self.save_for_backward(A_weights, B_weights)
nrows = nd.array([nrows], dtype='int64')
ncols = nd.array([ncols], dtype='int64')
nrows = nd.array([nrows], dtype="int64")
ncols = nd.array([ncols], dtype="int64")
return nrows, ncols, C_indptr, C_indices, C_eids, C_weights
def backward(self, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):
def backward(
self, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights
):
# Only the last argument is meaningful.
gidxC = self.backward_cache
A_weights, B_weights = self.saved_tensors
dgidxA, dA_weights = _csrmm(
gidxC, dC_weights, self.gidxB.reverse(), B_weights, self.gidxA.number_of_ntypes())
gidxC,
dC_weights,
self.gidxB.reverse(),
B_weights,
self.gidxA.number_of_ntypes(),
)
dgidxB, dB_weights = _csrmm(
self.gidxA.reverse(), A_weights, gidxC, dC_weights, self.gidxB.number_of_ntypes())
self.gidxA.reverse(),
A_weights,
gidxC,
dC_weights,
self.gidxB.number_of_ntypes(),
)
dA_weights = _csrmask(dgidxA, dA_weights, self.gidxA)
dB_weights = _csrmask(dgidxB, dB_weights, self.gidxB)
return dA_weights, dB_weights
def csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes):
op = CSRMM(gidxA, gidxB, num_vtypes)
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = op(A_weights, B_weights)
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = op(
A_weights, B_weights
)
gidxC = create_unitgraph_from_csr(
num_vtypes, nrows.asscalar(), ncols.asscalar(), C_indptr, C_indices, C_eids,
["coo", "csr", "csc"])
num_vtypes,
nrows.asscalar(),
ncols.asscalar(),
C_indptr,
C_indices,
C_eids,
["coo", "csr", "csc"],
)
return gidxC, C_weights
class CSRSum(mx.autograd.Function):
def __init__(self, gidxs):
super().__init__()
......@@ -429,29 +501,44 @@ class CSRSum(mx.autograd.Function):
def forward(self, *weights):
gidxC, C_weights = _csrsum(self.gidxs, weights)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(
0, False, 'csr')
(
nrows,
ncols,
C_indptr,
C_indices,
C_eids,
) = gidxC.adjacency_matrix_tensors(0, False, "csr")
# Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same
# as the underlying tensors of the created graph gidxC.
self.backward_cache = gidxC
nrows = nd.array([nrows], dtype='int64')
ncols = nd.array([ncols], dtype='int64')
nrows = nd.array([nrows], dtype="int64")
ncols = nd.array([ncols], dtype="int64")
return nrows, ncols, C_indptr, C_indices, C_eids, C_weights
def backward(self, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):
def backward(
self, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights
):
# Only the last argument is meaningful.
gidxC = self.backward_cache
return tuple(csrmask(gidxC, dC_weights, gidx) for gidx in self.gidxs)
def csrsum(gidxs, weights):
op = CSRSum(gidxs)
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = op(*weights)
num_vtypes = gidxs[0].number_of_ntypes()
gidxC = create_unitgraph_from_csr(
num_vtypes, nrows.asscalar(), ncols.asscalar(), C_indptr, C_indices, C_eids,
["coo", "csr", "csc"])
num_vtypes,
nrows.asscalar(),
ncols.asscalar(),
C_indptr,
C_indices,
C_eids,
["coo", "csr", "csc"],
)
return gidxC, C_weights
class CSRMask(mx.autograd.Function):
def __init__(self, gidxA, gidxB):
super().__init__()
......@@ -464,6 +551,7 @@ class CSRMask(mx.autograd.Function):
def backward(self, dB_weights):
return _csrmask(self.gidxB, dB_weights, self.gidxA)
def csrmask(gidxA, A_weights, gidxB):
op = CSRMask(gidxA, gidxB)
return op(A_weights)
This diff is collapsed.
from .tensor import *
from .sparse import *
from .tensor import *
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment