Unverified Commit a1d50f0f authored by Lingfan Yu's avatar Lingfan Yu Committed by GitHub
Browse files

[Refactor] Rename before release (#261)

* include/dgl/runtime

* include

* src/runtime

* src/graph

* src/scheduler

* src

* clean up CMakeLists

* further clean up in cmake

* install commands

* python/dgl/_ffi/_cython

* python/dgl/_ffi/_ctypes

* python/dgl/_ffi

* python/dgl

* some fix

* copy right
parent aabba9d4
...@@ -4,9 +4,6 @@ ...@@ -4,9 +4,6 @@
cmake_minimum_required(VERSION 2.8) cmake_minimum_required(VERSION 2.8)
project(dgl C CXX) project(dgl C CXX)
# Utility functions
include(cmake/util/Util.cmake)
if(EXISTS ${CMAKE_CURRENT_BINARY_DIR}/config.cmake) if(EXISTS ${CMAKE_CURRENT_BINARY_DIR}/config.cmake)
include(${CMAKE_CURRENT_BINARY_DIR}/config.cmake) include(${CMAKE_CURRENT_BINARY_DIR}/config.cmake)
else() else()
...@@ -15,37 +12,6 @@ else() ...@@ -15,37 +12,6 @@ else()
endif() endif()
endif() endif()
# NOTE: do not modify this file to change option values.
# You can create a config.cmake at build folder
# and add set(OPTION VALUE) to override these build options.
# Alernatively, use cmake -DOPTION=VALUE through command-line.
#tvm_option(USE_CUDA "Build with CUDA" OFF)
#tvm_option(USE_OPENCL "Build with OpenCL" OFF)
#tvm_option(USE_VULKAN "Build with Vulkan" OFF)
#tvm_option(USE_OPENGL "Build with OpenGL" OFF)
#tvm_option(USE_METAL "Build with Metal" OFF)
#tvm_option(USE_ROCM "Build with ROCM" OFF)
#tvm_option(ROCM_PATH "The path to rocm" /opt/rocm)
#tvm_option(USE_RPC "Build with RPC" ON)
#tvm_option(USE_LLVM "Build with LLVM, can be set to specific llvm-config path" OFF)
#tvm_option(USE_STACKVM_RUNTIME "Include stackvm into the runtime" OFF)
#tvm_option(USE_GRAPH_RUNTIME "Build with tiny graph runtime" ON)
#tvm_option(USE_GRAPH_RUNTIME_DEBUG "Build with tiny graph runtime debug mode" OFF)
#tvm_option(USE_RTTI "Build with RTTI" ON)
#tvm_option(USE_MSVC_MT "Build with MT" OFF)
#tvm_option(INSTALL_DEV "Install compiler infrastructure" OFF)
# Contrib library options
#tvm_option(USE_BLAS "The blas library to be linked" none)
#tvm_option(USE_MKL_PATH "MKL root path when use MKL blas" none)
#tvm_option(USE_CUDNN "Build with cuDNN" OFF)
#tvm_option(USE_CUBLAS "Build with cuBLAS" OFF)
#tvm_option(USE_MIOPEN "Build with ROCM:MIOpen" OFF)
#tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF)
#tvm_option(USE_SORT "Build with sort support" OFF)
#tvm_option(USE_NNPACK "Build with nnpack support" OFF)
#tvm_option(USE_RANDOM "Build with random support" OFF)
# include directories # include directories
include_directories("include") include_directories("include")
include_directories("third_party/dlpack/include") include_directories("third_party/dlpack/include")
...@@ -83,140 +49,15 @@ else(MSVC) ...@@ -83,140 +49,15 @@ else(MSVC)
endif() endif()
endif(MSVC) endif(MSVC)
# add source group
#FILE(GLOB_RECURSE GROUP_SOURCE "src/*.cc")
#FILE(GLOB_RECURSE GROUP_INCLUDE "src/*.h" "include/*.h")
#assign_source_group("Source" ${GROUP_SOURCE})
#assign_source_group("Include" ${GROUP_INCLUDE})
# Source file lists # Source file lists
file(GLOB CORE_SRCS src/graph/*.cc src/*.cc src/scheduler/*.cc) file(GLOB CORE_SRCS src/graph/*.cc src/*.cc src/scheduler/*.cc)
file(GLOB RUNTIME_SRCS src/runtime/*.cc) file(GLOB RUNTIME_SRCS src/runtime/*.cc)
# Package runtime rules
#if(NOT USE_RTTI)
# add_definitions(-DDMLC_ENABLE_RTTI=0)
#endif()
#
#if(USE_RPC)
# message(STATUS "Build with RPC support...")
# file(GLOB RUNTIME_RPC_SRCS src/runtime/rpc/*.cc)
# list(APPEND RUNTIME_SRCS ${RUNTIME_RPC_SRCS})
#endif(USE_RPC)
#
#file(GLOB STACKVM_RUNTIME_SRCS src/runtime/stackvm/*.cc)
#file(GLOB STACKVM_CODEGEN_SRCS src/codegen/stackvm/*.cc)
#list(APPEND COMPILER_SRCS ${STACKVM_CODEGEN_SRCS})
#if(USE_STACKVM_RUNTIME)
# message(STATUS "Build with stackvm support in runtime...")
# list(APPEND RUNTIME_SRCS ${STACKVM_RUNTIME_SRCS})
#else()
# list(APPEND COMPILER_SRCS ${STACKVM_RUNTIME_SRCS})
#endif(USE_STACKVM_RUNTIME)
#
#if(USE_GRAPH_RUNTIME)
# message(STATUS "Build with Graph runtime support...")
# file(GLOB RUNTIME_GRAPH_SRCS src/runtime/graph/*.cc)
# list(APPEND RUNTIME_SRCS ${RUNTIME_GRAPH_SRCS})
#
# if(USE_GRAPH_RUNTIME_DEBUG)
# set_source_files_properties(${RUNTIME_GRAPH_SRCS}
# PROPERTIES COMPILE_DEFINITIONS "TVM_GRAPH_RUNTIME_DEBUG")
# endif(USE_GRAPH_RUNTIME_DEBUG)
#endif(USE_GRAPH_RUNTIME)
# Module rules
#include(cmake/modules/VTA.cmake)
#include(cmake/modules/CUDA.cmake)
#include(cmake/modules/OpenCL.cmake)
#include(cmake/modules/OpenGL.cmake)
#include(cmake/modules/Vulkan.cmake)
#include(cmake/modules/Metal.cmake)
#include(cmake/modules/ROCM.cmake)
#include(cmake/modules/LLVM.cmake)
#include(cmake/modules/contrib/BLAS.cmake)
#include(cmake/modules/contrib/Random.cmake)
#include(cmake/modules/contrib/Sort.cmake)
#include(cmake/modules/contrib/NNPack.cmake)
add_library(dgl SHARED ${CORE_SRCS} ${RUNTIME_SRCS}) add_library(dgl SHARED ${CORE_SRCS} ${RUNTIME_SRCS})
#add_library(dgl_runtime SHARED ${RUNTIME_SRCS})
target_link_libraries(dgl ${DGL_LINKER_LIBS} ${DGL_RUNTIME_LINKER_LIBS}) target_link_libraries(dgl ${DGL_LINKER_LIBS} ${DGL_RUNTIME_LINKER_LIBS})
#target_link_libraries(dgl_runtime ${DGL_RUNTIME_LINKER_LIBS})
# Related headers
#target_include_directories(
# dgl
# PUBLIC "HalideIR/src"
# PUBLIC "topi/include")
# Tests
#set(TEST_EXECS "")
#file(GLOB TEST_SRCS tests/cpp/*.cc)
#find_library(GTEST_LIB gtest)
#if(GTEST_LIB)
# foreach(__srcpath ${TEST_SRCS})
# get_filename_component(__srcname ${__srcpath} NAME)
# string(REPLACE ".cc" "" __execname ${__srcname})
# add_executable(${__execname} ${__srcpath})
# list(APPEND TEST_EXECS ${__execname})
# target_link_libraries(${__execname}
# tvm ${GTEST_LIB} pthread)
# set_target_properties(${__execname} PROPERTIES EXCLUDE_FROM_ALL 1)
# set_target_properties(${__execname} PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD 1)
# endforeach()
# add_custom_target(cpptest DEPENDS ${TEST_EXECS})
#endif()
# Custom targets
#add_custom_target(runtime DEPENDS tvm_runtime)
# Installation rules # Installation rules
install(TARGETS dgl DESTINATION lib${LIB_SUFFIX}) install(TARGETS dgl DESTINATION lib${LIB_SUFFIX})
#install(TARGETS dgl_runtime DESTINATION lib${LIB_SUFFIX})
#if (INSTALL_DEV)
# install(
# DIRECTORY "include/." DESTINATION "include"
# FILES_MATCHING
# PATTERN "*.h"
# )
# install(
# DIRECTORY "topi/include/." DESTINATION "include"
# FILES_MATCHING
# PATTERN "*.h"
# )
# install(
# DIRECTORY "HalideIR/src/." DESTINATION "include/HalideIR"
# FILES_MATCHING
# PATTERN "*.h"
# )
# install(
# DIRECTORY "dlpack/include/." DESTINATION "include"
# FILES_MATCHING
# PATTERN "*.h"
# )
# install(
# DIRECTORY "nnvm/include/." DESTINATION "include"
# FILES_MATCHING
# PATTERN "*.h"
# )
#else(INSTALL_DEV)
# install(
# DIRECTORY "include/tvm/runtime/." DESTINATION "include/tvm/runtime"
# FILES_MATCHING
# PATTERN "*.h"
# )
#endif(INSTALL_DEV)
# More target definitions
#if(MSVC)
# target_compile_definitions(tvm PRIVATE -DHalide_EXPORTS)
# target_compile_definitions(tvm_runtime PRIVATE -DHalide_EXPORTS)
# target_compile_definitions(tvm PRIVATE -DTVM_EXPORTS)
# target_compile_definitions(tvm_runtime PRIVATE -DTVM_EXPORTS)
# target_compile_definitions(nnvm_compiler PRIVATE -DNNVM_EXPORTS)
#endif()
########################################
# Borrowed and adapted from TVM project
########################################
macro(__dgl_option variable description value)
if(NOT DEFINED ${variable})
set(${variable} ${value} CACHE STRING ${description})
endif()
endmacro()
#######################################################
# An option that the user can select. Can accept condition to control when option is available for user.
# Usage:
# dgl_option(<option_variable> "doc string" <initial value or boolean expression> [IF <condition>])
macro(dgl_option variable description value)
set(__value ${value})
set(__condition "")
set(__varname "__value")
foreach(arg ${ARGN})
if(arg STREQUAL "IF" OR arg STREQUAL "if")
set(__varname "__condition")
else()
list(APPEND ${__varname} ${arg})
endif()
endforeach()
unset(__varname)
if("${__condition}" STREQUAL "")
set(__condition 2 GREATER 1)
endif()
if(${__condition})
if("${__value}" MATCHES ";")
if(${__value})
__dgl_option(${variable} "${description}" ON)
else()
__dgl_option(${variable} "${description}" OFF)
endif()
elseif(DEFINED ${__value})
if(${__value})
__dgl_option(${variable} "${description}" ON)
else()
__dgl_option(${variable} "${description}" OFF)
endif()
else()
__dgl_option(${variable} "${description}" "${__value}")
endif()
else()
unset(${variable} CACHE)
endif()
endmacro()
function(assign_source_group group)
foreach(_source IN ITEMS ${ARGN})
if (IS_ABSOLUTE "${_source}")
file(RELATIVE_PATH _source_rel "${CMAKE_CURRENT_SOURCE_DIR}" "${_source}")
else()
set(_source_rel "${_source}")
endif()
get_filename_component(_source_path "${_source_rel}" PATH)
string(REPLACE "/" "\\" _source_path_msvc "${_source_path}")
source_group("${group}\\${_source_path_msvc}" FILES "${_source}")
endforeach()
endfunction(assign_source_group)
...@@ -2,7 +2,7 @@ git submodule init ...@@ -2,7 +2,7 @@ git submodule init
git submodule update git submodule update
md build md build
cd build cd build
cmake -DCMAKE_CXX_FLAGS="-DDMLC_LOG_STACK_TRACE=0 -DTVM_EXPORTS" -DCMAKE_MAKE_PROGRAM=mingw32-make .. -G "MSYS Makefiles" cmake -DCMAKE_CXX_FLAGS="-DDMLC_LOG_STACK_TRACE=0 -DDGL_EXPORTS" -DCMAKE_MAKE_PROGRAM=mingw32-make .. -G "MSYS Makefiles"
if errorlevel 1 exit 1 if errorlevel 1 exit 1
mingw32-make mingw32-make
if errorlevel 1 exit 1 if errorlevel 1 exit 1
......
...@@ -166,7 +166,7 @@ Then build the shared library and install the Python binding: ...@@ -166,7 +166,7 @@ Then build the shared library and install the Python binding:
md build md build
cd build cd build
cmake -DCMAKE_CXX_FLAGS="-DDMLC_LOG_STACK_TRACE=0 -DTVM_EXPORTS" -DCMAKE_MAKE_PROGRAM=mingw32-make .. -G "MSYS Makefiles" cmake -DCMAKE_CXX_FLAGS="-DDMLC_LOG_STACK_TRACE=0 -DDGL_EXPORTS" -DCMAKE_MAKE_PROGRAM=mingw32-make .. -G "MSYS Makefiles"
mingw32-make mingw32-make
cd ..\python cd ..\python
python setup.py install python setup.py install
...@@ -15,10 +15,10 @@ ...@@ -15,10 +15,10 @@
namespace dgl { namespace dgl {
typedef uint64_t dgl_id_t; typedef uint64_t dgl_id_t;
typedef tvm::runtime::NDArray IdArray; typedef dgl::runtime::NDArray IdArray;
typedef tvm::runtime::NDArray DegreeArray; typedef dgl::runtime::NDArray DegreeArray;
typedef tvm::runtime::NDArray BoolArray; typedef dgl::runtime::NDArray BoolArray;
typedef tvm::runtime::NDArray IntArray; typedef dgl::runtime::NDArray IntArray;
class Graph; class Graph;
class GraphOp; class GraphOp;
...@@ -386,12 +386,12 @@ class Graph { ...@@ -386,12 +386,12 @@ class Graph {
struct Subgraph { struct Subgraph {
/*! \brief The graph. */ /*! \brief The graph. */
Graph graph; Graph graph;
/*! /*!
* \brief The induced vertex ids. * \brief The induced vertex ids.
* \note This is also a map from the new vertex id to the vertex id in the parent graph. * \note This is also a map from the new vertex id to the vertex id in the parent graph.
*/ */
IdArray induced_vertices; IdArray induced_vertices;
/*! /*!
* \brief The induced edge ids. * \brief The induced edge ids.
* \note This is also a map from the new edge id to the edge id in the parent graph. * \note This is also a map from the new edge id to the edge id in the parent graph.
*/ */
......
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file tvm/runtime/c_backend_api.h * \file dgl/runtime/c_backend_api.h
* \brief TVM runtime backend API. * \brief DGL runtime backend API.
* *
* The functions defined in this header are intended to be * The functions defined in this header are intended to be
* used by compiled tvm operators, usually user do not need to use these * used by compiled dgl operators, usually user do not need to use these
* function directly. * function directly.
*/ */
#ifndef DGL_RUNTIME_C_BACKEND_API_H_ #ifndef DGL_RUNTIME_C_BACKEND_API_H_
...@@ -20,16 +20,16 @@ extern "C" { ...@@ -20,16 +20,16 @@ extern "C" {
/*! /*!
* \brief Backend function for modules to get function * \brief Backend function for modules to get function
* from its environment mod_node (its imports and global function). * from its environment mod_node (its imports and global function).
* The user do should not call TVMFuncFree on func. * The user do should not call DGLFuncFree on func.
* *
* \param mod_node The module handle. * \param mod_node The module handle.
* \param func_name The name of the function. * \param func_name The name of the function.
* \param out The result function. * \param out The result function.
* \return 0 when no error is thrown, -1 when failure happens * \return 0 when no error is thrown, -1 when failure happens
*/ */
TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node, DGL_DLL int DGLBackendGetFuncFromEnv(void* mod_node,
const char* func_name, const char* func_name,
TVMFunctionHandle *out); DGLFunctionHandle *out);
/*! /*!
* \brief Backend function to register system-wide library symbol. * \brief Backend function to register system-wide library symbol.
* *
...@@ -37,7 +37,7 @@ TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node, ...@@ -37,7 +37,7 @@ TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node,
* \param ptr The symbol address. * \param ptr The symbol address.
* \return 0 when no error is thrown, -1 when failure happens * \return 0 when no error is thrown, -1 when failure happens
*/ */
TVM_DLL int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr); DGL_DLL int DGLBackendRegisterSystemLibSymbol(const char* name, void* ptr);
/*! /*!
* \brief Backend function to allocate temporal workspace. * \brief Backend function to allocate temporal workspace.
...@@ -53,7 +53,7 @@ TVM_DLL int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr); ...@@ -53,7 +53,7 @@ TVM_DLL int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr);
* certain backends such as OpenGL. * certain backends such as OpenGL.
* \return nullptr when error is thrown, a valid ptr if success * \return nullptr when error is thrown, a valid ptr if success
*/ */
TVM_DLL void* TVMBackendAllocWorkspace(int device_type, DGL_DLL void* DGLBackendAllocWorkspace(int device_type,
int device_id, int device_id,
uint64_t nbytes, uint64_t nbytes,
int dtype_code_hint, int dtype_code_hint,
...@@ -67,14 +67,14 @@ TVM_DLL void* TVMBackendAllocWorkspace(int device_type, ...@@ -67,14 +67,14 @@ TVM_DLL void* TVMBackendAllocWorkspace(int device_type,
* \param device_id The device id which the space will be allocated. * \param device_id The device id which the space will be allocated.
* \return 0 when no error is thrown, -1 when failure happens * \return 0 when no error is thrown, -1 when failure happens
* *
* \sa TVMBackendAllocWorkspace * \sa DGLBackendAllocWorkspace
*/ */
TVM_DLL int TVMBackendFreeWorkspace(int device_type, DGL_DLL int DGLBackendFreeWorkspace(int device_type,
int device_id, int device_id,
void* ptr); void* ptr);
/*! /*!
* \brief Environment for TVM parallel task. * \brief Environment for DGL parallel task.
*/ */
typedef struct { typedef struct {
/*! /*!
...@@ -83,7 +83,7 @@ typedef struct { ...@@ -83,7 +83,7 @@ typedef struct {
void* sync_handle; void* sync_handle;
/*! \brief total amount of task */ /*! \brief total amount of task */
int32_t num_task; int32_t num_task;
} TVMParallelGroupEnv; } DGLParallelGroupEnv;
/*! /*!
* \brief The callback function to execute a parallel lambda * \brief The callback function to execute a parallel lambda
...@@ -91,8 +91,8 @@ typedef struct { ...@@ -91,8 +91,8 @@ typedef struct {
* \param penv The parallel environment backs the execution. * \param penv The parallel environment backs the execution.
* \param cdata The supporting closure data. * \param cdata The supporting closure data.
*/ */
typedef int (*FTVMParallelLambda)( typedef int (*FDGLParallelLambda)(
int task_id, TVMParallelGroupEnv* penv, void* cdata); int task_id, DGLParallelGroupEnv* penv, void* cdata);
/*! /*!
* \brief Backend function for running parallel jobs. * \brief Backend function for running parallel jobs.
...@@ -104,7 +104,7 @@ typedef int (*FTVMParallelLambda)( ...@@ -104,7 +104,7 @@ typedef int (*FTVMParallelLambda)(
* *
* \return 0 when no error is thrown, -1 when failure happens * \return 0 when no error is thrown, -1 when failure happens
*/ */
TVM_DLL int TVMBackendParallelLaunch(FTVMParallelLambda flambda, DGL_DLL int DGLBackendParallelLaunch(FDGLParallelLambda flambda,
void* cdata, void* cdata,
int num_task); int num_task);
...@@ -114,7 +114,7 @@ TVM_DLL int TVMBackendParallelLaunch(FTVMParallelLambda flambda, ...@@ -114,7 +114,7 @@ TVM_DLL int TVMBackendParallelLaunch(FTVMParallelLambda flambda,
* \param penv The parallel environment backs the execution. * \param penv The parallel environment backs the execution.
* \return 0 when no error is thrown, -1 when failure happens * \return 0 when no error is thrown, -1 when failure happens
*/ */
TVM_DLL int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv); DGL_DLL int DGLBackendParallelBarrier(int task_id, DGLParallelGroupEnv* penv);
/*! /*!
...@@ -128,12 +128,12 @@ TVM_DLL int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv); ...@@ -128,12 +128,12 @@ TVM_DLL int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv);
* \param nbytes Number of bytes in the closure data. * \param nbytes Number of bytes in the closure data.
* \return 0 when no error is thrown, -1 when failure happens * \return 0 when no error is thrown, -1 when failure happens
*/ */
TVM_DLL int TVMBackendRunOnce(void** handle, DGL_DLL int DGLBackendRunOnce(void** handle,
int (*f)(void*), int (*f)(void*),
void *cdata, void *cdata,
int nbytes); int nbytes);
#ifdef __cplusplus #ifdef __cplusplus
} // TVM_EXTERN_C } // DGL_EXTERN_C
#endif #endif
#endif // DGL_RUNTIME_C_BACKEND_API_H_ #endif // DGL_RUNTIME_C_BACKEND_API_H_
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* \file dgl/runtime/c_runtime_api.h * \file dgl/runtime/c_runtime_api.h
* \brief TVM runtime library. * \brief DGL runtime library.
* *
* The philosophy of TVM project is to customize the compilation * This runtime is adapted from TVM project
* stage to generate code that can used by other projects transparently.
* So this is a minimum runtime code gluing, and some limited
* memory management code to enable quick testing.
*
* The runtime API is independent from TVM compilation stack and can
* be linked via libtvm_runtime.
*
* The common flow is:
* - Use TVMFuncListGlobalNames to get global function name
* - Use TVMFuncCall to call these functions.
*/ */
#ifndef DGL_RUNTIME_C_RUNTIME_API_H_ #ifndef DGL_RUNTIME_C_RUNTIME_API_H_
#define DGL_RUNTIME_C_RUNTIME_API_H_ #define DGL_RUNTIME_C_RUNTIME_API_H_
// Macros to do weak linking // Macros to do weak linking
#ifdef _MSC_VER #ifdef _MSC_VER
#define TVM_WEAK __declspec(selectany) #define DGL_WEAK __declspec(selectany)
#else #else
#define TVM_WEAK __attribute__((weak)) #define DGL_WEAK __attribute__((weak))
#endif #endif
#ifdef __EMSCRIPTEN__ #ifdef __EMSCRIPTEN__
#include <emscripten/emscripten.h> #include <emscripten/emscripten.h>
#define TVM_DLL EMSCRIPTEN_KEEPALIVE #define DGL_DLL EMSCRIPTEN_KEEPALIVE
#endif #endif
#ifndef TVM_DLL #ifndef DGL_DLL
#ifdef _WIN32 #ifdef _WIN32
#ifdef TVM_EXPORTS #ifdef DGL_EXPORTS
#define TVM_DLL __declspec(dllexport) #define DGL_DLL __declspec(dllexport)
#else #else
#define TVM_DLL __declspec(dllimport) #define DGL_DLL __declspec(dllimport)
#endif #endif
#else #else
#define TVM_DLL #define DGL_DLL
#endif #endif
#endif #endif
// TVM version // DGL version
#define TVM_VERSION "0.5.dev" #define DGL_VERSION "0.5.dev"
// TVM Runtime is DLPack compatible. // DGL Runtime is DLPack compatible.
#include <dlpack/dlpack.h> #include <dlpack/dlpack.h>
#ifdef __cplusplus #ifdef __cplusplus
...@@ -56,9 +46,9 @@ extern "C" { ...@@ -56,9 +46,9 @@ extern "C" {
#include <stddef.h> #include <stddef.h>
/*! \brief type of array index. */ /*! \brief type of array index. */
typedef int64_t tvm_index_t; typedef int64_t dgl_index_t;
/*! \brief Extension device types in TVM */ /*! \brief Extension device types in DGL */
typedef enum { typedef enum {
kDLAOCL = 5, kDLAOCL = 5,
kDLSDAccel = 6, kDLSDAccel = 6,
...@@ -66,21 +56,21 @@ typedef enum { ...@@ -66,21 +56,21 @@ typedef enum {
// Extension DRAM type, used for quickly test extension device // Extension DRAM type, used for quickly test extension device
// The device api can differ depending on the xpu driver registered. // The device api can differ depending on the xpu driver registered.
kExtDev = 12, kExtDev = 12,
// AddExtraTVMType which is not in DLPack here // AddExtraDGLType which is not in DLPack here
} TVMDeviceExtType; } DGLDeviceExtType;
/*! /*!
* \brief The type code in TVMType * \brief The type code in DGLType
* \note TVMType is used in two places. * \note DGLType is used in two places.
*/ */
typedef enum { typedef enum {
// The type code of other types are compatible with DLPack. // The type code of other types are compatible with DLPack.
// The next few fields are extension types // The next few fields are extension types
// that is used by TVM API calls. // that is used by DGL API calls.
kHandle = 3U, kHandle = 3U,
kNull = 4U, kNull = 4U,
kTVMType = 5U, kDGLType = 5U,
kTVMContext = 6U, kDGLContext = 6U,
kArrayHandle = 7U, kArrayHandle = 7U,
kNodeHandle = 8U, kNodeHandle = 8U,
kModuleHandle = 9U, kModuleHandle = 9U,
...@@ -88,7 +78,7 @@ typedef enum { ...@@ -88,7 +78,7 @@ typedef enum {
kStr = 11U, kStr = 11U,
kBytes = 12U, kBytes = 12U,
kNDArrayContainer = 13U, kNDArrayContainer = 13U,
// Extension codes for other frameworks to integrate TVM PackedFunc. // Extension codes for other frameworks to integrate DGL PackedFunc.
// To make sure each framework's id do not conflict, use first and // To make sure each framework's id do not conflict, use first and
// last sections to mark ranges. // last sections to mark ranges.
// Open an issue at the repo if you need a section of code. // Open an issue at the repo if you need a section of code.
...@@ -98,32 +88,32 @@ typedef enum { ...@@ -98,32 +88,32 @@ typedef enum {
// The following section of code is used for non-reserved types. // The following section of code is used for non-reserved types.
kExtReserveEnd = 64U, kExtReserveEnd = 64U,
kExtEnd = 128U kExtEnd = 128U
} TVMTypeCode; } DGLTypeCode;
/*! /*!
* \brief The data type used in TVM Runtime. * \brief The data type used in DGL Runtime.
* *
* Examples * Examples
* - float: type_code = 2, bits = 32, lanes=1 * - float: type_code = 2, bits = 32, lanes=1
* - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4 * - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
* - int8: type_code = 0, bits = 8, lanes=1 * - int8: type_code = 0, bits = 8, lanes=1
* *
* \note Arguments TVM API function always takes bits=64 and lanes=1 * \note Arguments DGL API function always takes bits=64 and lanes=1
*/ */
typedef DLDataType TVMType; typedef DLDataType DGLType;
/*! /*!
* \brief The Device information, abstract away common device types. * \brief The Device information, abstract away common device types.
*/ */
typedef DLContext TVMContext; typedef DLContext DGLContext;
/*! /*!
* \brief The tensor array stucture to TVM API. * \brief The tensor array stucture to DGL API.
*/ */
typedef DLTensor TVMArray; typedef DLTensor DGLArray;
/*! \brief the array handle */ /*! \brief the array handle */
typedef TVMArray* TVMArrayHandle; typedef DGLArray* DGLArrayHandle;
/*! /*!
* \brief Union type of values * \brief Union type of values
...@@ -134,9 +124,9 @@ typedef union { ...@@ -134,9 +124,9 @@ typedef union {
double v_float64; double v_float64;
void* v_handle; void* v_handle;
const char* v_str; const char* v_str;
TVMType v_type; DGLType v_type;
TVMContext v_ctx; DGLContext v_ctx;
} TVMValue; } DGLValue;
/*! /*!
* \brief Byte array type used to pass in byte array * \brief Byte array type used to pass in byte array
...@@ -145,37 +135,37 @@ typedef union { ...@@ -145,37 +135,37 @@ typedef union {
typedef struct { typedef struct {
const char* data; const char* data;
size_t size; size_t size;
} TVMByteArray; } DGLByteArray;
/*! \brief Handle to TVM runtime modules. */ /*! \brief Handle to DGL runtime modules. */
typedef void* TVMModuleHandle; typedef void* DGLModuleHandle;
/*! \brief Handle to packed function handle. */ /*! \brief Handle to packed function handle. */
typedef void* TVMFunctionHandle; typedef void* DGLFunctionHandle;
/*! \brief Handle to hold return value. */ /*! \brief Handle to hold return value. */
typedef void* TVMRetValueHandle; typedef void* DGLRetValueHandle;
/*! /*!
* \brief The stream that is specific to device * \brief The stream that is specific to device
* can be NULL, which indicates the default one. * can be NULL, which indicates the default one.
*/ */
typedef void* TVMStreamHandle; typedef void* DGLStreamHandle;
/*! /*!
* \brief Used for implementing C API function. * \brief Used for implementing C API function.
* Set last error message before return. * Set last error message before return.
* \param msg The error message to be set. * \param msg The error message to be set.
*/ */
TVM_DLL void TVMAPISetLastError(const char* msg); DGL_DLL void DGLAPISetLastError(const char* msg);
/*! /*!
* \brief return str message of the last error * \brief return str message of the last error
* all function in this file will return 0 when success * all function in this file will return 0 when success
* and -1 when an error occured, * and -1 when an error occured,
* TVMGetLastError can be called to retrieve the error * DGLGetLastError can be called to retrieve the error
* *
* this function is threadsafe and can be called by different thread * this function is threadsafe and can be called by different thread
* \return error info * \return error info
*/ */
TVM_DLL const char *TVMGetLastError(void); DGL_DLL const char *DGLGetLastError(void);
/*! /*!
* \brief Load module from file. * \brief Load module from file.
* \param file_name The file name to load the module from. * \param file_name The file name to load the module from.
...@@ -184,11 +174,11 @@ TVM_DLL const char *TVMGetLastError(void); ...@@ -184,11 +174,11 @@ TVM_DLL const char *TVMGetLastError(void);
* *
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
* \note The resulting module do not contain import relation. * \note The resulting module do not contain import relation.
* It can be reconstructed by TVMModImport. * It can be reconstructed by DGLModImport.
*/ */
TVM_DLL int TVMModLoadFromFile(const char* file_name, DGL_DLL int DGLModLoadFromFile(const char* file_name,
const char* format, const char* format,
TVMModuleHandle* out); DGLModuleHandle* out);
/*! /*!
* \brief Add dep to mod's dependency. * \brief Add dep to mod's dependency.
...@@ -198,8 +188,8 @@ TVM_DLL int TVMModLoadFromFile(const char* file_name, ...@@ -198,8 +188,8 @@ TVM_DLL int TVMModLoadFromFile(const char* file_name,
* \param dep The dependent module to be imported. * \param dep The dependent module to be imported.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMModImport(TVMModuleHandle mod, DGL_DLL int DGLModImport(DGLModuleHandle mod,
TVMModuleHandle dep); DGLModuleHandle dep);
/*! /*!
* \brief Get function from the module. * \brief Get function from the module.
...@@ -209,10 +199,10 @@ TVM_DLL int TVMModImport(TVMModuleHandle mod, ...@@ -209,10 +199,10 @@ TVM_DLL int TVMModImport(TVMModuleHandle mod,
* \param out The result function, can be NULL if it is not available. * \param out The result function, can be NULL if it is not available.
* \return 0 when no error is thrown, -1 when failure happens * \return 0 when no error is thrown, -1 when failure happens
*/ */
TVM_DLL int TVMModGetFunction(TVMModuleHandle mod, DGL_DLL int DGLModGetFunction(DGLModuleHandle mod,
const char* func_name, const char* func_name,
int query_imports, int query_imports,
TVMFunctionHandle *out); DGLFunctionHandle *out);
/*! /*!
* \brief Free front-end extension type resource. * \brief Free front-end extension type resource.
...@@ -220,30 +210,30 @@ TVM_DLL int TVMModGetFunction(TVMModuleHandle mod, ...@@ -220,30 +210,30 @@ TVM_DLL int TVMModGetFunction(TVMModuleHandle mod,
* \param type_code The type of of the extension type. * \param type_code The type of of the extension type.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMExtTypeFree(void* handle, int type_code); DGL_DLL int DGLExtTypeFree(void* handle, int type_code);
/*! /*!
* \brief Free the Module * \brief Free the Module
* \param mod The module to be freed. * \param mod The module to be freed.
* *
* \note This may not free up the module's resources. * \note This may not free up the module's resources.
* If there is active TVMFunctionHandle uses the module * If there is active DGLFunctionHandle uses the module
* Or if this module is imported by another active module. * Or if this module is imported by another active module.
* *
* The all functions remains valid until TVMFuncFree is called. * The all functions remains valid until DGLFuncFree is called.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMModFree(TVMModuleHandle mod); DGL_DLL int DGLModFree(DGLModuleHandle mod);
/*! /*!
* \brief Free the function when it is no longer needed. * \brief Free the function when it is no longer needed.
* \param func The function handle * \param func The function handle
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMFuncFree(TVMFunctionHandle func); DGL_DLL int DGLFuncFree(DGLFunctionHandle func);
/*! /*!
* \brief Call a Packed TVM Function. * \brief Call a Packed DGL Function.
* *
* \param func node handle of the function. * \param func node handle of the function.
* \param arg_values The arguments * \param arg_values The arguments
...@@ -254,34 +244,34 @@ TVM_DLL int TVMFuncFree(TVMFunctionHandle func); ...@@ -254,34 +244,34 @@ TVM_DLL int TVMFuncFree(TVMFunctionHandle func);
* \param ret_type_code the type code of return value. * \param ret_type_code the type code of return value.
* *
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
* \note TVM calls always exchanges with type bits=64, lanes=1 * \note DGL calls always exchanges with type bits=64, lanes=1
* *
* \note API calls always exchanges with type bits=64, lanes=1 * \note API calls always exchanges with type bits=64, lanes=1
* If API call returns container handles (e.g. FunctionHandle) * If API call returns container handles (e.g. FunctionHandle)
* these handles should be managed by the front-end. * these handles should be managed by the front-end.
* The front-end need to call free function (e.g. TVMFuncFree) * The front-end need to call free function (e.g. DGLFuncFree)
* to free these handles. * to free these handles.
*/ */
TVM_DLL int TVMFuncCall(TVMFunctionHandle func, DGL_DLL int DGLFuncCall(DGLFunctionHandle func,
TVMValue* arg_values, DGLValue* arg_values,
int* type_codes, int* type_codes,
int num_args, int num_args,
TVMValue* ret_val, DGLValue* ret_val,
int* ret_type_code); int* ret_type_code);
/*! /*!
* \brief Set the return value of TVMPackedCFunc. * \brief Set the return value of DGLPackedCFunc.
* *
* This function is called by TVMPackedCFunc to set the return value. * This function is called by DGLPackedCFunc to set the return value.
* When this function is not called, the function returns null by default. * When this function is not called, the function returns null by default.
* *
* \param ret The return value handle, pass by ret in TVMPackedCFunc * \param ret The return value handle, pass by ret in DGLPackedCFunc
* \param value The value to be returned. * \param value The value to be returned.
* \param type_code The type of the value to be returned. * \param type_code The type of the value to be returned.
* \param num_ret Number of return values, for now only 1 is supported. * \param num_ret Number of return values, for now only 1 is supported.
*/ */
TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret, DGL_DLL int DGLCFuncSetReturn(DGLRetValueHandle ret,
TVMValue* value, DGLValue* value,
int* type_code, int* type_code,
int num_ret); int num_ret);
...@@ -295,7 +285,7 @@ TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret, ...@@ -295,7 +285,7 @@ TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret,
* *
* \return 0 when success, -1 when failure happens. * \return 0 when success, -1 when failure happens.
*/ */
TVM_DLL int TVMCbArgToReturn(TVMValue* value, int code); DGL_DLL int DGLCbArgToReturn(DGLValue* value, int code);
/*! /*!
* \brief C type of packed function. * \brief C type of packed function.
...@@ -305,37 +295,37 @@ TVM_DLL int TVMCbArgToReturn(TVMValue* value, int code); ...@@ -305,37 +295,37 @@ TVM_DLL int TVMCbArgToReturn(TVMValue* value, int code);
* \param num_args Number of arguments. * \param num_args Number of arguments.
* \param ret The return value handle. * \param ret The return value handle.
* \param resource_handle The handle additional resouce handle from fron-end. * \param resource_handle The handle additional resouce handle from fron-end.
* \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError. * \return 0 if success, -1 if failure happens, set error via DGLAPISetLastError.
* \sa TVMCFuncSetReturn * \sa DGLCFuncSetReturn
*/ */
typedef int (*TVMPackedCFunc)( typedef int (*DGLPackedCFunc)(
TVMValue* args, DGLValue* args,
int* type_codes, int* type_codes,
int num_args, int num_args,
TVMRetValueHandle ret, DGLRetValueHandle ret,
void* resource_handle); void* resource_handle);
/*! /*!
* \brief C callback to free the resource handle in C packed function. * \brief C callback to free the resource handle in C packed function.
* \param resource_handle The handle additional resouce handle from fron-end. * \param resource_handle The handle additional resouce handle from fron-end.
*/ */
typedef void (*TVMPackedCFuncFinalizer)(void* resource_handle); typedef void (*DGLPackedCFuncFinalizer)(void* resource_handle);
/*! /*!
* \brief Signature for extension function declarer. * \brief Signature for extension function declarer.
* *
* TVM call this function to get the extension functions * DGL call this function to get the extension functions
* The declarer will call register_func to register function and their name. * The declarer will call register_func to register function and their name.
* *
* \param register_func_handle The register function * \param register_func_handle The register function
* \return 0 if success, -1 if failure happens * \return 0 if success, -1 if failure happens
*/ */
typedef int (*TVMExtensionFuncDeclarer)(TVMFunctionHandle register_func_handle); typedef int (*DGLExtensionFuncDeclarer)(DGLFunctionHandle register_func_handle);
/*! /*!
* \brief Wrap a TVMPackedCFunc to become a FunctionHandle. * \brief Wrap a DGLPackedCFunc to become a FunctionHandle.
* *
* The resource_handle will be managed by TVM API, until the function is no longer used. * The resource_handle will be managed by DGL API, until the function is no longer used.
* *
* \param func The packed C function. * \param func The packed C function.
* \param resource_handle The resource handle from front-end, can be NULL. * \param resource_handle The resource handle from front-end, can be NULL.
...@@ -343,10 +333,10 @@ typedef int (*TVMExtensionFuncDeclarer)(TVMFunctionHandle register_func_handle); ...@@ -343,10 +333,10 @@ typedef int (*TVMExtensionFuncDeclarer)(TVMFunctionHandle register_func_handle);
* \param out the result function handle. * \param out the result function handle.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMFuncCreateFromCFunc(TVMPackedCFunc func, DGL_DLL int DGLFuncCreateFromCFunc(DGLPackedCFunc func,
void* resource_handle, void* resource_handle,
TVMPackedCFuncFinalizer fin, DGLPackedCFuncFinalizer fin,
TVMFunctionHandle *out); DGLFunctionHandle *out);
/*! /*!
* \brief Register the function to runtime's global table. * \brief Register the function to runtime's global table.
...@@ -357,8 +347,8 @@ TVM_DLL int TVMFuncCreateFromCFunc(TVMPackedCFunc func, ...@@ -357,8 +347,8 @@ TVM_DLL int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
* \param f The function to be registered. * \param f The function to be registered.
* \param override Whether allow override already registered function. * \param override Whether allow override already registered function.
*/ */
TVM_DLL int TVMFuncRegisterGlobal( DGL_DLL int DGLFuncRegisterGlobal(
const char* name, TVMFunctionHandle f, int override); const char* name, DGLFunctionHandle f, int override);
/*! /*!
* \brief Get a global function. * \brief Get a global function.
...@@ -366,10 +356,10 @@ TVM_DLL int TVMFuncRegisterGlobal( ...@@ -366,10 +356,10 @@ TVM_DLL int TVMFuncRegisterGlobal(
* \param name The name of the function. * \param name The name of the function.
* \param out the result function pointer, NULL if it does not exist. * \param out the result function pointer, NULL if it does not exist.
* *
* \note The function handle of global function is managed by TVM runtime, * \note The function handle of global function is managed by DGL runtime,
* So TVMFuncFree is should not be called when it get deleted. * So DGLFuncFree is should not be called when it get deleted.
*/ */
TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out); DGL_DLL int DGLFuncGetGlobal(const char* name, DGLFunctionHandle* out);
/*! /*!
* \brief List all the globally registered function name * \brief List all the globally registered function name
...@@ -377,7 +367,7 @@ TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out); ...@@ -377,7 +367,7 @@ TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out);
* \param out_array The array of function names. * \param out_array The array of function names.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMFuncListGlobalNames(int* out_size, DGL_DLL int DGLFuncListGlobalNames(int* out_size,
const char*** out_array); const char*** out_array);
// Array related apis for quick proptyping // Array related apis for quick proptyping
...@@ -395,21 +385,21 @@ TVM_DLL int TVMFuncListGlobalNames(int* out_size, ...@@ -395,21 +385,21 @@ TVM_DLL int TVMFuncListGlobalNames(int* out_size,
* \param out The output handle. * \param out The output handle.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape, DGL_DLL int DGLArrayAlloc(const dgl_index_t* shape,
int ndim, int ndim,
int dtype_code, int dtype_code,
int dtype_bits, int dtype_bits,
int dtype_lanes, int dtype_lanes,
int device_type, int device_type,
int device_id, int device_id,
TVMArrayHandle* out); DGLArrayHandle* out);
/*! /*!
* \brief Free the TVM Array. * \brief Free the DGL Array.
* \param handle The array handle to be freed. * \param handle The array handle to be freed.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMArrayFree(TVMArrayHandle handle); DGL_DLL int DGLArrayFree(DGLArrayHandle handle);
/*! /*!
* \brief Copy array data from CPU byte array. * \brief Copy array data from CPU byte array.
...@@ -418,7 +408,7 @@ TVM_DLL int TVMArrayFree(TVMArrayHandle handle); ...@@ -418,7 +408,7 @@ TVM_DLL int TVMArrayFree(TVMArrayHandle handle);
* \param nbytes The number of bytes to copy. * \param nbytes The number of bytes to copy.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMArrayCopyFromBytes(TVMArrayHandle handle, DGL_DLL int DGLArrayCopyFromBytes(DGLArrayHandle handle,
void* data, void* data,
size_t nbytes); size_t nbytes);
...@@ -429,7 +419,7 @@ TVM_DLL int TVMArrayCopyFromBytes(TVMArrayHandle handle, ...@@ -429,7 +419,7 @@ TVM_DLL int TVMArrayCopyFromBytes(TVMArrayHandle handle,
* \param nbytes The number of bytes to copy. * \param nbytes The number of bytes to copy.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMArrayCopyToBytes(TVMArrayHandle handle, DGL_DLL int DGLArrayCopyToBytes(DGLArrayHandle handle,
void* data, void* data,
size_t nbytes); size_t nbytes);
...@@ -440,9 +430,9 @@ TVM_DLL int TVMArrayCopyToBytes(TVMArrayHandle handle, ...@@ -440,9 +430,9 @@ TVM_DLL int TVMArrayCopyToBytes(TVMArrayHandle handle,
* \param stream The stream where the copy happens, can be NULL. * \param stream The stream where the copy happens, can be NULL.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from, DGL_DLL int DGLArrayCopyFromTo(DGLArrayHandle from,
TVMArrayHandle to, DGLArrayHandle to,
TVMStreamHandle stream); DGLStreamHandle stream);
/*! /*!
* \brief Produce an array from the DLManagedTensor that shares data memory * \brief Produce an array from the DLManagedTensor that shares data memory
...@@ -451,8 +441,8 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from, ...@@ -451,8 +441,8 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
* \param out The output array handle. * \param out The output array handle.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMArrayFromDLPack(DLManagedTensor* from, DGL_DLL int DGLArrayFromDLPack(DLManagedTensor* from,
TVMArrayHandle* out); DGLArrayHandle* out);
/*! /*!
* \brief Produce a DLMangedTensor from the array that shares data memory with * \brief Produce a DLMangedTensor from the array that shares data memory with
...@@ -461,14 +451,14 @@ TVM_DLL int TVMArrayFromDLPack(DLManagedTensor* from, ...@@ -461,14 +451,14 @@ TVM_DLL int TVMArrayFromDLPack(DLManagedTensor* from,
* \param out The DLManagedTensor handle. * \param out The DLManagedTensor handle.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMArrayToDLPack(TVMArrayHandle from, DGL_DLL int DGLArrayToDLPack(DGLArrayHandle from,
DLManagedTensor** out); DLManagedTensor** out);
/*! /*!
* \brief Delete (free) a DLManagedTensor's data. * \brief Delete (free) a DLManagedTensor's data.
* \param dltensor Pointer to the DLManagedTensor. * \param dltensor Pointer to the DLManagedTensor.
*/ */
TVM_DLL void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor); DGL_DLL void DGLDLManagedTensorCallDeleter(DLManagedTensor* dltensor);
/*! /*!
* \brief Create a new runtime stream. * \brief Create a new runtime stream.
...@@ -478,7 +468,7 @@ TVM_DLL void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor); ...@@ -478,7 +468,7 @@ TVM_DLL void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor);
* \param out The new stream handle * \param out The new stream handle
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMStreamCreate(int device_type, int device_id, TVMStreamHandle* out); DGL_DLL int DGLStreamCreate(int device_type, int device_id, DGLStreamHandle* out);
/*! /*!
* \brief Free a created stream handle. * \brief Free a created stream handle.
...@@ -488,7 +478,7 @@ TVM_DLL int TVMStreamCreate(int device_type, int device_id, TVMStreamHandle* out ...@@ -488,7 +478,7 @@ TVM_DLL int TVMStreamCreate(int device_type, int device_id, TVMStreamHandle* out
* \param stream The stream to be freed * \param stream The stream to be freed
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMStreamFree(int device_type, int device_id, TVMStreamHandle stream); DGL_DLL int DGLStreamFree(int device_type, int device_id, DGLStreamHandle stream);
/*! /*!
* \brief Set the runtime stream of current thread to be stream. * \brief Set the runtime stream of current thread to be stream.
...@@ -501,7 +491,7 @@ TVM_DLL int TVMStreamFree(int device_type, int device_id, TVMStreamHandle stream ...@@ -501,7 +491,7 @@ TVM_DLL int TVMStreamFree(int device_type, int device_id, TVMStreamHandle stream
* \param handle The stream handle. * \param handle The stream handle.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMSetStream(int device_type, int device_id, TVMStreamHandle handle); DGL_DLL int DGLSetStream(int device_type, int device_id, DGLStreamHandle handle);
/*! /*!
* \brief Wait until all computations on stream completes. * \brief Wait until all computations on stream completes.
...@@ -511,7 +501,7 @@ TVM_DLL int TVMSetStream(int device_type, int device_id, TVMStreamHandle handle) ...@@ -511,7 +501,7 @@ TVM_DLL int TVMSetStream(int device_type, int device_id, TVMStreamHandle handle)
* \param stream The stream to be synchronized. * \param stream The stream to be synchronized.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream); DGL_DLL int DGLSynchronize(int device_type, int device_id, DGLStreamHandle stream);
/*! /*!
* \brief Synchronize two streams of execution. * \brief Synchronize two streams of execution.
...@@ -522,12 +512,12 @@ TVM_DLL int TVMSynchronize(int device_type, int device_id, TVMStreamHandle strea ...@@ -522,12 +512,12 @@ TVM_DLL int TVMSynchronize(int device_type, int device_id, TVMStreamHandle strea
* \param dst The destination stream to synchronize. * \param dst The destination stream to synchronize.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
TVM_DLL int TVMStreamStreamSynchronize(int device_type, DGL_DLL int DGLStreamStreamSynchronize(int device_type,
int device_id, int device_id,
TVMStreamHandle src, DGLStreamHandle src,
TVMStreamHandle dst); DGLStreamHandle dst);
#ifdef __cplusplus #ifdef __cplusplus
} // TVM_EXTERN_C } // DGL_EXTERN_C
#endif #endif
#endif // DGL_RUNTIME_C_RUNTIME_API_H_ #endif // DGL_RUNTIME_C_RUNTIME_API_H_
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "packed_func.h" #include "packed_func.h"
#include "c_runtime_api.h" #include "c_runtime_api.h"
namespace tvm { namespace dgl {
namespace runtime { namespace runtime {
/*! /*!
* \brief the query type into GetAttr * \brief the query type into GetAttr
...@@ -37,7 +37,7 @@ constexpr int kTempAllocaAlignment = 64; ...@@ -37,7 +37,7 @@ constexpr int kTempAllocaAlignment = 64;
constexpr int kMaxStackAlloca = 1024; constexpr int kMaxStackAlloca = 1024;
/*! /*!
* \brief TVM Runtime Device API, abstracts the device * \brief DGL Runtime Device API, abstracts the device
* specific interface for memory management. * specific interface for memory management.
*/ */
class DeviceAPI { class DeviceAPI {
...@@ -48,7 +48,7 @@ class DeviceAPI { ...@@ -48,7 +48,7 @@ class DeviceAPI {
* \brief Set the environment device id to ctx * \brief Set the environment device id to ctx
* \param ctx The context to be set. * \param ctx The context to be set.
*/ */
virtual void SetDevice(TVMContext ctx) = 0; virtual void SetDevice(DGLContext ctx) = 0;
/*! /*!
* \brief Get attribute of specified device. * \brief Get attribute of specified device.
* \param ctx The device context * \param ctx The device context
...@@ -56,7 +56,7 @@ class DeviceAPI { ...@@ -56,7 +56,7 @@ class DeviceAPI {
* \param rv The return value. * \param rv The return value.
* \sa DeviceAttrKind * \sa DeviceAttrKind
*/ */
virtual void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) = 0; virtual void GetAttr(DGLContext ctx, DeviceAttrKind kind, DGLRetValue* rv) = 0;
/*! /*!
* \brief Allocate a data space on device. * \brief Allocate a data space on device.
* \param ctx The device context to perform operation. * \param ctx The device context to perform operation.
...@@ -66,16 +66,16 @@ class DeviceAPI { ...@@ -66,16 +66,16 @@ class DeviceAPI {
* as OpenGL, as nbytes & alignment are sufficient for most backends. * as OpenGL, as nbytes & alignment are sufficient for most backends.
* \return The allocated device pointer. * \return The allocated device pointer.
*/ */
virtual void* AllocDataSpace(TVMContext ctx, virtual void* AllocDataSpace(DGLContext ctx,
size_t nbytes, size_t nbytes,
size_t alignment, size_t alignment,
TVMType type_hint) = 0; DGLType type_hint) = 0;
/*! /*!
* \brief Free a data space on device. * \brief Free a data space on device.
* \param ctx The device context to perform operation. * \param ctx The device context to perform operation.
* \param ptr The data space. * \param ptr The data space.
*/ */
virtual void FreeDataSpace(TVMContext ctx, void* ptr) = 0; virtual void FreeDataSpace(DGLContext ctx, void* ptr) = 0;
/*! /*!
* \brief copy data from one place to another * \brief copy data from one place to another
* \param from The source array. * \param from The source array.
...@@ -94,16 +94,16 @@ class DeviceAPI { ...@@ -94,16 +94,16 @@ class DeviceAPI {
void* to, void* to,
size_t to_offset, size_t to_offset,
size_t num_bytes, size_t num_bytes,
TVMContext ctx_from, DGLContext ctx_from,
TVMContext ctx_to, DGLContext ctx_to,
TVMType type_hint, DGLType type_hint,
TVMStreamHandle stream) = 0; DGLStreamHandle stream) = 0;
/*! /*!
* \brief Create a new stream of execution. * \brief Create a new stream of execution.
* *
* \param ctx The context of allocation. * \param ctx The context of allocation.
*/ */
TVM_DLL virtual TVMStreamHandle CreateStream(TVMContext ctx); DGL_DLL virtual DGLStreamHandle CreateStream(DGLContext ctx);
/*! /*!
* \brief Free a stream of execution * \brief Free a stream of execution
...@@ -111,20 +111,20 @@ class DeviceAPI { ...@@ -111,20 +111,20 @@ class DeviceAPI {
* \param ctx The context of the stream * \param ctx The context of the stream
* \param stream The pointer to be freed. * \param stream The pointer to be freed.
*/ */
TVM_DLL virtual void FreeStream(TVMContext ctx, TVMStreamHandle stream); DGL_DLL virtual void FreeStream(DGLContext ctx, DGLStreamHandle stream);
/*! /*!
* \brief Synchronize the stream * \brief Synchronize the stream
* \param ctx The context to perform operation. * \param ctx The context to perform operation.
* \param stream The stream to be sync. * \param stream The stream to be sync.
*/ */
virtual void StreamSync(TVMContext ctx, TVMStreamHandle stream) = 0; virtual void StreamSync(DGLContext ctx, DGLStreamHandle stream) = 0;
/*! /*!
* \brief Set the stream * \brief Set the stream
* \param ctx The context to set stream. * \param ctx The context to set stream.
* \param stream The stream to be set. * \param stream The stream to be set.
*/ */
virtual void SetStream(TVMContext ctx, TVMStreamHandle stream) {} virtual void SetStream(DGLContext ctx, DGLStreamHandle stream) {}
/*! /*!
* \brief Synchronize 2 streams of execution. * \brief Synchronize 2 streams of execution.
* *
...@@ -137,9 +137,9 @@ class DeviceAPI { ...@@ -137,9 +137,9 @@ class DeviceAPI {
* \param event_src The source stream to synchronize. * \param event_src The source stream to synchronize.
* \param event_dst The destination stream to synchronize. * \param event_dst The destination stream to synchronize.
*/ */
TVM_DLL virtual void SyncStreamFromTo(TVMContext ctx, DGL_DLL virtual void SyncStreamFromTo(DGLContext ctx,
TVMStreamHandle event_src, DGLStreamHandle event_src,
TVMStreamHandle event_dst); DGLStreamHandle event_dst);
/*! /*!
* \brief Allocate temporal workspace for backend execution. * \brief Allocate temporal workspace for backend execution.
* *
...@@ -156,16 +156,16 @@ class DeviceAPI { ...@@ -156,16 +156,16 @@ class DeviceAPI {
* \param type_hint The type of elements. Only needed by certain backends such * \param type_hint The type of elements. Only needed by certain backends such
* as OpenGL, as nbytes is sufficient for most backends. * as OpenGL, as nbytes is sufficient for most backends.
*/ */
TVM_DLL virtual void* AllocWorkspace(TVMContext ctx, DGL_DLL virtual void* AllocWorkspace(DGLContext ctx,
size_t nbytes, size_t nbytes,
TVMType type_hint = {}); DGLType type_hint = {});
/*! /*!
* \brief Free temporal workspace in backend execution. * \brief Free temporal workspace in backend execution.
* *
* \param ctx The context of allocation. * \param ctx The context of allocation.
* \param ptr The pointer to be freed. * \param ptr The pointer to be freed.
*/ */
TVM_DLL virtual void FreeWorkspace(TVMContext ctx, void* ptr); DGL_DLL virtual void FreeWorkspace(DGLContext ctx, void* ptr);
/*! /*!
* \brief Get device API base don context. * \brief Get device API base don context.
...@@ -173,11 +173,11 @@ class DeviceAPI { ...@@ -173,11 +173,11 @@ class DeviceAPI {
* \param allow_missing Whether allow missing * \param allow_missing Whether allow missing
* \return The corresponding device API. * \return The corresponding device API.
*/ */
TVM_DLL static DeviceAPI* Get(TVMContext ctx, bool allow_missing = false); DGL_DLL static DeviceAPI* Get(DGLContext ctx, bool allow_missing = false);
}; };
/*! \brief The device type bigger than this is RPC device */ /*! \brief The device type bigger than this is RPC device */
constexpr int kRPCSessMask = 128; constexpr int kRPCSessMask = 128;
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace dgl
#endif // DGL_RUNTIME_DEVICE_API_H_ #endif // DGL_RUNTIME_DEVICE_API_H_
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file dgl/runtime/module.h * \file dgl/runtime/module.h
* \brief Runtime container of the functions generated by TVM, * \brief Runtime container of the functions generated by DGL,
* This is used to support dynamically link, load and save * This is used to support dynamically link, load and save
* functions from different convention under unified API. * functions from different convention under unified API.
*/ */
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include <unordered_map> #include <unordered_map>
#include "c_runtime_api.h" #include "c_runtime_api.h"
namespace tvm { namespace dgl {
namespace runtime { namespace runtime {
// The internal container of module. // The internal container of module.
...@@ -23,7 +23,7 @@ class ModuleNode; ...@@ -23,7 +23,7 @@ class ModuleNode;
class PackedFunc; class PackedFunc;
/*! /*!
* \brief Module container of TVM. * \brief Module container of DGL.
*/ */
class Module { class Module {
public: public:
...@@ -53,7 +53,7 @@ class Module { ...@@ -53,7 +53,7 @@ class Module {
* \note Cyclic dependency is not allowed among modules, * \note Cyclic dependency is not allowed among modules,
* An error will be thrown when cyclic dependency is detected. * An error will be thrown when cyclic dependency is detected.
*/ */
TVM_DLL void Import(Module other); DGL_DLL void Import(Module other);
/*! /*!
* \brief Load a module from file. * \brief Load a module from file.
* \param file_name The name of the host function module. * \param file_name The name of the host function module.
...@@ -61,7 +61,7 @@ class Module { ...@@ -61,7 +61,7 @@ class Module {
* \note This function won't load the import relationship. * \note This function won't load the import relationship.
* Re-create import relationship by calling Import. * Re-create import relationship by calling Import.
*/ */
TVM_DLL static Module LoadFromFile(const std::string& file_name, DGL_DLL static Module LoadFromFile(const std::string& file_name,
const std::string& format = ""); const std::string& format = "");
private: private:
...@@ -112,13 +112,13 @@ class ModuleNode { ...@@ -112,13 +112,13 @@ class ModuleNode {
* but not necessarily host modules. * but not necessarily host modules.
* We can use this to do AOT loading of bundled device functions. * We can use this to do AOT loading of bundled device functions.
*/ */
TVM_DLL virtual void SaveToBinary(dmlc::Stream* stream); DGL_DLL virtual void SaveToBinary(dmlc::Stream* stream);
/*! /*!
* \brief Get the source code of module, when available. * \brief Get the source code of module, when available.
* \param format Format of the source code, can be empty by default. * \param format Format of the source code, can be empty by default.
* \return Possible source code when available. * \return Possible source code when available.
*/ */
TVM_DLL virtual std::string GetSource(const std::string& format = ""); DGL_DLL virtual std::string GetSource(const std::string& format = "");
/*! /*!
* \brief Get a function from current environment * \brief Get a function from current environment
* The environment includes all the imports as well as Global functions. * The environment includes all the imports as well as Global functions.
...@@ -126,7 +126,7 @@ class ModuleNode { ...@@ -126,7 +126,7 @@ class ModuleNode {
* \param name name of the function. * \param name name of the function.
* \return The corresponding function. * \return The corresponding function.
*/ */
TVM_DLL const PackedFunc* GetFuncFromEnv(const std::string& name); DGL_DLL const PackedFunc* GetFuncFromEnv(const std::string& name);
/*! \return The module it imports from */ /*! \return The module it imports from */
const std::vector<Module>& imports() const { const std::vector<Module>& imports() const {
return imports_; return imports_;
...@@ -146,19 +146,19 @@ class ModuleNode { ...@@ -146,19 +146,19 @@ class ModuleNode {
/*! \brief namespace for constant symbols */ /*! \brief namespace for constant symbols */
namespace symbol { namespace symbol {
/*! \brief Global variable to store module context. */ /*! \brief Global variable to store module context. */
constexpr const char* tvm_module_ctx = "__tvm_module_ctx"; constexpr const char* dgl_module_ctx = "__dgl_module_ctx";
/*! \brief Global variable to store device module blob */ /*! \brief Global variable to store device module blob */
constexpr const char* tvm_dev_mblob = "__tvm_dev_mblob"; constexpr const char* dgl_dev_mblob = "__dgl_dev_mblob";
/*! \brief Number of bytes of device module blob. */ /*! \brief Number of bytes of device module blob. */
constexpr const char* tvm_dev_mblob_nbytes = "__tvm_dev_mblob_nbytes"; constexpr const char* dgl_dev_mblob_nbytes = "__dgl_dev_mblob_nbytes";
/*! \brief global function to set device */ /*! \brief global function to set device */
constexpr const char* tvm_set_device = "__tvm_set_device"; constexpr const char* dgl_set_device = "__dgl_set_device";
/*! \brief Auxiliary counter to global barrier. */ /*! \brief Auxiliary counter to global barrier. */
constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state"; constexpr const char* dgl_global_barrier_state = "__dgl_global_barrier_state";
/*! \brief Prepare the global barrier before kernels that uses global barrier. */ /*! \brief Prepare the global barrier before kernels that uses global barrier. */
constexpr const char* tvm_prepare_global_barrier = "__tvm_prepare_global_barrier"; constexpr const char* dgl_prepare_global_barrier = "__dgl_prepare_global_barrier";
/*! \brief Placeholder for the module's entry function. */ /*! \brief Placeholder for the module's entry function. */
constexpr const char* tvm_module_main = "__tvm_main__"; constexpr const char* dgl_module_main = "__dgl_main__";
} // namespace symbol } // namespace symbol
// implementations of inline functions. // implementations of inline functions.
...@@ -171,7 +171,7 @@ inline const ModuleNode* Module::operator->() const { ...@@ -171,7 +171,7 @@ inline const ModuleNode* Module::operator->() const {
} }
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace dgl
#include "packed_func.h" #include "packed_func.h"
#endif // DGL_RUNTIME_MODULE_H_ #endif // DGL_RUNTIME_MODULE_H_
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#include "c_runtime_api.h" #include "c_runtime_api.h"
#include "serializer.h" #include "serializer.h"
namespace tvm { namespace dgl {
namespace runtime { namespace runtime {
/*! /*!
* \brief Managed NDArray. * \brief Managed NDArray.
...@@ -94,7 +94,7 @@ class NDArray { ...@@ -94,7 +94,7 @@ class NDArray {
* \brief Copy data content from another array. * \brief Copy data content from another array.
* \param other The source array to be copied from. * \param other The source array to be copied from.
* \note The copy may happen asynchrously if it involves a GPU context. * \note The copy may happen asynchrously if it involves a GPU context.
* TVMSynchronize is necessary. * DGLSynchronize is necessary.
*/ */
inline void CopyFrom(DLTensor* other); inline void CopyFrom(DLTensor* other);
inline void CopyFrom(const NDArray& other); inline void CopyFrom(const NDArray& other);
...@@ -102,7 +102,7 @@ class NDArray { ...@@ -102,7 +102,7 @@ class NDArray {
* \brief Copy data content into another array. * \brief Copy data content into another array.
* \param other The source array to be copied from. * \param other The source array to be copied from.
* \note The copy may happen asynchrously if it involves a GPU context. * \note The copy may happen asynchrously if it involves a GPU context.
* TVMSynchronize is necessary. * DGLSynchronize is necessary.
*/ */
inline void CopyTo(DLTensor* other) const; inline void CopyTo(DLTensor* other) const;
inline void CopyTo(const NDArray& other) const; inline void CopyTo(const NDArray& other) const;
...@@ -129,14 +129,14 @@ class NDArray { ...@@ -129,14 +129,14 @@ class NDArray {
* \param dtype The data type of the new array. * \param dtype The data type of the new array.
* \note The memory size of new array must be smaller than the current one. * \note The memory size of new array must be smaller than the current one.
*/ */
TVM_DLL NDArray CreateView( DGL_DLL NDArray CreateView(
std::vector<int64_t> shape, DLDataType dtype); std::vector<int64_t> shape, DLDataType dtype);
/*! /*!
* \brief Create a reference view of NDArray that * \brief Create a reference view of NDArray that
* represents as DLManagedTensor. * represents as DLManagedTensor.
* \return A DLManagedTensor * \return A DLManagedTensor
*/ */
TVM_DLL DLManagedTensor* ToDLPack() const; DGL_DLL DLManagedTensor* ToDLPack() const;
/*! /*!
* \brief Create an empty NDArray. * \brief Create an empty NDArray.
* \param shape The shape of the new array. * \param shape The shape of the new array.
...@@ -144,7 +144,7 @@ class NDArray { ...@@ -144,7 +144,7 @@ class NDArray {
* \param ctx The context of the Array. * \param ctx The context of the Array.
* \return The created Array * \return The created Array
*/ */
TVM_DLL static NDArray Empty(std::vector<int64_t> shape, DGL_DLL static NDArray Empty(std::vector<int64_t> shape,
DLDataType dtype, DLDataType dtype,
DLContext ctx); DLContext ctx);
/*! /*!
...@@ -158,15 +158,15 @@ class NDArray { ...@@ -158,15 +158,15 @@ class NDArray {
* \param tensor The DLPack tensor to copy from. * \param tensor The DLPack tensor to copy from.
* \return The created NDArray view. * \return The created NDArray view.
*/ */
TVM_DLL static NDArray FromDLPack(DLManagedTensor* tensor); DGL_DLL static NDArray FromDLPack(DLManagedTensor* tensor);
/*! /*!
* \brief Function to copy data from one array to another. * \brief Function to copy data from one array to another.
* \param from The source array. * \param from The source array.
* \param to The target array. * \param to The target array.
* \param stream The stream used in copy. * \param stream The stream used in copy.
*/ */
TVM_DLL static void CopyFromTo( DGL_DLL static void CopyFromTo(
DLTensor* from, DLTensor* to, TVMStreamHandle stream = nullptr); DLTensor* from, DLTensor* to, DGLStreamHandle stream = nullptr);
// internal namespace // internal namespace
struct Internal; struct Internal;
...@@ -175,8 +175,8 @@ class NDArray { ...@@ -175,8 +175,8 @@ class NDArray {
Container* data_{nullptr}; Container* data_{nullptr};
// enable internal functions // enable internal functions
friend struct Internal; friend struct Internal;
friend class TVMRetValue; friend class DGLRetValue;
friend class TVMArgsSetter; friend class DGLArgsSetter;
}; };
/*! /*!
...@@ -321,11 +321,11 @@ inline const DLTensor* NDArray::operator->() const { ...@@ -321,11 +321,11 @@ inline const DLTensor* NDArray::operator->() const {
} }
/*! \brief Magic number for NDArray file */ /*! \brief Magic number for NDArray file */
constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F; constexpr uint64_t kDGLNDArrayMagic = 0xDD5E40F096B4A13F;
inline bool SaveDLTensor(dmlc::Stream* strm, inline bool SaveDLTensor(dmlc::Stream* strm,
DLTensor* tensor) { DLTensor* tensor) {
uint64_t header = kTVMNDArrayMagic, reserved = 0; uint64_t header = kDGLNDArrayMagic, reserved = 0;
strm->Write(header); strm->Write(header);
strm->Write(reserved); strm->Write(reserved);
// Always save data as CPU context // Always save data as CPU context
...@@ -361,9 +361,9 @@ inline bool SaveDLTensor(dmlc::Stream* strm, ...@@ -361,9 +361,9 @@ inline bool SaveDLTensor(dmlc::Stream* strm,
strm->Write(tensor->data, data_byte_size); strm->Write(tensor->data, data_byte_size);
} else { } else {
std::vector<uint8_t> bytes(data_byte_size); std::vector<uint8_t> bytes(data_byte_size);
CHECK_EQ(TVMArrayCopyToBytes( CHECK_EQ(DGLArrayCopyToBytes(
tensor, dmlc::BeginPtr(bytes), data_byte_size), 0) tensor, dmlc::BeginPtr(bytes), data_byte_size), 0)
<< TVMGetLastError(); << DGLGetLastError();
if (!DMLC_IO_NO_ENDIAN_SWAP) { if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(dmlc::BeginPtr(bytes), type_bytes, num_elems); dmlc::ByteSwap(dmlc::BeginPtr(bytes), type_bytes, num_elems);
} }
...@@ -382,7 +382,7 @@ inline bool NDArray::Load(dmlc::Stream* strm) { ...@@ -382,7 +382,7 @@ inline bool NDArray::Load(dmlc::Stream* strm) {
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
CHECK(strm->Read(&reserved)) CHECK(strm->Read(&reserved))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
CHECK(header == kTVMNDArrayMagic) CHECK(header == kDGLNDArrayMagic)
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
DLContext ctx; DLContext ctx;
int ndim; int ndim;
...@@ -421,5 +421,5 @@ inline bool NDArray::Load(dmlc::Stream* strm) { ...@@ -421,5 +421,5 @@ inline bool NDArray::Load(dmlc::Stream* strm) {
} }
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace dgl
#endif // DGL_RUNTIME_NDARRAY_H_ #endif // DGL_RUNTIME_NDARRAY_H_
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file dgl/runtime/packed_func.h * \file dgl/runtime/packed_func.h
* \brief Type-erased function used across TVM API. * \brief Type-erased function used across DGL API.
*/ */
#ifndef DGL_RUNTIME_PACKED_FUNC_H_ #ifndef DGL_RUNTIME_PACKED_FUNC_H_
#define DGL_RUNTIME_PACKED_FUNC_H_ #define DGL_RUNTIME_PACKED_FUNC_H_
...@@ -18,12 +18,12 @@ ...@@ -18,12 +18,12 @@
#include "module.h" #include "module.h"
#include "ndarray.h" #include "ndarray.h"
// Whether use TVM runtime in header only mode. // Whether use DGL runtime in header only mode.
#ifndef TVM_RUNTIME_HEADER_ONLY #ifndef DGL_RUNTIME_HEADER_ONLY
#define TVM_RUNTIME_HEADER_ONLY 0 #define DGL_RUNTIME_HEADER_ONLY 0
#endif #endif
namespace tvm { namespace dgl {
// Forward declare NodeRef and Node for extensions. // Forward declare NodeRef and Node for extensions.
// This header works fine without depend on NodeRef // This header works fine without depend on NodeRef
// as long as it is not used. // as long as it is not used.
...@@ -32,18 +32,18 @@ class NodeRef; ...@@ -32,18 +32,18 @@ class NodeRef;
namespace runtime { namespace runtime {
// forward declarations // forward declarations
class TVMArgs; class DGLArgs;
class TVMArgValue; class DGLArgValue;
class TVMRetValue; class DGLRetValue;
class TVMArgsSetter; class DGLArgsSetter;
/*! /*!
* \brief Packed function is a type-erased function. * \brief Packed function is a type-erased function.
* The arguments are passed by packed format. * The arguments are passed by packed format.
* *
* This is an useful unified interface to call generated functions, * This is an useful unified interface to call generated functions,
* It is the unified function function type of TVM. * It is the unified function function type of DGL.
* It corresponds to TVMFunctionHandle in C runtime API. * It corresponds to DGLFunctionHandle in C runtime API.
*/ */
class PackedFunc { class PackedFunc {
public: public:
...@@ -54,7 +54,7 @@ class PackedFunc { ...@@ -54,7 +54,7 @@ class PackedFunc {
* *
* \code * \code
* // Example code on how to implemented FType * // Example code on how to implemented FType
* void MyPackedFunc(TVMArgs args, TVMRetValue* rv) { * void MyPackedFunc(DGLArgs args, DGLRetValue* rv) {
* // automatically convert arguments to desired type. * // automatically convert arguments to desired type.
* int a0 = args[0]; * int a0 = args[0];
* float a1 = args[1]; * float a1 = args[1];
...@@ -65,7 +65,7 @@ class PackedFunc { ...@@ -65,7 +65,7 @@ class PackedFunc {
* } * }
* \endcode * \endcode
*/ */
using FType = std::function<void (TVMArgs args, TVMRetValue* rv)>; using FType = std::function<void (DGLArgs args, DGLRetValue* rv)>;
/*! \brief default constructor */ /*! \brief default constructor */
PackedFunc() {} PackedFunc() {}
/*! /*!
...@@ -88,13 +88,13 @@ class PackedFunc { ...@@ -88,13 +88,13 @@ class PackedFunc {
* \endcode * \endcode
*/ */
template<typename... Args> template<typename... Args>
inline TVMRetValue operator()(Args&& ...args) const; inline DGLRetValue operator()(Args&& ...args) const;
/*! /*!
* \brief Call the function in packed format. * \brief Call the function in packed format.
* \param args The arguments * \param args The arguments
* \param rv The return value. * \param rv The return value.
*/ */
inline void CallPacked(TVMArgs args, TVMRetValue* rv) const; inline void CallPacked(DGLArgs args, DGLRetValue* rv) const;
/*! \return the internal body function */ /*! \return the internal body function */
inline FType body() const; inline FType body() const;
/*! \return Whether the packed function is nullptr */ /*! \return Whether the packed function is nullptr */
...@@ -125,7 +125,7 @@ class TypedPackedFunc; ...@@ -125,7 +125,7 @@ class TypedPackedFunc;
* TypedPackedFunc enables compile time type checking. * TypedPackedFunc enables compile time type checking.
* TypedPackedFunc works with the runtime system: * TypedPackedFunc works with the runtime system:
* - It can be passed as an argument of PackedFunc. * - It can be passed as an argument of PackedFunc.
* - It can be assigned to TVMRetValue. * - It can be assigned to DGLRetValue.
* - It can be directly converted to a type-erased PackedFunc. * - It can be directly converted to a type-erased PackedFunc.
* *
* Developers should prefer TypedPackedFunc over PackedFunc in C++ code * Developers should prefer TypedPackedFunc over PackedFunc in C++ code
...@@ -161,7 +161,7 @@ class TypedPackedFunc<R(Args...)> { ...@@ -161,7 +161,7 @@ class TypedPackedFunc<R(Args...)> {
* *
* Example usage: * Example usage:
* \code * \code
* PackedFunc packed([](TVMArgs args, TVMRetValue *rv) { * PackedFunc packed([](DGLArgs args, DGLRetValue *rv) {
* int x = args[0]; * int x = args[0];
* *rv = x + 1; * *rv = x + 1;
* }); * });
...@@ -252,7 +252,7 @@ class TypedPackedFunc<R(Args...)> { ...@@ -252,7 +252,7 @@ class TypedPackedFunc<R(Args...)> {
} }
private: private:
friend class TVMRetValue; friend class DGLRetValue;
/*! \brief The internal packed function */ /*! \brief The internal packed function */
PackedFunc packed_; PackedFunc packed_;
/*! /*!
...@@ -266,10 +266,10 @@ class TypedPackedFunc<R(Args...)> { ...@@ -266,10 +266,10 @@ class TypedPackedFunc<R(Args...)> {
inline void AssignTypedLambda(FLambda flambda); inline void AssignTypedLambda(FLambda flambda);
}; };
/*! \brief Arguments into TVM functions. */ /*! \brief Arguments into DGL functions. */
class TVMArgs { class DGLArgs {
public: public:
const TVMValue* values; const DGLValue* values;
const int* type_codes; const int* type_codes;
int num_args; int num_args;
/*! /*!
...@@ -278,7 +278,7 @@ class TVMArgs { ...@@ -278,7 +278,7 @@ class TVMArgs {
* \param type_codes The argument type codes * \param type_codes The argument type codes
* \param num_args number of arguments. * \param num_args number of arguments.
*/ */
TVMArgs(const TVMValue* values, DGLArgs(const DGLValue* values,
const int* type_codes, const int* type_codes,
int num_args) int num_args)
: values(values), : values(values),
...@@ -291,7 +291,7 @@ class TVMArgs { ...@@ -291,7 +291,7 @@ class TVMArgs {
* \param i the index. * \param i the index.
* \return the ith argument. * \return the ith argument.
*/ */
inline TVMArgValue operator[](int i) const; inline DGLArgValue operator[](int i) const;
}; };
/*! /*!
...@@ -302,31 +302,31 @@ class TVMArgs { ...@@ -302,31 +302,31 @@ class TVMArgs {
inline const char* TypeCode2Str(int type_code); inline const char* TypeCode2Str(int type_code);
/*! /*!
* \brief convert a string to TVM type. * \brief convert a string to DGL type.
* \param s The string to be converted. * \param s The string to be converted.
* \return The corresponding tvm type. * \return The corresponding dgl type.
*/ */
inline TVMType String2TVMType(std::string s); inline DGLType String2DGLType(std::string s);
/*! /*!
* \brief convert a TVM type to string. * \brief convert a DGL type to string.
* \param t The type to be converted. * \param t The type to be converted.
* \return The corresponding tvm type in string. * \return The corresponding dgl type in string.
*/ */
inline std::string TVMType2String(TVMType t); inline std::string DGLType2String(DGLType t);
// macro to check type code. // macro to check type code.
#define TVM_CHECK_TYPE_CODE(CODE, T) \ #define DGL_CHECK_TYPE_CODE(CODE, T) \
CHECK_EQ(CODE, T) << " expected " \ CHECK_EQ(CODE, T) << " expected " \
<< TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \ << TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \
/*! /*!
* \brief Type traits to mark if a class is tvm extension type. * \brief Type traits to mark if a class is dgl extension type.
* *
* To enable extension type in C++ must be register () ed via marco. * To enable extension type in C++ must be register () ed via marco.
* TVM_REGISTER_EXT_TYPE(TypeName) after defining this with this traits. * DGL_REGISTER_EXT_TYPE(TypeName) after defining this with this traits.
* *
* Extension class can be passed and returned via PackedFunc in all tvm runtime. * Extension class can be passed and returned via PackedFunc in all dgl runtime.
* Internally extension class is stored as T*. * Internally extension class is stored as T*.
* *
* \tparam T the typename * \tparam T the typename
...@@ -357,18 +357,18 @@ class ExtTypeVTable { ...@@ -357,18 +357,18 @@ class ExtTypeVTable {
* \param type_code The type code * \param type_code The type code
* \return The registered vtable. * \return The registered vtable.
*/ */
TVM_DLL static ExtTypeVTable* Get(int type_code); DGL_DLL static ExtTypeVTable* Get(int type_code);
private: private:
// Internal registration function. // Internal registration function.
TVM_DLL static ExtTypeVTable* RegisterInternal(int type_code, const ExtTypeVTable& vt); DGL_DLL static ExtTypeVTable* RegisterInternal(int type_code, const ExtTypeVTable& vt);
}; };
/*! /*!
* \brief Internal base class to * \brief Internal base class to
* handle conversion to POD values. * handle conversion to POD values.
*/ */
class TVMPODValue_ { class DGLPODValue_ {
public: public:
operator double() const { operator double() const {
// Allow automatic conversion from int to float // Allow automatic conversion from int to float
...@@ -377,31 +377,31 @@ class TVMPODValue_ { ...@@ -377,31 +377,31 @@ class TVMPODValue_ {
if (type_code_ == kDLInt) { if (type_code_ == kDLInt) {
return static_cast<double>(value_.v_int64); return static_cast<double>(value_.v_int64);
} }
TVM_CHECK_TYPE_CODE(type_code_, kDLFloat); DGL_CHECK_TYPE_CODE(type_code_, kDLFloat);
return value_.v_float64; return value_.v_float64;
} }
operator int64_t() const { operator int64_t() const {
TVM_CHECK_TYPE_CODE(type_code_, kDLInt); DGL_CHECK_TYPE_CODE(type_code_, kDLInt);
return value_.v_int64; return value_.v_int64;
} }
operator uint64_t() const { operator uint64_t() const {
TVM_CHECK_TYPE_CODE(type_code_, kDLInt); DGL_CHECK_TYPE_CODE(type_code_, kDLInt);
return value_.v_int64; return value_.v_int64;
} }
operator int() const { operator int() const {
TVM_CHECK_TYPE_CODE(type_code_, kDLInt); DGL_CHECK_TYPE_CODE(type_code_, kDLInt);
CHECK_LE(value_.v_int64, CHECK_LE(value_.v_int64,
std::numeric_limits<int>::max()); std::numeric_limits<int>::max());
return static_cast<int>(value_.v_int64); return static_cast<int>(value_.v_int64);
} }
operator bool() const { operator bool() const {
TVM_CHECK_TYPE_CODE(type_code_, kDLInt); DGL_CHECK_TYPE_CODE(type_code_, kDLInt);
return value_.v_int64 != 0; return value_.v_int64 != 0;
} }
operator void*() const { operator void*() const {
if (type_code_ == kNull) return nullptr; if (type_code_ == kNull) return nullptr;
if (type_code_ == kArrayHandle) return value_.v_handle; if (type_code_ == kArrayHandle) return value_.v_handle;
TVM_CHECK_TYPE_CODE(type_code_, kHandle); DGL_CHECK_TYPE_CODE(type_code_, kHandle);
return value_.v_handle; return value_.v_handle;
} }
operator DLTensor*() const { operator DLTensor*() const {
...@@ -418,11 +418,11 @@ class TVMPODValue_ { ...@@ -418,11 +418,11 @@ class TVMPODValue_ {
} }
operator NDArray() const { operator NDArray() const {
if (type_code_ == kNull) return NDArray(); if (type_code_ == kNull) return NDArray();
TVM_CHECK_TYPE_CODE(type_code_, kNDArrayContainer); DGL_CHECK_TYPE_CODE(type_code_, kNDArrayContainer);
return NDArray(static_cast<NDArray::Container*>(value_.v_handle)); return NDArray(static_cast<NDArray::Container*>(value_.v_handle));
} }
operator TVMContext() const { operator DGLContext() const {
TVM_CHECK_TYPE_CODE(type_code_, kTVMContext); DGL_CHECK_TYPE_CODE(type_code_, kDGLContext);
return value_.v_ctx; return value_.v_ctx;
} }
template<typename TExtension> template<typename TExtension>
...@@ -444,69 +444,69 @@ class TVMPODValue_ { ...@@ -444,69 +444,69 @@ class TVMPODValue_ {
} }
protected: protected:
friend class TVMArgsSetter; friend class DGLArgsSetter;
friend class TVMRetValue; friend class DGLRetValue;
TVMPODValue_() : type_code_(kNull) {} DGLPODValue_() : type_code_(kNull) {}
TVMPODValue_(TVMValue value, int type_code) DGLPODValue_(DGLValue value, int type_code)
: value_(value), type_code_(type_code) {} : value_(value), type_code_(type_code) {}
/*! \brief The value */ /*! \brief The value */
TVMValue value_; DGLValue value_;
/*! \brief the type code */ /*! \brief the type code */
int type_code_; int type_code_;
}; };
/*! /*!
* \brief A single argument value to PackedFunc. * \brief A single argument value to PackedFunc.
* Containing both type_code and TVMValue * Containing both type_code and DGLValue
* *
* Provides utilities to do type cast into other types. * Provides utilities to do type cast into other types.
*/ */
class TVMArgValue : public TVMPODValue_ { class DGLArgValue : public DGLPODValue_ {
public: public:
/*! \brief default constructor */ /*! \brief default constructor */
TVMArgValue() {} DGLArgValue() {}
/*! /*!
* \brief constructor * \brief constructor
* \param value of the function * \param value of the function
* \param type_code The type code. * \param type_code The type code.
*/ */
TVMArgValue(TVMValue value, int type_code) DGLArgValue(DGLValue value, int type_code)
: TVMPODValue_(value, type_code) { : DGLPODValue_(value, type_code) {
} }
// reuse converter from parent // reuse converter from parent
using TVMPODValue_::operator double; using DGLPODValue_::operator double;
using TVMPODValue_::operator int64_t; using DGLPODValue_::operator int64_t;
using TVMPODValue_::operator uint64_t; using DGLPODValue_::operator uint64_t;
using TVMPODValue_::operator int; using DGLPODValue_::operator int;
using TVMPODValue_::operator bool; using DGLPODValue_::operator bool;
using TVMPODValue_::operator void*; using DGLPODValue_::operator void*;
using TVMPODValue_::operator DLTensor*; using DGLPODValue_::operator DLTensor*;
using TVMPODValue_::operator NDArray; using DGLPODValue_::operator NDArray;
using TVMPODValue_::operator TVMContext; using DGLPODValue_::operator DGLContext;
// conversion operator. // conversion operator.
operator std::string() const { operator std::string() const {
if (type_code_ == kTVMType) { if (type_code_ == kDGLType) {
return TVMType2String(operator TVMType()); return DGLType2String(operator DGLType());
} else if (type_code_ == kBytes) { } else if (type_code_ == kBytes) {
TVMByteArray* arr = static_cast<TVMByteArray*>(value_.v_handle); DGLByteArray* arr = static_cast<DGLByteArray*>(value_.v_handle);
return std::string(arr->data, arr->size); return std::string(arr->data, arr->size);
} else { } else {
TVM_CHECK_TYPE_CODE(type_code_, kStr); DGL_CHECK_TYPE_CODE(type_code_, kStr);
return std::string(value_.v_str); return std::string(value_.v_str);
} }
} }
operator TVMType() const { operator DGLType() const {
if (type_code_ == kStr) { if (type_code_ == kStr) {
return String2TVMType(operator std::string()); return String2DGLType(operator std::string());
} }
TVM_CHECK_TYPE_CODE(type_code_, kTVMType); DGL_CHECK_TYPE_CODE(type_code_, kDGLType);
return value_.v_type; return value_.v_type;
} }
operator PackedFunc() const { operator PackedFunc() const {
if (type_code_ == kNull) return PackedFunc(); if (type_code_ == kNull) return PackedFunc();
TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle); DGL_CHECK_TYPE_CODE(type_code_, kFuncHandle);
return *ptr<PackedFunc>(); return *ptr<PackedFunc>();
} }
template<typename FType> template<typename FType>
...@@ -514,10 +514,10 @@ class TVMArgValue : public TVMPODValue_ { ...@@ -514,10 +514,10 @@ class TVMArgValue : public TVMPODValue_ {
return TypedPackedFunc<FType>(operator PackedFunc()); return TypedPackedFunc<FType>(operator PackedFunc());
} }
operator Module() const { operator Module() const {
TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle); DGL_CHECK_TYPE_CODE(type_code_, kModuleHandle);
return *ptr<Module>(); return *ptr<Module>();
} }
const TVMValue& value() const { const DGLValue& value() const {
return value_; return value_;
} }
// Deferred extension handler. // Deferred extension handler.
...@@ -537,63 +537,63 @@ class TVMArgValue : public TVMPODValue_ { ...@@ -537,63 +537,63 @@ class TVMArgValue : public TVMPODValue_ {
/*! /*!
* \brief Return Value container, * \brief Return Value container,
* Unlike TVMArgValue, which only holds reference and do not delete * Unlike DGLArgValue, which only holds reference and do not delete
* the underlying container during destruction. * the underlying container during destruction.
* *
* TVMRetValue holds value and will manage the underlying containers * DGLRetValue holds value and will manage the underlying containers
* when it stores a complicated data type. * when it stores a complicated data type.
*/ */
class TVMRetValue : public TVMPODValue_ { class DGLRetValue : public DGLPODValue_ {
public: public:
/*! \brief default constructor */ /*! \brief default constructor */
TVMRetValue() {} DGLRetValue() {}
/*! /*!
* \brief move constructor from anoter return value. * \brief move constructor from anoter return value.
* \param other The other return value. * \param other The other return value.
*/ */
TVMRetValue(TVMRetValue&& other) DGLRetValue(DGLRetValue&& other)
: TVMPODValue_(other.value_, other.type_code_) { : DGLPODValue_(other.value_, other.type_code_) {
other.value_.v_handle = nullptr; other.value_.v_handle = nullptr;
other.type_code_ = kNull; other.type_code_ = kNull;
} }
/*! \brief destructor */ /*! \brief destructor */
~TVMRetValue() { ~DGLRetValue() {
this->Clear(); this->Clear();
} }
// reuse converter from parent // reuse converter from parent
using TVMPODValue_::operator double; using DGLPODValue_::operator double;
using TVMPODValue_::operator int64_t; using DGLPODValue_::operator int64_t;
using TVMPODValue_::operator uint64_t; using DGLPODValue_::operator uint64_t;
using TVMPODValue_::operator int; using DGLPODValue_::operator int;
using TVMPODValue_::operator bool; using DGLPODValue_::operator bool;
using TVMPODValue_::operator void*; using DGLPODValue_::operator void*;
using TVMPODValue_::operator DLTensor*; using DGLPODValue_::operator DLTensor*;
using TVMPODValue_::operator TVMContext; using DGLPODValue_::operator DGLContext;
using TVMPODValue_::operator NDArray; using DGLPODValue_::operator NDArray;
// Disable copy and assign from another value, but allow move. // Disable copy and assign from another value, but allow move.
TVMRetValue(const TVMRetValue& other) { DGLRetValue(const DGLRetValue& other) {
this->Assign(other); this->Assign(other);
} }
// conversion operators // conversion operators
operator std::string() const { operator std::string() const {
if (type_code_ == kTVMType) { if (type_code_ == kDGLType) {
return TVMType2String(operator TVMType()); return DGLType2String(operator DGLType());
} else if (type_code_ == kBytes) { } else if (type_code_ == kBytes) {
return *ptr<std::string>(); return *ptr<std::string>();
} }
TVM_CHECK_TYPE_CODE(type_code_, kStr); DGL_CHECK_TYPE_CODE(type_code_, kStr);
return *ptr<std::string>(); return *ptr<std::string>();
} }
operator TVMType() const { operator DGLType() const {
if (type_code_ == kStr) { if (type_code_ == kStr) {
return String2TVMType(operator std::string()); return String2DGLType(operator std::string());
} }
TVM_CHECK_TYPE_CODE(type_code_, kTVMType); DGL_CHECK_TYPE_CODE(type_code_, kDGLType);
return value_.v_type; return value_.v_type;
} }
operator PackedFunc() const { operator PackedFunc() const {
if (type_code_ == kNull) return PackedFunc(); if (type_code_ == kNull) return PackedFunc();
TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle); DGL_CHECK_TYPE_CODE(type_code_, kFuncHandle);
return *ptr<PackedFunc>(); return *ptr<PackedFunc>();
} }
template<typename FType> template<typename FType>
...@@ -601,91 +601,91 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -601,91 +601,91 @@ class TVMRetValue : public TVMPODValue_ {
return TypedPackedFunc<FType>(operator PackedFunc()); return TypedPackedFunc<FType>(operator PackedFunc());
} }
operator Module() const { operator Module() const {
TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle); DGL_CHECK_TYPE_CODE(type_code_, kModuleHandle);
return *ptr<Module>(); return *ptr<Module>();
} }
// Assign operators // Assign operators
TVMRetValue& operator=(TVMRetValue&& other) { DGLRetValue& operator=(DGLRetValue&& other) {
this->Clear(); this->Clear();
value_ = other.value_; value_ = other.value_;
type_code_ = other.type_code_; type_code_ = other.type_code_;
other.type_code_ = kNull; other.type_code_ = kNull;
return *this; return *this;
} }
TVMRetValue& operator=(double value) { DGLRetValue& operator=(double value) {
this->SwitchToPOD(kDLFloat); this->SwitchToPOD(kDLFloat);
value_.v_float64 = value; value_.v_float64 = value;
return *this; return *this;
} }
TVMRetValue& operator=(std::nullptr_t value) { DGLRetValue& operator=(std::nullptr_t value) {
this->SwitchToPOD(kNull); this->SwitchToPOD(kNull);
value_.v_handle = value; value_.v_handle = value;
return *this; return *this;
} }
TVMRetValue& operator=(void* value) { DGLRetValue& operator=(void* value) {
this->SwitchToPOD(kHandle); this->SwitchToPOD(kHandle);
value_.v_handle = value; value_.v_handle = value;
return *this; return *this;
} }
TVMRetValue& operator=(int64_t value) { DGLRetValue& operator=(int64_t value) {
this->SwitchToPOD(kDLInt); this->SwitchToPOD(kDLInt);
value_.v_int64 = value; value_.v_int64 = value;
return *this; return *this;
} }
TVMRetValue& operator=(int value) { DGLRetValue& operator=(int value) {
this->SwitchToPOD(kDLInt); this->SwitchToPOD(kDLInt);
value_.v_int64 = value; value_.v_int64 = value;
return *this; return *this;
} }
TVMRetValue& operator=(TVMType t) { DGLRetValue& operator=(DGLType t) {
this->SwitchToPOD(kTVMType); this->SwitchToPOD(kDGLType);
value_.v_type = t; value_.v_type = t;
return *this; return *this;
} }
TVMRetValue& operator=(bool value) { DGLRetValue& operator=(bool value) {
this->SwitchToPOD(kDLInt); this->SwitchToPOD(kDLInt);
value_.v_int64 = value; value_.v_int64 = value;
return *this; return *this;
} }
TVMRetValue& operator=(std::string value) { DGLRetValue& operator=(std::string value) {
this->SwitchToClass(kStr, value); this->SwitchToClass(kStr, value);
return *this; return *this;
} }
TVMRetValue& operator=(TVMByteArray value) { DGLRetValue& operator=(DGLByteArray value) {
this->SwitchToClass(kBytes, std::string(value.data, value.size)); this->SwitchToClass(kBytes, std::string(value.data, value.size));
return *this; return *this;
} }
TVMRetValue& operator=(NDArray other) { DGLRetValue& operator=(NDArray other) {
this->Clear(); this->Clear();
type_code_ = kNDArrayContainer; type_code_ = kNDArrayContainer;
value_.v_handle = other.data_; value_.v_handle = other.data_;
other.data_ = nullptr; other.data_ = nullptr;
return *this; return *this;
} }
TVMRetValue& operator=(PackedFunc f) { DGLRetValue& operator=(PackedFunc f) {
this->SwitchToClass(kFuncHandle, f); this->SwitchToClass(kFuncHandle, f);
return *this; return *this;
} }
template<typename FType> template<typename FType>
TVMRetValue& operator=(const TypedPackedFunc<FType>& f) { DGLRetValue& operator=(const TypedPackedFunc<FType>& f) {
return operator=(f.packed()); return operator=(f.packed());
} }
TVMRetValue& operator=(Module m) { DGLRetValue& operator=(Module m) {
this->SwitchToClass(kModuleHandle, m); this->SwitchToClass(kModuleHandle, m);
return *this; return *this;
} }
TVMRetValue& operator=(const TVMRetValue& other) { // NOLINT(*0 DGLRetValue& operator=(const DGLRetValue& other) { // NOLINT(*0
this->Assign(other); this->Assign(other);
return *this; return *this;
} }
TVMRetValue& operator=(const TVMArgValue& other) { DGLRetValue& operator=(const DGLArgValue& other) {
this->Assign(other); this->Assign(other);
return *this; return *this;
} }
template<typename T, template<typename T,
typename = typename std::enable_if< typename = typename std::enable_if<
extension_class_info<T>::code != 0>::type> extension_class_info<T>::code != 0>::type>
TVMRetValue& operator=(const T& other) { DGLRetValue& operator=(const T& other) {
this->SwitchToClass<T>( this->SwitchToClass<T>(
extension_class_info<T>::code, other); extension_class_info<T>::code, other);
return *this; return *this;
...@@ -699,7 +699,7 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -699,7 +699,7 @@ class TVMRetValue : public TVMPODValue_ {
* \param ret_value The return value. * \param ret_value The return value.
* \param ret_type_code The return type code. * \param ret_type_code The return type code.
*/ */
void MoveToCHost(TVMValue* ret_value, void MoveToCHost(DGLValue* ret_value,
int* ret_type_code) { int* ret_type_code) {
// cannot move str; need specially handle. // cannot move str; need specially handle.
CHECK(type_code_ != kStr && type_code_ != kBytes); CHECK(type_code_ != kStr && type_code_ != kBytes);
...@@ -708,22 +708,22 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -708,22 +708,22 @@ class TVMRetValue : public TVMPODValue_ {
type_code_ = kNull; type_code_ = kNull;
} }
/*! \return The value field, if the data is POD */ /*! \return The value field, if the data is POD */
const TVMValue& value() const { const DGLValue& value() const {
CHECK(type_code_ != kNodeHandle && CHECK(type_code_ != kNodeHandle &&
type_code_ != kFuncHandle && type_code_ != kFuncHandle &&
type_code_ != kModuleHandle && type_code_ != kModuleHandle &&
type_code_ != kStr) << "TVMRetValue.value can only be used for POD data"; type_code_ != kStr) << "DGLRetValue.value can only be used for POD data";
return value_; return value_;
} }
// NodeRef related extenstions: in tvm/packed_func_ext.h // NodeRef related extenstions: in dgl/packed_func_ext.h
template<typename T, template<typename T,
typename = typename std::enable_if< typename = typename std::enable_if<
std::is_class<T>::value>::type> std::is_class<T>::value>::type>
inline operator T() const; inline operator T() const;
template<typename TNodeRef> template<typename TNodeRef>
inline TNodeRef AsNodeRef() const; inline TNodeRef AsNodeRef() const;
inline TVMRetValue& operator=(const NodeRef& other); inline DGLRetValue& operator=(const NodeRef& other);
inline TVMRetValue& operator=(const std::shared_ptr<Node>& other); inline DGLRetValue& operator=(const std::shared_ptr<Node>& other);
private: private:
template<typename T> template<typename T>
...@@ -759,7 +759,7 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -759,7 +759,7 @@ class TVMRetValue : public TVMPODValue_ {
SwitchToPOD(other.type_code()); SwitchToPOD(other.type_code());
value_ = other.value_; value_ = other.value_;
} else { } else {
#if TVM_RUNTIME_HEADER_ONLY #if DGL_RUNTIME_HEADER_ONLY
LOG(FATAL) << "Header only mode do not support ext type"; LOG(FATAL) << "Header only mode do not support ext type";
#else #else
this->Clear(); this->Clear();
...@@ -803,7 +803,7 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -803,7 +803,7 @@ class TVMRetValue : public TVMPODValue_ {
} }
} }
if (type_code_ > kExtBegin) { if (type_code_ > kExtBegin) {
#if TVM_RUNTIME_HEADER_ONLY #if DGL_RUNTIME_HEADER_ONLY
LOG(FATAL) << "Header only mode do not support ext type"; LOG(FATAL) << "Header only mode do not support ext type";
#else #else
(*(ExtTypeVTable::Get(type_code_)->destroy))(value_.v_handle); (*(ExtTypeVTable::Get(type_code_)->destroy))(value_.v_handle);
...@@ -825,8 +825,8 @@ inline const char* TypeCode2Str(int type_code) { ...@@ -825,8 +825,8 @@ inline const char* TypeCode2Str(int type_code) {
case kNull: return "NULL"; case kNull: return "NULL";
case kNodeHandle: return "NodeHandle"; case kNodeHandle: return "NodeHandle";
case kArrayHandle: return "ArrayHandle"; case kArrayHandle: return "ArrayHandle";
case kTVMType: return "TVMType"; case kDGLType: return "DGLType";
case kTVMContext: return "TVMContext"; case kDGLContext: return "DGLContext";
case kFuncHandle: return "FunctionHandle"; case kFuncHandle: return "FunctionHandle";
case kModuleHandle: return "ModuleHandle"; case kModuleHandle: return "ModuleHandle";
case kNDArrayContainer: return "NDArrayContainer"; case kNDArrayContainer: return "NDArrayContainer";
...@@ -836,7 +836,7 @@ inline const char* TypeCode2Str(int type_code) { ...@@ -836,7 +836,7 @@ inline const char* TypeCode2Str(int type_code) {
} }
#ifndef _LIBCPP_SGX_NO_IOSTREAMS #ifndef _LIBCPP_SGX_NO_IOSTREAMS
inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*) inline std::ostream& operator<<(std::ostream& os, DGLType t) { // NOLINT(*)
os << TypeCode2Str(t.code); os << TypeCode2Str(t.code);
if (t.code == kHandle) return os; if (t.code == kHandle) return os;
os << static_cast<int>(t.bits); os << static_cast<int>(t.bits);
...@@ -847,7 +847,7 @@ inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*) ...@@ -847,7 +847,7 @@ inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*)
} }
#endif #endif
inline std::string TVMType2String(TVMType t) { inline std::string DGLType2String(DGLType t) {
#ifndef _LIBCPP_SGX_NO_IOSTREAMS #ifndef _LIBCPP_SGX_NO_IOSTREAMS
std::ostringstream os; std::ostringstream os;
os << t; os << t;
...@@ -864,8 +864,8 @@ inline std::string TVMType2String(TVMType t) { ...@@ -864,8 +864,8 @@ inline std::string TVMType2String(TVMType t) {
#endif #endif
} }
inline TVMType String2TVMType(std::string s) { inline DGLType String2DGLType(std::string s) {
TVMType t; DGLType t;
t.bits = 32; t.lanes = 1; t.bits = 32; t.lanes = 1;
const char* scan; const char* scan;
if (s.substr(0, 3) == "int") { if (s.substr(0, 3) == "int") {
...@@ -891,19 +891,19 @@ inline TVMType String2TVMType(std::string s) { ...@@ -891,19 +891,19 @@ inline TVMType String2TVMType(std::string s) {
return t; return t;
} }
inline TVMArgValue TVMArgs::operator[](int i) const { inline DGLArgValue DGLArgs::operator[](int i) const {
CHECK_LT(i, num_args) CHECK_LT(i, num_args)
<< "not enough argument passed, " << "not enough argument passed, "
<< num_args << " passed" << num_args << " passed"
<< " but request arg[" << i << "]."; << " but request arg[" << i << "].";
return TVMArgValue(values[i], type_codes[i]); return DGLArgValue(values[i], type_codes[i]);
} }
inline int TVMArgs::size() const { inline int DGLArgs::size() const {
return num_args; return num_args;
} }
inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { inline void PackedFunc::CallPacked(DGLArgs args, DGLRetValue* rv) const {
body_(args, rv); body_(args, rv);
} }
...@@ -939,9 +939,9 @@ inline void for_each(const F& f, Args&&... args) { // NOLINT(*) ...@@ -939,9 +939,9 @@ inline void for_each(const F& f, Args&&... args) { // NOLINT(*)
} // namespace detail } // namespace detail
/* \brief argument settter to PackedFunc */ /* \brief argument settter to PackedFunc */
class TVMArgsSetter { class DGLArgsSetter {
public: public:
TVMArgsSetter(TVMValue* values, int* type_codes) DGLArgsSetter(DGLValue* values, int* type_codes)
: values_(values), type_codes_(type_codes) {} : values_(values), type_codes_(type_codes) {}
// setters for POD types // setters for POD types
template<typename T, template<typename T,
...@@ -965,7 +965,7 @@ class TVMArgsSetter { ...@@ -965,7 +965,7 @@ class TVMArgsSetter {
values_[i].v_handle = value; values_[i].v_handle = value;
type_codes_[i] = kNull; type_codes_[i] = kNull;
} }
void operator()(size_t i, const TVMArgValue& value) const { void operator()(size_t i, const DGLArgValue& value) const {
values_[i] = value.value_; values_[i] = value.value_;
type_codes_[i] = value.type_code_; type_codes_[i] = value.type_code_;
} }
...@@ -977,13 +977,13 @@ class TVMArgsSetter { ...@@ -977,13 +977,13 @@ class TVMArgsSetter {
values_[i].v_handle = value; values_[i].v_handle = value;
type_codes_[i] = kArrayHandle; type_codes_[i] = kArrayHandle;
} }
void operator()(size_t i, TVMContext value) const { void operator()(size_t i, DGLContext value) const {
values_[i].v_ctx = value; values_[i].v_ctx = value;
type_codes_[i] = kTVMContext; type_codes_[i] = kDGLContext;
} }
void operator()(size_t i, TVMType value) const { void operator()(size_t i, DGLType value) const {
values_[i].v_type = value; values_[i].v_type = value;
type_codes_[i] = kTVMType; type_codes_[i] = kDGLType;
} }
void operator()(size_t i, const char* value) const { void operator()(size_t i, const char* value) const {
values_[i].v_str = value; values_[i].v_str = value;
...@@ -996,8 +996,8 @@ class TVMArgsSetter { ...@@ -996,8 +996,8 @@ class TVMArgsSetter {
values_[i].v_str = value.c_str(); values_[i].v_str = value.c_str();
type_codes_[i] = kStr; type_codes_[i] = kStr;
} }
void operator()(size_t i, const TVMByteArray& value) const { // NOLINT(*) void operator()(size_t i, const DGLByteArray& value) const { // NOLINT(*)
values_[i].v_handle = const_cast<TVMByteArray*>(&value); values_[i].v_handle = const_cast<DGLByteArray*>(&value);
type_codes_[i] = kBytes; type_codes_[i] = kBytes;
} }
void operator()(size_t i, const PackedFunc& value) const { // NOLINT(*) void operator()(size_t i, const PackedFunc& value) const { // NOLINT(*)
...@@ -1016,7 +1016,7 @@ class TVMArgsSetter { ...@@ -1016,7 +1016,7 @@ class TVMArgsSetter {
values_[i].v_handle = value.data_; values_[i].v_handle = value.data_;
type_codes_[i] = kNDArrayContainer; type_codes_[i] = kNDArrayContainer;
} }
void operator()(size_t i, const TVMRetValue& value) const { // NOLINT(*) void operator()(size_t i, const DGLRetValue& value) const { // NOLINT(*)
if (value.type_code() == kStr) { if (value.type_code() == kStr) {
values_[i].v_str = value.ptr<std::string>()->c_str(); values_[i].v_str = value.ptr<std::string>()->c_str();
type_codes_[i] = kStr; type_codes_[i] = kStr;
...@@ -1031,26 +1031,26 @@ class TVMArgsSetter { ...@@ -1031,26 +1031,26 @@ class TVMArgsSetter {
typename = typename std::enable_if< typename = typename std::enable_if<
extension_class_info<T>::code != 0>::type> extension_class_info<T>::code != 0>::type>
inline void operator()(size_t i, const T& value) const; inline void operator()(size_t i, const T& value) const;
// NodeRef related extenstions: in tvm/packed_func_ext.h // NodeRef related extenstions: in dgl/packed_func_ext.h
inline void operator()(size_t i, const NodeRef& other) const; // NOLINT(*) inline void operator()(size_t i, const NodeRef& other) const; // NOLINT(*)
private: private:
/*! \brief The values fields */ /*! \brief The values fields */
TVMValue* values_; DGLValue* values_;
/*! \brief The type code fields */ /*! \brief The type code fields */
int* type_codes_; int* type_codes_;
}; };
template<typename... Args> template<typename... Args>
inline TVMRetValue PackedFunc::operator()(Args&& ...args) const { inline DGLRetValue PackedFunc::operator()(Args&& ...args) const {
const int kNumArgs = sizeof...(Args); const int kNumArgs = sizeof...(Args);
const int kArraySize = kNumArgs > 0 ? kNumArgs : 1; const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
TVMValue values[kArraySize]; DGLValue values[kArraySize];
int type_codes[kArraySize]; int type_codes[kArraySize];
detail::for_each(TVMArgsSetter(values, type_codes), detail::for_each(DGLArgsSetter(values, type_codes),
std::forward<Args>(args)...); std::forward<Args>(args)...);
TVMRetValue rv; DGLRetValue rv;
body_(TVMArgs(values, type_codes, kNumArgs), &rv); body_(DGLArgs(values, type_codes, kNumArgs), &rv);
return rv; return rv;
} }
...@@ -1059,8 +1059,8 @@ template<typename R, int nleft, int index, typename F> ...@@ -1059,8 +1059,8 @@ template<typename R, int nleft, int index, typename F>
struct unpack_call_dispatcher { struct unpack_call_dispatcher {
template<typename ...Args> template<typename ...Args>
static void run(const F& f, static void run(const F& f,
const TVMArgs& args_pack, const DGLArgs& args_pack,
TVMRetValue* rv, DGLRetValue* rv,
Args&&... unpacked_args) { Args&&... unpacked_args) {
unpack_call_dispatcher<R, nleft - 1, index + 1, F> unpack_call_dispatcher<R, nleft - 1, index + 1, F>
::run(f, args_pack, rv, ::run(f, args_pack, rv,
...@@ -1073,8 +1073,8 @@ template<typename R, int index, typename F> ...@@ -1073,8 +1073,8 @@ template<typename R, int index, typename F>
struct unpack_call_dispatcher<R, 0, index, F> { struct unpack_call_dispatcher<R, 0, index, F> {
template<typename ...Args> template<typename ...Args>
static void run(const F& f, static void run(const F& f,
const TVMArgs& args_pack, const DGLArgs& args_pack,
TVMRetValue* rv, DGLRetValue* rv,
Args&&... unpacked_args) { Args&&... unpacked_args) {
*rv = R(f(std::forward<Args>(unpacked_args)...)); *rv = R(f(std::forward<Args>(unpacked_args)...));
} }
...@@ -1084,15 +1084,15 @@ template<int index, typename F> ...@@ -1084,15 +1084,15 @@ template<int index, typename F>
struct unpack_call_dispatcher<void, 0, index, F> { struct unpack_call_dispatcher<void, 0, index, F> {
template<typename ...Args> template<typename ...Args>
static void run(const F& f, static void run(const F& f,
const TVMArgs& args_pack, const DGLArgs& args_pack,
TVMRetValue* rv, DGLRetValue* rv,
Args&&... unpacked_args) { Args&&... unpacked_args) {
f(std::forward<Args>(unpacked_args)...); f(std::forward<Args>(unpacked_args)...);
} }
}; };
template<typename R, int nargs, typename F> template<typename R, int nargs, typename F>
inline void unpack_call(const F& f, const TVMArgs& args, TVMRetValue* rv) { inline void unpack_call(const F& f, const DGLArgs& args, DGLRetValue* rv) {
unpack_call_dispatcher<R, nargs, 0, F>::run(f, args, rv); unpack_call_dispatcher<R, nargs, 0, F>::run(f, args, rv);
} }
...@@ -1125,7 +1125,7 @@ TypedPackedFunc<R(Args...)>::TypedPackedFunc(PackedFunc packed) ...@@ -1125,7 +1125,7 @@ TypedPackedFunc<R(Args...)>::TypedPackedFunc(PackedFunc packed)
template<typename R, typename ...Args> template<typename R, typename ...Args>
template<typename FType> template<typename FType>
inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) { inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) {
packed_ = PackedFunc([flambda](const TVMArgs& args, TVMRetValue* rv) { packed_ = PackedFunc([flambda](const DGLArgs& args, DGLRetValue* rv) {
detail::unpack_call<R, sizeof...(Args)>(flambda, args, rv); detail::unpack_call<R, sizeof...(Args)>(flambda, args, rv);
}); });
} }
...@@ -1139,14 +1139,14 @@ inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const { ...@@ -1139,14 +1139,14 @@ inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {
// extension and node type handling // extension and node type handling
namespace detail { namespace detail {
template<typename T, typename TSrc, bool is_ext> template<typename T, typename TSrc, bool is_ext>
struct TVMValueCast { struct DGLValueCast {
static T Apply(const TSrc* self) { static T Apply(const TSrc* self) {
return self->template AsNodeRef<T>(); return self->template AsNodeRef<T>();
} }
}; };
template<typename T, typename TSrc> template<typename T, typename TSrc>
struct TVMValueCast<T, TSrc, true> { struct DGLValueCast<T, TSrc, true> {
static T Apply(const TSrc* self) { static T Apply(const TSrc* self) {
return self->template AsExtension<T>(); return self->template AsExtension<T>();
} }
...@@ -1154,21 +1154,21 @@ struct TVMValueCast<T, TSrc, true> { ...@@ -1154,21 +1154,21 @@ struct TVMValueCast<T, TSrc, true> {
} // namespace detail } // namespace detail
template<typename T, typename> template<typename T, typename>
inline TVMArgValue::operator T() const { inline DGLArgValue::operator T() const {
return detail:: return detail::
TVMValueCast<T, TVMArgValue, extension_class_info<T>::code != 0> DGLValueCast<T, DGLArgValue, extension_class_info<T>::code != 0>
::Apply(this); ::Apply(this);
} }
template<typename T, typename> template<typename T, typename>
inline TVMRetValue::operator T() const { inline DGLRetValue::operator T() const {
return detail:: return detail::
TVMValueCast<T, TVMRetValue, extension_class_info<T>::code != 0> DGLValueCast<T, DGLRetValue, extension_class_info<T>::code != 0>
::Apply(this); ::Apply(this);
} }
template<typename T, typename> template<typename T, typename>
inline void TVMArgsSetter::operator()(size_t i, const T& value) const { inline void DGLArgsSetter::operator()(size_t i, const T& value) const {
static_assert(extension_class_info<T>::code != 0, static_assert(extension_class_info<T>::code != 0,
"Need to have extesion code"); "Need to have extesion code");
type_codes_[i] = extension_class_info<T>::code; type_codes_[i] = extension_class_info<T>::code;
...@@ -1211,5 +1211,5 @@ inline PackedFunc Module::GetFunction(const std::string& name, bool query_import ...@@ -1211,5 +1211,5 @@ inline PackedFunc Module::GetFunction(const std::string& name, bool query_import
return pf; return pf;
} }
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace dgl
#endif // DGL_RUNTIME_PACKED_FUNC_H_ #endif // DGL_RUNTIME_PACKED_FUNC_H_
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file dgl/runtime/registry.h * \file dgl/runtime/registry.h
* \brief This file defines the TVM global function registry. * \brief This file defines the DGL global function registry.
* *
* The registered functions will be made available to front-end * The registered functions will be made available to front-end
* as well as backend users. * as well as backend users.
...@@ -12,12 +12,12 @@ ...@@ -12,12 +12,12 @@
* *
* Front-end can also pass callbacks as PackedFunc, or register * Front-end can also pass callbacks as PackedFunc, or register
* then into the same global registry in C++. * then into the same global registry in C++.
* The goal is to mix the front-end language and the TVM back-end. * The goal is to mix the front-end language and the DGL back-end.
* *
* \code * \code
* // register the function as MyAPIFuncName * // register the function as MyAPIFuncName
* TVM_REGISTER_GLOBAL(MyAPIFuncName) * DGL_REGISTER_GLOBAL(MyAPIFuncName)
* .set_body([](TVMArgs args, TVMRetValue* rv) { * .set_body([](DGLArgs args, DGLRetValue* rv) {
* // my code. * // my code.
* }); * });
* \endcode * \endcode
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
#include <vector> #include <vector>
#include "packed_func.h" #include "packed_func.h"
namespace tvm { namespace dgl {
namespace runtime { namespace runtime {
/*! \brief Registry for global function */ /*! \brief Registry for global function */
...@@ -39,7 +39,7 @@ class Registry { ...@@ -39,7 +39,7 @@ class Registry {
* \brief set the body of the function to be f * \brief set the body of the function to be f
* \param f The body of the function. * \param f The body of the function.
*/ */
TVM_DLL Registry& set_body(PackedFunc f); // NOLINT(*) DGL_DLL Registry& set_body(PackedFunc f); // NOLINT(*)
/*! /*!
* \brief set the body of the function to be f * \brief set the body of the function to be f
* \param f The body of the function. * \param f The body of the function.
...@@ -52,7 +52,7 @@ class Registry { ...@@ -52,7 +52,7 @@ class Registry {
* *
* \code * \code
* *
* TVM_REGISTER_API("addone") * DGL_REGISTER_API("addone")
* .set_body_typed<int(int)>([](int x) { return x + 1; }); * .set_body_typed<int(int)>([](int x) { return x + 1; });
* *
* \endcode * \endcode
...@@ -71,25 +71,25 @@ class Registry { ...@@ -71,25 +71,25 @@ class Registry {
* \param override Whether allow oveeride existing function. * \param override Whether allow oveeride existing function.
* \return Reference to theregistry. * \return Reference to theregistry.
*/ */
TVM_DLL static Registry& Register(const std::string& name, bool override = false); // NOLINT(*) DGL_DLL static Registry& Register(const std::string& name, bool override = false); // NOLINT(*)
/*! /*!
* \brief Erase global function from registry, if exist. * \brief Erase global function from registry, if exist.
* \param name The name of the function. * \param name The name of the function.
* \return Whether function exist. * \return Whether function exist.
*/ */
TVM_DLL static bool Remove(const std::string& name); DGL_DLL static bool Remove(const std::string& name);
/*! /*!
* \brief Get the global function by name. * \brief Get the global function by name.
* \param name The name of the function. * \param name The name of the function.
* \return pointer to the registered function, * \return pointer to the registered function,
* nullptr if it does not exist. * nullptr if it does not exist.
*/ */
TVM_DLL static const PackedFunc* Get(const std::string& name); // NOLINT(*) DGL_DLL static const PackedFunc* Get(const std::string& name); // NOLINT(*)
/*! /*!
* \brief Get the names of currently registered global function. * \brief Get the names of currently registered global function.
* \return The names * \return The names
*/ */
TVM_DLL static std::vector<std::string> ListNames(); DGL_DLL static std::vector<std::string> ListNames();
// Internal class. // Internal class.
struct Manager; struct Manager;
...@@ -104,41 +104,41 @@ class Registry { ...@@ -104,41 +104,41 @@ class Registry {
/*! \brief helper macro to supress unused warning */ /*! \brief helper macro to supress unused warning */
#if defined(__GNUC__) #if defined(__GNUC__)
#define TVM_ATTRIBUTE_UNUSED __attribute__((unused)) #define DGL_ATTRIBUTE_UNUSED __attribute__((unused))
#else #else
#define TVM_ATTRIBUTE_UNUSED #define DGL_ATTRIBUTE_UNUSED
#endif #endif
#define TVM_STR_CONCAT_(__x, __y) __x##__y #define DGL_STR_CONCAT_(__x, __y) __x##__y
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y) #define DGL_STR_CONCAT(__x, __y) DGL_STR_CONCAT_(__x, __y)
#define TVM_FUNC_REG_VAR_DEF \ #define DGL_FUNC_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_ ## TVM static DGL_ATTRIBUTE_UNUSED ::dgl::runtime::Registry& __mk_ ## DGL
#define TVM_TYPE_REG_VAR_DEF \ #define DGL_TYPE_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::ExtTypeVTable* __mk_ ## TVMT static DGL_ATTRIBUTE_UNUSED ::dgl::runtime::ExtTypeVTable* __mk_ ## DGLT
/*! /*!
* \brief Register a function globally. * \brief Register a function globally.
* \code * \code
* TVM_REGISTER_GLOBAL("MyPrint") * DGL_REGISTER_GLOBAL("MyPrint")
* .set_body([](TVMArgs args, TVMRetValue* rv) { * .set_body([](DGLArgs args, DGLRetValue* rv) {
* }); * });
* \endcode * \endcode
*/ */
#define TVM_REGISTER_GLOBAL(OpName) \ #define DGL_REGISTER_GLOBAL(OpName) \
TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = \ DGL_STR_CONCAT(DGL_FUNC_REG_VAR_DEF, __COUNTER__) = \
::tvm::runtime::Registry::Register(OpName) ::dgl::runtime::Registry::Register(OpName)
/*! /*!
* \brief Macro to register extension type. * \brief Macro to register extension type.
* This must be registered in a cc file * This must be registered in a cc file
* after the trait extension_class_info is defined. * after the trait extension_class_info is defined.
*/ */
#define TVM_REGISTER_EXT_TYPE(T) \ #define DGL_REGISTER_EXT_TYPE(T) \
TVM_STR_CONCAT(TVM_TYPE_REG_VAR_DEF, __COUNTER__) = \ DGL_STR_CONCAT(DGL_TYPE_REG_VAR_DEF, __COUNTER__) = \
::tvm::runtime::ExtTypeVTable::Register_<T>() ::dgl::runtime::ExtTypeVTable::Register_<T>()
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace dgl
#endif // DGL_RUNTIME_REGISTRY_H_ #endif // DGL_RUNTIME_REGISTRY_H_
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file dgl/runtime/serializer.h * \file dgl/runtime/serializer.h
* \brief Serializer extension to support TVM data types * \brief Serializer extension to support DGL data types
* Include this file to enable serialization of DLDataType, DLContext * Include this file to enable serialization of DLDataType, DLContext
*/ */
#ifndef DGL_RUNTIME_SERIALIZER_H_ #ifndef DGL_RUNTIME_SERIALIZER_H_
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
namespace tvm { namespace dgl {
namespace runtime { namespace runtime {
namespace threading { namespace threading {
...@@ -80,6 +80,6 @@ int MaxConcurrency(); ...@@ -80,6 +80,6 @@ int MaxConcurrency();
} // namespace threading } // namespace threading
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace dgl
#endif // DGL_RUNTIME_THREADING_BACKEND_H_ #endif // DGL_RUNTIME_THREADING_BACKEND_H_
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "c_runtime_api.h" #include "c_runtime_api.h"
namespace tvm { namespace dgl {
namespace runtime { namespace runtime {
/*! /*!
...@@ -18,18 +18,18 @@ namespace runtime { ...@@ -18,18 +18,18 @@ namespace runtime {
* \param bits The number of bits to be matched. * \param bits The number of bits to be matched.
* \param lanes The number of lanes sin the type. * \param lanes The number of lanes sin the type.
*/ */
inline bool TypeMatch(TVMType t, int code, int bits, int lanes = 1) { inline bool TypeMatch(DGLType t, int code, int bits, int lanes = 1) {
return t.code == code && t.bits == bits && t.lanes == lanes; return t.code == code && t.bits == bits && t.lanes == lanes;
} }
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace dgl
// Forward declare the intrinsic id we need // Forward declare the intrinsic id we need
// in structure fetch to enable stackvm in runtime // in structure fetch to enable stackvm in runtime
namespace tvm { namespace dgl {
namespace ir { namespace ir {
namespace intrinsic { namespace intrinsic {
/*! \brief The kind of structure field info used in intrinsic */ /*! \brief The kind of structure field info used in intrinsic */
enum TVMStructFieldKind : int { enum DGLStructFieldKind : int {
// array head address // array head address
kArrAddr, kArrAddr,
kArrData, kArrData,
...@@ -43,11 +43,11 @@ enum TVMStructFieldKind : int { ...@@ -43,11 +43,11 @@ enum TVMStructFieldKind : int {
kArrDeviceId, kArrDeviceId,
kArrDeviceType, kArrDeviceType,
kArrKindBound_, kArrKindBound_,
// TVMValue field // DGLValue field
kTVMValueContent, kDGLValueContent,
kTVMValueKindBound_ kDGLValueKindBound_
}; };
} // namespace intrinsic } // namespace intrinsic
} // namespace ir } // namespace ir
} // namespace tvm } // namespace dgl
#endif // DGL_RUNTIME_UTIL_H_ #endif // DGL_RUNTIME_UTIL_H_
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
namespace dgl { namespace dgl {
typedef tvm::runtime::NDArray IdArray; typedef dgl::runtime::NDArray IdArray;
namespace sched { namespace sched {
......
...@@ -9,16 +9,16 @@ from numbers import Number, Integral ...@@ -9,16 +9,16 @@ from numbers import Number, Integral
from ..base import _LIB, check_call from ..base import _LIB, check_call
from ..base import c_str, string_types from ..base import c_str, string_types
from ..runtime_ctypes import TVMType, TVMByteArray, TVMContext from ..runtime_ctypes import DGLType, DGLByteArray, DGLContext
from . import ndarray as _nd from . import ndarray as _nd
from .ndarray import NDArrayBase, _make_array from .ndarray import NDArrayBase, _make_array
from .types import TVMValue, TypeCode from .types import DGLValue, TypeCode
from .types import TVMPackedCFunc, TVMCFuncFinalizer from .types import DGLPackedCFunc, DGLCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
FunctionHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p
ModuleHandle = ctypes.c_void_p ModuleHandle = ctypes.c_void_p
TVMRetValueHandle = ctypes.c_void_p DGLRetValueHandle = ctypes.c_void_p
def _ctypes_free_resource(rhandle): def _ctypes_free_resource(rhandle):
"""callback to free resources when it it not needed.""" """callback to free resources when it it not needed."""
...@@ -26,11 +26,11 @@ def _ctypes_free_resource(rhandle): ...@@ -26,11 +26,11 @@ def _ctypes_free_resource(rhandle):
ctypes.pythonapi.Py_DecRef(pyobj) ctypes.pythonapi.Py_DecRef(pyobj)
# Global callback that is always alive # Global callback that is always alive
TVM_FREE_PYOBJ = TVMCFuncFinalizer(_ctypes_free_resource) DGL_FREE_PYOBJ = DGLCFuncFinalizer(_ctypes_free_resource)
ctypes.pythonapi.Py_IncRef(ctypes.py_object(TVM_FREE_PYOBJ)) ctypes.pythonapi.Py_IncRef(ctypes.py_object(DGL_FREE_PYOBJ))
def convert_to_tvm_func(pyfunc): def convert_to_dgl_func(pyfunc):
"""Convert a python function to TVM function """Convert a python function to DGL function
Parameters Parameters
---------- ----------
...@@ -39,8 +39,8 @@ def convert_to_tvm_func(pyfunc): ...@@ -39,8 +39,8 @@ def convert_to_tvm_func(pyfunc):
Returns Returns
------- -------
tvmfunc: tvm.nd.Function dglfunc: dgl.nd.Function
The converted tvm function. The converted dgl function.
""" """
local_pyfunc = pyfunc local_pyfunc = pyfunc
def cfun(args, type_codes, num_args, ret, _): def cfun(args, type_codes, num_args, ret, _):
...@@ -52,36 +52,36 @@ def convert_to_tvm_func(pyfunc): ...@@ -52,36 +52,36 @@ def convert_to_tvm_func(pyfunc):
rv = local_pyfunc(*pyargs) rv = local_pyfunc(*pyargs)
except Exception: except Exception:
msg = traceback.format_exc() msg = traceback.format_exc()
_LIB.TVMAPISetLastError(c_str(msg)) _LIB.DGLAPISetLastError(c_str(msg))
return -1 return -1
if rv is not None: if rv is not None:
if isinstance(rv, tuple): if isinstance(rv, tuple):
raise ValueError("PackedFunction can only support one return value") raise ValueError("PackedFunction can only support one return value")
temp_args = [] temp_args = []
values, tcodes, _ = _make_tvm_args((rv,), temp_args) values, tcodes, _ = _make_dgl_args((rv,), temp_args)
if not isinstance(ret, TVMRetValueHandle): if not isinstance(ret, DGLRetValueHandle):
ret = TVMRetValueHandle(ret) ret = DGLRetValueHandle(ret)
check_call(_LIB.TVMCFuncSetReturn(ret, values, tcodes, ctypes.c_int(1))) check_call(_LIB.DGLCFuncSetReturn(ret, values, tcodes, ctypes.c_int(1)))
_ = temp_args _ = temp_args
_ = rv _ = rv
return 0 return 0
handle = FunctionHandle() handle = FunctionHandle()
f = TVMPackedCFunc(cfun) f = DGLPackedCFunc(cfun)
# NOTE: We will need to use python-api to increase ref count of the f # NOTE: We will need to use python-api to increase ref count of the f
# TVM_FREE_PYOBJ will be called after it is no longer needed. # DGL_FREE_PYOBJ will be called after it is no longer needed.
pyobj = ctypes.py_object(f) pyobj = ctypes.py_object(f)
ctypes.pythonapi.Py_IncRef(pyobj) ctypes.pythonapi.Py_IncRef(pyobj)
check_call(_LIB.TVMFuncCreateFromCFunc( check_call(_LIB.DGLFuncCreateFromCFunc(
f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle))) f, pyobj, DGL_FREE_PYOBJ, ctypes.byref(handle)))
return _CLASS_FUNCTION(handle, False) return _CLASS_FUNCTION(handle, False)
def _make_tvm_args(args, temp_args): def _make_dgl_args(args, temp_args):
"""Pack arguments into c args tvm call accept""" """Pack arguments into c args dgl call accept"""
num_args = len(args) num_args = len(args)
values = (TVMValue * num_args)() values = (DGLValue * num_args)()
type_codes = (ctypes.c_int * num_args)() type_codes = (ctypes.c_int * num_args)()
for i, arg in enumerate(args): for i, arg in enumerate(args):
if arg is None: if arg is None:
...@@ -91,23 +91,23 @@ def _make_tvm_args(args, temp_args): ...@@ -91,23 +91,23 @@ def _make_tvm_args(args, temp_args):
values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p) values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
type_codes[i] = (TypeCode.NDARRAY_CONTAINER type_codes[i] = (TypeCode.NDARRAY_CONTAINER
if not arg.is_view else TypeCode.ARRAY_HANDLE) if not arg.is_view else TypeCode.ARRAY_HANDLE)
elif isinstance(arg, _nd._TVM_COMPATS): elif isinstance(arg, _nd._DGL_COMPATS):
values[i].v_handle = ctypes.c_void_p(arg._tvm_handle) values[i].v_handle = ctypes.c_void_p(arg._dgl_handle)
type_codes[i] = arg.__class__._tvm_tcode type_codes[i] = arg.__class__._dgl_tcode
elif isinstance(arg, Integral): elif isinstance(arg, Integral):
values[i].v_int64 = arg values[i].v_int64 = arg
type_codes[i] = TypeCode.INT type_codes[i] = TypeCode.INT
elif isinstance(arg, Number): elif isinstance(arg, Number):
values[i].v_float64 = arg values[i].v_float64 = arg
type_codes[i] = TypeCode.FLOAT type_codes[i] = TypeCode.FLOAT
elif isinstance(arg, TVMType): elif isinstance(arg, DGLType):
values[i].v_str = c_str(str(arg)) values[i].v_str = c_str(str(arg))
type_codes[i] = TypeCode.STR type_codes[i] = TypeCode.STR
elif isinstance(arg, TVMContext): elif isinstance(arg, DGLContext):
values[i].v_ctx = arg values[i].v_ctx = arg
type_codes[i] = TypeCode.TVM_CONTEXT type_codes[i] = TypeCode.DGL_CONTEXT
elif isinstance(arg, bytearray): elif isinstance(arg, bytearray):
arr = TVMByteArray() arr = DGLByteArray()
arr.data = ctypes.cast( arr.data = ctypes.cast(
(ctypes.c_byte * len(arg)).from_buffer(arg), (ctypes.c_byte * len(arg)).from_buffer(arg),
ctypes.POINTER(ctypes.c_byte)) ctypes.POINTER(ctypes.c_byte))
...@@ -129,7 +129,7 @@ def _make_tvm_args(args, temp_args): ...@@ -129,7 +129,7 @@ def _make_tvm_args(args, temp_args):
values[i].v_handle = arg values[i].v_handle = arg
type_codes[i] = TypeCode.HANDLE type_codes[i] = TypeCode.HANDLE
elif callable(arg): elif callable(arg):
arg = convert_to_tvm_func(arg) arg = convert_to_dgl_func(arg)
values[i].v_handle = arg.handle values[i].v_handle = arg.handle
type_codes[i] = TypeCode.FUNC_HANDLE type_codes[i] = TypeCode.FUNC_HANDLE
temp_args.append(arg) temp_args.append(arg)
...@@ -158,7 +158,7 @@ class FunctionBase(object): ...@@ -158,7 +158,7 @@ class FunctionBase(object):
def __del__(self): def __del__(self):
if not self.is_global and _LIB is not None: if not self.is_global and _LIB is not None:
check_call(_LIB.TVMFuncFree(self.handle)) check_call(_LIB.DGLFuncFree(self.handle))
def __call__(self, *args): def __call__(self, *args):
"""Call the function with positional arguments """Call the function with positional arguments
...@@ -167,10 +167,10 @@ class FunctionBase(object): ...@@ -167,10 +167,10 @@ class FunctionBase(object):
The positional arguments to the function call. The positional arguments to the function call.
""" """
temp_args = [] temp_args = []
values, tcodes, num_args = _make_tvm_args(args, temp_args) values, tcodes, num_args = _make_dgl_args(args, temp_args)
ret_val = TVMValue() ret_val = DGLValue()
ret_tcode = ctypes.c_int() ret_tcode = ctypes.c_int()
check_call(_LIB.TVMFuncCall( check_call(_LIB.DGLFuncCall(
self.handle, values, tcodes, ctypes.c_int(num_args), self.handle, values, tcodes, ctypes.c_int(num_args),
ctypes.byref(ret_val), ctypes.byref(ret_tcode))) ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
_ = temp_args _ = temp_args
...@@ -181,10 +181,10 @@ class FunctionBase(object): ...@@ -181,10 +181,10 @@ class FunctionBase(object):
def __init_handle_by_constructor__(fconstructor, args): def __init_handle_by_constructor__(fconstructor, args):
"""Initialize handle by constructor""" """Initialize handle by constructor"""
temp_args = [] temp_args = []
values, tcodes, num_args = _make_tvm_args(args, temp_args) values, tcodes, num_args = _make_dgl_args(args, temp_args)
ret_val = TVMValue() ret_val = DGLValue()
ret_tcode = ctypes.c_int() ret_tcode = ctypes.c_int()
check_call(_LIB.TVMFuncCall( check_call(_LIB.DGLFuncCall(
fconstructor.handle, values, tcodes, ctypes.c_int(num_args), fconstructor.handle, values, tcodes, ctypes.c_int(num_args),
ctypes.byref(ret_val), ctypes.byref(ret_tcode))) ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
_ = temp_args _ = temp_args
......
...@@ -4,11 +4,11 @@ from __future__ import absolute_import ...@@ -4,11 +4,11 @@ from __future__ import absolute_import
import ctypes import ctypes
from ..base import _LIB, check_call, c_str from ..base import _LIB, check_call, c_str
from ..runtime_ctypes import TVMArrayHandle from ..runtime_ctypes import DGLArrayHandle
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _return_handle from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _return_handle
TVMPyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p) DGLPyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
_c_str_dltensor = c_str('dltensor') _c_str_dltensor = c_str('dltensor')
_c_str_used_dltensor = c_str('used_dltensor') _c_str_used_dltensor = c_str('used_dltensor')
...@@ -28,10 +28,10 @@ def _from_dlpack(dltensor): ...@@ -28,10 +28,10 @@ def _from_dlpack(dltensor):
# set restype of PyCapsule calls. But weirdly, this does not # set restype of PyCapsule calls. But weirdly, this does not
# work out always. # work out always.
ptr = ctypes.cast(ptr, ctypes.c_void_p) ptr = ctypes.cast(ptr, ctypes.c_void_p)
handle = TVMArrayHandle() handle = DGLArrayHandle()
check_call(_LIB.TVMArrayFromDLPack(ptr, ctypes.byref(handle))) check_call(_LIB.DGLArrayFromDLPack(ptr, ctypes.byref(handle)))
ctypes.pythonapi.PyCapsule_SetName(dltensor, _c_str_used_dltensor) ctypes.pythonapi.PyCapsule_SetName(dltensor, _c_str_used_dltensor)
ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, TVMPyCapsuleDestructor(0)) ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, DGLPyCapsuleDestructor(0))
return _make_array(handle, False) return _make_array(handle, False)
raise ValueError("Expect a dltensor field, PyCapsule can only be consumed once") raise ValueError("Expect a dltensor field, PyCapsule can only be consumed once")
...@@ -44,10 +44,10 @@ def _dlpack_deleter(pycapsule): ...@@ -44,10 +44,10 @@ def _dlpack_deleter(pycapsule):
# set restype of PyCapsule calls. But weirdly, this does not # set restype of PyCapsule calls. But weirdly, this does not
# work out always. # work out always.
ptr = ctypes.cast(ptr, ctypes.c_void_p) ptr = ctypes.cast(ptr, ctypes.c_void_p)
_LIB.TVMDLManagedTensorCallDeleter(ptr) _LIB.DGLDLManagedTensorCallDeleter(ptr)
ctypes.pythonapi.PyCapsule_SetDestructor(pycapsule, TVMPyCapsuleDestructor(0)) ctypes.pythonapi.PyCapsule_SetDestructor(pycapsule, DGLPyCapsuleDestructor(0))
_c_dlpack_deleter = TVMPyCapsuleDestructor(_dlpack_deleter) _c_dlpack_deleter = DGLPyCapsuleDestructor(_dlpack_deleter)
class NDArrayBase(object): class NDArrayBase(object):
...@@ -59,18 +59,18 @@ class NDArrayBase(object): ...@@ -59,18 +59,18 @@ class NDArrayBase(object):
Parameters Parameters
---------- ----------
handle : TVMArrayHandle handle : DGLArrayHandle
the handle to the underlying C++ TVMArray the handle to the underlying C++ DGLArray
""" """
self.handle = handle self.handle = handle
self.is_view = is_view self.is_view = is_view
def __del__(self): def __del__(self):
if not self.is_view and _LIB: if not self.is_view and _LIB:
check_call(_LIB.TVMArrayFree(self.handle)) check_call(_LIB.DGLArrayFree(self.handle))
@property @property
def _tvm_handle(self): def _dgl_handle(self):
return ctypes.cast(self.handle, ctypes.c_void_p).value return ctypes.cast(self.handle, ctypes.c_void_p).value
def to_dlpack(self): def to_dlpack(self):
...@@ -81,23 +81,23 @@ class NDArrayBase(object): ...@@ -81,23 +81,23 @@ class NDArrayBase(object):
dlpack : DLPack tensor view of the array data dlpack : DLPack tensor view of the array data
""" """
ptr = ctypes.c_void_p() ptr = ctypes.c_void_p()
check_call(_LIB.TVMArrayToDLPack(self.handle, ctypes.byref(ptr))) check_call(_LIB.DGLArrayToDLPack(self.handle, ctypes.byref(ptr)))
return ctypes.pythonapi.PyCapsule_New(ptr, _c_str_dltensor, _c_dlpack_deleter) return ctypes.pythonapi.PyCapsule_New(ptr, _c_str_dltensor, _c_dlpack_deleter)
def _make_array(handle, is_view): def _make_array(handle, is_view):
handle = ctypes.cast(handle, TVMArrayHandle) handle = ctypes.cast(handle, DGLArrayHandle)
return _CLASS_NDARRAY(handle, is_view) return _CLASS_NDARRAY(handle, is_view)
_TVM_COMPATS = () _DGL_COMPATS = ()
def _reg_extension(cls, fcreate): def _reg_extension(cls, fcreate):
global _TVM_COMPATS global _DGL_COMPATS
_TVM_COMPATS += (cls,) _DGL_COMPATS += (cls,)
if fcreate: if fcreate:
fret = lambda x: fcreate(_return_handle(x)) fret = lambda x: fcreate(_return_handle(x))
RETURN_SWITCH[cls._tvm_tcode] = fret RETURN_SWITCH[cls._dgl_tcode] = fret
C_TO_PY_ARG_SWITCH[cls._tvm_tcode] = _wrap_arg_func(fret, cls._tvm_tcode) C_TO_PY_ARG_SWITCH[cls._dgl_tcode] = _wrap_arg_func(fret, cls._dgl_tcode)
_CLASS_NDARRAY = None _CLASS_NDARRAY = None
......
...@@ -4,26 +4,26 @@ from __future__ import absolute_import as _abs ...@@ -4,26 +4,26 @@ from __future__ import absolute_import as _abs
import ctypes import ctypes
from ..base import py_str, check_call, _LIB from ..base import py_str, check_call, _LIB
from ..runtime_ctypes import TVMByteArray, TypeCode from ..runtime_ctypes import DGLByteArray, TypeCode
class TVMValue(ctypes.Union): class DGLValue(ctypes.Union):
"""TVMValue in C API""" """DGLValue in C API"""
_fields_ = [("v_int64", ctypes.c_int64), _fields_ = [("v_int64", ctypes.c_int64),
("v_float64", ctypes.c_double), ("v_float64", ctypes.c_double),
("v_handle", ctypes.c_void_p), ("v_handle", ctypes.c_void_p),
("v_str", ctypes.c_char_p)] ("v_str", ctypes.c_char_p)]
TVMPackedCFunc = ctypes.CFUNCTYPE( DGLPackedCFunc = ctypes.CFUNCTYPE(
ctypes.c_int, ctypes.c_int,
ctypes.POINTER(TVMValue), ctypes.POINTER(DGLValue),
ctypes.POINTER(ctypes.c_int), ctypes.POINTER(ctypes.c_int),
ctypes.c_int, ctypes.c_int,
ctypes.c_void_p, ctypes.c_void_p,
ctypes.c_void_p) ctypes.c_void_p)
TVMCFuncFinalizer = ctypes.CFUNCTYPE( DGLCFuncFinalizer = ctypes.CFUNCTYPE(
None, None,
ctypes.c_void_p) ctypes.c_void_p)
...@@ -40,7 +40,7 @@ def _return_bytes(x): ...@@ -40,7 +40,7 @@ def _return_bytes(x):
handle = x.v_handle handle = x.v_handle
if not isinstance(handle, ctypes.c_void_p): if not isinstance(handle, ctypes.c_void_p):
handle = ctypes.c_void_p(handle) handle = ctypes.c_void_p(handle)
arr = ctypes.cast(handle, ctypes.POINTER(TVMByteArray))[0] arr = ctypes.cast(handle, ctypes.POINTER(DGLByteArray))[0]
size = arr.size size = arr.size
res = bytearray(size) res = bytearray(size)
rptr = (ctypes.c_byte * size).from_buffer(res) rptr = (ctypes.c_byte * size).from_buffer(res)
...@@ -51,7 +51,7 @@ def _return_bytes(x): ...@@ -51,7 +51,7 @@ def _return_bytes(x):
def _wrap_arg_func(return_f, type_code): def _wrap_arg_func(return_f, type_code):
tcode = ctypes.c_int(type_code) tcode = ctypes.c_int(type_code)
def _wrap_func(x): def _wrap_func(x):
check_call(_LIB.TVMCbArgToReturn(ctypes.byref(x), tcode)) check_call(_LIB.DGLCbArgToReturn(ctypes.byref(x), tcode))
return return_f(x) return return_f(x)
return _wrap_func return _wrap_func
......
...@@ -5,14 +5,14 @@ from cpython cimport pycapsule ...@@ -5,14 +5,14 @@ from cpython cimport pycapsule
from libc.stdint cimport int64_t, uint64_t, uint8_t, uint16_t from libc.stdint cimport int64_t, uint64_t, uint8_t, uint16_t
import ctypes import ctypes
cdef enum TVMTypeCode: cdef enum DGLTypeCode:
kInt = 0 kInt = 0
kUInt = 1 kUInt = 1
kFloat = 2 kFloat = 2
kHandle = 3 kHandle = 3
kNull = 4 kNull = 4
kTVMType = 5 kDGLType = 5
kTVMContext = 6 kDGLContext = 6
kArrayHandle = 7 kArrayHandle = 7
kNodeHandle = 8 kNodeHandle = 8
kModuleHandle = 9 kModuleHandle = 9
...@@ -22,7 +22,7 @@ cdef enum TVMTypeCode: ...@@ -22,7 +22,7 @@ cdef enum TVMTypeCode:
kNDArrayContainer = 13 kNDArrayContainer = 13
kExtBegin = 15 kExtBegin = 15
cdef extern from "tvm/runtime/c_runtime_api.h": cdef extern from "dgl/runtime/c_runtime_api.h":
ctypedef struct DLDataType: ctypedef struct DLDataType:
uint8_t code uint8_t code
uint8_t bits uint8_t bits
...@@ -46,7 +46,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h": ...@@ -46,7 +46,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
void* manager_ctx void* manager_ctx
void (*deleter)(DLManagedTensor* self) void (*deleter)(DLManagedTensor* self)
ctypedef struct TVMValue: ctypedef struct DGLValue:
int64_t v_int64 int64_t v_int64
double v_float64 double v_float64
void* v_handle void* v_handle
...@@ -54,65 +54,65 @@ cdef extern from "tvm/runtime/c_runtime_api.h": ...@@ -54,65 +54,65 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
DLDataType v_type DLDataType v_type
DLContext v_ctx DLContext v_ctx
ctypedef int64_t tvm_index_t ctypedef int64_t dgl_index_t
ctypedef DLTensor* DLTensorHandle ctypedef DLTensor* DLTensorHandle
ctypedef void* TVMStreamHandle ctypedef void* DGLStreamHandle
ctypedef void* TVMRetValueHandle ctypedef void* DGLRetValueHandle
ctypedef void* TVMFunctionHandle ctypedef void* DGLFunctionHandle
ctypedef void* NodeHandle ctypedef void* NodeHandle
ctypedef int (*TVMPackedCFunc)( ctypedef int (*DGLPackedCFunc)(
TVMValue* args, DGLValue* args,
int* type_codes, int* type_codes,
int num_args, int num_args,
TVMRetValueHandle ret, DGLRetValueHandle ret,
void* resource_handle) void* resource_handle)
ctypedef void (*TVMPackedCFuncFinalizer)(void* resource_handle) ctypedef void (*DGLPackedCFuncFinalizer)(void* resource_handle)
cdef extern from "tvm/runtime/c_runtime_api.h": cdef extern from "dgl/runtime/c_runtime_api.h":
void TVMAPISetLastError(const char* msg) void DGLAPISetLastError(const char* msg)
const char *TVMGetLastError() const char *DGLGetLastError()
int TVMFuncCall(TVMFunctionHandle func, int DGLFuncCall(DGLFunctionHandle func,
TVMValue* arg_values, DGLValue* arg_values,
int* type_codes, int* type_codes,
int num_args, int num_args,
TVMValue* ret_val, DGLValue* ret_val,
int* ret_type_code) int* ret_type_code)
int TVMFuncFree(TVMFunctionHandle func) int DGLFuncFree(DGLFunctionHandle func)
int TVMCFuncSetReturn(TVMRetValueHandle ret, int DGLCFuncSetReturn(DGLRetValueHandle ret,
TVMValue* value, DGLValue* value,
int* type_code, int* type_code,
int num_ret) int num_ret)
int TVMFuncCreateFromCFunc(TVMPackedCFunc func, int DGLFuncCreateFromCFunc(DGLPackedCFunc func,
void* resource_handle, void* resource_handle,
TVMPackedCFuncFinalizer fin, DGLPackedCFuncFinalizer fin,
TVMFunctionHandle *out) DGLFunctionHandle *out)
int TVMCbArgToReturn(TVMValue* value, int code) int DGLCbArgToReturn(DGLValue* value, int code)
int TVMArrayAlloc(tvm_index_t* shape, int DGLArrayAlloc(dgl_index_t* shape,
tvm_index_t ndim, dgl_index_t ndim,
DLDataType dtype, DLDataType dtype,
DLContext ctx, DLContext ctx,
DLTensorHandle* out) DLTensorHandle* out)
int TVMArrayFree(DLTensorHandle handle) int DGLArrayFree(DLTensorHandle handle)
int TVMArrayCopyFromTo(DLTensorHandle src, int DGLArrayCopyFromTo(DLTensorHandle src,
DLTensorHandle to, DLTensorHandle to,
TVMStreamHandle stream) DGLStreamHandle stream)
int TVMArrayFromDLPack(DLManagedTensor* arr_from, int DGLArrayFromDLPack(DLManagedTensor* arr_from,
DLTensorHandle* out) DLTensorHandle* out)
int TVMArrayToDLPack(DLTensorHandle arr_from, int DGLArrayToDLPack(DLTensorHandle arr_from,
DLManagedTensor** out) DLManagedTensor** out)
void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor) void DGLDLManagedTensorCallDeleter(DLManagedTensor* dltensor)
cdef extern from "tvm/c_dsl_api.h": cdef extern from "dgl/c_dsl_api.h":
int TVMNodeFree(NodeHandle handle) int DGLNodeFree(NodeHandle handle)
int TVMNodeTypeKey2Index(const char* type_key, int DGLNodeTypeKey2Index(const char* type_key,
int* out_index) int* out_index)
int TVMNodeGetTypeIndex(NodeHandle handle, int DGLNodeGetTypeIndex(NodeHandle handle,
int* out_index) int* out_index)
int TVMNodeGetAttr(NodeHandle handle, int DGLNodeGetAttr(NodeHandle handle,
const char* key, const char* key,
TVMValue* out_value, DGLValue* out_value,
int* out_type_code, int* out_type_code,
int* out_success) int* out_success)
...@@ -140,7 +140,7 @@ cdef inline c_str(pystr): ...@@ -140,7 +140,7 @@ cdef inline c_str(pystr):
cdef inline CALL(int ret): cdef inline CALL(int ret):
if ret != 0: if ret != 0:
raise DGLError(py_str(TVMGetLastError())) raise DGLError(py_str(DGLGetLastError()))
cdef inline object ctypes_handle(void* chandle): cdef inline object ctypes_handle(void* chandle):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment