"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "a28f1f9af40242d5774cea4a082da1ae80b12fe1"
Unverified Commit cded5b80 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Feature] Bump DLPack to v0.7 and decouple DLPack from the core library (#4454)

* rename `DLContext` to `DGLContext`

* rename `kDLGPU` to `kDLCUDA`

* replace DLTensor with DGLArray

* fix linting

* Unify DGLType and DLDataType to DGLDataType

* Fix FFI

* rename DLDeviceType to DGLDeviceType

* decouple dlpack from the core library

* fix bug

* fix lint

* fix merge

* fix build

* address comments

* rename dl_converter to dlpack_convert

* remove redundant comments
parent f1689ad0
...@@ -24,8 +24,8 @@ namespace aten { ...@@ -24,8 +24,8 @@ namespace aten {
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
/*! \return A special array to represent null. */ /*! \return A special array to represent null. */
inline NDArray NullArray(const DLDataType& dtype = DLDataType{kDLInt, 64, 1}, inline NDArray NullArray(const DGLDataType& dtype = DGLDataType{kDGLInt, 64, 1},
const DLContext& ctx = DLContext{kDLCPU, 0}) { const DGLContext& ctx = DGLContext{kDGLCPU, 0}) {
return NDArray::Empty({0}, dtype, ctx); return NDArray::Empty({0}, dtype, ctx);
} }
...@@ -44,7 +44,7 @@ inline bool IsNullArray(NDArray array) { ...@@ -44,7 +44,7 @@ inline bool IsNullArray(NDArray array) {
* \return id array * \return id array
*/ */
IdArray NewIdArray(int64_t length, IdArray NewIdArray(int64_t length,
DLContext ctx = DLContext{kDLCPU, 0}, DGLContext ctx = DGLContext{kDGLCPU, 0},
uint8_t nbits = 64); uint8_t nbits = 64);
/*! /*!
...@@ -57,7 +57,7 @@ IdArray NewIdArray(int64_t length, ...@@ -57,7 +57,7 @@ IdArray NewIdArray(int64_t length,
template <typename T> template <typename T>
IdArray VecToIdArray(const std::vector<T>& vec, IdArray VecToIdArray(const std::vector<T>& vec,
uint8_t nbits = 64, uint8_t nbits = 64,
DLContext ctx = DLContext{kDLCPU, 0}); DGLContext ctx = DGLContext{kDGLCPU, 0});
/*! /*!
* \brief Return an array representing a 1D range. * \brief Return an array representing a 1D range.
...@@ -67,7 +67,7 @@ IdArray VecToIdArray(const std::vector<T>& vec, ...@@ -67,7 +67,7 @@ IdArray VecToIdArray(const std::vector<T>& vec,
* \param ctx Device context * \param ctx Device context
* \return range array * \return range array
*/ */
IdArray Range(int64_t low, int64_t high, uint8_t nbits, DLContext ctx); IdArray Range(int64_t low, int64_t high, uint8_t nbits, DGLContext ctx);
/*! /*!
* \brief Return an array full of the given value * \brief Return an array full of the given value
...@@ -77,7 +77,7 @@ IdArray Range(int64_t low, int64_t high, uint8_t nbits, DLContext ctx); ...@@ -77,7 +77,7 @@ IdArray Range(int64_t low, int64_t high, uint8_t nbits, DLContext ctx);
* \param ctx Device context * \param ctx Device context
* \return the result array * \return the result array
*/ */
IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx); IdArray Full(int64_t val, int64_t length, uint8_t nbits, DGLContext ctx);
/*! /*!
* \brief Return an array full of the given value with the given type. * \brief Return an array full of the given value with the given type.
...@@ -87,7 +87,7 @@ IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx); ...@@ -87,7 +87,7 @@ IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx);
* \return the result array * \return the result array
*/ */
template <typename DType> template <typename DType>
NDArray Full(DType val, int64_t length, DLContext ctx); NDArray Full(DType val, int64_t length, DGLContext ctx);
/*! \brief Create a deep copy of the given array */ /*! \brief Create a deep copy of the given array */
IdArray Clone(IdArray arr); IdArray Clone(IdArray arr);
...@@ -226,7 +226,7 @@ NDArray Concat(const std::vector<IdArray>& arrays); ...@@ -226,7 +226,7 @@ NDArray Concat(const std::vector<IdArray>& arrays);
/*!\brief Return whether the array is a valid 1D int array*/ /*!\brief Return whether the array is a valid 1D int array*/
inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) { inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) {
return arr->ndim == 1 && arr->dtype.code == kDLInt; return arr->ndim == 1 && arr->dtype.code == kDGLInt;
} }
/*! /*!
...@@ -343,8 +343,8 @@ std::string ToDebugString(NDArray array); ...@@ -343,8 +343,8 @@ std::string ToDebugString(NDArray array);
template <typename T> template <typename T>
IdArray VecToIdArray(const std::vector<T>& vec, IdArray VecToIdArray(const std::vector<T>& vec,
uint8_t nbits, uint8_t nbits,
DLContext ctx) { DGLContext ctx) {
IdArray ret = NewIdArray(vec.size(), DLContext{kDLCPU, 0}, nbits); IdArray ret = NewIdArray(vec.size(), DGLContext{kDGLCPU, 0}, nbits);
if (nbits == 32) { if (nbits == 32) {
std::copy(vec.begin(), vec.end(), static_cast<int32_t*>(ret->data)); std::copy(vec.begin(), vec.end(), static_cast<int32_t*>(ret->data));
} else if (nbits == 64) { } else if (nbits == 64) {
...@@ -359,9 +359,9 @@ IdArray VecToIdArray(const std::vector<T>& vec, ...@@ -359,9 +359,9 @@ IdArray VecToIdArray(const std::vector<T>& vec,
* \brief Get the context of the first array, and check if the non-null arrays' * \brief Get the context of the first array, and check if the non-null arrays'
* contexts are the same. * contexts are the same.
*/ */
inline DLContext GetContextOf(const std::vector<IdArray>& arrays) { inline DGLContext GetContextOf(const std::vector<IdArray>& arrays) {
bool first = true; bool first = true;
DLContext result; DGLContext result;
for (auto& array : arrays) { for (auto& array : arrays) {
if (first) { if (first) {
first = false; first = false;
......
...@@ -122,11 +122,10 @@ struct COOMatrix { ...@@ -122,11 +122,10 @@ struct COOMatrix {
} }
/*! \brief Return a copy of this matrix on the give device context. */ /*! \brief Return a copy of this matrix on the give device context. */
inline COOMatrix CopyTo(const DLContext &ctx) const { inline COOMatrix CopyTo(const DGLContext &ctx) const {
if (ctx == row->ctx) if (ctx == row->ctx)
return *this; return *this;
return COOMatrix(num_rows, num_cols, row.CopyTo(ctx), return COOMatrix(num_rows, num_cols, row.CopyTo(ctx), col.CopyTo(ctx),
col.CopyTo(ctx),
aten::IsNullArray(data) ? data : data.CopyTo(ctx), aten::IsNullArray(data) ? data : data.CopyTo(ctx),
row_sorted, col_sorted); row_sorted, col_sorted);
} }
...@@ -134,9 +133,9 @@ struct COOMatrix { ...@@ -134,9 +133,9 @@ struct COOMatrix {
/*! /*!
* \brief Pin the row, col and data (if not Null) of the matrix. * \brief Pin the row, col and data (if not Null) of the matrix.
* \note This is an in-place method. Behavior depends on the current context, * \note This is an in-place method. Behavior depends on the current context,
* kDLCPU: will be pinned; * kDGLCPU: will be pinned;
* IsPinned: directly return; * IsPinned: directly return;
* kDLGPU: invalid, will throw an error. * kDGLCUDA: invalid, will throw an error.
* The context check is deferred to pinning the NDArray. * The context check is deferred to pinning the NDArray.
*/ */
inline void PinMemory_() { inline void PinMemory_() {
......
...@@ -115,21 +115,19 @@ struct CSRMatrix { ...@@ -115,21 +115,19 @@ struct CSRMatrix {
} }
/*! \brief Return a copy of this matrix on the give device context. */ /*! \brief Return a copy of this matrix on the give device context. */
inline CSRMatrix CopyTo(const DLContext &ctx) const { inline CSRMatrix CopyTo(const DGLContext &ctx) const {
if (ctx == indptr->ctx) if (ctx == indptr->ctx)
return *this; return *this;
return CSRMatrix(num_rows, num_cols, indptr.CopyTo(ctx), return CSRMatrix(num_rows, num_cols, indptr.CopyTo(ctx), indices.CopyTo(ctx),
indices.CopyTo(ctx), aten::IsNullArray(data) ? data : data.CopyTo(ctx), sorted);
aten::IsNullArray(data) ? data : data.CopyTo(ctx),
sorted);
} }
/*! /*!
* \brief Pin the indptr, indices and data (if not Null) of the matrix. * \brief Pin the indptr, indices and data (if not Null) of the matrix.
* \note This is an in-place method. Behavior depends on the current context, * \note This is an in-place method. Behavior depends on the current context,
* kDLCPU: will be pinned; * kDGLCPU: will be pinned;
* IsPinned: directly return; * IsPinned: directly return;
* kDLGPU: invalid, will throw an error. * kDGLCUDA: invalid, will throw an error.
* The context check is deferred to pinning the NDArray. * The context check is deferred to pinning the NDArray.
*/ */
inline void PinMemory_() { inline void PinMemory_() {
......
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
* }); * });
*/ */
#define ATEN_XPU_SWITCH(val, XPU, op, ...) do { \ #define ATEN_XPU_SWITCH(val, XPU, op, ...) do { \
if ((val) == kDLCPU) { \ if ((val) == kDGLCPU) { \
constexpr auto XPU = kDLCPU; \ constexpr auto XPU = kDGLCPU; \
{__VA_ARGS__} \ {__VA_ARGS__} \
} else { \ } else { \
LOG(FATAL) << "Operator " << (op) << " does not support " \ LOG(FATAL) << "Operator " << (op) << " does not support " \
...@@ -43,11 +43,11 @@ ...@@ -43,11 +43,11 @@
*/ */
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
#define ATEN_XPU_SWITCH_CUDA(val, XPU, op, ...) do { \ #define ATEN_XPU_SWITCH_CUDA(val, XPU, op, ...) do { \
if ((val) == kDLCPU) { \ if ((val) == kDGLCPU) { \
constexpr auto XPU = kDLCPU; \ constexpr auto XPU = kDGLCPU; \
{__VA_ARGS__} \ {__VA_ARGS__} \
} else if ((val) == kDLGPU) { \ } else if ((val) == kDGLCUDA) { \
constexpr auto XPU = kDLGPU; \ constexpr auto XPU = kDGLCUDA; \
{__VA_ARGS__} \ {__VA_ARGS__} \
} else { \ } else { \
LOG(FATAL) << "Operator " << (op) << " does not support " \ LOG(FATAL) << "Operator " << (op) << " does not support " \
...@@ -69,7 +69,7 @@ ...@@ -69,7 +69,7 @@
* }); * });
*/ */
#define ATEN_ID_TYPE_SWITCH(val, IdType, ...) do { \ #define ATEN_ID_TYPE_SWITCH(val, IdType, ...) do { \
CHECK_EQ((val).code, kDLInt) << "ID must be integer type"; \ CHECK_EQ((val).code, kDGLInt) << "ID must be integer type"; \
if ((val).bits == 32) { \ if ((val).bits == 32) { \
typedef int32_t IdType; \ typedef int32_t IdType; \
{__VA_ARGS__} \ {__VA_ARGS__} \
...@@ -114,7 +114,7 @@ ...@@ -114,7 +114,7 @@
* }); * });
*/ */
#define ATEN_FLOAT_TYPE_SWITCH(val, FloatType, val_name, ...) do { \ #define ATEN_FLOAT_TYPE_SWITCH(val, FloatType, val_name, ...) do { \
CHECK_EQ((val).code, kDLFloat) \ CHECK_EQ((val).code, kDGLFloat) \
<< (val_name) << " must be float type"; \ << (val_name) << " must be float type"; \
if ((val).bits == 32) { \ if ((val).bits == 32) { \
typedef float FloatType; \ typedef float FloatType; \
...@@ -128,7 +128,7 @@ ...@@ -128,7 +128,7 @@
} while (0) } while (0)
#define ATEN_FLOAT_BITS_SWITCH(val, bits, val_name, ...) do { \ #define ATEN_FLOAT_BITS_SWITCH(val, bits, val_name, ...) do { \
CHECK_EQ((val).code, kDLFloat) \ CHECK_EQ((val).code, kDGLFloat) \
<< (val_name) << " must be float type"; \ << (val_name) << " must be float type"; \
if ((val).bits == 16) { \ if ((val).bits == 16) { \
constexpr int bits = 16; \ constexpr int bits = 16; \
...@@ -154,16 +154,16 @@ ...@@ -154,16 +154,16 @@
* }); * });
*/ */
#define ATEN_DTYPE_SWITCH(val, DType, val_name, ...) do { \ #define ATEN_DTYPE_SWITCH(val, DType, val_name, ...) do { \
if ((val).code == kDLInt && (val).bits == 32) { \ if ((val).code == kDGLInt && (val).bits == 32) { \
typedef int32_t DType; \ typedef int32_t DType; \
{__VA_ARGS__} \ {__VA_ARGS__} \
} else if ((val).code == kDLInt && (val).bits == 64) { \ } else if ((val).code == kDGLInt && (val).bits == 64) { \
typedef int64_t DType; \ typedef int64_t DType; \
{__VA_ARGS__} \ {__VA_ARGS__} \
} else if ((val).code == kDLFloat && (val).bits == 32) { \ } else if ((val).code == kDGLFloat && (val).bits == 32) { \
typedef float DType; \ typedef float DType; \
{__VA_ARGS__} \ {__VA_ARGS__} \
} else if ((val).code == kDLFloat && (val).bits == 64) { \ } else if ((val).code == kDGLFloat && (val).bits == 64) { \
typedef double DType; \ typedef double DType; \
{__VA_ARGS__} \ {__VA_ARGS__} \
} else { \ } else { \
...@@ -205,10 +205,10 @@ ...@@ -205,10 +205,10 @@
* Identical to ATEN_ID_TYPE_SWITCH except for a different error message. * Identical to ATEN_ID_TYPE_SWITCH except for a different error message.
*/ */
#define ATEN_CSR_DTYPE_SWITCH(val, DType, ...) do { \ #define ATEN_CSR_DTYPE_SWITCH(val, DType, ...) do { \
if ((val).code == kDLInt && (val).bits == 32) { \ if ((val).code == kDGLInt && (val).bits == 32) { \
typedef int32_t DType; \ typedef int32_t DType; \
{__VA_ARGS__} \ {__VA_ARGS__} \
} else if ((val).code == kDLInt && (val).bits == 64) { \ } else if ((val).code == kDGLInt && (val).bits == 64) { \
typedef int64_t DType; \ typedef int64_t DType; \
{__VA_ARGS__} \ {__VA_ARGS__} \
} else { \ } else { \
...@@ -278,13 +278,13 @@ ...@@ -278,13 +278,13 @@
///////////////////////// Array checks ////////////////////////// ///////////////////////// Array checks //////////////////////////
#define IS_INT32(a) \ #define IS_INT32(a) \
((a)->dtype.code == kDLInt && (a)->dtype.bits == 32) ((a)->dtype.code == kDGLInt && (a)->dtype.bits == 32)
#define IS_INT64(a) \ #define IS_INT64(a) \
((a)->dtype.code == kDLInt && (a)->dtype.bits == 64) ((a)->dtype.code == kDGLInt && (a)->dtype.bits == 64)
#define IS_FLOAT32(a) \ #define IS_FLOAT32(a) \
((a)->dtype.code == kDLFloat && (a)->dtype.bits == 32) ((a)->dtype.code == kDGLFloat && (a)->dtype.bits == 32)
#define IS_FLOAT64(a) \ #define IS_FLOAT64(a) \
((a)->dtype.code == kDLFloat && (a)->dtype.bits == 64) ((a)->dtype.code == kDGLFloat && (a)->dtype.bits == 64)
#define CHECK_IF(cond, prop, value_name, dtype_name) \ #define CHECK_IF(cond, prop, value_name, dtype_name) \
CHECK(cond) << "Expecting " << (prop) << " of " << (value_name) << " to be " << (dtype_name) CHECK(cond) << "Expecting " << (prop) << " of " << (value_name) << " to be " << (dtype_name)
......
...@@ -34,7 +34,7 @@ typedef NDArray TypeArray; ...@@ -34,7 +34,7 @@ typedef NDArray TypeArray;
namespace aten { namespace aten {
static const DLContext CPU{kDLCPU, 0}; static const DGLContext CPU{kDGLCPU, 0};
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
......
...@@ -104,12 +104,12 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -104,12 +104,12 @@ class BaseHeteroGraph : public runtime::Object {
/*! /*!
* \brief Get the data type of node and edge IDs of this graph. * \brief Get the data type of node and edge IDs of this graph.
*/ */
virtual DLDataType DataType() const = 0; virtual DGLDataType DataType() const = 0;
/*! /*!
* \brief Get the device context of this graph. * \brief Get the device context of this graph.
*/ */
virtual DLContext Context() const = 0; virtual DGLContext Context() const = 0;
/*! /*!
* \brief Pin graph. * \brief Pin graph.
......
...@@ -89,8 +89,8 @@ class Graph: public GraphInterface { ...@@ -89,8 +89,8 @@ class Graph: public GraphInterface {
num_edges_ = 0; num_edges_ = 0;
} }
DLContext Context() const override { DGLContext Context() const override {
return DLContext{kDLCPU, 0}; return DGLContext{kDGLCPU, 0};
} }
uint8_t NumBits() const override { uint8_t NumBits() const override {
......
...@@ -137,7 +137,7 @@ class GraphInterface : public runtime::Object { ...@@ -137,7 +137,7 @@ class GraphInterface : public runtime::Object {
/*! /*!
* \brief Get the device context of this graph. * \brief Get the device context of this graph.
*/ */
virtual DLContext Context() const = 0; virtual DGLContext Context() const = 0;
/*! /*!
* \brief Get the number of integer bits used to store node/edge ids (32 or 64). * \brief Get the number of integer bits used to store node/edge ids (32 or 64).
......
...@@ -69,7 +69,7 @@ class CSR : public GraphInterface { ...@@ -69,7 +69,7 @@ class CSR : public GraphInterface {
LOG(FATAL) << "CSR graph does not allow mutation."; LOG(FATAL) << "CSR graph does not allow mutation.";
} }
DLContext Context() const override { DGLContext Context() const override {
return adj_.indptr->ctx; return adj_.indptr->ctx;
} }
...@@ -214,7 +214,7 @@ class CSR : public GraphInterface { ...@@ -214,7 +214,7 @@ class CSR : public GraphInterface {
* \param ctx The target context. * \param ctx The target context.
* \return The graph under another context. * \return The graph under another context.
*/ */
CSR CopyTo(const DLContext& ctx) const; CSR CopyTo(const DGLContext& ctx) const;
/*! /*!
* \brief Copy data to shared memory. * \brief Copy data to shared memory.
...@@ -288,7 +288,7 @@ class COO : public GraphInterface { ...@@ -288,7 +288,7 @@ class COO : public GraphInterface {
LOG(FATAL) << "COO graph does not allow mutation."; LOG(FATAL) << "COO graph does not allow mutation.";
} }
DLContext Context() const override { DGLContext Context() const override {
return adj_.row->ctx; return adj_.row->ctx;
} }
...@@ -472,7 +472,7 @@ class COO : public GraphInterface { ...@@ -472,7 +472,7 @@ class COO : public GraphInterface {
* \param ctx The target context. * \param ctx The target context.
* \return The graph under another context. * \return The graph under another context.
*/ */
COO CopyTo(const DLContext& ctx) const; COO CopyTo(const DGLContext& ctx) const;
/*! /*!
* \brief Copy data to shared memory. * \brief Copy data to shared memory.
...@@ -578,7 +578,7 @@ class ImmutableGraph: public GraphInterface { ...@@ -578,7 +578,7 @@ class ImmutableGraph: public GraphInterface {
LOG(FATAL) << "Clear isn't supported in ImmutableGraph"; LOG(FATAL) << "Clear isn't supported in ImmutableGraph";
} }
DLContext Context() const override { DGLContext Context() const override {
return AnyGraph()->Context(); return AnyGraph()->Context();
} }
...@@ -911,7 +911,7 @@ class ImmutableGraph: public GraphInterface { ...@@ -911,7 +911,7 @@ class ImmutableGraph: public GraphInterface {
* \param ctx The target context. * \param ctx The target context.
* \return The graph under another context. * \return The graph under another context.
*/ */
static ImmutableGraphPtr CopyTo(ImmutableGraphPtr g, const DLContext& ctx); static ImmutableGraphPtr CopyTo(ImmutableGraphPtr g, const DGLContext& ctx);
/*! /*!
* \brief Copy data to shared memory. * \brief Copy data to shared memory.
......
...@@ -145,7 +145,7 @@ class RandomEngine { ...@@ -145,7 +145,7 @@ class RandomEngine {
*/ */
template <typename IdxType, typename FloatType> template <typename IdxType, typename FloatType>
IdArray Choice(IdxType num, FloatArray prob, bool replace = true) { IdArray Choice(IdxType num, FloatArray prob, bool replace = true) {
const DLDataType dtype{kDLInt, sizeof(IdxType) * 8, 1}; const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1};
IdArray ret = IdArray::Empty({num}, dtype, prob->ctx); IdArray ret = IdArray::Empty({num}, dtype, prob->ctx);
Choice<IdxType, FloatType>(num, prob, static_cast<IdxType*>(ret->data), replace); Choice<IdxType, FloatType>(num, prob, static_cast<IdxType*>(ret->data), replace);
return ret; return ret;
...@@ -178,9 +178,9 @@ class RandomEngine { ...@@ -178,9 +178,9 @@ class RandomEngine {
*/ */
template <typename IdxType> template <typename IdxType>
IdArray UniformChoice(IdxType num, IdxType population, bool replace = true) { IdArray UniformChoice(IdxType num, IdxType population, bool replace = true) {
const DLDataType dtype{kDLInt, sizeof(IdxType) * 8, 1}; const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1};
// TODO(minjie): only CPU implementation right now // TODO(minjie): only CPU implementation right now
IdArray ret = IdArray::Empty({num}, dtype, DLContext{kDLCPU, 0}); IdArray ret = IdArray::Empty({num}, dtype, DGLContext{kDGLCPU, 0});
UniformChoice<IdxType>(num, population, static_cast<IdxType*>(ret->data), replace); UniformChoice<IdxType>(num, population, static_cast<IdxType*>(ret->data), replace);
return ret; return ret;
} }
...@@ -230,8 +230,8 @@ class RandomEngine { ...@@ -230,8 +230,8 @@ class RandomEngine {
template <typename IdxType, typename FloatType> template <typename IdxType, typename FloatType>
IdArray BiasedChoice( IdArray BiasedChoice(
IdxType num, const IdxType *split, FloatArray bias, bool replace = true) { IdxType num, const IdxType *split, FloatArray bias, bool replace = true) {
const DLDataType dtype{kDLInt, sizeof(IdxType) * 8, 1}; const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1};
IdArray ret = IdArray::Empty({num}, dtype, DLContext{kDLCPU, 0}); IdArray ret = IdArray::Empty({num}, dtype, DGLContext{kDGLCPU, 0});
BiasedChoice<IdxType, FloatType>(num, split, bias, static_cast<IdxType*>(ret->data), replace); BiasedChoice<IdxType, FloatType>(num, split, bias, static_cast<IdxType*>(ret->data), replace);
return ret; return ret;
} }
......
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016-2022 by Contributors
* \file dgl/runtime/c_runtime_api.h * \file dgl/runtime/c_runtime_api.h
* \brief DGL runtime library. * \brief DGL runtime library.
* *
...@@ -35,10 +35,6 @@ ...@@ -35,10 +35,6 @@
// DGL version // DGL version
#define DGL_VERSION "0.9" #define DGL_VERSION "0.9"
// DGL Runtime is DLPack compatible.
#include <dlpack/dlpack.h>
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
...@@ -48,28 +44,31 @@ extern "C" { ...@@ -48,28 +44,31 @@ extern "C" {
/*! \brief type of array index. */ /*! \brief type of array index. */
typedef int64_t dgl_index_t; typedef int64_t dgl_index_t;
/*! \brief Extension device types in DGL */ /*!
* \brief The device type in DGLContext.
*/
#ifdef __cplusplus
typedef enum : int32_t {
#else
typedef enum { typedef enum {
kDLAOCL = 5, #endif
kDLSDAccel = 6, /*! \brief CPU device */
kOpenGL = 11, kDGLCPU = 1,
// Extension DRAM type, used for quickly test extension device /*! \brief CUDA GPU device */
// The device api can differ depending on the xpu driver registered. kDGLCUDA = 2,
kExtDev = 12, // add more devices once supported
// AddExtraDGLType which is not in DLPack here } DGLDeviceType;
} DGLDeviceExtType;
/*! /*!
* \brief The type code in DGLType * \brief The object type code is used in DGL FFI to indicate the types of objects passed between C and Python.
* \note DGLType is used in two places.
*/ */
typedef enum { typedef enum {
// The type code of other types are compatible with DLPack. kInt = 0U,
// The next few fields are extension types kUInt = 1U,
// that is used by DGL API calls. kFloat = 2U,
kHandle = 3U, kHandle = 3U,
kNull = 4U, kNull = 4U,
kDGLType = 5U, kDGLDataType = 5U,
kDGLContext = 6U, kDGLContext = 6U,
kArrayHandle = 7U, kArrayHandle = 7U,
kObjectHandle = 8U, kObjectHandle = 8U,
...@@ -88,29 +87,112 @@ typedef enum { ...@@ -88,29 +87,112 @@ typedef enum {
// The following section of code is used for non-reserved types. // The following section of code is used for non-reserved types.
kExtReserveEnd = 64U, kExtReserveEnd = 64U,
kExtEnd = 128U kExtEnd = 128U
} DGLTypeCode; } DGLObjectTypeCode;
/*! /*!
* \brief The data type used in DGL Runtime. * \brief The type code options DGLDataType.
*/
typedef enum {
/*! \brief signed integer */
kDGLInt = 0U,
/*! \brief unsigned integer */
kDGLUInt = 1U,
/*! \brief IEEE floating point */
kDGLFloat = 2U,
/*! \brief bfloat16 */
kDGLBfloat = 4U,
// add more data types if we are going to support them
} DGLDataTypeCode;
/*!
* \brief The data type the tensor can hold. The data type is assumed to follow the
* native endian-ness. An explicit error message should be raised when attempting to
* export an array with non-native endianness
* *
* Examples * Examples
* - float: type_code = 2, bits = 32, lanes=1 * - float: type_code = 2, bits = 32, lanes=1
* - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4 * - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
* - int8: type_code = 0, bits = 8, lanes=1 * - int8: type_code = 0, bits = 8, lanes=1
*
* \note Arguments DGL API function always takes bits=64 and lanes=1
*/ */
typedef DLDataType DGLType; typedef struct {
/*!
* \brief Type code of base types.
* We keep it uint8_t instead of DGLDataTypeCode for minimal memory
* footprint, but the value should be one of DGLDataTypeCode enum values.
* */
uint8_t code;
/*!
* \brief Number of bits, common choices are 8, 16, 32.
*/
uint8_t bits;
/*! \brief Number of lanes in the type, used for vector types. */
uint16_t lanes;
} DGLDataType;
/*! /*!
* \brief The Device information, abstract away common device types. * \brief The Device information, abstract away common device types.
*/ */
typedef DLContext DGLContext; typedef struct {
/*! \brief The device type used in the device. */
DGLDeviceType device_type;
/*!
* \brief The device index.
* For vanilla CPU memory, pinned memory, or managed memory, this is set to 0.
*/
int32_t device_id;
} DGLContext;
/*! /*!
* \brief The tensor array stucture to DGL API. * \brief The tensor array stucture to DGL API.
* The structure is heavily inspired by DLTensor from DLPack.
*/ */
typedef DLTensor DGLArray; typedef struct {
/*!
* \brief The data pointer points to the allocated data.
*
* Depending on the device context, it can be a CPU pointer, or a CUDA
* device pointer or acl_mem handle in OpenCL.
* This pointer is always aligned to 256 bytes as in CUDA. Use the
* `byte_offset` field to mark the beginning of the actual data (if the address
* is not 256 byte aligned).
*
* Note that as of Nov 2021, multiply libraries (CuPy, PyTorch, TensorFlow,
* TVM, perhaps others) do not adhere to this 256 byte alignment requirement
* on CPU/CUDA/ROCm, and always use `byte_offset=0`. This is likely to be
* fixed in the future; at the moment it is recommended
* to not rely on the data pointer being correctly aligned.
*
* For a DGLArray, the size of memory required to store the contents of
* data can be calculated as follows:
*
* \code{.c}
* static inline size_t GetDataSize(const DGLArray* t) {
* size_t size = 1;
* for (int32_t i = 0; i < t->ndim; ++i) {
* size *= t->shape[i];
* }
* size *= (t->dtype.bits * t->dtype.lanes + 7) / 8;
* return size;
* }
* \endcode
*/
void* data;
/*! \brief The device of the tensor */
DGLContext ctx;
/*! \brief Number of dimensions */
int32_t ndim;
/*! \brief The data type of the pointer*/
DGLDataType dtype;
/*! \brief The shape of the tensor */
int64_t* shape;
/*!
* \brief strides of the tensor (in number of elements, not bytes)
* can be NULL, indicating tensor is compact and row-majored.
*/
int64_t* strides;
/*! \brief The offset in bytes to the beginning pointer to data */
uint64_t byte_offset;
} DGLArray;
/*! \brief the array handle */ /*! \brief the array handle */
typedef DGLArray* DGLArrayHandle; typedef DGLArray* DGLArrayHandle;
...@@ -124,7 +206,7 @@ typedef union { ...@@ -124,7 +206,7 @@ typedef union {
double v_float64; double v_float64;
void* v_handle; void* v_handle;
const char* v_str; const char* v_str;
DGLType v_type; DGLDataType v_type;
DGLContext v_ctx; DGLContext v_ctx;
} DGLValue; } DGLValue;
...@@ -455,32 +537,6 @@ DGL_DLL int DGLArrayCopyToBytes(DGLArrayHandle handle, ...@@ -455,32 +537,6 @@ DGL_DLL int DGLArrayCopyToBytes(DGLArrayHandle handle,
DGL_DLL int DGLArrayCopyFromTo(DGLArrayHandle from, DGL_DLL int DGLArrayCopyFromTo(DGLArrayHandle from,
DGLArrayHandle to); DGLArrayHandle to);
/*!
* \brief Produce an array from the DLManagedTensor that shares data memory
* with the DLManagedTensor.
* \param from The source DLManagedTensor.
* \param out The output array handle.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLArrayFromDLPack(DLManagedTensor* from,
DGLArrayHandle* out);
/*!
* \brief Produce a DLMangedTensor from the array that shares data memory with
* the array.
* \param from The source array.
* \param out The DLManagedTensor handle.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLArrayToDLPack(DGLArrayHandle from, DLManagedTensor** out,
int alignment = 0);
/*!
* \brief Delete (free) a DLManagedTensor's data.
* \param dltensor Pointer to the DLManagedTensor.
*/
DGL_DLL void DGLDLManagedTensorCallDeleter(DLManagedTensor* dltensor);
/*! /*!
* \brief Create a new runtime stream. * \brief Create a new runtime stream.
* *
...@@ -557,12 +613,12 @@ DGL_DLL int DGLLoadTensorAdapter(const char *path); ...@@ -557,12 +613,12 @@ DGL_DLL int DGLLoadTensorAdapter(const char *path);
/*! /*!
* \brief Pin host memory. * \brief Pin host memory.
*/ */
int DGLArrayPinData(DGLArrayHandle handle, DLContext ctx); int DGLArrayPinData(DGLArrayHandle handle, DGLContext ctx);
/*! /*!
* \brief Unpin host memory. * \brief Unpin host memory.
*/ */
int DGLArrayUnpinData(DGLArrayHandle handle, DLContext ctx); int DGLArrayUnpinData(DGLArrayHandle handle, DGLContext ctx);
/*! /*!
* \brief Record the stream that's using this tensor. * \brief Record the stream that's using this tensor.
......
...@@ -75,7 +75,7 @@ class DeviceAPI { ...@@ -75,7 +75,7 @@ class DeviceAPI {
virtual void* AllocDataSpace(DGLContext ctx, virtual void* AllocDataSpace(DGLContext ctx,
size_t nbytes, size_t nbytes,
size_t alignment, size_t alignment,
DGLType type_hint) = 0; DGLDataType type_hint) = 0;
/*! /*!
* \brief Free a data space on device. * \brief Free a data space on device.
* \param ctx The device context to perform operation. * \param ctx The device context to perform operation.
...@@ -101,7 +101,7 @@ class DeviceAPI { ...@@ -101,7 +101,7 @@ class DeviceAPI {
size_t num_bytes, size_t num_bytes,
DGLContext ctx_from, DGLContext ctx_from,
DGLContext ctx_to, DGLContext ctx_to,
DGLType type_hint) = 0; DGLDataType type_hint) = 0;
/*! /*!
* \brief Create a new stream of execution. * \brief Create a new stream of execution.
* *
...@@ -189,7 +189,7 @@ class DeviceAPI { ...@@ -189,7 +189,7 @@ class DeviceAPI {
*/ */
DGL_DLL virtual void* AllocWorkspace(DGLContext ctx, DGL_DLL virtual void* AllocWorkspace(DGLContext ctx,
size_t nbytes, size_t nbytes,
DGLType type_hint = {}); DGLDataType type_hint = {});
/*! /*!
* \brief Free temporal workspace in backend execution. * \brief Free temporal workspace in backend execution.
* *
...@@ -213,7 +213,7 @@ class DeviceAPI { ...@@ -213,7 +213,7 @@ class DeviceAPI {
* \param allow_missing Whether allow missing * \param allow_missing Whether allow missing
* \return The corresponding device API. * \return The corresponding device API.
*/ */
DGL_DLL static DeviceAPI* Get(DLDeviceType dev_type, bool allow_missing = false); DGL_DLL static DeviceAPI* Get(DGLDeviceType dev_type, bool allow_missing = false);
}; };
/*! \brief The device type bigger than this is RPC device */ /*! \brief The device type bigger than this is RPC device */
......
/*!
* Copyright (c) 2022 by Contributors
* \file include/dgl/runtime/dlpack_convert.h
* \brief Conversion between NDArray and DLPack.
*/
#ifndef DGL_RUNTIME_DLPACK_CONVERT_H_
#define DGL_RUNTIME_DLPACK_CONVERT_H_
#include "c_runtime_api.h"
#include "ndarray.h"
struct DLManagedTensor;
namespace dgl {
namespace runtime {
struct DLPackConvert {
/*!
* \brief Create a DGL NDArray from a DLPack tensor.
*
* This allows us to create a NDArray using the memory
* allocated by an external deep learning framework
* that is DLPack compatible.
*
* The memory is retained until the NDArray went out of scope.
* \param tensor The DLPack tensor to copy from.
* \return The created NDArray view.
*/
static NDArray FromDLPack(DLManagedTensor* tensor);
/*!
* \brief Deleter for NDArray converted from DLPack.
*
* This is used from data which is passed from external DLPack(DLManagedTensor)
* that are not allocated inside of DGL.
* This enables us to create NDArray from memory allocated by other
* frameworks that are DLPack compatible
*/
static void DLPackDeleter(NDArray::Container* ptr);
/*! \brief Convert a DGL NDArray to a DLPack tensor.
*
* \param from The DGL NDArray.
* \return A DLPack tensor.
*/
static DLManagedTensor* ToDLPack(const NDArray &from);
};
} // namespace runtime
} // namespace dgl
#ifdef __cplusplus
extern "C" {
#endif
/*!
* \brief Delete (free) a DLManagedTensor's data.
* \param dltensor Pointer to the DLManagedTensor.
*/
DGL_DLL void DGLDLManagedTensorCallDeleter(DLManagedTensor* dltensor);
/*!
* \brief Produce an array from the DLManagedTensor that shares data memory
* with the DLManagedTensor.
* \param from The source DLManagedTensor.
* \param out The output array handle.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLArrayFromDLPack(DLManagedTensor* from,
DGLArrayHandle* out);
/*!
* \brief Produce a DLMangedTensor from the array that shares data memory with
* the array.
* \param from The source array.
* \param out The DLManagedTensor handle.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLArrayToDLPack(DGLArrayHandle from, DLManagedTensor** out,
int alignment = 0);
#ifdef __cplusplus
} // DGL_EXTERN_C
#endif
#endif // DGL_RUNTIME_DLPACK_CONVERT_H_
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017-2022 by Contributors
* \file dgl/runtime/ndarray.h * \file dgl/runtime/ndarray.h
* \brief Abstract device memory management API * \brief Abstract device memory management API
*/ */
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#include <memory> #include <memory>
#include "c_runtime_api.h" #include "c_runtime_api.h"
#include "dlpack/dlpack.h"
#include "serializer.h" #include "serializer.h"
#include "shared_mem.h" #include "shared_mem.h"
...@@ -23,44 +22,49 @@ ...@@ -23,44 +22,49 @@
#endif #endif
// forward declaration // forward declaration
inline std::ostream& operator << (std::ostream& os, DGLType t); inline std::ostream& operator << (std::ostream& os, DGLDataType t);
namespace dgl { namespace dgl {
/*! /*!
* \brief Type traits that converts a C type to a DLDataType. * \brief Type traits that converts a C type to a DGLDataType.
* *
* Usage: * Usage:
* DLDataTypeTraits<int>::dtype == dtype * DGLDataTypeTraits<int>::dtype == dtype
*/ */
template<typename T> template<typename T>
struct DLDataTypeTraits { struct DGLDataTypeTraits {
static constexpr DLDataType dtype{0, 0, 0}; // dummy static constexpr DGLDataType dtype{0, 0, 0}; // dummy
}; };
#define GEN_DLDATATYPETRAITS_FOR(T, code, bits) \ #define GEN_DGLDATATYPETRAITS_FOR(T, code, bits) \
template<> \ template<> \
struct DLDataTypeTraits<T> { \ struct DGLDataTypeTraits<T> { \
static constexpr DLDataType dtype{code, bits, 1}; \ static constexpr DGLDataType dtype{code, bits, 1}; \
} }
GEN_DLDATATYPETRAITS_FOR(int8_t, kDLInt, 8); GEN_DGLDATATYPETRAITS_FOR(int8_t, kDGLInt, 8);
GEN_DLDATATYPETRAITS_FOR(int16_t, kDLInt, 16); GEN_DGLDATATYPETRAITS_FOR(int16_t, kDGLInt, 16);
GEN_DLDATATYPETRAITS_FOR(int32_t, kDLInt, 32); GEN_DGLDATATYPETRAITS_FOR(int32_t, kDGLInt, 32);
GEN_DLDATATYPETRAITS_FOR(int64_t, kDLInt, 64); GEN_DGLDATATYPETRAITS_FOR(int64_t, kDGLInt, 64);
// XXX(BarclayII) most DL frameworks do not support unsigned int and long arrays, so I'm just // XXX(BarclayII) most DL frameworks do not support unsigned int and long arrays, so I'm just
// converting uints to signed DTypes. // converting uints to signed DTypes.
GEN_DLDATATYPETRAITS_FOR(uint32_t, kDLInt, 32); GEN_DGLDATATYPETRAITS_FOR(uint32_t, kDGLInt, 32);
GEN_DLDATATYPETRAITS_FOR(uint64_t, kDLInt, 64); GEN_DGLDATATYPETRAITS_FOR(uint64_t, kDGLInt, 64);
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
#ifdef USE_FP16 #ifdef USE_FP16
GEN_DLDATATYPETRAITS_FOR(__half, kDLFloat, 16); GEN_DGLDATATYPETRAITS_FOR(__half, kDGLFloat, 16);
#endif #endif
#endif #endif
GEN_DLDATATYPETRAITS_FOR(float, kDLFloat, 32); GEN_DGLDATATYPETRAITS_FOR(float, kDGLFloat, 32);
GEN_DLDATATYPETRAITS_FOR(double, kDLFloat, 64); GEN_DGLDATATYPETRAITS_FOR(double, kDGLFloat, 64);
#undef GEN_DLDATATYPETRAITS_FOR #undef GEN_DGLDATATYPETRAITS_FOR
namespace runtime { namespace runtime {
/*!
* \brief DLPack converter.
*/
struct DLPackConvert;
/*! /*!
* \brief Managed NDArray. * \brief Managed NDArray.
* The array is backed by reference counted blocks. * The array is backed by reference counted blocks.
...@@ -135,8 +139,8 @@ class NDArray { ...@@ -135,8 +139,8 @@ class NDArray {
* \note this number is approximate in multi-threaded setting. * \note this number is approximate in multi-threaded setting.
*/ */
inline int use_count() const; inline int use_count() const;
/*! \return Pointer to content of DLTensor */ /*! \return Pointer to content of DGLArray */
inline const DLTensor* operator->() const; inline const DGLArray* operator->() const;
/*! \return True if the ndarray is contiguous. */ /*! \return True if the ndarray is contiguous. */
bool IsContiguous() const; bool IsContiguous() const;
/*! \return the data pointer with type. */ /*! \return the data pointer with type. */
...@@ -152,9 +156,9 @@ class NDArray { ...@@ -152,9 +156,9 @@ class NDArray {
* \param other The source array to be copied from. * \param other The source array to be copied from.
* \note The copy runs on the dgl internal stream if it involves a GPU context. * \note The copy runs on the dgl internal stream if it involves a GPU context.
*/ */
inline void CopyFrom(DLTensor* other); inline void CopyFrom(DGLArray* other);
inline void CopyFrom(const NDArray& other); inline void CopyFrom(const NDArray& other);
inline void CopyTo(DLTensor *other) const; inline void CopyTo(DGLArray *other) const;
inline void CopyTo(const NDArray &other) const; inline void CopyTo(const NDArray &other) const;
/*! /*!
...@@ -162,7 +166,7 @@ class NDArray { ...@@ -162,7 +166,7 @@ class NDArray {
* \param ctx The target context. * \param ctx The target context.
* \return The array under another context. * \return The array under another context.
*/ */
inline NDArray CopyTo(const DLContext &ctx) const; inline NDArray CopyTo(const DGLContext &ctx) const;
/*! /*!
* \brief Return a new array with a copy of the content. * \brief Return a new array with a copy of the content.
*/ */
...@@ -171,9 +175,9 @@ class NDArray { ...@@ -171,9 +175,9 @@ class NDArray {
* \brief In-place method to pin the current array by calling PinContainer * \brief In-place method to pin the current array by calling PinContainer
* on the underlying NDArray:Container. * on the underlying NDArray:Container.
* \note This is an in-place method. Behavior depends on the current context, * \note This is an in-place method. Behavior depends on the current context,
* kDLCPU: will be pinned; * kDGLCPU: will be pinned;
* IsPinned: directly return; * IsPinned: directly return;
* kDLGPU: invalid, will throw an error. * kDGLCUDA: invalid, will throw an error.
*/ */
inline void PinMemory_(); inline void PinMemory_();
/*! /*!
...@@ -212,13 +216,7 @@ class NDArray { ...@@ -212,13 +216,7 @@ class NDArray {
* \note The memory size of new array must be smaller than the current one. * \note The memory size of new array must be smaller than the current one.
*/ */
DGL_DLL NDArray CreateView( DGL_DLL NDArray CreateView(
std::vector<int64_t> shape, DLDataType dtype, int64_t offset = 0); std::vector<int64_t> shape, DGLDataType dtype, int64_t offset = 0);
/*!
* \brief Create a reference view of NDArray that
* represents as DLManagedTensor.
* \return A DLManagedTensor
*/
DGL_DLL DLManagedTensor* ToDLPack() const;
/*! /*!
* \brief Create an empty NDArray. * \brief Create an empty NDArray.
* \param shape The shape of the new array. * \param shape The shape of the new array.
...@@ -227,8 +225,8 @@ class NDArray { ...@@ -227,8 +225,8 @@ class NDArray {
* \return The created Array * \return The created Array
*/ */
DGL_DLL static NDArray Empty(std::vector<int64_t> shape, DGL_DLL static NDArray Empty(std::vector<int64_t> shape,
DLDataType dtype, DGLDataType dtype,
DLContext ctx); DGLContext ctx);
/*! /*!
* \brief Create an empty NDArray with shared memory. * \brief Create an empty NDArray with shared memory.
* \param name The name of shared memory. * \param name The name of shared memory.
...@@ -240,8 +238,8 @@ class NDArray { ...@@ -240,8 +238,8 @@ class NDArray {
*/ */
DGL_DLL static NDArray EmptyShared(const std::string &name, DGL_DLL static NDArray EmptyShared(const std::string &name,
std::vector<int64_t> shape, std::vector<int64_t> shape,
DLDataType dtype, DGLDataType dtype,
DLContext ctx, DGLContext ctx,
bool is_create); bool is_create);
/*! /*!
* \brief Get the size of the array in the number of bytes. * \brief Get the size of the array in the number of bytes.
...@@ -253,26 +251,19 @@ class NDArray { ...@@ -253,26 +251,19 @@ class NDArray {
*/ */
int64_t NumElements() const; int64_t NumElements() const;
/*!
* \brief Create a NDArray backed by a dlpack tensor.
*
* This allows us to create a NDArray using the memory
* allocated by an external deep learning framework
* that is DLPack compatible.
*
* The memory is retained until the NDArray went out of scope.
* \param tensor The DLPack tensor to copy from.
* \return The created NDArray view.
*/
DGL_DLL static NDArray FromDLPack(DLManagedTensor* tensor);
/*! /*!
* \brief Create a NDArray by copying from std::vector. * \brief Create a NDArray by copying from std::vector.
* \tparam T Type of vector data. Determines the dtype of returned array. * \tparam T Type of vector data. Determines the dtype of returned array.
*/ */
template<typename T> template<typename T>
DGL_DLL static NDArray FromVector( DGL_DLL static NDArray FromVector(
const std::vector<T>& vec, DLContext ctx = DLContext{kDLCPU, 0}); const std::vector<T>& vec, DGLContext ctx = DGLContext{kDGLCPU, 0});
/*!
* \brief Create a NDArray from a raw pointer.
*/
DGL_DLL static NDArray CreateFromRaw(const std::vector<int64_t>& shape,
DGLDataType dtype, DGLContext ctx, void* raw, bool auto_free);
/*! /*!
* \brief Create a std::vector from a 1D NDArray. * \brief Create a std::vector from a 1D NDArray.
...@@ -292,23 +283,23 @@ class NDArray { ...@@ -292,23 +283,23 @@ class NDArray {
* \param (optional) stream The stream used in copy. * \param (optional) stream The stream used in copy.
*/ */
DGL_DLL static void CopyFromTo( DGL_DLL static void CopyFromTo(
DLTensor* from, DLTensor* to); DGLArray* from, DGLArray* to);
DGL_DLL static void CopyFromTo( DGL_DLL static void CopyFromTo(
DLTensor* from, DLTensor* to, DGLStreamHandle stream); DGLArray* from, DGLArray* to, DGLStreamHandle stream);
/*! /*!
* \brief Function to pin the DLTensor of a Container. * \brief Function to pin the DGLArray of a Container.
* \param ptr The container to be pinned. * \param ptr The container to be pinned.
* \note Data of the given array will be pinned inplace. * \note Data of the given array will be pinned inplace.
* Behavior depends on the current context, * Behavior depends on the current context,
* kDLCPU: will be pinned; * kDGLCPU: will be pinned;
* IsPinned: directly return; * IsPinned: directly return;
* kDLGPU: invalid, will throw an error. * kDGLCUDA: invalid, will throw an error.
*/ */
DGL_DLL static void PinContainer(Container* ptr); DGL_DLL static void PinContainer(Container* ptr);
/*! /*!
* \brief Function to unpin the DLTensor of a Container. * \brief Function to unpin the DGLArray of a Container.
* \param ptr The container to be unpinned. * \param ptr The container to be unpinned.
* \note Data of the given array will be unpinned inplace. * \note Data of the given array will be unpinned inplace.
* Behavior depends on the current context, * Behavior depends on the current context,
...@@ -318,7 +309,7 @@ class NDArray { ...@@ -318,7 +309,7 @@ class NDArray {
DGL_DLL static void UnpinContainer(Container* ptr); DGL_DLL static void UnpinContainer(Container* ptr);
/*! /*!
* \brief Function check if the DLTensor of a Container is pinned. * \brief Function check if the DGLArray of a Container is pinned.
* \param ptr The container to be checked. * \param ptr The container to be checked.
* \return true if pinned. * \return true if pinned.
*/ */
...@@ -332,45 +323,57 @@ class NDArray { ...@@ -332,45 +323,57 @@ class NDArray {
DGL_DLL static void RecordStream(DGLArray* tensor, DGLStreamHandle stream); DGL_DLL static void RecordStream(DGLArray* tensor, DGLStreamHandle stream);
// internal namespace // internal namespace
struct Internal; struct Internal {
// Default deleter for the container
static void DefaultDeleter(NDArray::Container* ptr);
// Local create function which allocates tensor metadata
// but does not allocate space for the data.
static NDArray Create(std::vector<int64_t> shape,
DGLDataType dtype, DGLContext ctx);
// Implementation of API function
static DGLArray* MoveAsDGLArray(NDArray arr);
};
private: private:
/*! \brief Internal Data content */ /*! \brief Internal Data content */
Container* data_{nullptr}; Container* data_{nullptr};
// enable internal functions // enable internal functions
friend struct Internal; friend struct Internal;
friend struct DLPackConvert;
friend class DGLRetValue; friend class DGLRetValue;
friend class DGLArgsSetter; friend class DGLArgsSetter;
}; };
/*! /*!
* \brief Save a DLTensor to stream * \brief Save a DGLArray to stream
* \param strm The outpu stream * \param strm The outpu stream
* \param tensor The tensor to be saved. * \param tensor The tensor to be saved.
*/ */
inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor); inline bool SaveDGLArray(dmlc::Stream* strm, const DGLArray* tensor);
/*! /*!
* \brief Reference counted Container object used to back NDArray. * \brief Reference counted Container object used to back NDArray.
* *
* This object is DLTensor compatible: * This object is DGLArray compatible:
* the pointer to the NDArrayContainer can be directly * the pointer to the NDArrayContainer can be directly
* interpreted as a DLTensor* * interpreted as a DGLArray*
* *
* \note: do not use this function directly, use NDArray. * \note: do not use this function directly, use NDArray.
*/ */
struct NDArray::Container { struct NDArray::Container {
public: public:
// NOTE: the first part of this structure is the same as /*! NOTE: the first part of this structure is the same as
// DLManagedTensor, note that, however, the deleter * DLManagedTensor, note that, however, the deleter
// is only called when the reference counter goes to 0 * is only called when the reference counter goes to 0
*/
/*! /*!
* \brief The corresponding dl_tensor field. * \brief Tensor structure.
* \note it is important that the first field is DLTensor * \note it is important that the first field is DGLArray
* So that this data structure is DLTensor compatible. * So that this data structure is DGLArray compatible.
* The head ptr of this struct can be viewed as DLTensor*. * The head ptr of this struct can be viewed as DGLArray*.
*/ */
DLTensor dl_tensor; DGLArray dl_tensor;
/*! /*!
* \brief addtional context, reserved for recycling * \brief addtional context, reserved for recycling
* \note We can attach additional content here * \note We can attach additional content here
...@@ -411,6 +414,7 @@ struct NDArray::Container { ...@@ -411,6 +414,7 @@ struct NDArray::Container {
} }
private: private:
friend struct DLPackConvert;
friend class NDArray; friend class NDArray;
friend class RPCWrappedFunc; friend class RPCWrappedFunc;
/*! /*!
...@@ -450,7 +454,7 @@ inline void NDArray::reset() { ...@@ -450,7 +454,7 @@ inline void NDArray::reset() {
} }
} }
inline void NDArray::CopyFrom(DLTensor* other) { inline void NDArray::CopyFrom(DGLArray* other) {
CHECK(data_ != nullptr); CHECK(data_ != nullptr);
CopyFromTo(other, &(data_->dl_tensor)); CopyFromTo(other, &(data_->dl_tensor));
} }
...@@ -460,7 +464,7 @@ inline void NDArray::CopyFrom(const NDArray& other) { ...@@ -460,7 +464,7 @@ inline void NDArray::CopyFrom(const NDArray& other) {
CopyFrom(&(other.data_->dl_tensor)); CopyFrom(&(other.data_->dl_tensor));
} }
inline void NDArray::CopyTo(DLTensor *other) const { inline void NDArray::CopyTo(DGLArray *other) const {
CHECK(data_ != nullptr); CHECK(data_ != nullptr);
CopyFromTo(&(data_->dl_tensor), other); CopyFromTo(&(data_->dl_tensor), other);
} }
...@@ -470,9 +474,9 @@ inline void NDArray::CopyTo(const NDArray &other) const { ...@@ -470,9 +474,9 @@ inline void NDArray::CopyTo(const NDArray &other) const {
CopyTo(&(other.data_->dl_tensor)); CopyTo(&(other.data_->dl_tensor));
} }
inline NDArray NDArray::CopyTo(const DLContext &ctx) const { inline NDArray NDArray::CopyTo(const DGLContext &ctx) const {
CHECK(data_ != nullptr); CHECK(data_ != nullptr);
const DLTensor* dptr = operator->(); const DGLArray* dptr = operator->();
NDArray ret = Empty(std::vector<int64_t>(dptr->shape, dptr->shape + dptr->ndim), NDArray ret = Empty(std::vector<int64_t>(dptr->shape, dptr->shape + dptr->ndim),
dptr->dtype, ctx); dptr->dtype, ctx);
this->CopyTo(ret); this->CopyTo(ret);
...@@ -481,7 +485,7 @@ inline NDArray NDArray::CopyTo(const DLContext &ctx) const { ...@@ -481,7 +485,7 @@ inline NDArray NDArray::CopyTo(const DLContext &ctx) const {
inline NDArray NDArray::Clone() const { inline NDArray NDArray::Clone() const {
CHECK(data_ != nullptr); CHECK(data_ != nullptr);
const DLTensor* dptr = operator->(); const DGLArray* dptr = operator->();
return this->CopyTo(dptr->ctx); return this->CopyTo(dptr->ctx);
} }
...@@ -510,15 +514,15 @@ inline int NDArray::use_count() const { ...@@ -510,15 +514,15 @@ inline int NDArray::use_count() const {
return data_->ref_counter_.load(std::memory_order_relaxed); return data_->ref_counter_.load(std::memory_order_relaxed);
} }
inline const DLTensor* NDArray::operator->() const { inline const DGLArray* NDArray::operator->() const {
return &(data_->dl_tensor); return &(data_->dl_tensor);
} }
/*! \brief Magic number for NDArray file */ /*! \brief Magic number for NDArray file */
constexpr uint64_t kDGLNDArrayMagic = 0xDD5E40F096B4A13F; constexpr uint64_t kDGLNDArrayMagic = 0xDD5E40F096B4A13F;
inline bool SaveDLTensor(dmlc::Stream* strm, inline bool SaveDGLArray(dmlc::Stream* strm,
DLTensor* tensor) { DGLArray* tensor) {
uint64_t header = kDGLNDArrayMagic, reserved = 0; uint64_t header = kDGLNDArrayMagic, reserved = 0;
strm->Write(header); strm->Write(header);
strm->Write(reserved); strm->Write(reserved);
...@@ -531,8 +535,8 @@ inline bool SaveDLTensor(dmlc::Stream* strm, ...@@ -531,8 +535,8 @@ inline bool SaveDLTensor(dmlc::Stream* strm,
// //
// We can always do array.CopyTo(target_ctx) to get a corresponding // We can always do array.CopyTo(target_ctx) to get a corresponding
// array in the target context. // array in the target context.
DLContext cpu_ctx; DGLContext cpu_ctx;
cpu_ctx.device_type = kDLCPU; cpu_ctx.device_type = kDGLCPU;
cpu_ctx.device_id = 0; cpu_ctx.device_id = 0;
strm->Write(cpu_ctx); strm->Write(cpu_ctx);
strm->Write(tensor->ndim); strm->Write(tensor->ndim);
...@@ -548,7 +552,7 @@ inline bool SaveDLTensor(dmlc::Stream* strm, ...@@ -548,7 +552,7 @@ inline bool SaveDLTensor(dmlc::Stream* strm,
strm->Write(data_byte_size); strm->Write(data_byte_size);
if (DMLC_IO_NO_ENDIAN_SWAP && if (DMLC_IO_NO_ENDIAN_SWAP &&
tensor->ctx.device_type == kDLCPU && tensor->ctx.device_type == kDGLCPU &&
tensor->strides == nullptr && tensor->strides == nullptr &&
tensor->byte_offset == 0) { tensor->byte_offset == 0) {
// quick path // quick path
...@@ -573,16 +577,16 @@ inline bool SaveDLTensor(dmlc::Stream* strm, ...@@ -573,16 +577,16 @@ inline bool SaveDLTensor(dmlc::Stream* strm,
*/ */
inline const char* TypeCode2Str(int type_code) { inline const char* TypeCode2Str(int type_code) {
switch (type_code) { switch (type_code) {
case kDLInt: return "int"; case kDGLInt: return "int";
case kDLUInt: return "uint"; case kDGLUInt: return "uint";
case kDLFloat: return "float"; case kDGLFloat: return "float";
case kStr: return "str"; case kStr: return "str";
case kBytes: return "bytes"; case kBytes: return "bytes";
case kHandle: return "handle"; case kHandle: return "handle";
case kNull: return "NULL"; case kNull: return "NULL";
case kObjectHandle: return "ObjectHandle"; case kObjectHandle: return "ObjectHandle";
case kArrayHandle: return "ArrayHandle"; case kArrayHandle: return "ArrayHandle";
case kDGLType: return "DGLType"; case kDGLDataType: return "DGLDataType";
case kDGLContext: return "DGLContext"; case kDGLContext: return "DGLContext";
case kFuncHandle: return "FunctionHandle"; case kFuncHandle: return "FunctionHandle";
case kModuleHandle: return "ModuleHandle"; case kModuleHandle: return "ModuleHandle";
...@@ -597,17 +601,11 @@ inline const char* TypeCode2Str(int type_code) { ...@@ -597,17 +601,11 @@ inline const char* TypeCode2Str(int type_code) {
* \param device_type The device type code. * \param device_type The device type code.
* \return The name of the device. * \return The name of the device.
*/ */
inline const char* DeviceTypeCode2Str(DLDeviceType device_type) { inline const char* DeviceTypeCode2Str(DGLDeviceType device_type) {
switch (device_type) { switch (device_type) {
case kDLCPU: return "cpu"; case kDGLCPU: return "cpu";
case kDLGPU: return "cuda"; case kDGLCUDA: return "cuda";
case kDLCPUPinned: return "cpu_pinned"; default: LOG(FATAL) << "Unsupported device type code="
case kDLOpenCL: return "opencl";
case kDLVulkan: return "vulkan";
case kDLMetal: return "metal";
case kDLVPI: return "vpi";
case kDLROCM: return "rocm";
default: LOG(FATAL) << "Unknown device type code="
<< static_cast<int>(device_type); return ""; << static_cast<int>(device_type); return "";
} }
} }
...@@ -617,16 +615,16 @@ inline const char* DeviceTypeCode2Str(DLDeviceType device_type) { ...@@ -617,16 +615,16 @@ inline const char* DeviceTypeCode2Str(DLDeviceType device_type) {
* \param s The string to be converted. * \param s The string to be converted.
* \return The corresponding dgl type. * \return The corresponding dgl type.
*/ */
inline DGLType String2DGLType(std::string s) { inline DGLDataType String2DGLDataType(std::string s) {
DGLType t; DGLDataType t;
t.bits = 32; t.lanes = 1; t.bits = 32; t.lanes = 1;
const char* scan; const char* scan;
if (s.substr(0, 3) == "int") { if (s.substr(0, 3) == "int") {
t.code = kDLInt; scan = s.c_str() + 3; t.code = kDGLInt; scan = s.c_str() + 3;
} else if (s.substr(0, 4) == "uint") { } else if (s.substr(0, 4) == "uint") {
t.code = kDLUInt; scan = s.c_str() + 4; t.code = kDGLUInt; scan = s.c_str() + 4;
} else if (s.substr(0, 5) == "float") { } else if (s.substr(0, 5) == "float") {
t.code = kDLFloat; scan = s.c_str() + 5; t.code = kDGLFloat; scan = s.c_str() + 5;
} else if (s.substr(0, 6) == "handle") { } else if (s.substr(0, 6) == "handle") {
t.code = kHandle; t.code = kHandle;
t.bits = 64; // handle uses 64 bit by default. t.bits = 64; // handle uses 64 bit by default.
...@@ -649,7 +647,7 @@ inline DGLType String2DGLType(std::string s) { ...@@ -649,7 +647,7 @@ inline DGLType String2DGLType(std::string s) {
* \param t The type to be converted. * \param t The type to be converted.
* \return The corresponding dgl type in string. * \return The corresponding dgl type in string.
*/ */
inline std::string DGLType2String(DGLType t) { inline std::string DGLDataType2String(DGLDataType t) {
#ifndef _LIBCPP_SGX_NO_IOSTREAMS #ifndef _LIBCPP_SGX_NO_IOSTREAMS
std::ostringstream os; std::ostringstream os;
os << t; os << t;
...@@ -728,20 +726,20 @@ dgl::runtime::NDArray operator != (int64_t lhs, const dgl::runtime::NDArray& a2) ...@@ -728,20 +726,20 @@ dgl::runtime::NDArray operator != (int64_t lhs, const dgl::runtime::NDArray& a2)
std::ostream& operator << (std::ostream& os, dgl::runtime::NDArray array); std::ostream& operator << (std::ostream& os, dgl::runtime::NDArray array);
///////////////// Operator overloading for DLDataType ///////////////// ///////////////// Operator overloading for DGLDataType /////////////////
/*! \brief Check whether two data types are the same.*/ /*! \brief Check whether two data types are the same.*/
inline bool operator == (const DLDataType& ty1, const DLDataType& ty2) { inline bool operator == (const DGLDataType& ty1, const DGLDataType& ty2) {
return ty1.code == ty2.code && ty1.bits == ty2.bits && ty1.lanes == ty2.lanes; return ty1.code == ty2.code && ty1.bits == ty2.bits && ty1.lanes == ty2.lanes;
} }
/*! \brief Check whether two data types are different.*/ /*! \brief Check whether two data types are different.*/
inline bool operator != (const DLDataType& ty1, const DLDataType& ty2) { inline bool operator != (const DGLDataType& ty1, const DGLDataType& ty2) {
return !(ty1 == ty2); return !(ty1 == ty2);
} }
#ifndef _LIBCPP_SGX_NO_IOSTREAMS #ifndef _LIBCPP_SGX_NO_IOSTREAMS
inline std::ostream& operator << (std::ostream& os, DGLType t) { inline std::ostream& operator << (std::ostream& os, DGLDataType t) {
os << dgl::runtime::TypeCode2Str(t.code); os << dgl::runtime::TypeCode2Str(t.code);
if (t.code == kHandle) return os; if (t.code == kHandle) return os;
os << static_cast<int>(t.bits); os << static_cast<int>(t.bits);
...@@ -752,20 +750,20 @@ inline std::ostream& operator << (std::ostream& os, DGLType t) { ...@@ -752,20 +750,20 @@ inline std::ostream& operator << (std::ostream& os, DGLType t) {
} }
#endif #endif
///////////////// Operator overloading for DLContext ///////////////// ///////////////// Operator overloading for DGLContext /////////////////
/*! \brief Check whether two device contexts are the same.*/ /*! \brief Check whether two device contexts are the same.*/
inline bool operator == (const DLContext& ctx1, const DLContext& ctx2) { inline bool operator == (const DGLContext& ctx1, const DGLContext& ctx2) {
return ctx1.device_type == ctx2.device_type && ctx1.device_id == ctx2.device_id; return ctx1.device_type == ctx2.device_type && ctx1.device_id == ctx2.device_id;
} }
/*! \brief Check whether two device contexts are different.*/ /*! \brief Check whether two device contexts are different.*/
inline bool operator != (const DLContext& ctx1, const DLContext& ctx2) { inline bool operator != (const DGLContext& ctx1, const DGLContext& ctx2) {
return !(ctx1 == ctx2); return !(ctx1 == ctx2);
} }
#ifndef _LIBCPP_SGX_NO_IOSTREAMS #ifndef _LIBCPP_SGX_NO_IOSTREAMS
inline std::ostream& operator << (std::ostream& os, const DLContext& ctx) { inline std::ostream& operator << (std::ostream& os, const DGLContext& ctx) {
return os << dgl::runtime::DeviceTypeCode2Str(ctx.device_type) << ":" << ctx.device_id; return os << dgl::runtime::DeviceTypeCode2Str(ctx.device_type) << ":" << ctx.device_id;
} }
#endif #endif
......
...@@ -350,28 +350,28 @@ class DGLPODValue_ { ...@@ -350,28 +350,28 @@ class DGLPODValue_ {
// Allow automatic conversion from int to float // Allow automatic conversion from int to float
// This avoids errors when user pass in int from // This avoids errors when user pass in int from
// the frontend while the API expects a float. // the frontend while the API expects a float.
if (type_code_ == kDLInt) { if (type_code_ == kDGLInt) {
return static_cast<double>(value_.v_int64); return static_cast<double>(value_.v_int64);
} }
DGL_CHECK_TYPE_CODE(type_code_, kDLFloat); DGL_CHECK_TYPE_CODE(type_code_, kDGLFloat);
return value_.v_float64; return value_.v_float64;
} }
operator int64_t() const { operator int64_t() const {
DGL_CHECK_TYPE_CODE(type_code_, kDLInt); DGL_CHECK_TYPE_CODE(type_code_, kDGLInt);
return value_.v_int64; return value_.v_int64;
} }
operator uint64_t() const { operator uint64_t() const {
DGL_CHECK_TYPE_CODE(type_code_, kDLInt); DGL_CHECK_TYPE_CODE(type_code_, kDGLInt);
return value_.v_int64; return value_.v_int64;
} }
operator int() const { operator int() const {
DGL_CHECK_TYPE_CODE(type_code_, kDLInt); DGL_CHECK_TYPE_CODE(type_code_, kDGLInt);
CHECK_LE(value_.v_int64, CHECK_LE(value_.v_int64,
std::numeric_limits<int>::max()); std::numeric_limits<int>::max());
return static_cast<int>(value_.v_int64); return static_cast<int>(value_.v_int64);
} }
operator bool() const { operator bool() const {
DGL_CHECK_TYPE_CODE(type_code_, kDLInt); DGL_CHECK_TYPE_CODE(type_code_, kDGLInt);
return value_.v_int64 != 0; return value_.v_int64 != 0;
} }
operator void*() const { operator void*() const {
...@@ -380,14 +380,14 @@ class DGLPODValue_ { ...@@ -380,14 +380,14 @@ class DGLPODValue_ {
DGL_CHECK_TYPE_CODE(type_code_, kHandle); DGL_CHECK_TYPE_CODE(type_code_, kHandle);
return value_.v_handle; return value_.v_handle;
} }
operator DLTensor*() const { operator DGLArray*() const {
if (type_code_ == kArrayHandle || if (type_code_ == kArrayHandle ||
type_code_ == kNDArrayContainer) { type_code_ == kNDArrayContainer) {
return static_cast<DLTensor*>(value_.v_handle); return static_cast<DGLArray*>(value_.v_handle);
} else { } else {
if (type_code_ == kNull) return nullptr; if (type_code_ == kNull) return nullptr;
LOG(FATAL) << "Expected " LOG(FATAL) << "Expected "
<< "DLTensor* or NDArray but get " << "DGLArray* or NDArray but get "
<< TypeCode2Str(type_code_); << TypeCode2Str(type_code_);
return nullptr; return nullptr;
} }
...@@ -457,14 +457,14 @@ class DGLArgValue : public DGLPODValue_ { ...@@ -457,14 +457,14 @@ class DGLArgValue : public DGLPODValue_ {
using DGLPODValue_::operator int; using DGLPODValue_::operator int;
using DGLPODValue_::operator bool; using DGLPODValue_::operator bool;
using DGLPODValue_::operator void*; using DGLPODValue_::operator void*;
using DGLPODValue_::operator DLTensor*; using DGLPODValue_::operator DGLArray*;
using DGLPODValue_::operator NDArray; using DGLPODValue_::operator NDArray;
using DGLPODValue_::operator DGLContext; using DGLPODValue_::operator DGLContext;
// conversion operator. // conversion operator.
operator std::string() const { operator std::string() const {
if (type_code_ == kDGLType) { if (type_code_ == kDGLDataType) {
return DGLType2String(operator DGLType()); return DGLDataType2String(operator DGLDataType());
} else if (type_code_ == kBytes) { } else if (type_code_ == kBytes) {
DGLByteArray* arr = static_cast<DGLByteArray*>(value_.v_handle); DGLByteArray* arr = static_cast<DGLByteArray*>(value_.v_handle);
return std::string(arr->data, arr->size); return std::string(arr->data, arr->size);
...@@ -473,11 +473,11 @@ class DGLArgValue : public DGLPODValue_ { ...@@ -473,11 +473,11 @@ class DGLArgValue : public DGLPODValue_ {
return std::string(value_.v_str); return std::string(value_.v_str);
} }
} }
operator DGLType() const { operator DGLDataType() const {
if (type_code_ == kStr) { if (type_code_ == kStr) {
return String2DGLType(operator std::string()); return String2DGLDataType(operator std::string());
} }
DGL_CHECK_TYPE_CODE(type_code_, kDGLType); DGL_CHECK_TYPE_CODE(type_code_, kDGLDataType);
return value_.v_type; return value_.v_type;
} }
operator PackedFunc() const { operator PackedFunc() const {
...@@ -549,7 +549,7 @@ class DGLRetValue : public DGLPODValue_ { ...@@ -549,7 +549,7 @@ class DGLRetValue : public DGLPODValue_ {
using DGLPODValue_::operator int; using DGLPODValue_::operator int;
using DGLPODValue_::operator bool; using DGLPODValue_::operator bool;
using DGLPODValue_::operator void*; using DGLPODValue_::operator void*;
using DGLPODValue_::operator DLTensor*; using DGLPODValue_::operator DGLArray*;
using DGLPODValue_::operator DGLContext; using DGLPODValue_::operator DGLContext;
using DGLPODValue_::operator NDArray; using DGLPODValue_::operator NDArray;
// Disable copy and assign from another value, but allow move. // Disable copy and assign from another value, but allow move.
...@@ -558,19 +558,19 @@ class DGLRetValue : public DGLPODValue_ { ...@@ -558,19 +558,19 @@ class DGLRetValue : public DGLPODValue_ {
} }
// conversion operators // conversion operators
operator std::string() const { operator std::string() const {
if (type_code_ == kDGLType) { if (type_code_ == kDGLDataType) {
return DGLType2String(operator DGLType()); return DGLDataType2String(operator DGLDataType());
} else if (type_code_ == kBytes) { } else if (type_code_ == kBytes) {
return *ptr<std::string>(); return *ptr<std::string>();
} }
DGL_CHECK_TYPE_CODE(type_code_, kStr); DGL_CHECK_TYPE_CODE(type_code_, kStr);
return *ptr<std::string>(); return *ptr<std::string>();
} }
operator DGLType() const { operator DGLDataType() const {
if (type_code_ == kStr) { if (type_code_ == kStr) {
return String2DGLType(operator std::string()); return String2DGLDataType(operator std::string());
} }
DGL_CHECK_TYPE_CODE(type_code_, kDGLType); DGL_CHECK_TYPE_CODE(type_code_, kDGLDataType);
return value_.v_type; return value_.v_type;
} }
operator PackedFunc() const { operator PackedFunc() const {
...@@ -595,7 +595,7 @@ class DGLRetValue : public DGLPODValue_ { ...@@ -595,7 +595,7 @@ class DGLRetValue : public DGLPODValue_ {
return *this; return *this;
} }
DGLRetValue& operator=(double value) { DGLRetValue& operator=(double value) {
this->SwitchToPOD(kDLFloat); this->SwitchToPOD(kDGLFloat);
value_.v_float64 = value; value_.v_float64 = value;
return *this; return *this;
} }
...@@ -610,17 +610,17 @@ class DGLRetValue : public DGLPODValue_ { ...@@ -610,17 +610,17 @@ class DGLRetValue : public DGLPODValue_ {
return *this; return *this;
} }
DGLRetValue& operator=(int64_t value) { DGLRetValue& operator=(int64_t value) {
this->SwitchToPOD(kDLInt); this->SwitchToPOD(kDGLInt);
value_.v_int64 = value; value_.v_int64 = value;
return *this; return *this;
} }
DGLRetValue& operator=(int value) { DGLRetValue& operator=(int value) {
this->SwitchToPOD(kDLInt); this->SwitchToPOD(kDGLInt);
value_.v_int64 = value; value_.v_int64 = value;
return *this; return *this;
} }
DGLRetValue& operator=(DGLType t) { DGLRetValue& operator=(DGLDataType t) {
this->SwitchToPOD(kDGLType); this->SwitchToPOD(kDGLDataType);
value_.v_type = t; value_.v_type = t;
return *this; return *this;
} }
...@@ -630,7 +630,7 @@ class DGLRetValue : public DGLPODValue_ { ...@@ -630,7 +630,7 @@ class DGLRetValue : public DGLPODValue_ {
return *this; return *this;
} }
DGLRetValue& operator=(bool value) { DGLRetValue& operator=(bool value) {
this->SwitchToPOD(kDLInt); this->SwitchToPOD(kDGLInt);
value_.v_int64 = value; value_.v_int64 = value;
return *this; return *this;
} }
...@@ -859,17 +859,17 @@ class DGLArgsSetter { ...@@ -859,17 +859,17 @@ class DGLArgsSetter {
std::is_integral<T>::value>::type> std::is_integral<T>::value>::type>
void operator()(size_t i, T value) const { void operator()(size_t i, T value) const {
values_[i].v_int64 = static_cast<int64_t>(value); values_[i].v_int64 = static_cast<int64_t>(value);
type_codes_[i] = kDLInt; type_codes_[i] = kDGLInt;
} }
void operator()(size_t i, uint64_t value) const { void operator()(size_t i, uint64_t value) const {
values_[i].v_int64 = static_cast<int64_t>(value); values_[i].v_int64 = static_cast<int64_t>(value);
CHECK_LE(value, CHECK_LE(value,
static_cast<uint64_t>(std::numeric_limits<int64_t>::max())); static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
type_codes_[i] = kDLInt; type_codes_[i] = kDGLInt;
} }
void operator()(size_t i, double value) const { void operator()(size_t i, double value) const {
values_[i].v_float64 = value; values_[i].v_float64 = value;
type_codes_[i] = kDLFloat; type_codes_[i] = kDGLFloat;
} }
void operator()(size_t i, std::nullptr_t value) const { void operator()(size_t i, std::nullptr_t value) const {
values_[i].v_handle = value; values_[i].v_handle = value;
...@@ -883,7 +883,7 @@ class DGLArgsSetter { ...@@ -883,7 +883,7 @@ class DGLArgsSetter {
values_[i].v_handle = value; values_[i].v_handle = value;
type_codes_[i] = kHandle; type_codes_[i] = kHandle;
} }
void operator()(size_t i, DLTensor* value) const { void operator()(size_t i, DGLArray* value) const {
values_[i].v_handle = value; values_[i].v_handle = value;
type_codes_[i] = kArrayHandle; type_codes_[i] = kArrayHandle;
} }
...@@ -891,9 +891,9 @@ class DGLArgsSetter { ...@@ -891,9 +891,9 @@ class DGLArgsSetter {
values_[i].v_ctx = value; values_[i].v_ctx = value;
type_codes_[i] = kDGLContext; type_codes_[i] = kDGLContext;
} }
void operator()(size_t i, DGLType value) const { void operator()(size_t i, DGLDataType value) const {
values_[i].v_type = value; values_[i].v_type = value;
type_codes_[i] = kDGLType; type_codes_[i] = kDGLDataType;
} }
void operator()(size_t i, const char* value) const { void operator()(size_t i, const char* value) const {
values_[i].v_str = value; values_[i].v_str = value;
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file dgl/runtime/serializer.h * \file dgl/runtime/serializer.h
* \brief Serializer extension to support DGL data types * \brief Serializer extension to support DGL data types
* Include this file to enable serialization of DLDataType, DLContext * Include this file to enable serialization of DGLDataType, DGLContext
*/ */
#ifndef DGL_RUNTIME_SERIALIZER_H_ #ifndef DGL_RUNTIME_SERIALIZER_H_
#define DGL_RUNTIME_SERIALIZER_H_ #define DGL_RUNTIME_SERIALIZER_H_
...@@ -16,13 +16,13 @@ namespace dmlc { ...@@ -16,13 +16,13 @@ namespace dmlc {
namespace serializer { namespace serializer {
template <> template <>
struct Handler<DLDataType> { struct Handler<DGLDataType> {
inline static void Write(Stream *strm, const DLDataType &dtype) { inline static void Write(Stream *strm, const DGLDataType &dtype) {
Handler<uint8_t>::Write(strm, dtype.code); Handler<uint8_t>::Write(strm, dtype.code);
Handler<uint8_t>::Write(strm, dtype.bits); Handler<uint8_t>::Write(strm, dtype.bits);
Handler<uint16_t>::Write(strm, dtype.lanes); Handler<uint16_t>::Write(strm, dtype.lanes);
} }
inline static bool Read(Stream *strm, DLDataType *dtype) { inline static bool Read(Stream *strm, DGLDataType *dtype) {
if (!Handler<uint8_t>::Read(strm, &(dtype->code))) return false; if (!Handler<uint8_t>::Read(strm, &(dtype->code))) return false;
if (!Handler<uint8_t>::Read(strm, &(dtype->bits))) return false; if (!Handler<uint8_t>::Read(strm, &(dtype->bits))) return false;
if (!Handler<uint16_t>::Read(strm, &(dtype->lanes))) return false; if (!Handler<uint16_t>::Read(strm, &(dtype->lanes))) return false;
...@@ -31,16 +31,16 @@ struct Handler<DLDataType> { ...@@ -31,16 +31,16 @@ struct Handler<DLDataType> {
}; };
template <> template <>
struct Handler<DLContext> { struct Handler<DGLContext> {
inline static void Write(Stream *strm, const DLContext &ctx) { inline static void Write(Stream *strm, const DGLContext &ctx) {
int32_t device_type = static_cast<int32_t>(ctx.device_type); int32_t device_type = static_cast<int32_t>(ctx.device_type);
Handler<int32_t>::Write(strm, device_type); Handler<int32_t>::Write(strm, device_type);
Handler<int32_t>::Write(strm, ctx.device_id); Handler<int32_t>::Write(strm, ctx.device_id);
} }
inline static bool Read(Stream *strm, DLContext *ctx) { inline static bool Read(Stream *strm, DGLContext *ctx) {
int32_t device_type = 0; int32_t device_type = 0;
if (!Handler<int32_t>::Read(strm, &(device_type))) return false; if (!Handler<int32_t>::Read(strm, &(device_type))) return false;
ctx->device_type = static_cast<DLDeviceType>(device_type); ctx->device_type = static_cast<DGLDeviceType>(device_type);
if (!Handler<int32_t>::Read(strm, &(ctx->device_id))) return false; if (!Handler<int32_t>::Read(strm, &(ctx->device_id))) return false;
return true; return true;
} }
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file dgl/runtime/serializer.h * \file dgl/runtime/serializer.h
* \brief Serializer extension to support DGL data types * \brief Serializer extension to support DGL data types
* Include this file to enable serialization of DLDataType, DLContext * Include this file to enable serialization of DGLDataType, DGLContext
*/ */
#ifndef DGL_RUNTIME_SMART_PTR_SERIALIZER_H_ #ifndef DGL_RUNTIME_SMART_PTR_SERIALIZER_H_
#define DGL_RUNTIME_SMART_PTR_SERIALIZER_H_ #define DGL_RUNTIME_SMART_PTR_SERIALIZER_H_
......
...@@ -18,7 +18,7 @@ namespace runtime { ...@@ -18,7 +18,7 @@ namespace runtime {
* \param bits The number of bits to be matched. * \param bits The number of bits to be matched.
* \param lanes The number of lanes sin the type. * \param lanes The number of lanes sin the type.
*/ */
inline bool TypeMatch(DGLType t, int code, int bits, int lanes = 1) { inline bool TypeMatch(DGLDataType t, int code, int bits, int lanes = 1) {
return t.code == code && t.bits == bits && t.lanes == lanes; return t.code == code && t.bits == bits && t.lanes == lanes;
} }
} // namespace runtime } // namespace runtime
......
...@@ -10,7 +10,7 @@ from numbers import Number, Integral ...@@ -10,7 +10,7 @@ from numbers import Number, Integral
from ..base import _LIB, check_call from ..base import _LIB, check_call
from ..base import c_str, string_types from ..base import c_str, string_types
from ..object_generic import convert_to_object, ObjectGeneric from ..object_generic import convert_to_object, ObjectGeneric
from ..runtime_ctypes import DGLType, DGLByteArray, DGLContext from ..runtime_ctypes import DGLDataType, DGLByteArray, DGLContext
from . import ndarray as _nd from . import ndarray as _nd
from .ndarray import NDArrayBase, _make_array from .ndarray import NDArrayBase, _make_array
from .types import DGLValue, TypeCode from .types import DGLValue, TypeCode
...@@ -115,7 +115,7 @@ def _make_dgl_args(args, temp_args): ...@@ -115,7 +115,7 @@ def _make_dgl_args(args, temp_args):
elif isinstance(arg, Number): elif isinstance(arg, Number):
values[i].v_float64 = arg values[i].v_float64 = arg
type_codes[i] = TypeCode.FLOAT type_codes[i] = TypeCode.FLOAT
elif isinstance(arg, DGLType): elif isinstance(arg, DGLDataType):
values[i].v_str = c_str(str(arg)) values[i].v_str = c_str(str(arg))
type_codes[i] = TypeCode.STR type_codes[i] = TypeCode.STR
elif isinstance(arg, DGLContext): elif isinstance(arg, DGLContext):
......
...@@ -4,7 +4,7 @@ from __future__ import absolute_import as _abs ...@@ -4,7 +4,7 @@ from __future__ import absolute_import as _abs
import ctypes import ctypes
from ..base import py_str, check_call, _LIB from ..base import py_str, check_call, _LIB
from ..runtime_ctypes import DGLByteArray, TypeCode, DGLType, DGLContext from ..runtime_ctypes import DGLByteArray, TypeCode, DGLDataType, DGLContext
class DGLValue(ctypes.Union): class DGLValue(ctypes.Union):
"""DGLValue in C API""" """DGLValue in C API"""
...@@ -12,7 +12,7 @@ class DGLValue(ctypes.Union): ...@@ -12,7 +12,7 @@ class DGLValue(ctypes.Union):
("v_float64", ctypes.c_double), ("v_float64", ctypes.c_double),
("v_handle", ctypes.c_void_p), ("v_handle", ctypes.c_void_p),
("v_str", ctypes.c_char_p), ("v_str", ctypes.c_char_p),
("v_type", DGLType), ("v_type", DGLDataType),
("v_ctx", DGLContext)] ("v_ctx", DGLContext)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment