Unverified Commit 8801154b authored by VoVAllen's avatar VoVAllen Committed by GitHub
Browse files

Merge pull request #1 from jermainewang/cpp

Cpp
parents b46abb09 b2c1c4fa
# pylint: disable=invalid-name, unused-import
"""Function namespace."""
from __future__ import absolute_import
import sys
import ctypes
from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
try:
# pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_function, _set_class_module
from ._cy3.core import FunctionBase as _FunctionBase
from ._cy3.core import convert_to_tvm_func
else:
from ._cy2.core import _set_class_function, _set_class_module
from ._cy2.core import FunctionBase as _FunctionBase
from ._cy2.core import convert_to_tvm_func
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_function, _set_class_module
from ._ctypes.function import FunctionBase as _FunctionBase
from ._ctypes.function import convert_to_tvm_func
FunctionHandle = ctypes.c_void_p
class Function(_FunctionBase):
"""The PackedFunc object.
Function plays an key role to bridge front and backend in TVM.
Function provide a type-erased interface, you can call function with positional arguments.
The compiled module returns Function.
TVM backend also registers and exposes its API as Functions.
For example, the developer function exposed in tvm.ir_pass are actually
C++ functions that are registered as PackedFunc
The following are list of common usage scenario of tvm.Function.
- Automatic exposure of C++ API into python
- To call PackedFunc from python side
- To call python callbacks to inspect results in generated code
- Bring python hook into C++ backend
See Also
--------
tvm.register_func: How to register global function.
tvm.get_global_func: How to get global function.
"""
pass
class ModuleBase(object):
"""Base class for module"""
__slots__ = ["handle", "_entry", "entry_name"]
def __init__(self, handle):
self.handle = handle
self._entry = None
self.entry_name = "__tvm_main__"
def __del__(self):
check_call(_LIB.TVMModFree(self.handle))
@property
def entry_func(self):
"""Get the entry function
Returns
-------
f : Function
The entry function if exist
"""
if self._entry:
return self._entry
self._entry = self.get_function(self.entry_name)
return self._entry
def get_function(self, name, query_imports=False):
"""Get function from the module.
Parameters
----------
name : str
The name of the function
query_imports : bool
Whether also query modules imported by this module.
Returns
-------
f : Function
The result function.
"""
ret_handle = FunctionHandle()
check_call(_LIB.TVMModGetFunction(
self.handle, c_str(name),
ctypes.c_int(query_imports),
ctypes.byref(ret_handle)))
if not ret_handle.value:
raise AttributeError(
"Module has no function '%s'" % name)
return Function(ret_handle, False)
def import_module(self, module):
"""Add module to the import list of current one.
Parameters
----------
module : Module
The other module.
"""
check_call(_LIB.TVMModImport(self.handle, module.handle))
def __getitem__(self, name):
if not isinstance(name, string_types):
raise ValueError("Can only take string as function name")
return self.get_function(name)
def __call__(self, *args):
if self._entry:
return self._entry(*args)
f = self.entry_func
return f(*args)
def register_func(func_name, f=None, override=False):
"""Register global function
Parameters
----------
func_name : str or function
The function name
f : function, optional
The function to be registered.
override: boolean optional
Whether override existing entry.
Returns
-------
fregister : function
Register function if f is not specified.
Examples
--------
The following code registers my_packed_func as global function.
Note that we simply get it back from global function table to invoke
it from python side. However, we can also invoke the same function
from C++ backend, or in the compiled TVM code.
.. code-block:: python
targs = (10, 10.0, "hello")
@tvm.register_func
def my_packed_func(*args):
assert(tuple(args) == targs)
return 10
# Get it out from global function table
f = tvm.get_global_func("my_packed_func")
assert isinstance(f, tvm.nd.Function)
y = f(*targs)
assert y == 10
"""
if callable(func_name):
f = func_name
func_name = f.__name__
if not isinstance(func_name, str):
raise ValueError("expect string function name")
ioverride = ctypes.c_int(override)
def register(myf):
"""internal register function"""
if not isinstance(myf, Function):
myf = convert_to_tvm_func(myf)
check_call(_LIB.TVMFuncRegisterGlobal(
c_str(func_name), myf.handle, ioverride))
return myf
if f:
return register(f)
return register
def get_global_func(name, allow_missing=False):
"""Get a global function by name
Parameters
----------
name : str
The name of the global function
allow_missing : bool
Whether allow missing function or raise an error.
Returns
-------
func : tvm.Function
The function to be returned, None if function is missing.
"""
handle = FunctionHandle()
check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))
if handle.value:
return Function(handle, False)
else:
if allow_missing:
return None
else:
raise ValueError("Cannot find global function %s" % name)
def list_global_func_names():
"""Get list of global functions registered.
Returns
-------
names : list
List of global functions names.
"""
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()
check_call(_LIB.TVMFuncListGlobalNames(ctypes.byref(size),
ctypes.byref(plist)))
fnames = []
for i in range(size.value):
fnames.append(py_str(plist[i]))
return fnames
def extract_ext_funcs(finit):
"""
Extract the extension PackedFuncs from a C module.
Parameters
----------
finit : ctypes function
a ctypes that takes signature of TVMExtensionDeclarer
Returns
-------
fdict : dict of str to Function
The extracted functions
"""
fdict = {}
def _list(name, func):
fdict[name] = func
myf = convert_to_tvm_func(_list)
ret = finit(myf.handle)
_ = myf
if ret != 0:
raise RuntimeError("cannot initialize with %s" % finit)
return fdict
def _get_api(f):
flocal = f
flocal.is_global = True
return flocal
def _init_api(namespace, target_module_name=None):
"""Initialize api for a given module name
namespace : str
The namespace of the source registry
target_module_name : str
The target module name if different from namespace
"""
target_module_name = (
target_module_name if target_module_name else namespace)
if namespace.startswith("dgl."):
_init_api_prefix(target_module_name, namespace[4:])
else:
_init_api_prefix(target_module_name, namespace)
def _init_api_prefix(module_name, prefix):
module = sys.modules[module_name]
for name in list_global_func_names():
if prefix == "api":
fname = name
if name.startswith("_"):
target_module = sys.modules["dgl._api_internal"]
else:
target_module = module
else:
if not name.startswith(prefix):
continue
fname = name[len(prefix)+1:]
target_module = module
if fname.find(".") != -1:
continue
f = get_global_func(name)
ff = _get_api(f)
ff.__name__ = fname
ff.__doc__ = ("DGL PackedFunc %s. " % fname)
setattr(target_module, ff.__name__, ff)
_set_class_function(Function)
"""Library information."""
from __future__ import absolute_import
import sys
import os
def find_lib_path(name=None, search_path=None, optional=False):
"""Find dynamic library files.
Parameters
----------
name : list of str
List of names to be found.
Returns
-------
lib_path : list(string)
List of all found path to the libraries
"""
# See https://github.com/dmlc/tvm/issues/281 for some background.
# NB: This will either be the source directory (if TVM is run
# inplace) or the install directory (if TVM is installed).
# An installed TVM's curr_path will look something like:
# $PREFIX/lib/python3.6/site-packages/tvm/_ffi
ffi_dir = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
source_dir = os.path.join(ffi_dir, "..", "..", "..")
install_lib_dir = os.path.join(ffi_dir, "..", "..", "..", "..")
dll_path = []
if os.environ.get('DGL_LIBRARY_PATH', None):
dll_path.append(os.environ['DGL_LIBRARY_PATH'])
if sys.platform.startswith('linux') and os.environ.get('LD_LIBRARY_PATH', None):
dll_path.extend([p.strip() for p in os.environ['LD_LIBRARY_PATH'].split(":")])
elif sys.platform.startswith('darwin') and os.environ.get('DYLD_LIBRARY_PATH', None):
dll_path.extend([p.strip() for p in os.environ['DYLD_LIBRARY_PATH'].split(":")])
# Pip lib directory
dll_path.append(os.path.join(ffi_dir, ".."))
# Default cmake build directory
dll_path.append(os.path.join(source_dir, "build"))
dll_path.append(os.path.join(source_dir, "build", "Release"))
# Default make build directory
dll_path.append(os.path.join(source_dir, "lib"))
dll_path.append(install_lib_dir)
dll_path = [os.path.abspath(x) for x in dll_path]
if search_path is not None:
if search_path is list:
dll_path = dll_path + search_path
else:
dll_path.append(search_path)
if name is not None:
if isinstance(name, list):
lib_dll_path = []
for n in name:
lib_dll_path += [os.path.join(p, n) for p in dll_path]
else:
lib_dll_path = [os.path.join(p, name) for p in dll_path]
else:
if sys.platform.startswith('win32'):
lib_dll_path = [os.path.join(p, 'libdgl.dll') for p in dll_path] +\
[os.path.join(p, 'dgl.dll') for p in dll_path]
elif sys.platform.startswith('darwin'):
lib_dll_path = [os.path.join(p, 'libdgl.dylib') for p in dll_path]
else:
lib_dll_path = [os.path.join(p, 'libdgl.so') for p in dll_path]
# try to find lib_dll_path
lib_found = [p for p in lib_dll_path if os.path.exists(p) and os.path.isfile(p)]
if not lib_found:
message = ('Cannot find the files.\n' +
'List of candidates:\n' +
str('\n'.join(lib_dll_path + runtime_dll_path)))
if not optional:
raise RuntimeError(message)
return None
return lib_found
# current version
# We use the version of the incoming release for code
# that is under development.
# The following line is set by tvm/python/update_version.py
__version__ = "0.5.dev"
# pylint: disable=invalid-name, unused-import
"""Runtime NDArray api"""
from __future__ import absolute_import
import sys
import ctypes
import numpy as np
from .base import _LIB, check_call, c_array, string_types, _FFI_MODE, c_str
from .runtime_ctypes import TVMType, TVMContext, TVMArray, TVMArrayHandle
from .runtime_ctypes import TypeCode, tvm_shape_index_t
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
try:
# pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
from ._cy3.core import NDArrayBase as _NDArrayBase
else:
from ._cy2.core import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
from ._cy2.core import NDArrayBase as _NDArrayBase
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.ndarray import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
def context(dev_type, dev_id=0):
"""Construct a TVM context with given device type and id.
Parameters
----------
dev_type: int or str
The device type mask or name of the device.
dev_id : int, optional
The integer device id
Returns
-------
ctx: TVMContext
The corresponding context.
Examples
--------
Context can be used to create reflection of context by
string representation of the device type.
.. code-block:: python
assert tvm.context("cpu", 1) == tvm.cpu(1)
assert tvm.context("gpu", 0) == tvm.gpu(0)
assert tvm.context("cuda", 0) == tvm.gpu(0)
"""
if isinstance(dev_type, string_types):
dev_type = dev_type.split()[0]
if dev_type not in TVMContext.STR2MASK:
raise ValueError("Unknown device type %s" % dev_type)
dev_type = TVMContext.STR2MASK[dev_type]
return TVMContext(dev_type, dev_id)
def numpyasarray(np_data):
"""Return a TVMArray representation of a numpy array.
"""
data = np_data
assert data.flags['C_CONTIGUOUS']
arr = TVMArray()
shape = c_array(tvm_shape_index_t, data.shape)
arr.data = data.ctypes.data_as(ctypes.c_void_p)
arr.shape = shape
arr.strides = None
arr.dtype = TVMType(np.dtype(data.dtype).name)
arr.ndim = data.ndim
# CPU device
arr.ctx = context(1, 0)
return arr, shape
def empty(shape, dtype="float32", ctx=context(1, 0)):
"""Create an empty array given shape and device
Parameters
----------
shape : tuple of int
The shape of the array
dtype : type or str
The data type of the array.
ctx : TVMContext
The context of the array
Returns
-------
arr : tvm.nd.NDArray
The array tvm supported.
"""
shape = c_array(tvm_shape_index_t, shape)
ndim = ctypes.c_int(len(shape))
handle = TVMArrayHandle()
dtype = TVMType(dtype)
check_call(_LIB.TVMArrayAlloc(
shape, ndim,
ctypes.c_int(dtype.type_code),
ctypes.c_int(dtype.bits),
ctypes.c_int(dtype.lanes),
ctx.device_type,
ctx.device_id,
ctypes.byref(handle)))
return _make_array(handle, False)
def from_dlpack(dltensor):
"""Produce an array from a DLPack tensor without memory copy.
Retreives the underlying DLPack tensor's pointer to create an array from the
data. Removes the original DLPack tensor's destructor as now the array is
responsible for destruction.
Parameters
----------
dltensor : DLPack tensor
Input DLManagedTensor, can only be consumed once.
Returns
-------
arr: tvm.nd.NDArray
The array view of the tensor data.
"""
return _from_dlpack(dltensor)
class NDArrayBase(_NDArrayBase):
"""A simple Device/CPU Array object in runtime."""
@property
def shape(self):
"""Shape of this array"""
return tuple(self.handle.contents.shape[i] for i in range(self.handle.contents.ndim))
@property
def dtype(self):
"""Type of this array"""
return str(self.handle.contents.dtype)
@property
def ctx(self):
"""context of this array"""
return self.handle.contents.ctx
@property
def context(self):
"""context of this array"""
return self.ctx
def __hash__(self):
return ctypes.cast(self.handle, ctypes.c_void_p).value
def __eq__(self, other):
return self.same_as(other)
def __ne__(self, other):
return not self.__eq__(other)
def same_as(self, other):
"""Check object identity equality
Parameters
----------
other : object
The other object to compare to
Returns
-------
same : bool
Whether other is same as self.
"""
if not isinstance(other, NDArrayBase):
return False
return self.__hash__() == other.__hash__()
def __setitem__(self, in_slice, value):
"""Set ndarray value"""
if (not isinstance(in_slice, slice) or
in_slice.start is not None
or in_slice.stop is not None):
raise ValueError('Array only support set from numpy array')
if isinstance(value, NDArrayBase):
if value.handle is not self.handle:
value.copyto(self)
elif isinstance(value, (np.ndarray, np.generic)):
self.copyfrom(value)
else:
raise TypeError('type %s not supported' % str(type(value)))
def copyfrom(self, source_array):
"""Peform an synchronize copy from the array.
Parameters
----------
source_array : array_like
The data source we should like to copy from.
Returns
-------
arr : NDArray
Reference to self.
"""
if isinstance(source_array, NDArrayBase):
source_array.copyto(self)
return self
if not isinstance(source_array, np.ndarray):
try:
source_array = np.array(source_array, dtype=self.dtype)
except:
raise TypeError('array must be an array_like data,' +
'type %s is not supported' % str(type(source_array)))
t = TVMType(self.dtype)
shape, dtype = self.shape, self.dtype
if t.lanes > 1:
shape = shape + (t.lanes,)
t.lanes = 1
dtype = str(t)
if source_array.shape != shape:
raise ValueError("array shape do not match the shape of NDArray {0} vs {1}".format(
source_array.shape, shape))
source_array = np.ascontiguousarray(source_array, dtype=dtype)
assert source_array.flags['C_CONTIGUOUS']
data = source_array.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize)
check_call(_LIB.TVMArrayCopyFromBytes(self.handle, data, nbytes))
return self
def __repr__(self):
res = "<tvm.NDArray shape={0}, {1}>\n".format(self.shape, self.context)
res += self.asnumpy().__repr__()
return res
def __str__(self):
return str(self.asnumpy())
def asnumpy(self):
"""Convert this array to numpy array
Returns
-------
np_arr : numpy.ndarray
The corresponding numpy array.
"""
t = TVMType(self.dtype)
shape, dtype = self.shape, self.dtype
if t.lanes > 1:
shape = shape + (t.lanes,)
t.lanes = 1
dtype = str(t)
np_arr = np.empty(shape, dtype=dtype)
assert np_arr.flags['C_CONTIGUOUS']
data = np_arr.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize)
check_call(_LIB.TVMArrayCopyToBytes(self.handle, data, nbytes))
return np_arr
def copyto(self, target):
"""Copy array to target
Parameters
----------
target : NDArray
The target array to be copied, must have same shape as this array.
"""
if isinstance(target, TVMContext):
target = empty(self.shape, self.dtype, target)
if isinstance(target, NDArrayBase):
check_call(_LIB.TVMArrayCopyFromTo(
self.handle, target.handle, None))
else:
raise ValueError("Unsupported target type %s" % str(type(target)))
return target
def free_extension_handle(handle, type_code):
"""Free c++ extension type handle
Parameters
----------
handle : ctypes.c_void_p
The handle to the extension type.
type_code : int
The tyoe code
"""
check_call(_LIB.TVMExtTypeFree(handle, ctypes.c_int(type_code)))
def register_extension(cls, fcreate=None):
"""Register a extension class to TVM.
After the class is registered, the class will be able
to directly pass as Function argument generated by TVM.
Parameters
----------
cls : class
The class object to be registered as extension.
Note
----
The registered class is requires one property: _tvm_handle and a class attribute _tvm_tcode.
- ```_tvm_handle``` returns integer represents the address of the handle.
- ```_tvm_tcode``` gives integer represents type code of the class.
Returns
-------
cls : class
The class being registered.
fcreate : function, optional
The creation function to create a class object given handle value.
Example
-------
The following code registers user defined class
MyTensor to be DLTensor compatible.
.. code-block:: python
@tvm.register_extension
class MyTensor(object):
_tvm_tcode = tvm.TypeCode.ARRAY_HANDLE
def __init__(self):
self.handle = _LIB.NewDLTensor()
@property
def _tvm_handle(self):
return self.handle.value
"""
if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN:
raise ValueError("Cannot register create when extension tcode is same as buildin")
_reg_extension(cls, fcreate)
return cls
"""Common runtime ctypes."""
# pylint: disable=invalid-name
from __future__ import absolute_import
import ctypes
import json
import numpy as np
from .base import _LIB, check_call
from .. import _api_internal
tvm_shape_index_t = ctypes.c_int64
class TypeCode(object):
"""Type code used in API calls"""
INT = 0
UINT = 1
FLOAT = 2
HANDLE = 3
NULL = 4
TVM_TYPE = 5
TVM_CONTEXT = 6
ARRAY_HANDLE = 7
NODE_HANDLE = 8
MODULE_HANDLE = 9
FUNC_HANDLE = 10
STR = 11
BYTES = 12
NDARRAY_CONTAINER = 13
EXT_BEGIN = 15
class TVMByteArray(ctypes.Structure):
"""Temp data structure for byte array."""
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)]
class TVMType(ctypes.Structure):
"""TVM datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8),
("lanes", ctypes.c_uint16)]
CODE2STR = {
0 : 'int',
1 : 'uint',
2 : 'float',
4 : 'handle'
}
def __init__(self, type_str):
super(TVMType, self).__init__()
if isinstance(type_str, np.dtype):
type_str = str(type_str)
arr = type_str.split("x")
head = arr[0]
self.lanes = int(arr[1]) if len(arr) > 1 else 1
bits = 32
if head.startswith("int"):
self.type_code = 0
head = head[3:]
elif head.startswith("uint"):
self.type_code = 1
head = head[4:]
elif head.startswith("float"):
self.type_code = 2
head = head[5:]
elif head.startswith("handle"):
self.type_code = 4
bits = 64
head = ""
else:
raise ValueError("Donot know how to handle type %s" % type_str)
bits = int(head) if head else bits
self.bits = bits
def __repr__(self):
x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits)
if self.lanes != 1:
x += "x%d" % self.lanes
return x
def __eq__(self, other):
return (self.bits == other.bits and
self.type_code == other.type_code and
self.lanes == other.lanes)
def __ne__(self, other):
return not self.__eq__(other)
RPC_SESS_MASK = 128
class TVMContext(ctypes.Structure):
"""TVM context strucure."""
_fields_ = [("device_type", ctypes.c_int),
("device_id", ctypes.c_int)]
MASK2STR = {
1 : 'cpu',
2 : 'gpu',
4 : 'opencl',
5 : 'aocl',
6 : 'sdaccel',
7 : 'vulkan',
8 : 'metal',
9 : 'vpi',
10: 'rocm',
11: 'opengl',
12: 'ext_dev',
}
STR2MASK = {
'llvm': 1,
'stackvm': 1,
'cpu': 1,
'gpu': 2,
'cuda': 2,
'nvptx': 2,
'cl': 4,
'opencl': 4,
'aocl' : 5,
'aocl_sw_emu' : 5,
'sdaccel': 6,
'vulkan': 7,
'metal': 8,
'vpi': 9,
'rocm': 10,
'opengl': 11,
'ext_dev': 12,
}
def __init__(self, device_type, device_id):
super(TVMContext, self).__init__()
self.device_type = device_type
self.device_id = device_id
@property
def exist(self):
"""Whether this device exist."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 0) != 0
@property
def max_threads_per_block(self):
"""Maximum number of threads on each block."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 1)
@property
def warp_size(self):
"""Number of threads that executes in concurrent."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 2)
@property
def max_shared_memory_per_block(self):
"""Total amount of shared memory per block in bytes."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 3)
@property
def compute_version(self):
"""Get compute verison number in string.
Currently used to get compute capability of CUDA device.
Returns
-------
version : str
The version string in `major.minor` format.
"""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 4)
@property
def device_name(self):
"""Return the string name of device."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 5)
@property
def max_clock_rate(self):
"""Return the max clock frequency of device."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 6)
@property
def multi_processor_count(self):
"""Return the number of compute units of device."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 7)
@property
def max_thread_dimensions(self):
"""Return the maximum size of each thread axis
Returns
-------
dims: List of int
The maximum length of threadIdx.x, threadIdx.y, threadIdx.z
"""
return json.loads(_api_internal._GetDeviceAttr(
self.device_type, self.device_id, 8))
def sync(self):
"""Synchronize until jobs finished at the context."""
check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, None))
def __eq__(self, other):
return (isinstance(other, TVMContext) and
self.device_id == other.device_id and
self.device_type == other.device_type)
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
if self.device_type >= RPC_SESS_MASK:
tbl_id = self.device_type / RPC_SESS_MASK - 1
dev_type = self.device_type % RPC_SESS_MASK
return "remote[%d]:%s(%d)" % (
tbl_id, TVMContext.MASK2STR[dev_type], self.device_id)
return "%s(%d)" % (
TVMContext.MASK2STR[self.device_type], self.device_id)
def __hash__(self):
return hash((self.device_type, self.device_id))
class TVMArray(ctypes.Structure):
"""TVMValue in C API"""
_fields_ = [("data", ctypes.c_void_p),
("ctx", TVMContext),
("ndim", ctypes.c_int),
("dtype", TVMType),
("shape", ctypes.POINTER(tvm_shape_index_t)),
("strides", ctypes.POINTER(tvm_shape_index_t)),
("byte_offset", ctypes.c_uint64)]
TVMArrayHandle = ctypes.POINTER(TVMArray)
from __future__ import absolute_import
import os import os
__backend__ = os.environ.get('DGLBACKEND', 'pytorch').lower() __backend__ = os.environ.get('DGLBACKEND', 'pytorch').lower()
if __backend__ == 'numpy': if __backend__ == 'numpy':
from dgl.backend.numpy import * from .numpy import *
elif __backend__ == 'pytorch': elif __backend__ == 'pytorch':
from dgl.backend.pytorch import * from .pytorch import *
elif __backend__ == 'mxnet':
from .mxnet import *
else: else:
raise Exception("Unsupported backend %s" % __backend__) raise Exception("Unsupported backend %s" % __backend__)
from __future__ import absolute_import
import numpy as np
import mxnet as mx
import mxnet.ndarray as F
import scipy.sparse
import ctypes
from .._ffi.base import _LIB, check_call, c_array
from .._ffi.runtime_ctypes import TVMType, TVMContext, TVMArray
from .._ffi.runtime_ctypes import TypeCode, tvm_shape_index_t
# Tensor types
Tensor = mx.nd.NDArray
SparseTensor = mx.nd.sparse.CSRNDArray
# Data types
float16 = np.float16
float32 = np.float32
float64 = np.float64
uint8 = np.uint8
int8 = np.int8
int16 = np.int16
int32 = np.int32
int64 = np.int64
# Operators
tensor = mx.nd.array
sum = F.sum
def max(x):
return F.max(x).asnumpy()[0]
def sparse_tensor(idx, data, shape):
return mx.nd.sparse.csr_matrix((data, (idx[0], idx[1])), tuple(shape))
def astype(a, ty):
return F.cast(a, ty)
def asnumpy(a):
return a.asnumpy()
def from_numpy(np_data):
return mx.nd.array(np_data, dtype=np_data.dtype)
def pack(tensors):
return F.concat(*tensors, dim=0)
def unpack(x, indices_or_sections=1):
return th.split(x, indices_or_sections)
# TODO this doesn't exist for symbol.
def shape(x):
return x.shape
def dtype(x):
return x.dtype
def isinteger(x):
return x.dtype in [np.int, np.int8, np.int16, np.int32, np.int64]
def unique(x):
# TODO this isn't the best way of running unique.
tmp = x.asnumpy()
tmp = np.unique(tmp)
return mx.nd.array(tmp, ctx=x.context, dtype=x.dtype)
def gather_row(data, row_index):
return data[row_index,]
scatter_row = mx.nd.contrib.index_copy
def broadcast_to(x, to_array):
return x + F.zeros_like(to_array)
squeeze = F.squeeze
unsqueeze = F.expand_dims
# TODO this doesn't exist for symbol.
reshape = F.reshape
ones = F.ones
zeros = F.zeros
arange = F.arange
def spmm(spm, mat):
return mx.nd.dot(spm, mat)
def sort(x, dim=None, descending=False):
if dim is None:
dim = -1
ascend = not descending
# TODO this isn't an ideal implementation.
val = F.sort(x, axis=dim, is_ascend=ascend)
idx = F.argsort(x, axis=dim, is_ascend=ascend)
idx = F.cast(idx, dtype='int64')
return val, idx
def to_context(x, ctx):
if ctx is None:
return x
elif ctx.device_type == TVMContext.STR2MASK['cuda']:
return x.as_in_context(mx.gpu(ctx.device_id))
elif ctx.device_type == TVMContext.STR2MASK['cpu']:
return x.as_in_context(mx.cpu())
else:
raise RuntimeError('Invalid context', ctx)
def get_context(x):
if x.context.device_type == 'cpu':
return TVMContext(TVMContext.STR2MASK['cpu'], 0)
else:
return TVMContext(
TVMContext.STR2MASK[x.context.device_type], x.context.device_id)
def _typestr(arr_dtype):
return arr_dtype
def zerocopy_to_dlpack(arr):
"""Return a dlpack compatible array using zero copy."""
return arr.to_dlpack_for_read()
def zerocopy_from_dlpack(dlpack_arr):
"""Return a tensor using zero copy."""
return mx.nd.from_dlpack(dlpack_arr)
def zerocopy_to_numpy(arr):
"""Return a numpy array that shares the data."""
return arr.asnumpy()
def zerocopy_from_numpy(np_data):
"""Return a tensor that shares the numpy data."""
return mx.nd.array(np_data, dtype=np_data.dtype)
...@@ -22,3 +22,7 @@ def unpack(a, split_size_or_sections=None): ...@@ -22,3 +22,7 @@ def unpack(a, split_size_or_sections=None):
def shape(a): def shape(a):
return a.shape return a.shape
def nonzero_1d(a):
assert a.ndim == 2
return np.nonzero(a)[0]
from __future__ import absolute_import from __future__ import absolute_import
import ctypes
import scipy as sp
import torch as th import torch as th
import scipy.sparse from torch.utils import dlpack
import dgl.context as context
from .._ffi.base import _LIB, check_call, c_array
from .._ffi.runtime_ctypes import TVMType, TVMContext, TVMArray
from .._ffi.runtime_ctypes import TypeCode, tvm_shape_index_t
from .. import ndarray as nd
# Tensor types # Tensor types
Tensor = th.Tensor Tensor = th.Tensor
...@@ -23,6 +29,7 @@ tensor = th.tensor ...@@ -23,6 +29,7 @@ tensor = th.tensor
sparse_tensor = th.sparse.FloatTensor sparse_tensor = th.sparse.FloatTensor
sum = th.sum sum = th.sum
max = th.max max = th.max
stack = th.stack
def astype(a, ty): def astype(a, ty):
return a.type(ty) return a.type(ty)
...@@ -30,8 +37,11 @@ def astype(a, ty): ...@@ -30,8 +37,11 @@ def astype(a, ty):
def asnumpy(a): def asnumpy(a):
return a.cpu().numpy() return a.cpu().numpy()
def pack(tensors): def from_numpy(np_data):
return th.cat(tensors) return th.from_numpy(np_data)
def pack(tensors, dim=0):
return th.cat(tensors, dim)
def unpack(x, indices_or_sections=1): def unpack(x, indices_or_sections=1):
return th.split(x, indices_or_sections) return th.split(x, indices_or_sections)
...@@ -39,8 +49,8 @@ def unpack(x, indices_or_sections=1): ...@@ -39,8 +49,8 @@ def unpack(x, indices_or_sections=1):
def shape(x): def shape(x):
return x.shape return x.shape
def isinteger(x): def dtype(x):
return x.dtype in [th.int, th.int8, th.int16, th.int32, th.int64] return x.dtype
unique = th.unique unique = th.unique
...@@ -65,19 +75,76 @@ sort = th.sort ...@@ -65,19 +75,76 @@ sort = th.sort
arange = th.arange arange = th.arange
mul = th.mul mul = th.mul
def to_context(x, ctx): def to_context(arr, ctx):
if ctx is None: if ctx is None:
return x return arr
elif ctx.device == 'gpu': elif ctx.device_type == TVMContext.STR2MASK['cuda']:
th.cuda.set_device(ctx.device_id) th.cuda.set_device(ctx.device_id)
return x.cuda() return arr.cuda()
elif ctx.device == 'cpu': elif ctx.device_type == TVMContext.STR2MASK['cpu']:
return x.cpu() return arr.cpu()
else: else:
raise RuntimeError('Invalid context', ctx) raise RuntimeError('Invalid context', ctx)
def get_context(x): def get_context(arr):
if x.device.type == 'cpu': if arr.device.type == 'cpu':
return context.cpu() return TVMContext(TVMContext.STR2MASK['cpu'], 0)
else:
return TVMContext(
TVMContext.STR2MASK[arr.device.type], arr.device.index)
def _typestr(arr_dtype):
if arr_dtype in (th.float16, th.half):
return 'float16'
elif arr_dtype in (th.float32, th.float):
return 'float32'
elif arr_dtype in (th.float64, th.double):
return 'float64'
elif arr_dtype in (th.int16, th.short):
return 'int16'
elif arr_dtype in (th.int32, th.int):
return 'int32'
elif arr_dtype in (th.int64, th.long):
return 'int64'
elif arr_dtype == th.int8:
return 'int8'
elif arr_dtype == th.uint8:
return 'uint8'
else: else:
return context.gpu(x.device.index) raise RuntimeError('Unsupported data type:', arr_dtype)
def zerocopy_to_dlpack(arr):
"""Return a dlpack compatible array using zero copy."""
return dlpack.to_dlpack(arr)
def zerocopy_from_dlpack(dlpack_arr):
"""Return a tensor using zero copy."""
return dlpack.from_dlpack(dlpack_arr)
def zerocopy_to_numpy(arr):
"""Return a numpy array that shares the data."""
# TODO(minjie): zero copy
return arr.numpy()
def zerocopy_from_numpy(np_data):
"""Return a tensor that shares the numpy data."""
return th.from_numpy(np_data)
'''
data = arr_data
assert data.is_contiguous()
arr = TVMArray()
shape = c_array(tvm_shape_index_t, tuple(data.shape))
arr.data = ctypes.cast(data.data_ptr(), ctypes.c_void_p)
arr.shape = shape
arr.strides = None
arr.dtype = TVMType(_typestr(data.dtype))
arr.ndim = len(shape)
arr.ctx = get_context(data)
return arr
'''
def nonzero_1d(arr):
"""Return a 1D tensor with nonzero element indices in a 1D vector"""
assert arr.dim() == 1
return th.nonzero(arr)[:, 0]
"""Classes and functions for batching multiple graphs together."""
from __future__ import absolute_import
import numpy as np
from dgl.graph import DGLGraph
import dgl.backend as F
import dgl
class BatchedDGLGraph(DGLGraph):
def __init__(self, graph_list, node_attrs=None, edge_attrs=None, **attr):
super(BatchedDGLGraph, self).__init__(**attr)
self.graph_list = graph_list
self.graph_idx = {}
for idx, g in enumerate(self.graph_list):
self.graph_idx[g] = idx
self.num_nodes = [len(g) for g in self.graph_list]
self.num_edges = [g.size() for g in self.graph_list]
# calc index offset
self.node_offset = np.cumsum([0] + self.num_nodes)
self.edge_offset = np.cumsum([0] + self.num_edges)
# in-order add relabeled nodes
self.add_nodes_from(range(self.node_offset[-1]))
# in-order add relabeled edges
self.new_edge_list = [np.array(g.edge_list) + offset
for g, offset in zip(self.graph_list, self.node_offset[:-1])]
self.new_edges = np.concatenate(self.new_edge_list)
self.add_edges_from(self.new_edges)
assert self.size() == self.edge_offset[-1]
# set new node attr
if node_attrs:
attrs = {}
for key in node_attrs:
vals = [g.pop_n_repr(key) for g in self.graph_list]
attrs[key] = F.pack(vals)
self.set_n_repr(attrs)
else:
for g in self.graph_list:
self._node_frame.append(g._node_frame)
# set new edge attr
if edge_attrs:
attrs = {}
for key in edge_attrs:
vals = [g.pop_e_repr(key) for g in self.graph_list]
attrs[key] = F.pack(vals)
self.set_e_repr(attrs)
else:
for g in self.graph_list:
self._edge_frame.append(g._edge_frame)
def query_new_node(self, g, u):
idx = self.graph_idx[g]
offset = self.node_offset[idx]
if isinstance(u, (int, np.array, F.Tensor)):
return u + offset
else:
return np.array(u) + offset
def query_new_edge(self, g, src, dst):
idx = self.graph_idx[g]
offset = self.node_offset[idx]
if isinstance(src, (int, np.ndarray, F.Tensor)) and \
isinstance(dst, (int, np.ndarray, F.Tensor)):
return src + offset, dst + offset
else:
return np.array(src) + offset, np.array(dst) + offset
def query_node_start_offset(self):
return self.node_offset[:-1].copy()
def query_edge_start_offset(self):
return self.edge_offset[:-1].copy()
def unbatch(graph_batch):
"""Unbatch the graph and return a list of subgraphs.
Parameters
----------
graph_batch : DGLGraph
The batched graph.
"""
graph_list = graph_batch.graph_list
num_graphs = len(graph_list)
# split and set node attrs
attrs = [{} for _ in range(num_graphs)] # node attr dict for each graph
for key in graph_batch.node_attr_schemes():
vals = F.unpack(graph_batch.pop_n_repr(key), graph_batch.num_nodes)
for attr, val in zip(attrs, vals):
attr[key] = val
for attr, g in zip(attrs, graph_list):
g.set_n_repr(attr)
# split and set edge attrs
attrs = [{} for _ in range(num_graphs)] # edge attr dict for each graph
for key in graph_batch.edge_attr_schemes():
vals = F.unpack(graph_batch.pop_e_repr(key), graph_batch.num_edges)
for attr, val in zip(attrs, vals):
attr[key] = val
for attr, g in zip(attrs, graph_list):
g.set_e_repr(attr)
return graph_list
# FIXME (lingfan): Do we really need the batch API?
# Can't we let user call BatchedDGLGraph(graph_list) directly
# and make unbatch a member function of BatchedDGLGraph
def batch(graph_list, node_attrs=None, edge_attrs=None):
"""Batch a list of DGLGraphs into one single graph.
Once batch is called, the structure of both merged graph and graphs in graph_list
must not bbe mutated, or unbatch's behavior will be undefined.
Parameters
----------
graph_list : iterable
A list of DGLGraphs to be batched.
node_attrs : str or iterable
A list of node attributes needed for merged graph
It's user's resposiblity to make sure node_attrs exists
edge_attrs : str or iterable
A list of edge attributes needed for merged graph
It's user's resposiblity to make sure edge_attrs exists
Return
------
newgrh: DGLGraph
one single merged graph
"""
return BatchedDGLGraph(graph_list, node_attrs, edge_attrs)
"""Classes and functions for batching multiple graphs together."""
from __future__ import absolute_import
import numpy as np
from .base import ALL, is_all
from .frame import FrameRef
from .graph import DGLGraph
from . import graph_index as gi
from . import backend as F
from . import utils
__all__ = ['BatchedDGLGraph', 'batch', 'unbatch', 'split']
class BatchedDGLGraph(DGLGraph):
"""The batched DGL graph.
The batched graph is read-only.
Parameters
----------
graph_list : iterable
A list of DGLGraphs to be batched.
node_attrs : str or iterable
The node attributes to also be batched.
edge_attrs : str or iterable, optional
The edge attributes to also be batched.
"""
def __init__(self, graph_list, node_attrs, edge_attrs):
# create batched graph index
batched_index = gi.disjoint_union([g._graph for g in graph_list])
# create batched node and edge frames
# NOTE: following code will materialize the columns of the input graphs.
batched_node_frame = FrameRef()
for gr in graph_list:
cols = {key : gr._node_frame[key] for key in node_attrs}
batched_node_frame.append(cols)
batched_edge_frame = FrameRef()
for gr in graph_list:
cols = {key : gr._edge_frame[key] for key in edge_attrs}
batched_edge_frame.append(cols)
super(BatchedDGLGraph, self).__init__(
graph_data=batched_index,
node_frame=batched_node_frame,
edge_frame=batched_edge_frame)
# extra members
self._batch_size = 0
self._batch_num_nodes = []
self._batch_num_edges = []
for gr in graph_list:
if isinstance(gr, BatchedDGLGraph):
# handle the input is again a batched graph.
self._batch_size += gr._batch_size
self._batch_num_nodes += gr._batch_num_nodes
self._batch_num_edges += gr._batch_num_edges
else:
self._batch_size += 1
self._batch_num_nodes.append(gr.number_of_nodes())
self._batch_num_edges.append(gr.number_of_edges())
@property
def batch_size(self):
"""Number of graphs in this batch."""
return self._batch_size
@property
def batch_num_nodes(self):
"""Number of nodes of each graph in this batch."""
return self._batch_num_nodes
@property
def batch_num_edges(self):
"""Number of edges of each graph in this batch."""
return self._batch_num_edges
# override APIs
def add_nodes(self, num, reprs=None):
"""Add nodes. Disabled because BatchedDGLGraph is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.')
def add_edge(self, u, v, reprs=None):
"""Add one edge. Disabled because BatchedDGLGraph is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.')
def add_edges(self, u, v, reprs=None):
"""Add many edges. Disabled because BatchedDGLGraph is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.')
# new APIs
def __getitem__(self, idx):
"""Slice the batch and return the batch of graphs specified by the idx."""
# TODO
pass
def __setitem__(self, idx, val):
"""Set the value of the slice. The graph size cannot be changed."""
# TODO
pass
def readout(self, reduce_func):
"""Perform readout for each graph in the batch.
The readout value is a tensor of shape (B, D1, D2, ...) where B is the
batch size.
Parameters
----------
reduce_func : callable
The reduce function for readout.
Returns
-------
dict of tensors
The readout values.
"""
# TODO
pass
'''
def query_new_node(self, g, u):
idx = self.graph_idx[g]
offset = self.node_offset[idx]
if isinstance(u, (int, np.array, F.Tensor)):
return u + offset
else:
return np.array(u) + offset
def query_new_edge(self, g, src, dst):
idx = self.graph_idx[g]
offset = self.node_offset[idx]
if isinstance(src, (int, np.ndarray, F.Tensor)) and \
isinstance(dst, (int, np.ndarray, F.Tensor)):
return src + offset, dst + offset
else:
return np.array(src) + offset, np.array(dst) + offset
def query_node_start_offset(self):
return self.node_offset[:-1].copy()
def query_edge_start_offset(self):
return self.edge_offset[:-1].copy()
'''
def split(graph_batch, num_or_size_splits):
"""Split the batch."""
# TODO(minjie): could follow torch.split syntax
pass
def unbatch(graph):
"""Unbatch and return the list of graphs in this batch.
Parameters
----------
graph : BatchedDGLGraph
The batched graph.
"""
assert isinstance(graph, BatchedDGLGraph)
bsize = graph.batch_size
bn = graph.batch_num_nodes
be = graph.batch_num_edges
pttns = gi.disjoint_partition(graph._graph, utils.toindex(bn))
# split the frames
node_frames = [FrameRef() for i in range(bsize)]
edge_frames = [FrameRef() for i in range(bsize)]
for attr, col in graph._node_frame.items():
col_splits = F.unpack(col, bn)
for i in range(bsize):
node_frames[i][attr] = col_splits[i]
for attr, col in graph._edge_frame.items():
col_splits = F.unpack(col, be)
for i in range(bsize):
edge_frames[i][attr] = col_splits[i]
return [DGLGraph(graph_data=pttns[i],
node_frame=node_frames[i],
edge_frame=edge_frames[i]) for i in range(bsize)]
def batch(graph_list, node_attrs=ALL, edge_attrs=ALL):
"""Batch a list of DGLGraphs into one single graph.
Once batch is called, the structure of both merged graph and graphs in graph_list
must not be mutated, or unbatch's behavior will be undefined.
Parameters
----------
graph_list : iterable
A list of DGLGraphs to be batched.
node_attrs : str or iterable, optional
The node attributes to also be batched. Specify None to not batch any attributes.
edge_attrs : str or iterable, optional
The edge attributes to also be batched. Specify None to not batch any attributes.
Returns
-------
newgrh: BatchedDGLGraph
one single batched graph
"""
if node_attrs is None:
node_attrs = []
elif is_all(node_attrs):
node_attrs = graph_list[0].node_attr_schemes()
elif isinstance(node_attrs, str):
node_attrs = [node_attrs]
if edge_attrs is None:
edge_attrs = []
elif is_all(edge_attrs):
edge_attrs = graph_list[0].edge_attr_schemes()
elif isinstance(edge_attrs, str):
edge_attrs = [edge_attrs]
return BatchedDGLGraph(graph_list, node_attrs, edge_attrs)
"""High-performance graph structure query component.
TODO: Currently implemented by igraph. Should replace with more efficient
solution later.
"""
from __future__ import absolute_import
import igraph
import dgl.backend as F
from dgl.backend import Tensor
import dgl.utils as utils
class CachedGraph:
def __init__(self):
self._graph = igraph.Graph(directed=True)
self._freeze = False
def add_nodes(self, num_nodes):
if self._freeze:
raise RuntimeError('Freezed cached graph cannot be mutated.')
self._graph.add_vertices(num_nodes)
def add_edge(self, u, v):
if self._freeze:
raise RuntimeError('Freezed cached graph cannot be mutated.')
self._graph.add_edge(u, v)
def add_edges(self, u, v):
if self._freeze:
raise RuntimeError('Freezed cached graph cannot be mutated.')
# The edge will be assigned ids equal to the order.
uvs = list(utils.edge_iter(u, v))
self._graph.add_edges(uvs)
def get_edge_id(self, u, v):
uvs = list(utils.edge_iter(u, v))
eids = self._graph.get_eids(uvs)
return utils.toindex(eids)
def in_edges(self, v):
"""Get in-edges of the vertices.
Parameters
----------
v : utils.Index
The vertex ids.
Returns
-------
src : utils.Index
The src vertex ids.
dst : utils.Index
The dst vertex ids.
orphan : utils.Index
The vertice that have no in-edges.
"""
src = []
dst = []
orphan = []
for vv in utils.node_iter(v):
uu = self._graph.predecessors(vv)
if len(uu) == 0:
orphan.append(vv)
else:
src += uu
dst += [vv] * len(uu)
src = utils.toindex(src)
dst = utils.toindex(dst)
orphan = utils.toindex(orphan)
return src, dst, orphan
def out_edges(self, u):
"""Get out-edges of the vertices.
Parameters
----------
v : utils.Index
The vertex ids.
Returns
-------
src : utils.Index
The src vertex ids.
dst : utils.Index
The dst vertex ids.
orphan : utils.Index
The vertice that have no out-edges.
"""
src = []
dst = []
orphan = []
for uu in utils.node_iter(u):
vv = self._graph.successors(uu)
if len(vv) == 0:
orphan.append(uu)
else:
src += [uu] * len(vv)
dst += vv
src = utils.toindex(src)
dst = utils.toindex(dst)
orphan = utils.toindex(orphan)
return src, dst, orphan
def in_degrees(self, v):
degs = self._graph.indegree(list(v))
return utils.toindex(degs)
def num_edges(self):
return self._graph.ecount()
@utils.cached_member
def edges(self):
elist = self._graph.get_edgelist()
src = [u for u, _ in elist]
dst = [v for _, v in elist]
src = utils.toindex(src)
dst = utils.toindex(dst)
return src, dst
@utils.cached_member
def adjmat(self):
"""Return a sparse adjacency matrix.
The row dimension represents the dst nodes; the column dimension
represents the src nodes.
"""
elist = self._graph.get_edgelist()
src = F.tensor([u for u, _ in elist], dtype=F.int64)
dst = F.tensor([v for _, v in elist], dtype=F.int64)
src = F.unsqueeze(src, 0)
dst = F.unsqueeze(dst, 0)
idx = F.pack([dst, src])
n = self._graph.vcount()
dat = F.ones((len(elist),))
mat = F.sparse_tensor(idx, dat, [n, n])
return utils.CtxCachedObject(lambda ctx: F.to_context(mat, ctx))
def freeze(self):
self._freeze = True
def create_cached_graph(dglgraph):
cg = CachedGraph()
cg.add_nodes(dglgraph.number_of_nodes())
cg._graph.add_edges(dglgraph.edge_list)
cg.freeze()
return cg
"""DGL's device context shim."""
class Context(object):
def __init__(self, dev, devid=-1):
self.device = dev
self.device_id = devid
def __str__(self):
return '{}:{}'.format(self.device, self.device_id)
def __eq__(self, other):
return self.device == other.device and self.device_id == other.device_id
def __hash__(self):
return hash((self.device, self.device_id))
def gpu(gpuid):
return Context('gpu', gpuid)
def cpu():
return Context('cpu')
...@@ -4,6 +4,7 @@ from __future__ import absolute_import ...@@ -4,6 +4,7 @@ from __future__ import absolute_import
from . import citation_graph as citegrh from . import citation_graph as citegrh
from .tree import * from .tree import *
from .utils import * from .utils import *
from .sbm import SBMMixture
def register_data_args(parser): def register_data_args(parser):
parser.add_argument("--dataset", type=str, required=True, parser.add_argument("--dataset", type=str, required=True,
......
...@@ -11,7 +11,7 @@ import networkx as nx ...@@ -11,7 +11,7 @@ import networkx as nx
import scipy.sparse as sp import scipy.sparse as sp
import os, sys import os, sys
from dgl.data.utils import download, extract_archive, get_download_dir from .utils import download, extract_archive, get_download_dir
_urls = { _urls = {
'cora' : 'https://www.dropbox.com/s/3ggdpkj7ou8svoc/cora.zip?dl=1', 'cora' : 'https://www.dropbox.com/s/3ggdpkj7ou8svoc/cora.zip?dl=1',
......
import math
import os
import pickle
import numpy as np
import numpy.random as npr
import scipy as sp
import networkx as nx
from torch.utils.data import Dataset
from .. import backend as F
from ..batched_graph import batch
from ..graph import DGLGraph
from ..utils import Index
def sbm(n_blocks, block_size, p, q, rng=None):
""" (Symmetric) Stochastic Block Model
Parameters
----------
n_blocks : int
Number of blocks.
block_size : int
Block size.
p : float
Probability for intra-community edge.
q : float
Probability for inter-community edge.
Returns
-------
scipy sparse matrix
The adjacency matrix of generated graph.
"""
n = n_blocks * block_size
p /= n
q /= n
rng = np.random.RandomState() if rng is None else rng
rows = []
cols = []
for i in range(n_blocks):
for j in range(i, n_blocks):
density = p if i == j else q
block = sp.sparse.random(block_size, block_size, density,
random_state=rng, data_rvs=lambda n: np.ones(n))
rows.append(block.row + i * block_size)
cols.append(block.col + j * block_size)
rows = np.hstack(rows)
cols = np.hstack(cols)
a = sp.sparse.coo_matrix((np.ones(rows.shape[0]), (rows, cols)), shape=(n, n))
adj = sp.sparse.triu(a) + sp.sparse.triu(a, 1).transpose()
return adj
class SBMMixture(Dataset):
""" Symmetric Stochastic Block Model Mixture
Please refer to Appendix C of "Supervised Community Detection with Hierarchical Graph Neural Networks" (https://arxiv.org/abs/1705.08415) for details.
Parameters
----------
n_graphs : int
Number of graphs.
n_nodes : int
Number of nodes.
n_communities : int
Number of communities.
k : int, optional
Multiplier.
avg_deg : int, optional
Average degree.
p : callable or str, optional
Random density generator.
rng : numpy.random.RandomState, optional
Random number generator.
"""
def __init__(self, n_graphs, n_nodes, n_communities,
k=2, avg_deg=3, p='Appendix C', rng=None):
super(SBMMixture, self).__init__()
self._n_nodes = n_nodes
assert n_nodes % n_communities == 0
block_size = n_nodes // n_communities
if type(p) is str:
p = {'Appendix C' : self._appendix_c}[p]
self._k = k
self._avg_deg = avg_deg
self._gs = [DGLGraph() for i in range(n_graphs)]
adjs = [sbm(n_communities, block_size, *p()) for i in range(n_graphs)]
for g, adj in zip(self._gs, adjs):
g.from_scipy_sparse_matrix(adj)
self._lgs = [g.line_graph() for g in self._gs]
in_degrees = lambda g: g.in_degrees(Index(F.arange(g.number_of_nodes(),
dtype=F.int64))).unsqueeze(1).float()
self._g_degs = [in_degrees(g) for g in self._gs]
self._lg_degs = [in_degrees(lg) for lg in self._lgs]
self._eid2nids = list(zip(*[g.edges(sorted=True) for g in self._gs]))[0]
def __len__(self):
return len(self._gs)
def __getitem__(self, idx):
return self._gs[idx], self._lgs[idx], \
self._g_degs[idx], self._lg_degs[idx], self._eid2nids[idx]
def _appendix_c(self):
q = npr.uniform(0, self._avg_deg - math.sqrt(self._avg_deg))
p = self._k * self._avg_deg - q
return p, q
def collate_fn(self, x):
g, lg, deg_g, deg_lg, eid2nid = zip(*x)
g_batch = batch(g)
lg_batch = batch(lg)
degg_batch = F.pack(deg_g)
deglg_batch = F.pack(deg_lg)
eid2nid_batch = F.pack([x + i * self._n_nodes for i, x in enumerate(eid2nid)])
return g_batch, lg_batch, degg_batch, deglg_batch, eid2nid_batch
...@@ -10,9 +10,9 @@ from nltk.tree import Tree ...@@ -10,9 +10,9 @@ from nltk.tree import Tree
from nltk.corpus.reader import BracketParseCorpusReader from nltk.corpus.reader import BracketParseCorpusReader
import networkx as nx import networkx as nx
import dgl from .. import backend as F
import dgl.backend as F from ..graph import DGLGraph
from dgl.data.utils import download, extract_archive, get_download_dir from .utils import download, extract_archive, get_download_dir
_urls = { _urls = {
'sst' : 'https://www.dropbox.com/s/dw8kr2vuq7k4dqi/sst.zip?dl=1', 'sst' : 'https://www.dropbox.com/s/dw8kr2vuq7k4dqi/sst.zip?dl=1',
...@@ -66,7 +66,9 @@ class SST(object): ...@@ -66,7 +66,9 @@ class SST(object):
# add root # add root
g.add_node(0, x=SST.PAD_WORD, y=int(root.label())) g.add_node(0, x=SST.PAD_WORD, y=int(root.label()))
_rec_build(0, root) _rec_build(0, root)
return dgl.DGLGraph(g) ret = DGLGraph()
ret.from_networkx(g, node_attrs=['x', 'y'])
return ret
def __getitem__(self, idx): def __getitem__(self, idx):
return self.trees[idx] return self.trees[idx]
...@@ -77,22 +79,3 @@ class SST(object): ...@@ -77,22 +79,3 @@ class SST(object):
@property @property
def num_vocabs(self): def num_vocabs(self):
return len(self.vocab) return len(self.vocab)
@staticmethod
def batcher(batch):
nid_with_word = []
wordid = []
label = []
gnid = 0
for tree in batch:
for nid in range(tree.number_of_nodes()):
if tree.nodes[nid]['x'] != SST.PAD_WORD:
nid_with_word.append(gnid)
wordid.append(tree.nodes[nid]['x'])
label.append(tree.nodes[nid]['y'])
gnid += 1
batch_trees = dgl.batch(batch)
return SSTBatch(graph=batch_trees,
nid_with_word=F.tensor(nid_with_word, dtype=F.int64),
wordid=F.tensor(wordid, dtype=F.int64),
label=F.tensor(label, dtype=F.int64))
...@@ -4,9 +4,9 @@ from __future__ import absolute_import ...@@ -4,9 +4,9 @@ from __future__ import absolute_import
from collections import MutableMapping from collections import MutableMapping
import numpy as np import numpy as np
import dgl.backend as F from . import backend as F
from dgl.backend import Tensor from .backend import Tensor
import dgl.utils as utils from . import utils
class Frame(MutableMapping): class Frame(MutableMapping):
def __init__(self, data=None): def __init__(self, data=None):
...@@ -123,7 +123,7 @@ class FrameRef(MutableMapping): ...@@ -123,7 +123,7 @@ class FrameRef(MutableMapping):
def select_rows(self, query): def select_rows(self, query):
rowids = self._getrowid(query) rowids = self._getrowid(query)
def _lazy_select(key): def _lazy_select(key):
idx = rowids.totensor(F.get_context(self._frame[key])) idx = rowids.tousertensor(F.get_context(self._frame[key]))
return F.gather_row(self._frame[key], idx) return F.gather_row(self._frame[key], idx)
return utils.LazyDict(_lazy_select, keys=self.schemes) return utils.LazyDict(_lazy_select, keys=self.schemes)
...@@ -132,7 +132,7 @@ class FrameRef(MutableMapping): ...@@ -132,7 +132,7 @@ class FrameRef(MutableMapping):
if self.is_span_whole_column(): if self.is_span_whole_column():
return col return col
else: else:
idx = self.index().totensor(F.get_context(col)) idx = self.index().tousertensor(F.get_context(col))
return F.gather_row(col, idx) return F.gather_row(col, idx)
def __setitem__(self, key, val): def __setitem__(self, key, val):
...@@ -141,7 +141,7 @@ class FrameRef(MutableMapping): ...@@ -141,7 +141,7 @@ class FrameRef(MutableMapping):
else: else:
self.update_rows(key, val) self.update_rows(key, val)
def add_column(self, name, col): def add_column(self, name, col, inplace=False):
shp = F.shape(col) shp = F.shape(col)
if self.is_span_whole_column(): if self.is_span_whole_column():
if self.num_columns == 0: if self.num_columns == 0:
...@@ -156,19 +156,26 @@ class FrameRef(MutableMapping): ...@@ -156,19 +156,26 @@ class FrameRef(MutableMapping):
else: else:
fcol = F.zeros((self._frame.num_rows,) + shp[1:]) fcol = F.zeros((self._frame.num_rows,) + shp[1:])
fcol = F.to_context(fcol, colctx) fcol = F.to_context(fcol, colctx)
idx = self.index().totensor(colctx) idx = self.index().tousertensor(colctx)
newfcol = F.scatter_row(fcol, idx, col) if inplace:
self._frame[name] = newfcol self._frame[name] = fcol
self._frame[name][idx] = col
else:
newfcol = F.scatter_row(fcol, idx, col)
self._frame[name] = newfcol
def update_rows(self, query, other): def update_rows(self, query, other, inplace=False):
rowids = self._getrowid(query) rowids = self._getrowid(query)
for key, col in other.items(): for key, col in other.items():
if key not in self: if key not in self:
# add new column # add new column
tmpref = FrameRef(self._frame, rowids) tmpref = FrameRef(self._frame, rowids)
tmpref.add_column(key, col) tmpref.add_column(key, col, inplace)
idx = rowids.totensor(F.get_context(self._frame[key])) idx = rowids.tousertensor(F.get_context(self._frame[key]))
self._frame[key] = F.scatter_row(self._frame[key], idx, col) if inplace:
self._frame[key][idx] = col
else:
self._frame[key] = F.scatter_row(self._frame[key], idx, col)
def __delitem__(self, key): def __delitem__(self, key):
if isinstance(key, str): if isinstance(key, str):
...@@ -223,8 +230,8 @@ class FrameRef(MutableMapping): ...@@ -223,8 +230,8 @@ class FrameRef(MutableMapping):
# shortcut for identical mapping # shortcut for identical mapping
return query return query
else: else:
idxtensor = self.index().totensor() idxtensor = self.index().tousertensor()
return utils.toindex(F.gather_row(idxtensor, query.totensor())) return utils.toindex(F.gather_row(idxtensor, query.tousertensor()))
def index(self): def index(self):
if self._index is None: if self._index is None:
......
from .message import * """DGL builtin functors"""
from __future__ import absolute_import
from .message import *
from .reducer import * from .reducer import *
...@@ -2,9 +2,11 @@ ...@@ -2,9 +2,11 @@
from __future__ import absolute_import from __future__ import absolute_import
import operator import operator
import dgl.backend as F
__all__ = ["MessageFunction", "src_mul_edge", "copy_src", "copy_edge"] __all__ = ["MessageFunction", "src_mul_edge", "copy_src", "copy_edge"]
class MessageFunction(object): class MessageFunction(object):
def __call__(self, src, edge): def __call__(self, src, edge):
raise NotImplementedError raise NotImplementedError
...@@ -12,10 +14,28 @@ class MessageFunction(object): ...@@ -12,10 +14,28 @@ class MessageFunction(object):
def name(self): def name(self):
raise NotImplementedError raise NotImplementedError
def is_spmv_supported(self, g):
raise NotImplementedError
class BundledMessageFunction(MessageFunction): class BundledMessageFunction(MessageFunction):
def __init__(self, fn_list): def __init__(self, fn_list):
if not isinstance(fn_list, (list, tuple)):
fn_list = [fn_list]
else:
# sanity check on out field
for fn in fn_list:
# cannot perform check for udf
if isinstance(fn, MessageFunction) and fn.out_field is None:
raise RuntimeError("Not specifying out field for multiple message is ambiguous")
self.fn_list = fn_list self.fn_list = fn_list
def is_spmv_supported(self, g):
for fn in self.fn_list:
if not isinstance(fn, MessageFunction) or not fn.is_spmv_supported(g):
return False
return True
def __call__(self, src, edge): def __call__(self, src, edge):
ret = None ret = None
for fn in self.fn_list: for fn in self.fn_list:
...@@ -24,16 +44,34 @@ class BundledMessageFunction(MessageFunction): ...@@ -24,16 +44,34 @@ class BundledMessageFunction(MessageFunction):
ret = msg ret = msg
else: else:
try: try:
# ret and msg must be dict
ret.update(msg) ret.update(msg)
except e: except:
raise RuntimeError("Failed to merge results of two builtin" raise RuntimeError("Must specify out field for multiple message")
" message functions. Please specify out_field"
" for the builtin message function.")
return ret return ret
def name(self): def name(self):
return "bundled" return "bundled"
def _is_spmv_supported_node_feat(g, field):
if field is None:
feat = g.get_n_repr()
else:
feat = g.get_n_repr()[field]
shape = F.shape(feat)
return len(shape) == 1 or len(shape) == 2
def _is_spmv_supported_edge_feat(g, field):
# check shape, only scalar edge feature can be optimized at the moment
if field is None:
feat = g.get_e_repr()
else:
feat = g.get_e_repr()[field]
shape = F.shape(feat)
return len(shape) == 1 or (len(shape) == 2 and shape[1] == 1)
class SrcMulEdgeMessageFunction(MessageFunction): class SrcMulEdgeMessageFunction(MessageFunction):
def __init__(self, mul_op, src_field=None, edge_field=None, out_field=None): def __init__(self, mul_op, src_field=None, edge_field=None, out_field=None):
self.mul_op = mul_op self.mul_op = mul_op
...@@ -41,6 +79,10 @@ class SrcMulEdgeMessageFunction(MessageFunction): ...@@ -41,6 +79,10 @@ class SrcMulEdgeMessageFunction(MessageFunction):
self.edge_field = edge_field self.edge_field = edge_field
self.out_field = out_field self.out_field = out_field
def is_spmv_supported(self, g):
return _is_spmv_supported_node_feat(g, self.src_field) \
and _is_spmv_supported_edge_feat(g, self.edge_field)
def __call__(self, src, edge): def __call__(self, src, edge):
if self.src_field is not None: if self.src_field is not None:
src = src[self.src_field] src = src[self.src_field]
...@@ -60,6 +102,9 @@ class CopySrcMessageFunction(MessageFunction): ...@@ -60,6 +102,9 @@ class CopySrcMessageFunction(MessageFunction):
self.src_field = src_field self.src_field = src_field
self.out_field = out_field self.out_field = out_field
def is_spmv_supported(self, g):
return _is_spmv_supported_node_feat(g, self.src_field)
def __call__(self, src, edge): def __call__(self, src, edge):
if self.src_field is not None: if self.src_field is not None:
ret = src[self.src_field] ret = src[self.src_field]
...@@ -78,6 +123,11 @@ class CopyEdgeMessageFunction(MessageFunction): ...@@ -78,6 +123,11 @@ class CopyEdgeMessageFunction(MessageFunction):
self.edge_field = edge_field self.edge_field = edge_field
self.out_field = out_field self.out_field = out_field
def is_spmv_supported(self, g):
# TODO: support this with g-spmv
return False
# return _is_spmv_supported_edge_feat(g, self.edge_field)
def __call__(self, src, edge): def __call__(self, src, edge):
if self.edge_field is not None: if self.edge_field is not None:
ret = edge[self.edge_field] ret = edge[self.edge_field]
...@@ -90,7 +140,8 @@ class CopyEdgeMessageFunction(MessageFunction): ...@@ -90,7 +140,8 @@ class CopyEdgeMessageFunction(MessageFunction):
def name(self): def name(self):
return "copy_edge" return "copy_edge"
def src_mul_edge(src=None, edge=None, out=None): def src_mul_edge(src=None, edge=None, out=None):
"""TODO(minjie): docstring """ """TODO(minjie): docstring """
return SrcMulEdgeMessageFunction(operator.mul, src, edge, out) return SrcMulEdgeMessageFunction(operator.mul, src, edge, out)
......
"""Built-in reducer function.""" """Built-in reducer function."""
from __future__ import absolute_import from __future__ import absolute_import
import dgl.backend as F from .. import backend as F
__all__ = ["ReduceFunction", "sum", "max"] __all__ = ["ReduceFunction", "sum", "max"]
...@@ -12,10 +12,26 @@ class ReduceFunction(object): ...@@ -12,10 +12,26 @@ class ReduceFunction(object):
def name(self): def name(self):
raise NotImplementedError raise NotImplementedError
def is_spmv_supported(self):
raise NotImplementedError
class BundledReduceFunction(ReduceFunction): class BundledReduceFunction(ReduceFunction):
def __init__(self, fn_list): def __init__(self, fn_list):
if not isinstance(fn_list, (list, tuple)):
fn_list = [fn_list]
else:
# sanity check on out field
for fn in fn_list:
if isinstance(fn, ReduceFunction) and fn.out_field is None:
raise RuntimeError("Not specifying out field for multiple reduce is ambiguous")
self.fn_list = fn_list self.fn_list = fn_list
def is_spmv_supported(self):
for fn in self.fn_list:
if not isinstance(fn, ReduceFunction) or not fn.is_spmv_supported():
return False
return True
def __call__(self, node, msgs): def __call__(self, node, msgs):
ret = None ret = None
for fn in self.fn_list: for fn in self.fn_list:
...@@ -24,46 +40,50 @@ class BundledReduceFunction(ReduceFunction): ...@@ -24,46 +40,50 @@ class BundledReduceFunction(ReduceFunction):
ret = rpr ret = rpr
else: else:
try: try:
# ret and rpr must be dict
ret.update(rpr) ret.update(rpr)
except e: except:
raise RuntimeError("Failed to merge results of two builtin" raise RuntimeError("Must specify out field for multiple reudce")
" reduce functions. Please specify out_field"
" for the builtin reduce function.")
return ret return ret
def name(self): def name(self):
return "bundled" return "bundled"
class SumReducerFunction(ReduceFunction): class ReducerFunctionTemplate(ReduceFunction):
def __init__(self, batch_sum_op, nonbatch_sum_op, msg_field=None, out_field=None): def __init__(self, name, batch_op, nonbatch_op, msg_field=None, out_field=None):
self.batch_sum_op = batch_sum_op self.name = name
self.nonbatch_sum_op = nonbatch_sum_op self.batch_op = batch_op
self.nonbatch_op = nonbatch_op
self.msg_field = msg_field self.msg_field = msg_field
self.out_field = out_field self.out_field = out_field
def is_spmv_supported(self):
# TODO: support max
return self.name == "sum"
def __call__(self, node, msgs): def __call__(self, node, msgs):
if isinstance(msgs, list): if isinstance(msgs, list):
if self.msg_field is None: if self.msg_field is None:
ret = self.nonbatch_sum_op(msgs) ret = self.nonbatch_op(msgs)
else: else:
ret = self.nonbatch_sum_op([msg[self.msg_field] for msg in msgs]) ret = self.nonbatch_op([msg[self.msg_field] for msg in msgs])
else: else:
if self.msg_field is None: if self.msg_field is None:
ret = self.batch_sum_op(msgs, 1) ret = self.batch_op(msgs, 1)
else: else:
ret = self.batch_sum_op(msgs[self.msg_field], 1) ret = self.batch_op(msgs[self.msg_field], 1)
if self.out_field is None: if self.out_field is None:
return ret return ret
else: else:
return {self.out_field : ret} return {self.out_field : ret}
def name(self): def name(self):
return "sum" return self.name
_python_sum = sum _python_sum = sum
def sum(msgs=None, out=None): def sum(msgs=None, out=None):
return SumReducerFunction(F.sum, _python_sum, msgs, out) return ReducerFunctionTemplate("sum", F.sum, _python_sum, msgs, out)
_python_max = max _python_max = max
def max(msgs=None, out=None): def max(msgs=None, out=None):
return SumReducerFunction(F.max, _python_max, msgs, out) return ReducerFunctionTemplate("max", F.max, _python_max, msgs, out)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment