Unverified Commit f1b19a6b authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[CUDA] Many CUDA operators; Prepare for DGLGraph on CUDA (#1660)

* add cuda utils; change g.to; add g.device

* split array.h into several headers

* cuda index select

* file

* three cuda kernels

* add cuda elementwise arith and several others

* cuda CSRIsNonZero

* fix lint

* lint

* lint

* fix bug in changing ctx to property

* address comments

* remove unused codes

* address comments
parent 42b0c38f
/*!
* Copyright (c) 2019 by Contributors
* Copyright (c) 2020 by Contributors
* \file dgl/array.h
* \brief Array types and common array operations required by DGL.
* \brief Common array operations required by DGL.
*
* Note that this is not meant for a full support of array library such as ATen.
* Only a limited set of operators required by DGL are implemented.
*/
#ifndef DGL_ARRAY_H_
#define DGL_ARRAY_H_
#include <dmlc/io.h>
#include <dmlc/serializer.h>
#include <algorithm>
#include <vector>
#include <tuple>
#include <utility>
#include <string>
#include "./runtime/ndarray.h"
#include "./runtime/object.h"
namespace dgl {
typedef uint64_t dgl_id_t;
typedef uint64_t dgl_type_t;
/*! \brief Type for dgl fomrat code, whose binary representation indices
* which sparse format is in use and which is not.
*
* Suppose the binary representation is xyz, then
* - x indicates whether csc is in use (1 for true and 0 for false).
* - y indicates whether csr is in use.
* - z indicates whether coo is in use.
*/
typedef uint8_t dgl_format_code_t;
using dgl::runtime::NDArray;
typedef NDArray IdArray;
typedef NDArray DegreeArray;
typedef NDArray BoolArray;
typedef NDArray IntArray;
typedef NDArray FloatArray;
typedef NDArray TypeArray;
/*!
* \brief Sparse format.
*/
enum class SparseFormat {
kAny = 0,
kCOO = 1,
kCSR = 2,
kCSC = 3,
kAuto = 4 // kAuto is a placeholder that indicates it would be materialized later.
};
// Parse sparse format from string.
inline SparseFormat ParseSparseFormat(const std::string& name) {
if (name == "coo")
return SparseFormat::kCOO;
else if (name == "csr")
return SparseFormat::kCSR;
else if (name == "csc")
return SparseFormat::kCSC;
else if (name == "any")
return SparseFormat::kAny;
else if (name == "auto")
return SparseFormat::kAuto;
else
LOG(FATAL) << "Sparse format not recognized";
return SparseFormat::kAny;
}
// Create string from sparse format.
inline std::string ToStringSparseFormat(SparseFormat sparse_format) {
if (sparse_format == SparseFormat::kCOO)
return std::string("coo");
else if (sparse_format == SparseFormat::kCSR)
return std::string("csr");
else if (sparse_format == SparseFormat::kCSC)
return std::string("csc");
else if (sparse_format == SparseFormat::kAny)
return std::string("any");
else
return std::string("auto");
}
// Sparse matrix object that is exposed to python API.
struct SparseMatrix : public runtime::Object {
// Sparse format.
int32_t format = 0;
// Shape of this matrix.
int64_t num_rows = 0, num_cols = 0;
// Index arrays. For CSR, it is {indptr, indices, data}. For COO, it is {row, col, data}.
std::vector<IdArray> indices;
// Boolean flags.
// TODO(minjie): We might revisit this later to provide a more general solution. Currently,
// we only consider aten::COOMatrix and aten::CSRMatrix.
std::vector<bool> flags;
SparseMatrix() {}
SparseMatrix(int32_t fmt, int64_t nrows, int64_t ncols,
const std::vector<IdArray>& idx,
const std::vector<bool>& flg)
: format(fmt), num_rows(nrows), num_cols(ncols), indices(idx), flags(flg) {}
static constexpr const char* _type_key = "aten.SparseMatrix";
DGL_DECLARE_OBJECT_TYPE_INFO(SparseMatrix, runtime::Object);
};
// Define SparseMatrixRef
DGL_DEFINE_OBJECT_REF(SparseMatrixRef, SparseMatrix);
namespace aten {
//////////////////////////////////////////////////////////////////////
// ID array
//////////////////////////////////////////////////////////////////////
/*! \return A special array to represent null. */
inline NDArray NullArray() {
return NDArray::Empty({0}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
}
/*!
* \return Whether the input array is a null array.
*/
inline bool IsNullArray(NDArray array) {
return array->shape[0] == 0;
}
/*!
* \brief Create a new id array with given length
* \param length The array length
* \param ctx The array context
* \param nbits The number of integer bits
* \return id array
*/
IdArray NewIdArray(int64_t length,
DLContext ctx = DLContext{kDLCPU, 0},
uint8_t nbits = 64);
/*!
* \brief Create a new id array using the given vector data
* \param vec The vector data
* \param nbits The integer bits of the returned array
* \param ctx The array context
* \return the id array
*/
template <typename T>
IdArray VecToIdArray(const std::vector<T>& vec,
uint8_t nbits = 64,
DLContext ctx = DLContext{kDLCPU, 0});
/*!
* \brief Return an array representing a 1D range.
* \param low Lower bound (inclusive).
* \param high Higher bound (exclusive).
* \param nbits result array's bits (32 or 64)
* \param ctx Device context
* \return range array
*/
IdArray Range(int64_t low, int64_t high, uint8_t nbits, DLContext ctx);
/*!
* \brief Return an array full of the given value
* \param val The value to fill.
* \param length Number of elements.
* \param nbits result array's bits (32 or 64)
* \param ctx Device context
* \return the result array
*/
IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx);
/*! \brief Create a deep copy of the given array */
IdArray Clone(IdArray arr);
/*! \brief Convert the idarray to the given bit width */
IdArray AsNumBits(IdArray arr, uint8_t bits);
/*! \brief Arithmetic functions */
IdArray Add(IdArray lhs, IdArray rhs);
IdArray Sub(IdArray lhs, IdArray rhs);
IdArray Mul(IdArray lhs, IdArray rhs);
IdArray Div(IdArray lhs, IdArray rhs);
IdArray Add(IdArray lhs, dgl_id_t rhs);
IdArray Sub(IdArray lhs, dgl_id_t rhs);
IdArray Mul(IdArray lhs, dgl_id_t rhs);
IdArray Div(IdArray lhs, dgl_id_t rhs);
IdArray Add(dgl_id_t lhs, IdArray rhs);
IdArray Sub(dgl_id_t lhs, IdArray rhs);
IdArray Mul(dgl_id_t lhs, IdArray rhs);
IdArray Div(dgl_id_t lhs, IdArray rhs);
BoolArray LT(IdArray lhs, dgl_id_t rhs);
/*! \brief Stack two arrays (of len L) into a 2*L length array */
IdArray HStack(IdArray arr1, IdArray arr2);
/*!
* \brief Return the data under the index. In numpy notation, A[I]
* \tparam ValueType The type of return value.
*/
template<typename ValueType>
ValueType IndexSelect(NDArray array, uint64_t index);
NDArray IndexSelect(NDArray array, IdArray index);
/*!
* \brief Permute the elements of an array according to given indices.
*
* Equivalent to:
*
* <code>
* result = np.zeros_like(array)
* result[indices] = array
* </code>
*/
NDArray Scatter(NDArray array, IdArray indices);
/*!
* \brief Repeat each element a number of times. Equivalent to np.repeat(array, repeats)
* \param array A 1D vector
* \param repeats A 1D integer vector for number of times to repeat for each element in
* \c array. Must have the same shape as \c array.
*/
NDArray Repeat(NDArray array, IdArray repeats);
/*!
* \brief Relabel the given ids to consecutive ids.
*
* Relabeling is done inplace. The mapping is created from the union
* of the give arrays.
*
* \param arrays The id arrays to relabel.
* \return mapping array M from new id to old id.
*/
IdArray Relabel_(const std::vector<IdArray>& arrays);
/*!\brief Return whether the array is a valid 1D int array*/
inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) {
return arr->ndim == 1 && arr->dtype.code == kDLInt;
}
/*!
* \brief Packs a tensor containing padded sequences of variable length.
*
* Similar to \c pack_padded_sequence in PyTorch, except that
*
* 1. The length for each sequence (before padding) is inferred as the number
* of elements before the first occurrence of \c pad_value.
* 2. It does not sort the sequences by length.
* 3. Along with the tensor containing the packed sequence, it returns both the
* length, as well as the offsets to the packed tensor, of each sequence.
*
* \param array The tensor containing sequences padded to the same length
* \param pad_value The padding value
* \return A triplet of packed tensor, the length tensor, and the offset tensor
*
* \note Example: consider the following array with padding value -1:
*
* <code>
* [[1, 2, -1, -1],
* [3, 4, 5, -1]]
* </code>
*
* The packed tensor would be [1, 2, 3, 4, 5].
*
* The length tensor would be [2, 3], i.e. the length of each sequence before padding.
*
* The offset tensor would be [0, 2], i.e. the offset to the packed tensor for each
* sequence (before padding)
*/
template<typename ValueType>
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value);
/*!
* \brief Batch-slice a 1D or 2D array, and then pack the list of sliced arrays
* by concatenation.
*
* If a 2D array is given, then the function is equivalent to:
*
* <code>
* def ConcatSlices(array, lengths):
* slices = [array[i, :l] for i, l in enumerate(lengths)]
* packed = np.concatenate(slices)
* offsets = np.cumsum([0] + lengths[:-1])
* return packed, offsets
* </code>
*
* If a 1D array is given, then the function is equivalent to
*
* <code>
* def ConcatSlices(array, lengths):
* slices = [array[:l] for l in lengths]
* packed = np.concatenate(slices)
* offsets = np.cumsum([0] + lengths[:-1])
* return packed, offsets
* </code>
*
* \param array A 1D or 2D tensor for slicing
* \param lengths A 1D tensor indicating the number of elements to slice
* \return The tensor with packed slices along with the offsets.
*/
std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths);
//////////////////////////////////////////////////////////////////////
// Sparse matrix
//////////////////////////////////////////////////////////////////////
/*!
* \brief Plain CSR matrix
*
* The column indices are 0-based and are not necessarily sorted. The data array stores
* integer ids for reading edge features.
*
* Note that we do allow duplicate non-zero entries -- multiple non-zero entries
* that have the same row, col indices. It corresponds to multigraph in
* graph terminology.
*/
constexpr uint64_t kDGLSerialize_AtenCsrMatrixMagic = 0xDD6cd31205dff127;
struct CSRMatrix {
/*! \brief the dense shape of the matrix */
int64_t num_rows = 0, num_cols = 0;
/*! \brief CSR index arrays */
IdArray indptr, indices;
/*! \brief data index array. When is null, assume it is from 0 to NNZ - 1. */
IdArray data;
/*! \brief whether the column indices per row are sorted */
bool sorted = false;
/*! \brief default constructor */
CSRMatrix() = default;
/*! \brief constructor */
CSRMatrix(int64_t nrows, int64_t ncols, IdArray parr, IdArray iarr,
IdArray darr = NullArray(), bool sorted_flag = false)
: num_rows(nrows),
num_cols(ncols),
indptr(parr),
indices(iarr),
data(darr),
sorted(sorted_flag) {
CHECK_EQ(indptr->dtype.bits, indices->dtype.bits)
<< "The indptr and indices arrays must have the same data type.";
}
/*! \brief constructor from SparseMatrix object */
explicit CSRMatrix(const SparseMatrix& spmat)
: num_rows(spmat.num_rows),
num_cols(spmat.num_cols),
indptr(spmat.indices[0]),
indices(spmat.indices[1]),
data(spmat.indices[2]),
sorted(spmat.flags[0]) {
CHECK_EQ(indptr->dtype.bits, indices->dtype.bits)
<< "The indptr and indices arrays must have the same data type.";
}
// Convert to a SparseMatrix object that can return to python.
SparseMatrix ToSparseMatrix() const {
return SparseMatrix(static_cast<int32_t>(SparseFormat::kCSR), num_rows,
num_cols, {indptr, indices, data}, {sorted});
}
bool Load(dmlc::Stream* fs) {
uint64_t magicNum;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_AtenCsrMatrixMagic)
<< "Invalid CSRMatrix Data";
CHECK(fs->Read(&num_cols)) << "Invalid num_cols";
CHECK(fs->Read(&num_rows)) << "Invalid num_rows";
CHECK(fs->Read(&indptr)) << "Invalid indptr";
CHECK(fs->Read(&indices)) << "Invalid indices";
CHECK(fs->Read(&data)) << "Invalid data";
CHECK(fs->Read(&sorted)) << "Invalid sorted";
return true;
}
void Save(dmlc::Stream* fs) const {
fs->Write(kDGLSerialize_AtenCsrMatrixMagic);
fs->Write(num_cols);
fs->Write(num_rows);
fs->Write(indptr);
fs->Write(indices);
fs->Write(data);
fs->Write(sorted);
}
};
/*!
* \brief Plain COO structure
*
* The data array stores integer ids for reading edge features.
* Note that we do allow duplicate non-zero entries -- multiple non-zero entries
* that have the same row, col indices. It corresponds to multigraph in
* graph terminology.
*/
constexpr uint64_t kDGLSerialize_AtenCooMatrixMagic = 0xDD61ffd305dff127;
// TODO(BarclayII): Graph queries on COO formats should support the case where
// data ordered by rows/columns instead of EID.
struct COOMatrix {
/*! \brief the dense shape of the matrix */
int64_t num_rows = 0, num_cols = 0;
/*! \brief COO index arrays */
IdArray row, col;
/*! \brief data index array. When is null, assume it is from 0 to NNZ - 1. */
IdArray data;
/*! \brief whether the row indices are sorted */
bool row_sorted = false;
/*! \brief whether the column indices per row are sorted */
bool col_sorted = false;
/*! \brief default constructor */
COOMatrix() = default;
/*! \brief constructor */
COOMatrix(int64_t nrows, int64_t ncols, IdArray rarr, IdArray carr,
IdArray darr = NullArray(), bool rsorted = false,
bool csorted = false)
: num_rows(nrows),
num_cols(ncols),
row(rarr),
col(carr),
data(darr),
row_sorted(rsorted),
col_sorted(csorted) {
CHECK_EQ(row->dtype.bits, col->dtype.bits)
<< "The row and col arrays must have the same data type.";
}
/*! \brief constructor from SparseMatrix object */
explicit COOMatrix(const SparseMatrix& spmat)
: num_rows(spmat.num_rows),
num_cols(spmat.num_cols),
row(spmat.indices[0]),
col(spmat.indices[1]),
data(spmat.indices[2]),
row_sorted(spmat.flags[0]),
col_sorted(spmat.flags[1]) {
CHECK_EQ(row->dtype.bits, col->dtype.bits)
<< "The row and col arrays must have the same data type.";
}
// Convert to a SparseMatrix object that can return to python.
SparseMatrix ToSparseMatrix() const {
return SparseMatrix(static_cast<int32_t>(SparseFormat::kCOO), num_rows,
num_cols, {row, col, data}, {row_sorted, col_sorted});
}
bool Load(dmlc::Stream* fs) {
uint64_t magicNum;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_AtenCooMatrixMagic)
<< "Invalid COOMatrix Data";
CHECK(fs->Read(&num_cols)) << "Invalid num_cols";
CHECK(fs->Read(&num_rows)) << "Invalid num_rows";
CHECK(fs->Read(&row)) << "Invalid row";
CHECK(fs->Read(&col)) << "Invalid col";
CHECK(fs->Read(&data)) << "Invalid data";
CHECK(fs->Read(&row_sorted)) << "Invalid row_sorted";
CHECK(fs->Read(&col_sorted)) << "Invalid col_sorted";
return true;
}
void Save(dmlc::Stream* fs) const {
fs->Write(kDGLSerialize_AtenCooMatrixMagic);
fs->Write(num_cols);
fs->Write(num_rows);
fs->Write(row);
fs->Write(col);
fs->Write(data);
fs->Write(row_sorted);
fs->Write(col_sorted);
}
};
///////////////////////// CSR routines //////////////////////////
/*! \brief Return true if the value (row, col) is non-zero */
bool CSRIsNonZero(CSRMatrix , int64_t row, int64_t col);
/*!
* \brief Batched implementation of CSRIsNonZero.
* \note This operator allows broadcasting (i.e, either row or col can be of length 1).
*/
runtime::NDArray CSRIsNonZero(CSRMatrix, runtime::NDArray row, runtime::NDArray col);
/*! \brief Return the nnz of the given row */
int64_t CSRGetRowNNZ(CSRMatrix , int64_t row);
runtime::NDArray CSRGetRowNNZ(CSRMatrix , runtime::NDArray row);
/*! \brief Return the column index array of the given row */
runtime::NDArray CSRGetRowColumnIndices(CSRMatrix , int64_t row);
/*! \brief Return the data array of the given row */
runtime::NDArray CSRGetRowData(CSRMatrix , int64_t row);
/*! \brief Whether the CSR matrix contains data */
inline bool CSRHasData(CSRMatrix csr) {
return !IsNullArray(csr.data);
}
/* \brief Get data. The return type is an ndarray due to possible duplicate entries. */
runtime::NDArray CSRGetData(CSRMatrix , int64_t row, int64_t col);
/*!
* \brief Batched implementation of CSRGetData.
* \note This operator allows broadcasting (i.e, either row or col can be of length 1).
*/
runtime::NDArray CSRGetData(CSRMatrix, runtime::NDArray rows, runtime::NDArray cols);
/*!
* \brief Get the data and the row,col indices for each returned entries.
* \note This operator allows broadcasting (i.e, either row or col can be of length 1).
*/
std::vector<runtime::NDArray> CSRGetDataAndIndices(
CSRMatrix , runtime::NDArray rows, runtime::NDArray cols);
/*! \brief Return a transposed CSR matrix */
CSRMatrix CSRTranspose(CSRMatrix csr);
/*!
* \brief Convert CSR matrix to COO matrix.
* \param csr Input csr matrix
* \param data_as_order If true, the data array in the input csr matrix contains the order
* by which the resulting COO tuples are stored. In this case, the
* data array of the resulting COO matrix will be empty because it
* is essentially a consecutive range.
* \return a coo matrix
*/
COOMatrix CSRToCOO(CSRMatrix csr, bool data_as_order);
/*!
* \brief Slice rows of the given matrix and return.
* \param csr CSR matrix
* \param start Start row id (inclusive)
* \param end End row id (exclusive)
*
* Examples:
* num_rows = 4
* num_cols = 4
* indptr = [0, 2, 3, 3, 5]
* indices = [1, 0, 2, 3, 1]
*
* After CSRSliceRows(csr, 1, 3)
*
* num_rows = 2
* num_cols = 4
* indptr = [0, 1, 1]
* indices = [2]
*/
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end);
CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows);
/*!
* \brief Get the submatrix specified by the row and col ids.
*
* In numpy notation, given matrix M, row index array I, col index array J
* This function returns the submatrix M[I, J].
*
* \param csr The input csr matrix
* \param rows The row index to select
* \param cols The col index to select
* \return submatrix
*/
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
/*! \return True if the matrix has duplicate entries */
bool CSRHasDuplicate(CSRMatrix csr);
/*!
* \brief Sort the column index at each row in the ascending order.
*
* Examples:
* num_rows = 4
* num_cols = 4
* indptr = [0, 2, 3, 3, 5]
* indices = [1, 0, 2, 3, 1]
*
* After CSRSort_(&csr)
*
* indptr = [0, 2, 3, 3, 5]
* indices = [0, 1, 1, 2, 3]
*/
void CSRSort_(CSRMatrix* csr);
/*!
* \brief Reorder the rows and colmns according to the new row and column order.
* \param csr The input csr matrix.
* \param new_row_ids the new row Ids (the index is the old row Id)
* \param new_col_ids the new column Ids (the index is the old col Id).
*/
CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
/*!
* \brief Remove entries from CSR matrix by entry indices (data indices)
* \return A new CSR matrix as well as a mapping from the new CSR entries to the old CSR
* entries.
*/
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);
/*!
* \brief Randomly select a fixed number of non-zero entries along each given row independently.
*
* The function performs random choices along each row independently.
* The picked indices are returned in the form of a COO matrix.
*
* If replace is false and a row has fewer non-zero values than num_samples,
* all the values are picked.
*
* Examples:
*
* // csr.num_rows = 4;
* // csr.num_cols = 4;
* // csr.indptr = [0, 2, 3, 3, 5]
* // csr.indices = [0, 1, 1, 2, 3]
* // csr.data = [2, 3, 0, 1, 4]
* CSRMatrix csr = ...;
* IdArray rows = ... ; // [1, 3]
* COOMatrix sampled = CSRRowWiseSampling(csr, rows, 2, FloatArray(), false);
* // possible sampled coo matrix:
* // sampled.num_rows = 4
* // sampled.num_cols = 4
* // sampled.rows = [1, 3, 3]
* // sampled.cols = [1, 2, 3]
* // sampled.data = [3, 0, 4]
*
* \param mat Input CSR matrix.
* \param rows Rows to sample from.
* \param num_samples Number of samples
* \param prob Unnormalized probability array. Should be of the same length as the data array.
* If an empty array is provided, assume uniform.
* \param replace True if sample with replacement
* \return A COOMatrix storing the picked row, col and data indices.
*/
COOMatrix CSRRowWiseSampling(
CSRMatrix mat,
IdArray rows,
int64_t num_samples,
FloatArray prob = FloatArray(),
bool replace = true);
/*!
* \brief Select K non-zero entries with the largest weights along each given row.
*
* The function performs top-k selection along each row independently.
* The picked indices are returned in the form of a COO matrix.
*
* If replace is false and a row has fewer non-zero values than k,
* all the values are picked.
*
* Examples:
*
* // csr.num_rows = 4;
* // csr.num_cols = 4;
* // csr.indptr = [0, 2, 3, 3, 5]
* // csr.indices = [0, 1, 1, 2, 3]
* // csr.data = [2, 3, 0, 1, 4]
* CSRMatrix csr = ...;
* IdArray rows = ... ; // [0, 1, 3]
* FloatArray weight = ... ; // [1., 0., -1., 10., 20.]
* COOMatrix sampled = CSRRowWiseTopk(csr, rows, 1, weight);
* // possible sampled coo matrix:
* // sampled.num_rows = 4
* // sampled.num_cols = 4
* // sampled.rows = [0, 1, 3]
* // sampled.cols = [1, 1, 2]
* // sampled.data = [3, 0, 1]
*
* \param mat Input CSR matrix.
* \param rows Rows to sample from.
* \param k The K value.
* \param weight Weight associated with each entry. Should be of the same length as the
* data array. If an empty array is provided, assume uniform.
* \param ascending If true, elements are sorted by ascending order, equivalent to find
* the K smallest values. Otherwise, find K largest values.
* \return A COOMatrix storing the picked row and col indices. Its data field stores the
* the index of the picked elements in the value array.
*/
COOMatrix CSRRowWiseTopk(
CSRMatrix mat,
IdArray rows,
int64_t k,
FloatArray weight,
bool ascending = false);
///////////////////////// COO routines //////////////////////////
/*! \brief Return true if the value (row, col) is non-zero */
bool COOIsNonZero(COOMatrix , int64_t row, int64_t col);
/*!
* \brief Batched implementation of COOIsNonZero.
* \note This operator allows broadcasting (i.e, either row or col can be of length 1).
*/
runtime::NDArray COOIsNonZero(COOMatrix , runtime::NDArray row, runtime::NDArray col);
/*! \brief Return the nnz of the given row */
int64_t COOGetRowNNZ(COOMatrix , int64_t row);
runtime::NDArray COOGetRowNNZ(COOMatrix , runtime::NDArray row);
/*! \brief Return the data array of the given row */
std::pair<runtime::NDArray, runtime::NDArray>
COOGetRowDataAndIndices(COOMatrix , int64_t row);
/*! \brief Whether the COO matrix contains data */
inline bool COOHasData(COOMatrix csr) {
return !IsNullArray(csr.data);
}
/*! \brief Get data. The return type is an ndarray due to possible duplicate entries. */
runtime::NDArray COOGetData(COOMatrix , int64_t row, int64_t col);
/*!
* \brief Get the data and the row,col indices for each returned entries.
* \note This operator allows broadcasting (i.e, either row or col can be of length 1).
*/
std::vector<runtime::NDArray> COOGetDataAndIndices(
COOMatrix , runtime::NDArray rows, runtime::NDArray cols);
/*! \brief Return a transposed COO matrix */
COOMatrix COOTranspose(COOMatrix coo);
/*!
* \brief Convert COO matrix to CSR matrix.
*
* If the input COO matrix does not have data array, the data array of
* the result CSR matrix stores a shuffle index for how the entries
* will be reordered in CSR. The i^th entry in the result CSR corresponds
* to the CSR.data[i] th entry in the input COO.
*/
CSRMatrix COOToCSR(COOMatrix coo);
/*!
* \brief Slice rows of the given matrix and return.
* \param coo COO matrix
* \param start Start row id (inclusive)
* \param end End row id (exclusive)
*/
COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end);
COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows);
/*!
* \brief Get the submatrix specified by the row and col ids.
*
* In numpy notation, given matrix M, row index array I, col index array J
* This function returns the submatrix M[I, J].
*
* \param coo The input coo matrix
* \param rows The row index to select
* \param cols The col index to select
* \return submatrix
*/
COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
/*! \return True if the matrix has duplicate entries */
bool COOHasDuplicate(COOMatrix coo);
/*!
* \brief Deduplicate the entries of a sorted COO matrix, replacing the data with the
* number of occurrences of the row-col coordinates.
*/
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);
/*!
* \brief Sort the indices of a COO matrix.
*
* The function sorts row indices in ascending order. If sort_column is true,
* col indices are sorted in ascending order too. The data array of the returned COOMatrix
* stores the shuffled index which could be used to fetch edge data.
*
* \param mat The input coo matrix
* \param sort_column True if column index should be sorted too.
* \return COO matrix with index sorted.
*/
COOMatrix COOSort(COOMatrix mat, bool sort_column = false);
/*!
* \brief Remove entries from COO matrix by entry indices (data indices)
* \return A new COO matrix as well as a mapping from the new COO entries to the old COO
* entries.
*/
COOMatrix COORemove(COOMatrix coo, IdArray entries);
/*!
* \brief Reorder the rows and colmns according to the new row and column order.
* \param csr The input coo matrix.
* \param new_row_ids the new row Ids (the index is the old row Id)
* \param new_col_ids the new column Ids (the index is the old col Id).
*/
COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
/*!
* \brief Randomly select a fixed number of non-zero entries along each given row independently.
*
* The function performs random choices along each row independently.
* The picked indices are returned in the form of a COO matrix.
*
* If replace is false and a row has fewer non-zero values than num_samples,
* all the values are picked.
*
* Examples:
*
* // coo.num_rows = 4;
* // coo.num_cols = 4;
* // coo.rows = [0, 0, 1, 3, 3]
* // coo.cols = [0, 1, 1, 2, 3]
* // coo.data = [2, 3, 0, 1, 4]
* COOMatrix coo = ...;
* IdArray rows = ... ; // [1, 3]
* COOMatrix sampled = COORowWiseSampling(coo, rows, 2, FloatArray(), false);
* // possible sampled coo matrix:
* // sampled.num_rows = 4
* // sampled.num_cols = 4
* // sampled.rows = [1, 3, 3]
* // sampled.cols = [1, 2, 3]
* // sampled.data = [3, 0, 4]
*
* \param mat Input coo matrix.
* \param rows Rows to sample from.
* \param num_samples Number of samples
* \param prob Unnormalized probability array. Should be of the same length as the data array.
* If an empty array is provided, assume uniform.
* \param replace True if sample with replacement
* \return A COOMatrix storing the picked row and col indices. Its data field stores the
* the index of the picked elements in the value array.
*/
COOMatrix COORowWiseSampling(
COOMatrix mat,
IdArray rows,
int64_t num_samples,
FloatArray prob = FloatArray(),
bool replace = true);
/*!
* \brief Select K non-zero entries with the largest weights along each given row.
*
* The function performs top-k selection along each row independently.
* The picked indices are returned in the form of a COO matrix.
*
* If replace is false and a row has fewer non-zero values than k,
* all the values are picked.
*
* Examples:
*
* // coo.num_rows = 4;
* // coo.num_cols = 4;
* // coo.rows = [0, 0, 1, 3, 3]
* // coo.cols = [0, 1, 1, 2, 3]
* // coo.data = [2, 3, 0, 1, 4]
* COOMatrix coo = ...;
* IdArray rows = ... ; // [0, 1, 3]
* FloatArray weight = ... ; // [1., 0., -1., 10., 20.]
* COOMatrix sampled = COORowWiseTopk(coo, rows, 1, weight);
* // possible sampled coo matrix:
* // sampled.num_rows = 4
* // sampled.num_cols = 4
* // sampled.rows = [0, 1, 3]
* // sampled.cols = [1, 1, 2]
* // sampled.data = [3, 0, 1]
*
* \param mat Input COO matrix.
* \param rows Rows to sample from.
* \param k The K value.
* \param weight Weight associated with each entry. Should be of the same length as the
* data array. If an empty array is provided, assume uniform.
* \param ascending If true, elements are sorted by ascending order, equivalent to find
* the K smallest values. Otherwise, find K largest values.
* \return A COOMatrix storing the picked row and col indices. Its data field stores the
* the index of the picked elements in the value array.
*/
COOMatrix COORowWiseTopk(
COOMatrix mat,
IdArray rows,
int64_t k,
NDArray weight,
bool ascending = false);
// inline implementations
template <typename T>
IdArray VecToIdArray(const std::vector<T>& vec,
uint8_t nbits,
DLContext ctx) {
IdArray ret = NewIdArray(vec.size(), DLContext{kDLCPU, 0}, nbits);
if (nbits == 32) {
std::copy(vec.begin(), vec.end(), static_cast<int32_t*>(ret->data));
} else if (nbits == 64) {
std::copy(vec.begin(), vec.end(), static_cast<int64_t*>(ret->data));
} else {
LOG(FATAL) << "Only int32 or int64 is supported.";
}
return ret.CopyTo(ctx);
}
///////////////////////// Dispatchers //////////////////////////
/*
* Dispatch according to device:
*
* ATEN_XPU_SWITCH(array->ctx.device_type, XPU, {
* // Now XPU is a placeholder for array->ctx.device_type
* DeviceSpecificImplementation<XPU>(...);
* });
*/
#define ATEN_XPU_SWITCH(val, XPU, op, ...) do { \
if ((val) == kDLCPU) { \
constexpr auto XPU = kDLCPU; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "Operator " << (op) << " does not support " \
<< dgl::runtime::DeviceTypeCode2Str(val) \
<< " device."; \
} \
} while (0)
/*
* Dispatch according to device:
*
* XXX(minjie): temporary macro that allows CUDA operator
*
* ATEN_XPU_SWITCH(array->ctx.device_type, XPU, {
* // Now XPU is a placeholder for array->ctx.device_type
* DeviceSpecificImplementation<XPU>(...);
* });
*/
#ifdef DGL_USE_CUDA
#define ATEN_XPU_SWITCH_CUDA(val, XPU, op, ...) do { \
if ((val) == kDLCPU) { \
constexpr auto XPU = kDLCPU; \
{__VA_ARGS__} \
} else if ((val) == kDLGPU) { \
constexpr auto XPU = kDLGPU; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "Operator " << (op) << " does not support " \
<< dgl::runtime::DeviceTypeCode2Str(val) \
<< " device."; \
} \
} while (0)
#else // DGL_USE_CUDA
#define ATEN_XPU_SWITCH_CUDA ATEN_XPU_SWITCH
#endif // DGL_USE_CUDA
/*
* Dispatch according to integral type (either int32 or int64):
*
* ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {
* // Now IdType is the type corresponding to data type in array.
* // For instance, one can do this for a CPU array:
* DType *data = static_cast<DType *>(array->data);
* });
*/
#define ATEN_ID_TYPE_SWITCH(val, IdType, ...) do { \
CHECK_EQ((val).code, kDLInt) << "ID must be integer type"; \
if ((val).bits == 32) { \
typedef int32_t IdType; \
{__VA_ARGS__} \
} else if ((val).bits == 64) { \
typedef int64_t IdType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "ID can only be int32 or int64"; \
} \
} while (0)
/*
* Dispatch according to bits (either int32 or int64):
*
* ATEN_ID_BITS_SWITCH(bits, IdType, {
* // Now IdType is the type corresponding to data type in array.
* // For instance, one can do this for a CPU array:
* DType *data = static_cast<DType *>(array->data);
* });
*/
#define ATEN_ID_BITS_SWITCH(bits, IdType, ...) \
do { \
CHECK((bits) == 32 || (bits) == 64) << "bits must be 32 or 64"; \
if ((bits) == 32) { \
typedef int32_t IdType; \
{ __VA_ARGS__ } \
} else if ((bits) == 64) { \
typedef int64_t IdType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "ID can only be int32 or int64"; \
} \
} while (0)
/*
* Dispatch according to float type (either float32 or float64):
*
* ATEN_FLOAT_TYPE_SWITCH(array->dtype, FloatType, {
* // Now FloatType is the type corresponding to data type in array.
* // For instance, one can do this for a CPU array:
* FloatType *data = static_cast<FloatType *>(array->data);
* });
*/
#define ATEN_FLOAT_TYPE_SWITCH(val, FloatType, val_name, ...) do { \
CHECK_EQ((val).code, kDLFloat) \
<< (val_name) << " must be float type"; \
if ((val).bits == 32) { \
typedef float FloatType; \
{__VA_ARGS__} \
} else if ((val).bits == 64) { \
typedef double FloatType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << (val_name) << " can only be float32 or float64"; \
} \
} while (0)
/*
* Dispatch according to data type (int32, int64, float32 or float64):
*
* ATEN_DTYPE_SWITCH(array->dtype, DType, {
* // Now DType is the type corresponding to data type in array.
* // For instance, one can do this for a CPU array:
* DType *data = static_cast<DType *>(array->data);
* });
*/
#define ATEN_DTYPE_SWITCH(val, DType, val_name, ...) do { \
if ((val).code == kDLInt && (val).bits == 32) { \
typedef int32_t DType; \
{__VA_ARGS__} \
} else if ((val).code == kDLInt && (val).bits == 64) { \
typedef int64_t DType; \
{__VA_ARGS__} \
} else if ((val).code == kDLFloat && (val).bits == 32) { \
typedef float DType; \
{__VA_ARGS__} \
} else if ((val).code == kDLFloat && (val).bits == 64) { \
typedef double DType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << (val_name) << " can only be int32, int64, float32 or float64"; \
} \
} while (0)
/*
* Dispatch according to integral type of CSR graphs.
* Identical to ATEN_ID_TYPE_SWITCH except for a different error message.
*/
#define ATEN_CSR_DTYPE_SWITCH(val, DType, ...) do { \
if ((val).code == kDLInt && (val).bits == 32) { \
typedef int32_t DType; \
{__VA_ARGS__} \
} else if ((val).code == kDLInt && (val).bits == 64) { \
typedef int64_t DType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "CSR matrix data can only be int32 or int64"; \
} \
} while (0)
// Macro to dispatch according to device context and index type.
#define ATEN_CSR_SWITCH(csr, XPU, IdType, op, ...) \
ATEN_XPU_SWITCH((csr).indptr->ctx.device_type, XPU, op, { \
ATEN_ID_TYPE_SWITCH((csr).indptr->dtype, IdType, { \
{__VA_ARGS__} \
}); \
});
// Macro to dispatch according to device context and index type.
#define ATEN_COO_SWITCH(coo, XPU, IdType, op, ...) \
ATEN_XPU_SWITCH((coo).row->ctx.device_type, XPU, op, { \
ATEN_ID_TYPE_SWITCH((coo).row->dtype, IdType, { \
{__VA_ARGS__} \
}); \
});
///////////////////////// Array checks //////////////////////////
#define IS_INT32(a) \
((a)->dtype.code == kDLInt && (a)->dtype.bits == 32)
#define IS_INT64(a) \
((a)->dtype.code == kDLInt && (a)->dtype.bits == 64)
#define IS_FLOAT32(a) \
((a)->dtype.code == kDLFloat && (a)->dtype.bits == 32)
#define IS_FLOAT64(a) \
((a)->dtype.code == kDLFloat && (a)->dtype.bits == 64)
#define CHECK_IF(cond, prop, value_name, dtype_name) \
CHECK(cond) << "Expecting " << (prop) << " of " << (value_name) << " to be " << (dtype_name)
#define CHECK_INT32(value, value_name) \
CHECK_IF(IS_INT32(value), "dtype", value_name, "int32")
#define CHECK_INT64(value, value_name) \
CHECK_IF(IS_INT64(value), "dtype", value_name, "int64")
#define CHECK_INT(value, value_name) \
CHECK_IF(IS_INT32(value) || IS_INT64(value), "dtype", value_name, "int32 or int64")
#define CHECK_FLOAT32(value, value_name) \
CHECK_IF(IS_FLOAT32(value), "dtype", value_name, "float32")
#define CHECK_FLOAT64(value, value_name) \
CHECK_IF(IS_FLOAT64(value), "dtype", value_name, "float64")
#define CHECK_FLOAT(value, value_name) \
CHECK_IF(IS_FLOAT32(value) || IS_FLOAT64(value), "dtype", value_name, "float32 or float64")
#define CHECK_NDIM(value, _ndim, value_name) \
CHECK_IF((value)->ndim == (_ndim), "ndim", value_name, _ndim)
} // namespace aten
} // namespace dgl
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::aten::CSRMatrix, true);
DMLC_DECLARE_TRAITS(has_saveload, dgl::aten::COOMatrix, true);
} // namespace dmlc
#include "./aten/types.h"
#include "./aten/array_ops.h"
#include "./aten/macro.h"
#include "./aten/spmat.h"
#include "./aten/csr.h"
#include "./aten/coo.h"
#endif // DGL_ARRAY_H_
/*!
* Copyright (c) 2020 by Contributors
* \file dgl/aten/array_ops.h
* \brief Common array operations required by DGL.
*
* Note that this is not meant for a full support of array library such as ATen.
* Only a limited set of operators required by DGL are implemented.
*/
#ifndef DGL_ATEN_ARRAY_OPS_H_
#define DGL_ATEN_ARRAY_OPS_H_
#include <algorithm>
#include <utility>
#include <vector>
#include <tuple>
#include "./types.h"
namespace dgl {
namespace aten {
//////////////////////////////////////////////////////////////////////
// ID array
//////////////////////////////////////////////////////////////////////
/*! \return A special array to represent null. */
inline NDArray NullArray() {
return NDArray::Empty({0}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
}
/*!
* \return Whether the input array is a null array.
*/
inline bool IsNullArray(NDArray array) {
return array->shape[0] == 0;
}
/*!
* \brief Create a new id array with given length
* \param length The array length
* \param ctx The array context
* \param nbits The number of integer bits
* \return id array
*/
IdArray NewIdArray(int64_t length,
DLContext ctx = DLContext{kDLCPU, 0},
uint8_t nbits = 64);
/*!
* \brief Create a new id array using the given vector data
* \param vec The vector data
* \param nbits The integer bits of the returned array
* \param ctx The array context
* \return the id array
*/
template <typename T>
IdArray VecToIdArray(const std::vector<T>& vec,
uint8_t nbits = 64,
DLContext ctx = DLContext{kDLCPU, 0});
/*!
* \brief Return an array representing a 1D range.
* \param low Lower bound (inclusive).
* \param high Higher bound (exclusive).
* \param nbits result array's bits (32 or 64)
* \param ctx Device context
* \return range array
*/
IdArray Range(int64_t low, int64_t high, uint8_t nbits, DLContext ctx);
/*!
* \brief Return an array full of the given value
* \param val The value to fill.
* \param length Number of elements.
* \param nbits result array's bits (32 or 64)
* \param ctx Device context
* \return the result array
*/
IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx);
/*! \brief Create a deep copy of the given array */
IdArray Clone(IdArray arr);
/*! \brief Convert the idarray to the given bit width */
IdArray AsNumBits(IdArray arr, uint8_t bits);
/*! \brief Arithmetic functions */
IdArray Add(IdArray lhs, IdArray rhs);
IdArray Sub(IdArray lhs, IdArray rhs);
IdArray Mul(IdArray lhs, IdArray rhs);
IdArray Div(IdArray lhs, IdArray rhs);
IdArray Add(IdArray lhs, int64_t rhs);
IdArray Sub(IdArray lhs, int64_t rhs);
IdArray Mul(IdArray lhs, int64_t rhs);
IdArray Div(IdArray lhs, int64_t rhs);
IdArray Add(int64_t lhs, IdArray rhs);
IdArray Sub(int64_t lhs, IdArray rhs);
IdArray Mul(int64_t lhs, IdArray rhs);
IdArray Div(int64_t lhs, IdArray rhs);
IdArray Neg(IdArray array);
// XXX(minjie): currently using integer array for bool type
IdArray GT(IdArray lhs, IdArray rhs);
IdArray LT(IdArray lhs, IdArray rhs);
IdArray GE(IdArray lhs, IdArray rhs);
IdArray LE(IdArray lhs, IdArray rhs);
IdArray EQ(IdArray lhs, IdArray rhs);
IdArray NE(IdArray lhs, IdArray rhs);
IdArray GT(IdArray lhs, int64_t rhs);
IdArray LT(IdArray lhs, int64_t rhs);
IdArray GE(IdArray lhs, int64_t rhs);
IdArray LE(IdArray lhs, int64_t rhs);
IdArray EQ(IdArray lhs, int64_t rhs);
IdArray NE(IdArray lhs, int64_t rhs);
IdArray GT(int64_t lhs, IdArray rhs);
IdArray LT(int64_t lhs, IdArray rhs);
IdArray GE(int64_t lhs, IdArray rhs);
IdArray LE(int64_t lhs, IdArray rhs);
IdArray EQ(int64_t lhs, IdArray rhs);
IdArray NE(int64_t lhs, IdArray rhs);
/*! \brief Stack two arrays (of len L) into a 2*L length array */
IdArray HStack(IdArray arr1, IdArray arr2);
/*!
* \brief Return the data under the index. In numpy notation, A[I]
* \tparam ValueType The type of return value.
*/
template<typename ValueType>
ValueType IndexSelect(NDArray array, uint64_t index);
NDArray IndexSelect(NDArray array, IdArray index);
/*!
* \brief Permute the elements of an array according to given indices.
*
* Equivalent to:
*
* <code>
* result = np.zeros_like(array)
* result[indices] = array
* </code>
*/
NDArray Scatter(NDArray array, IdArray indices);
/*!
* \brief Repeat each element a number of times. Equivalent to np.repeat(array, repeats)
* \param array A 1D vector
* \param repeats A 1D integer vector for number of times to repeat for each element in
* \c array. Must have the same shape as \c array.
*/
NDArray Repeat(NDArray array, IdArray repeats);
/*!
* \brief Relabel the given ids to consecutive ids.
*
* Relabeling is done inplace. The mapping is created from the union
* of the give arrays.
*
* Example:
*
* Given two IdArrays [2, 3, 10, 0, 2] and [4, 10, 5], one possible return
* mapping is [2, 3, 10, 4, 0, 5], meaning the new ID 0 maps to the old ID
* 2, 1 maps to 3, so on and so forth.
*
* \param arrays The id arrays to relabel.
* \return mapping array M from new id to old id.
*/
IdArray Relabel_(const std::vector<IdArray>& arrays);
/*!\brief Return whether the array is a valid 1D int array*/
inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) {
return arr->ndim == 1 && arr->dtype.code == kDLInt;
}
/*!
* \brief Packs a tensor containing padded sequences of variable length.
*
* Similar to \c pack_padded_sequence in PyTorch, except that
*
* 1. The length for each sequence (before padding) is inferred as the number
* of elements before the first occurrence of \c pad_value.
* 2. It does not sort the sequences by length.
* 3. Along with the tensor containing the packed sequence, it returns both the
* length, as well as the offsets to the packed tensor, of each sequence.
*
* \param array The tensor containing sequences padded to the same length
* \param pad_value The padding value
* \return A triplet of packed tensor, the length tensor, and the offset tensor
*
* \note Example: consider the following array with padding value -1:
*
* <code>
* [[1, 2, -1, -1],
* [3, 4, 5, -1]]
* </code>
*
* The packed tensor would be [1, 2, 3, 4, 5].
*
* The length tensor would be [2, 3], i.e. the length of each sequence before padding.
*
* The offset tensor would be [0, 2], i.e. the offset to the packed tensor for each
* sequence (before padding)
*/
template<typename ValueType>
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value);
/*!
* \brief Batch-slice a 1D or 2D array, and then pack the list of sliced arrays
* by concatenation.
*
* If a 2D array is given, then the function is equivalent to:
*
* <code>
* def ConcatSlices(array, lengths):
* slices = [array[i, :l] for i, l in enumerate(lengths)]
* packed = np.concatenate(slices)
* offsets = np.cumsum([0] + lengths[:-1])
* return packed, offsets
* </code>
*
* If a 1D array is given, then the function is equivalent to
*
* <code>
* def ConcatSlices(array, lengths):
* slices = [array[:l] for l in lengths]
* packed = np.concatenate(slices)
* offsets = np.cumsum([0] + lengths[:-1])
* return packed, offsets
* </code>
*
* \param array A 1D or 2D tensor for slicing
* \param lengths A 1D tensor indicating the number of elements to slice
* \return The tensor with packed slices along with the offsets.
*/
std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths);
// inline implementations
template <typename T>
IdArray VecToIdArray(const std::vector<T>& vec,
uint8_t nbits,
DLContext ctx) {
IdArray ret = NewIdArray(vec.size(), DLContext{kDLCPU, 0}, nbits);
if (nbits == 32) {
std::copy(vec.begin(), vec.end(), static_cast<int32_t*>(ret->data));
} else if (nbits == 64) {
std::copy(vec.begin(), vec.end(), static_cast<int64_t*>(ret->data));
} else {
LOG(FATAL) << "Only int32 or int64 is supported.";
}
return ret.CopyTo(ctx);
}
} // namespace aten
} // namespace dgl
#endif // DGL_ATEN_ARRAY_OPS_H_
/*!
* Copyright (c) 2020 by Contributors
* \file dgl/aten/coo.h
* \brief Common COO operations required by DGL.
*/
#ifndef DGL_ATEN_COO_H_
#define DGL_ATEN_COO_H_
#include <dmlc/io.h>
#include <dmlc/serializer.h>
#include <vector>
#include <utility>
#include "./types.h"
#include "./array_ops.h"
#include "./spmat.h"
#include "./macro.h"
namespace dgl {
namespace aten {
struct CSRMatrix;
/*!
* \brief Plain COO structure
*
* The data array stores integer ids for reading edge features.
* Note that we do allow duplicate non-zero entries -- multiple non-zero entries
* that have the same row, col indices. It corresponds to multigraph in
* graph terminology.
*/
constexpr uint64_t kDGLSerialize_AtenCooMatrixMagic = 0xDD61ffd305dff127;
// TODO(BarclayII): Graph queries on COO formats should support the case where
// data ordered by rows/columns instead of EID.
struct COOMatrix {
/*! \brief the dense shape of the matrix */
int64_t num_rows = 0, num_cols = 0;
/*! \brief COO index arrays */
IdArray row, col;
/*! \brief data index array. When is null, assume it is from 0 to NNZ - 1. */
IdArray data;
/*! \brief whether the row indices are sorted */
bool row_sorted = false;
/*! \brief whether the column indices per row are sorted */
bool col_sorted = false;
/*! \brief default constructor */
COOMatrix() = default;
/*! \brief constructor */
COOMatrix(int64_t nrows, int64_t ncols, IdArray rarr, IdArray carr,
IdArray darr = NullArray(), bool rsorted = false,
bool csorted = false)
: num_rows(nrows),
num_cols(ncols),
row(rarr),
col(carr),
data(darr),
row_sorted(rsorted),
col_sorted(csorted) {
CheckValidity();
}
/*! \brief constructor from SparseMatrix object */
explicit COOMatrix(const SparseMatrix& spmat)
: num_rows(spmat.num_rows),
num_cols(spmat.num_cols),
row(spmat.indices[0]),
col(spmat.indices[1]),
data(spmat.indices[2]),
row_sorted(spmat.flags[0]),
col_sorted(spmat.flags[1]) {
CheckValidity();
}
// Convert to a SparseMatrix object that can return to python.
SparseMatrix ToSparseMatrix() const {
return SparseMatrix(static_cast<int32_t>(SparseFormat::kCOO), num_rows,
num_cols, {row, col, data}, {row_sorted, col_sorted});
}
bool Load(dmlc::Stream* fs) {
uint64_t magicNum;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_AtenCooMatrixMagic)
<< "Invalid COOMatrix Data";
CHECK(fs->Read(&num_cols)) << "Invalid num_cols";
CHECK(fs->Read(&num_rows)) << "Invalid num_rows";
CHECK(fs->Read(&row)) << "Invalid row";
CHECK(fs->Read(&col)) << "Invalid col";
CHECK(fs->Read(&data)) << "Invalid data";
CHECK(fs->Read(&row_sorted)) << "Invalid row_sorted";
CHECK(fs->Read(&col_sorted)) << "Invalid col_sorted";
CheckValidity();
return true;
}
void Save(dmlc::Stream* fs) const {
fs->Write(kDGLSerialize_AtenCooMatrixMagic);
fs->Write(num_cols);
fs->Write(num_rows);
fs->Write(row);
fs->Write(col);
fs->Write(data);
fs->Write(row_sorted);
fs->Write(col_sorted);
}
inline void CheckValidity() const {
CHECK_SAME_DTYPE(row, col);
CHECK_SAME_CONTEXT(row, col);
if (!aten::IsNullArray(data)) {
CHECK_SAME_DTYPE(row, data);
CHECK_SAME_CONTEXT(row, data);
}
CHECK_NO_OVERFLOW(row->dtype, num_rows);
CHECK_NO_OVERFLOW(row->dtype, num_cols);
}
};
///////////////////////// COO routines //////////////////////////
/*! \brief Return true if the value (row, col) is non-zero */
bool COOIsNonZero(COOMatrix , int64_t row, int64_t col);
/*!
* \brief Batched implementation of COOIsNonZero.
* \note This operator allows broadcasting (i.e, either row or col can be of length 1).
*/
runtime::NDArray COOIsNonZero(COOMatrix , runtime::NDArray row, runtime::NDArray col);
/*! \brief Return the nnz of the given row */
int64_t COOGetRowNNZ(COOMatrix , int64_t row);
runtime::NDArray COOGetRowNNZ(COOMatrix , runtime::NDArray row);
/*! \brief Return the data array of the given row */
std::pair<runtime::NDArray, runtime::NDArray>
COOGetRowDataAndIndices(COOMatrix , int64_t row);
/*! \brief Whether the COO matrix contains data */
inline bool COOHasData(COOMatrix csr) {
return !IsNullArray(csr.data);
}
/*! \brief Get data. The return type is an ndarray due to possible duplicate entries. */
runtime::NDArray COOGetData(COOMatrix , int64_t row, int64_t col);
/*!
* \brief Get the data and the row,col indices for each returned entries.
* \note This operator allows broadcasting (i.e, either row or col can be of length 1).
*/
std::vector<runtime::NDArray> COOGetDataAndIndices(
COOMatrix , runtime::NDArray rows, runtime::NDArray cols);
/*! \brief Return a transposed COO matrix */
COOMatrix COOTranspose(COOMatrix coo);
/*!
* \brief Convert COO matrix to CSR matrix.
*
* If the input COO matrix does not have data array, the data array of
* the result CSR matrix stores a shuffle index for how the entries
* will be reordered in CSR. The i^th entry in the result CSR corresponds
* to the CSR.data[i] th entry in the input COO.
*/
CSRMatrix COOToCSR(COOMatrix coo);
/*!
* \brief Slice rows of the given matrix and return.
* \param coo COO matrix
* \param start Start row id (inclusive)
* \param end End row id (exclusive)
*/
COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end);
COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows);
/*!
* \brief Get the submatrix specified by the row and col ids.
*
* In numpy notation, given matrix M, row index array I, col index array J
* This function returns the submatrix M[I, J].
*
* \param coo The input coo matrix
* \param rows The row index to select
* \param cols The col index to select
* \return submatrix
*/
COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
/*! \return True if the matrix has duplicate entries */
bool COOHasDuplicate(COOMatrix coo);
/*!
* \brief Deduplicate the entries of a sorted COO matrix, replacing the data with the
* number of occurrences of the row-col coordinates.
*/
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);
/*!
* \brief Sort the indices of a COO matrix.
*
* The function sorts row indices in ascending order. If sort_column is true,
* col indices are sorted in ascending order too. The data array of the returned COOMatrix
* stores the shuffled index which could be used to fetch edge data.
*
* \param mat The input coo matrix
* \param sort_column True if column index should be sorted too.
* \return COO matrix with index sorted.
*/
COOMatrix COOSort(COOMatrix mat, bool sort_column = false);
/*!
* \brief Remove entries from COO matrix by entry indices (data indices)
* \return A new COO matrix as well as a mapping from the new COO entries to the old COO
* entries.
*/
COOMatrix COORemove(COOMatrix coo, IdArray entries);
/*!
* \brief Reorder the rows and colmns according to the new row and column order.
* \param csr The input coo matrix.
* \param new_row_ids the new row Ids (the index is the old row Id)
* \param new_col_ids the new column Ids (the index is the old col Id).
*/
COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
/*!
* \brief Randomly select a fixed number of non-zero entries along each given row independently.
*
* The function performs random choices along each row independently.
* The picked indices are returned in the form of a COO matrix.
*
* If replace is false and a row has fewer non-zero values than num_samples,
* all the values are picked.
*
* Examples:
*
* // coo.num_rows = 4;
* // coo.num_cols = 4;
* // coo.rows = [0, 0, 1, 3, 3]
* // coo.cols = [0, 1, 1, 2, 3]
* // coo.data = [2, 3, 0, 1, 4]
* COOMatrix coo = ...;
* IdArray rows = ... ; // [1, 3]
* COOMatrix sampled = COORowWiseSampling(coo, rows, 2, FloatArray(), false);
* // possible sampled coo matrix:
* // sampled.num_rows = 4
* // sampled.num_cols = 4
* // sampled.rows = [1, 3, 3]
* // sampled.cols = [1, 2, 3]
* // sampled.data = [3, 0, 4]
*
* \param mat Input coo matrix.
* \param rows Rows to sample from.
* \param num_samples Number of samples
* \param prob Unnormalized probability array. Should be of the same length as the data array.
* If an empty array is provided, assume uniform.
* \param replace True if sample with replacement
* \return A COOMatrix storing the picked row and col indices. Its data field stores the
* the index of the picked elements in the value array.
*/
COOMatrix COORowWiseSampling(
COOMatrix mat,
IdArray rows,
int64_t num_samples,
FloatArray prob = FloatArray(),
bool replace = true);
/*!
* \brief Select K non-zero entries with the largest weights along each given row.
*
* The function performs top-k selection along each row independently.
* The picked indices are returned in the form of a COO matrix.
*
* If replace is false and a row has fewer non-zero values than k,
* all the values are picked.
*
* Examples:
*
* // coo.num_rows = 4;
* // coo.num_cols = 4;
* // coo.rows = [0, 0, 1, 3, 3]
* // coo.cols = [0, 1, 1, 2, 3]
* // coo.data = [2, 3, 0, 1, 4]
* COOMatrix coo = ...;
* IdArray rows = ... ; // [0, 1, 3]
* FloatArray weight = ... ; // [1., 0., -1., 10., 20.]
* COOMatrix sampled = COORowWiseTopk(coo, rows, 1, weight);
* // possible sampled coo matrix:
* // sampled.num_rows = 4
* // sampled.num_cols = 4
* // sampled.rows = [0, 1, 3]
* // sampled.cols = [1, 1, 2]
* // sampled.data = [3, 0, 1]
*
* \param mat Input COO matrix.
* \param rows Rows to sample from.
* \param k The K value.
* \param weight Weight associated with each entry. Should be of the same length as the
* data array. If an empty array is provided, assume uniform.
* \param ascending If true, elements are sorted by ascending order, equivalent to find
* the K smallest values. Otherwise, find K largest values.
* \return A COOMatrix storing the picked row and col indices. Its data field stores the
* the index of the picked elements in the value array.
*/
COOMatrix COORowWiseTopk(
COOMatrix mat,
IdArray rows,
int64_t k,
NDArray weight,
bool ascending = false);
} // namespace aten
} // namespace dgl
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::aten::COOMatrix, true);
} // namespace dmlc
#endif // DGL_ATEN_COO_H_
/*!
* Copyright (c) 2020 by Contributors
* \file dgl/aten/csr.h
* \brief Common CSR operations required by DGL.
*/
#ifndef DGL_ATEN_CSR_H_
#define DGL_ATEN_CSR_H_
#include <dmlc/io.h>
#include <dmlc/serializer.h>
#include <vector>
#include "./types.h"
#include "./array_ops.h"
#include "./spmat.h"
#include "./macro.h"
namespace dgl {
namespace aten {
struct COOMatrix;
/*!
* \brief Plain CSR matrix
*
* The column indices are 0-based and are not necessarily sorted. The data array stores
* integer ids for reading edge features.
*
* Note that we do allow duplicate non-zero entries -- multiple non-zero entries
* that have the same row, col indices. It corresponds to multigraph in
* graph terminology.
*/
constexpr uint64_t kDGLSerialize_AtenCsrMatrixMagic = 0xDD6cd31205dff127;
struct CSRMatrix {
/*! \brief the dense shape of the matrix */
int64_t num_rows = 0, num_cols = 0;
/*! \brief CSR index arrays */
IdArray indptr, indices;
/*! \brief data index array. When is null, assume it is from 0 to NNZ - 1. */
IdArray data;
/*! \brief whether the column indices per row are sorted */
bool sorted = false;
/*! \brief default constructor */
CSRMatrix() = default;
/*! \brief constructor */
CSRMatrix(int64_t nrows, int64_t ncols, IdArray parr, IdArray iarr,
IdArray darr = NullArray(), bool sorted_flag = false)
: num_rows(nrows),
num_cols(ncols),
indptr(parr),
indices(iarr),
data(darr),
sorted(sorted_flag) {
CheckValidity();
}
/*! \brief constructor from SparseMatrix object */
explicit CSRMatrix(const SparseMatrix& spmat)
: num_rows(spmat.num_rows),
num_cols(spmat.num_cols),
indptr(spmat.indices[0]),
indices(spmat.indices[1]),
data(spmat.indices[2]),
sorted(spmat.flags[0]) {
CheckValidity();
}
// Convert to a SparseMatrix object that can return to python.
SparseMatrix ToSparseMatrix() const {
return SparseMatrix(static_cast<int32_t>(SparseFormat::kCSR), num_rows,
num_cols, {indptr, indices, data}, {sorted});
}
bool Load(dmlc::Stream* fs) {
uint64_t magicNum;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_AtenCsrMatrixMagic)
<< "Invalid CSRMatrix Data";
CHECK(fs->Read(&num_cols)) << "Invalid num_cols";
CHECK(fs->Read(&num_rows)) << "Invalid num_rows";
CHECK(fs->Read(&indptr)) << "Invalid indptr";
CHECK(fs->Read(&indices)) << "Invalid indices";
CHECK(fs->Read(&data)) << "Invalid data";
CHECK(fs->Read(&sorted)) << "Invalid sorted";
CheckValidity();
return true;
}
void Save(dmlc::Stream* fs) const {
fs->Write(kDGLSerialize_AtenCsrMatrixMagic);
fs->Write(num_cols);
fs->Write(num_rows);
fs->Write(indptr);
fs->Write(indices);
fs->Write(data);
fs->Write(sorted);
}
inline void CheckValidity() const {
CHECK_SAME_DTYPE(indptr, indices);
CHECK_SAME_CONTEXT(indptr, indices);
if (!aten::IsNullArray(data)) {
CHECK_SAME_DTYPE(indptr, data);
CHECK_SAME_CONTEXT(indptr, data);
}
CHECK_NO_OVERFLOW(indptr->dtype, num_rows);
CHECK_NO_OVERFLOW(indptr->dtype, num_cols);
}
};
///////////////////////// CSR routines //////////////////////////
/*! \brief Return true if the value (row, col) is non-zero */
bool CSRIsNonZero(CSRMatrix , int64_t row, int64_t col);
/*!
* \brief Batched implementation of CSRIsNonZero.
* \note This operator allows broadcasting (i.e, either row or col can be of length 1).
*/
runtime::NDArray CSRIsNonZero(CSRMatrix, runtime::NDArray row, runtime::NDArray col);
/*! \brief Return the nnz of the given row */
int64_t CSRGetRowNNZ(CSRMatrix , int64_t row);
runtime::NDArray CSRGetRowNNZ(CSRMatrix , runtime::NDArray row);
/*! \brief Return the column index array of the given row */
runtime::NDArray CSRGetRowColumnIndices(CSRMatrix , int64_t row);
/*! \brief Return the data array of the given row */
runtime::NDArray CSRGetRowData(CSRMatrix , int64_t row);
/*! \brief Whether the CSR matrix contains data */
inline bool CSRHasData(CSRMatrix csr) {
return !IsNullArray(csr.data);
}
/* \brief Get data. The return type is an ndarray due to possible duplicate entries. */
runtime::NDArray CSRGetData(CSRMatrix , int64_t row, int64_t col);
/*!
* \brief Batched implementation of CSRGetData.
* \note This operator allows broadcasting (i.e, either row or col can be of length 1).
*/
runtime::NDArray CSRGetData(CSRMatrix, runtime::NDArray rows, runtime::NDArray cols);
/*!
* \brief Get the data and the row,col indices for each returned entries.
* \note This operator allows broadcasting (i.e, either row or col can be of length 1).
*/
std::vector<runtime::NDArray> CSRGetDataAndIndices(
CSRMatrix , runtime::NDArray rows, runtime::NDArray cols);
/*! \brief Return a transposed CSR matrix */
CSRMatrix CSRTranspose(CSRMatrix csr);
/*!
* \brief Convert CSR matrix to COO matrix.
* \param csr Input csr matrix
* \param data_as_order If true, the data array in the input csr matrix contains the order
* by which the resulting COO tuples are stored. In this case, the
* data array of the resulting COO matrix will be empty because it
* is essentially a consecutive range.
* \return a coo matrix
*/
COOMatrix CSRToCOO(CSRMatrix csr, bool data_as_order);
/*!
* \brief Slice rows of the given matrix and return.
* \param csr CSR matrix
* \param start Start row id (inclusive)
* \param end End row id (exclusive)
*
* Examples:
* num_rows = 4
* num_cols = 4
* indptr = [0, 2, 3, 3, 5]
* indices = [1, 0, 2, 3, 1]
*
* After CSRSliceRows(csr, 1, 3)
*
* num_rows = 2
* num_cols = 4
* indptr = [0, 1, 1]
* indices = [2]
*/
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end);
CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows);
/*!
* \brief Get the submatrix specified by the row and col ids.
*
* In numpy notation, given matrix M, row index array I, col index array J
* This function returns the submatrix M[I, J].
*
* \param csr The input csr matrix
* \param rows The row index to select
* \param cols The col index to select
* \return submatrix
*/
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
/*! \return True if the matrix has duplicate entries */
bool CSRHasDuplicate(CSRMatrix csr);
/*!
* \brief Sort the column index at each row in the ascending order.
*
* Examples:
* num_rows = 4
* num_cols = 4
* indptr = [0, 2, 3, 3, 5]
* indices = [1, 0, 2, 3, 1]
*
* After CSRSort_(&csr)
*
* indptr = [0, 2, 3, 3, 5]
* indices = [0, 1, 1, 2, 3]
*/
void CSRSort_(CSRMatrix* csr);
/*!
* \brief Reorder the rows and colmns according to the new row and column order.
* \param csr The input csr matrix.
* \param new_row_ids the new row Ids (the index is the old row Id)
* \param new_col_ids the new column Ids (the index is the old col Id).
*/
CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
/*!
* \brief Remove entries from CSR matrix by entry indices (data indices)
* \return A new CSR matrix as well as a mapping from the new CSR entries to the old CSR
* entries.
*/
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);
/*!
* \brief Randomly select a fixed number of non-zero entries along each given row independently.
*
* The function performs random choices along each row independently.
* The picked indices are returned in the form of a COO matrix.
*
* If replace is false and a row has fewer non-zero values than num_samples,
* all the values are picked.
*
* Examples:
*
* // csr.num_rows = 4;
* // csr.num_cols = 4;
* // csr.indptr = [0, 2, 3, 3, 5]
* // csr.indices = [0, 1, 1, 2, 3]
* // csr.data = [2, 3, 0, 1, 4]
* CSRMatrix csr = ...;
* IdArray rows = ... ; // [1, 3]
* COOMatrix sampled = CSRRowWiseSampling(csr, rows, 2, FloatArray(), false);
* // possible sampled coo matrix:
* // sampled.num_rows = 4
* // sampled.num_cols = 4
* // sampled.rows = [1, 3, 3]
* // sampled.cols = [1, 2, 3]
* // sampled.data = [3, 0, 4]
*
* \param mat Input CSR matrix.
* \param rows Rows to sample from.
* \param num_samples Number of samples
* \param prob Unnormalized probability array. Should be of the same length as the data array.
* If an empty array is provided, assume uniform.
* \param replace True if sample with replacement
* \return A COOMatrix storing the picked row, col and data indices.
*/
COOMatrix CSRRowWiseSampling(
CSRMatrix mat,
IdArray rows,
int64_t num_samples,
FloatArray prob = FloatArray(),
bool replace = true);
/*!
* \brief Select K non-zero entries with the largest weights along each given row.
*
* The function performs top-k selection along each row independently.
* The picked indices are returned in the form of a COO matrix.
*
* If replace is false and a row has fewer non-zero values than k,
* all the values are picked.
*
* Examples:
*
* // csr.num_rows = 4;
* // csr.num_cols = 4;
* // csr.indptr = [0, 2, 3, 3, 5]
* // csr.indices = [0, 1, 1, 2, 3]
* // csr.data = [2, 3, 0, 1, 4]
* CSRMatrix csr = ...;
* IdArray rows = ... ; // [0, 1, 3]
* FloatArray weight = ... ; // [1., 0., -1., 10., 20.]
* COOMatrix sampled = CSRRowWiseTopk(csr, rows, 1, weight);
* // possible sampled coo matrix:
* // sampled.num_rows = 4
* // sampled.num_cols = 4
* // sampled.rows = [0, 1, 3]
* // sampled.cols = [1, 1, 2]
* // sampled.data = [3, 0, 1]
*
* \param mat Input CSR matrix.
* \param rows Rows to sample from.
* \param k The K value.
* \param weight Weight associated with each entry. Should be of the same length as the
* data array. If an empty array is provided, assume uniform.
* \param ascending If true, elements are sorted by ascending order, equivalent to find
* the K smallest values. Otherwise, find K largest values.
* \return A COOMatrix storing the picked row and col indices. Its data field stores the
* the index of the picked elements in the value array.
*/
COOMatrix CSRRowWiseTopk(
CSRMatrix mat,
IdArray rows,
int64_t k,
FloatArray weight,
bool ascending = false);
} // namespace aten
} // namespace dgl
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::aten::CSRMatrix, true);
} // namespace dmlc
#endif // DGL_ATEN_CSR_H_
/*!
* Copyright (c) 2020 by Contributors
* \file dgl/aten/macro.h
* \brief Common macros for aten package.
*/
#ifndef DGL_ATEN_MACRO_H_
#define DGL_ATEN_MACRO_H_
///////////////////////// Dispatchers //////////////////////////
/*
* Dispatch according to device:
*
* ATEN_XPU_SWITCH(array->ctx.device_type, XPU, {
* // Now XPU is a placeholder for array->ctx.device_type
* DeviceSpecificImplementation<XPU>(...);
* });
*/
#define ATEN_XPU_SWITCH(val, XPU, op, ...) do { \
if ((val) == kDLCPU) { \
constexpr auto XPU = kDLCPU; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "Operator " << (op) << " does not support " \
<< dgl::runtime::DeviceTypeCode2Str(val) \
<< " device."; \
} \
} while (0)
/*
* Dispatch according to device:
*
* XXX(minjie): temporary macro that allows CUDA operator
*
* ATEN_XPU_SWITCH(array->ctx.device_type, XPU, {
* // Now XPU is a placeholder for array->ctx.device_type
* DeviceSpecificImplementation<XPU>(...);
* });
*/
#ifdef DGL_USE_CUDA
#define ATEN_XPU_SWITCH_CUDA(val, XPU, op, ...) do { \
if ((val) == kDLCPU) { \
constexpr auto XPU = kDLCPU; \
{__VA_ARGS__} \
} else if ((val) == kDLGPU) { \
constexpr auto XPU = kDLGPU; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "Operator " << (op) << " does not support " \
<< dgl::runtime::DeviceTypeCode2Str(val) \
<< " device."; \
} \
} while (0)
#else // DGL_USE_CUDA
#define ATEN_XPU_SWITCH_CUDA ATEN_XPU_SWITCH
#endif // DGL_USE_CUDA
/*
* Dispatch according to integral type (either int32 or int64):
*
* ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {
* // Now IdType is the type corresponding to data type in array.
* // For instance, one can do this for a CPU array:
* DType *data = static_cast<DType *>(array->data);
* });
*/
#define ATEN_ID_TYPE_SWITCH(val, IdType, ...) do { \
CHECK_EQ((val).code, kDLInt) << "ID must be integer type"; \
if ((val).bits == 32) { \
typedef int32_t IdType; \
{__VA_ARGS__} \
} else if ((val).bits == 64) { \
typedef int64_t IdType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "ID can only be int32 or int64"; \
} \
} while (0)
/*
* Dispatch according to bits (either int32 or int64):
*
* ATEN_ID_BITS_SWITCH(bits, IdType, {
* // Now IdType is the type corresponding to data type in array.
* // For instance, one can do this for a CPU array:
* DType *data = static_cast<DType *>(array->data);
* });
*/
#define ATEN_ID_BITS_SWITCH(bits, IdType, ...) \
do { \
CHECK((bits) == 32 || (bits) == 64) << "bits must be 32 or 64"; \
if ((bits) == 32) { \
typedef int32_t IdType; \
{ __VA_ARGS__ } \
} else if ((bits) == 64) { \
typedef int64_t IdType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "ID can only be int32 or int64"; \
} \
} while (0)
/*
* Dispatch according to float type (either float32 or float64):
*
* ATEN_FLOAT_TYPE_SWITCH(array->dtype, FloatType, {
* // Now FloatType is the type corresponding to data type in array.
* // For instance, one can do this for a CPU array:
* FloatType *data = static_cast<FloatType *>(array->data);
* });
*/
#define ATEN_FLOAT_TYPE_SWITCH(val, FloatType, val_name, ...) do { \
CHECK_EQ((val).code, kDLFloat) \
<< (val_name) << " must be float type"; \
if ((val).bits == 32) { \
typedef float FloatType; \
{__VA_ARGS__} \
} else if ((val).bits == 64) { \
typedef double FloatType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << (val_name) << " can only be float32 or float64"; \
} \
} while (0)
/*
* Dispatch according to data type (int32, int64, float32 or float64):
*
* ATEN_DTYPE_SWITCH(array->dtype, DType, {
* // Now DType is the type corresponding to data type in array.
* // For instance, one can do this for a CPU array:
* DType *data = static_cast<DType *>(array->data);
* });
*/
#define ATEN_DTYPE_SWITCH(val, DType, val_name, ...) do { \
if ((val).code == kDLInt && (val).bits == 32) { \
typedef int32_t DType; \
{__VA_ARGS__} \
} else if ((val).code == kDLInt && (val).bits == 64) { \
typedef int64_t DType; \
{__VA_ARGS__} \
} else if ((val).code == kDLFloat && (val).bits == 32) { \
typedef float DType; \
{__VA_ARGS__} \
} else if ((val).code == kDLFloat && (val).bits == 64) { \
typedef double DType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << (val_name) << " can only be int32, int64, float32 or float64"; \
} \
} while (0)
/*
* Dispatch according to integral type of CSR graphs.
* Identical to ATEN_ID_TYPE_SWITCH except for a different error message.
*/
#define ATEN_CSR_DTYPE_SWITCH(val, DType, ...) do { \
if ((val).code == kDLInt && (val).bits == 32) { \
typedef int32_t DType; \
{__VA_ARGS__} \
} else if ((val).code == kDLInt && (val).bits == 64) { \
typedef int64_t DType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "CSR matrix data can only be int32 or int64"; \
} \
} while (0)
// Macro to dispatch according to device context and index type.
#define ATEN_CSR_SWITCH(csr, XPU, IdType, op, ...) \
ATEN_XPU_SWITCH((csr).indptr->ctx.device_type, XPU, op, { \
ATEN_ID_TYPE_SWITCH((csr).indptr->dtype, IdType, { \
{__VA_ARGS__} \
}); \
});
// Macro to dispatch according to device context and index type.
#define ATEN_COO_SWITCH(coo, XPU, IdType, op, ...) \
ATEN_XPU_SWITCH((coo).row->ctx.device_type, XPU, op, { \
ATEN_ID_TYPE_SWITCH((coo).row->dtype, IdType, { \
{__VA_ARGS__} \
}); \
});
// Macro to dispatch according to device context (allowing cuda)
#ifdef DGL_USE_CUDA
#define ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, op, ...) \
ATEN_XPU_SWITCH_CUDA((csr).indptr->ctx.device_type, XPU, op, { \
ATEN_ID_TYPE_SWITCH((csr).indptr->dtype, IdType, { \
{__VA_ARGS__} \
}); \
});
// Macro to dispatch according to device context and index type.
#define ATEN_COO_SWITCH_CUDA(coo, XPU, IdType, op, ...) \
ATEN_XPU_SWITCH_CUDA((coo).row->ctx.device_type, XPU, op, { \
ATEN_ID_TYPE_SWITCH((coo).row->dtype, IdType, { \
{__VA_ARGS__} \
}); \
});
#else // DGL_USE_CUDA
#define ATEN_CSR_SWITCH_CUDA ATEN_CSR_SWITCH
#define ATEN_COO_SWITCH_CUDA ATEN_COO_SWITCH
#endif // DGL_USE_CUDA
///////////////////////// Array checks //////////////////////////
#define IS_INT32(a) \
((a)->dtype.code == kDLInt && (a)->dtype.bits == 32)
#define IS_INT64(a) \
((a)->dtype.code == kDLInt && (a)->dtype.bits == 64)
#define IS_FLOAT32(a) \
((a)->dtype.code == kDLFloat && (a)->dtype.bits == 32)
#define IS_FLOAT64(a) \
((a)->dtype.code == kDLFloat && (a)->dtype.bits == 64)
#define CHECK_IF(cond, prop, value_name, dtype_name) \
CHECK(cond) << "Expecting " << (prop) << " of " << (value_name) << " to be " << (dtype_name)
#define CHECK_INT32(value, value_name) \
CHECK_IF(IS_INT32(value), "dtype", value_name, "int32")
#define CHECK_INT64(value, value_name) \
CHECK_IF(IS_INT64(value), "dtype", value_name, "int64")
#define CHECK_INT(value, value_name) \
CHECK_IF(IS_INT32(value) || IS_INT64(value), "dtype", value_name, "int32 or int64")
#define CHECK_FLOAT32(value, value_name) \
CHECK_IF(IS_FLOAT32(value), "dtype", value_name, "float32")
#define CHECK_FLOAT64(value, value_name) \
CHECK_IF(IS_FLOAT64(value), "dtype", value_name, "float64")
#define CHECK_FLOAT(value, value_name) \
CHECK_IF(IS_FLOAT32(value) || IS_FLOAT64(value), "dtype", value_name, "float32 or float64")
#define CHECK_NDIM(value, _ndim, value_name) \
CHECK_IF((value)->ndim == (_ndim), "ndim", value_name, _ndim)
#define CHECK_SAME_DTYPE(VAR1, VAR2) \
CHECK(VAR1->dtype == VAR2->dtype) \
<< "Expected " << (#VAR2) << " to be the same type as " << (#VAR1) << "(" \
<< (VAR1)->dtype << ")" \
<< ". But got " << (VAR2)->dtype << ".";
#define CHECK_SAME_CONTEXT(VAR1, VAR2) \
CHECK(VAR1->ctx == VAR2->ctx) \
<< "Expected " << (#VAR2) << " to have the same device context as " << (#VAR1) << "(" \
<< (VAR1)->ctx << ")" \
<< ". But got " << (VAR2)->ctx << ".";
#define CHECK_NO_OVERFLOW(dtype, val) \
do { \
if (sizeof(val) == 8 && (dtype).bits == 32) \
CHECK_LE((val), 0x7FFFFFFFL) << "int32 overflow for argument " << (#val) << "."; \
} while (0);
#endif // DGL_ATEN_MACRO_H_
/*!
* Copyright (c) 2020 by Contributors
* \file dgl/aten/spmat.h
* \brief Sparse matrix definitions
*/
#ifndef DGL_ATEN_SPMAT_H_
#define DGL_ATEN_SPMAT_H_
#include <string>
#include <vector>
#include "./types.h"
#include "../runtime/object.h"
namespace dgl {
/*!
* \brief Sparse format.
*/
enum class SparseFormat {
kAny = 0,
kCOO = 1,
kCSR = 2,
kCSC = 3,
kAuto = 4 // kAuto is a placeholder that indicates it would be materialized later.
};
// Parse sparse format from string.
inline SparseFormat ParseSparseFormat(const std::string& name) {
if (name == "coo")
return SparseFormat::kCOO;
else if (name == "csr")
return SparseFormat::kCSR;
else if (name == "csc")
return SparseFormat::kCSC;
else if (name == "any")
return SparseFormat::kAny;
else if (name == "auto")
return SparseFormat::kAuto;
else
LOG(FATAL) << "Sparse format not recognized";
return SparseFormat::kAny;
}
// Create string from sparse format.
inline std::string ToStringSparseFormat(SparseFormat sparse_format) {
if (sparse_format == SparseFormat::kCOO)
return std::string("coo");
else if (sparse_format == SparseFormat::kCSR)
return std::string("csr");
else if (sparse_format == SparseFormat::kCSC)
return std::string("csc");
else if (sparse_format == SparseFormat::kAny)
return std::string("any");
else
return std::string("auto");
}
// Sparse matrix object that is exposed to python API.
struct SparseMatrix : public runtime::Object {
// Sparse format.
int32_t format = 0;
// Shape of this matrix.
int64_t num_rows = 0, num_cols = 0;
// Index arrays. For CSR, it is {indptr, indices, data}. For COO, it is {row, col, data}.
std::vector<IdArray> indices;
// Boolean flags.
// TODO(minjie): We might revisit this later to provide a more general solution. Currently,
// we only consider aten::COOMatrix and aten::CSRMatrix.
std::vector<bool> flags;
SparseMatrix() {}
SparseMatrix(int32_t fmt, int64_t nrows, int64_t ncols,
const std::vector<IdArray>& idx,
const std::vector<bool>& flg)
: format(fmt), num_rows(nrows), num_cols(ncols), indices(idx), flags(flg) {}
static constexpr const char* _type_key = "aten.SparseMatrix";
DGL_DECLARE_OBJECT_TYPE_INFO(SparseMatrix, runtime::Object);
};
// Define SparseMatrixRef
DGL_DEFINE_OBJECT_REF(SparseMatrixRef, SparseMatrix);
} // namespace dgl
#endif // DGL_ATEN_SPMAT_H_
/*!
* Copyright (c) 2020 by Contributors
* \file dgl/aten/types.h
* \brief Array and ID types
*/
#ifndef DGL_ATEN_TYPES_H_
#define DGL_ATEN_TYPES_H_
#include <cstdint>
#include "../runtime/ndarray.h"
namespace dgl {
typedef uint64_t dgl_id_t;
typedef uint64_t dgl_type_t;
/*! \brief Type for dgl fomrat code, whose binary representation indices
* which sparse format is in use and which is not.
*
* Suppose the binary representation is xyz, then
* - x indicates whether csc is in use (1 for true and 0 for false).
* - y indicates whether csr is in use.
* - z indicates whether coo is in use.
*/
typedef uint8_t dgl_format_code_t;
using dgl::runtime::NDArray;
typedef NDArray IdArray;
typedef NDArray DegreeArray;
typedef NDArray BoolArray;
typedef NDArray IntArray;
typedef NDArray FloatArray;
typedef NDArray TypeArray;
} // namespace dgl
#endif // DGL_ATEN_TYPES_H_
......@@ -127,6 +127,14 @@ class NDArray {
inline const DLTensor* operator->() const;
/*! \return True if the ndarray is contiguous. */
bool IsContiguous() const;
/*! \return the data pointer with type. */
template <typename T>
inline T* Ptr() const {
if (!defined())
return nullptr;
else
return static_cast<T*>(operator->()->data);
}
/*!
* \brief Copy data content from another array.
* \param other The source array to be copied from.
......@@ -575,6 +583,50 @@ namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::runtime::NDArray, true);
} // namespace dmlc
///////////////// Operator overloading for NDArray /////////////////
dgl::runtime::NDArray operator + (const dgl::runtime::NDArray& a1,
const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator - (const dgl::runtime::NDArray& a1,
const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator * (const dgl::runtime::NDArray& a1,
const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator / (const dgl::runtime::NDArray& a1,
const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator + (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator - (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator * (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator / (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator + (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator - (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator * (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator / (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator - (const dgl::runtime::NDArray& array);
dgl::runtime::NDArray operator > (const dgl::runtime::NDArray& a1,
const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator < (const dgl::runtime::NDArray& a1,
const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator >= (const dgl::runtime::NDArray& a1,
const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator <= (const dgl::runtime::NDArray& a1,
const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator == (const dgl::runtime::NDArray& a1,
const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator != (const dgl::runtime::NDArray& a1,
const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator > (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator < (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator >= (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator <= (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator == (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator != (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator > (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator < (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator >= (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator <= (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator == (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator != (int64_t lhs, const dgl::runtime::NDArray& a2);
///////////////// Operator overloading for DLDataType /////////////////
/*! \brief Check whether two data types are the same.*/
......
......@@ -228,11 +228,49 @@ def context(input):
pass
def device_type(ctx):
"""Return a str representing device type"""
"""Return a str representing device type.
Parameters
----------
ctx : Device context object.
Device context.
Returns
-------
str
"""
pass
def device_id(ctx):
"""Return device index"""
"""Return device index.
For CPU, the index does not matter. For GPU, the index means which GPU
device on the machine.
Parameters
----------
ctx : Device context object.
Device context.
Returns
-------
int
The device index.
"""
pass
def to_backend_ctx(dglctx):
"""Convert a DGL context object to a backend context.
Parameters
----------
dglctx : dgl.ndarray.DGLContext
DGL context object. See _ffi.runtime_types for definition.
Returns
-------
ctx : framework-specific context object.
"""
pass
def astype(input, ty):
......
......@@ -109,6 +109,15 @@ def device_type(ctx):
def device_id(ctx):
return ctx.device_id
def to_backend_ctx(dglctx):
dev_type = dglctx.device_type
if dev_type == 1:
return mx.cpu()
elif dev_type == 2:
return mx.gpu(dglctx.device_id)
else:
raise ValueError('Unsupported DGL device context:', dglctx)
def astype(input, ty):
return nd.cast(input, ty)
......
......@@ -70,14 +70,24 @@ def context(input):
return input.device
def device_type(ctx):
return ctx.type
return th.device(ctx).type
def device_id(ctx):
ctx = th.device(ctx)
if ctx.index is None:
return 0
else:
return ctx.index
def to_backend_ctx(dglctx):
dev_type = dglctx.device_type
if dev_type == 1:
return th.device('cpu')
elif dev_type == 2:
return th.device('cuda', dglctx.device_id)
else:
raise ValueError('Unsupported DGL device context:', dglctx)
def astype(input, ty):
return input.type(ty)
......
......@@ -118,6 +118,14 @@ def device_type(ctx):
def device_id(ctx):
return tf.DeviceSpec.from_string(ctx).device_index
def to_backend_ctx(dglctx):
dev_type = dglctx.device_type
if dev_type == 1:
return "/cpu:0"
elif dev_type == 2:
return "/gpu:%d" % (dglctx.device_id)
else:
raise ValueError('Unsupported DGL device context:', dglctx)
def astype(input, ty):
return tf.cast(input, dtype=ty)
......
......@@ -4047,14 +4047,36 @@ class DGLHeteroGraph(object):
edges = F.tensor(edges)
return F.boolean_mask(edges, e_mask)
@property
def device(self):
"""Get the device context of this graph.
Examples
--------
The following example uses PyTorch backend.
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
>>> print(g.device)
device(type='cpu')
>>> g = g.to('cuda:0')
>>> print(g.device)
device(type='cuda', index=0)
Returns
-------
Device context object
"""
return F.to_backend_ctx(self._graph.ctx)
def to(self, ctx, **kwargs): # pylint: disable=invalid-name
"""Move both ndata and edata to the targeted mode (cpu/gpu)
Framework agnostic
"""Move ndata, edata and graph structure to the targeted device context (cpu/gpu).
Parameters
----------
ctx : framework-specific context object
ctx : Framework-specific device context object
The context to move data to.
kwargs : Key-word arguments.
Key-word arguments fed to the framework copy function.
Returns
-------
......@@ -4069,15 +4091,25 @@ class DGLHeteroGraph(object):
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
>>> g.edges['plays'].data['h'] = torch.tensor([[0.], [1.], [2.], [3.]])
>>> g = g.to(torch.device('cuda:0'))
>>> g1 = g.to(torch.device('cuda:0'))
>>> print(g1.device)
device(type='cuda', index=0)
>>> print(g.device)
device(type='cpu')
"""
for i in range(len(self._node_frames)):
for k in self._node_frames[i].keys():
self._node_frames[i][k] = F.copy_to(self._node_frames[i][k], ctx)
for i in range(len(self._edge_frames)):
for k in self._edge_frames[i].keys():
self._edge_frames[i][k] = F.copy_to(self._edge_frames[i][k], ctx)
return self
new_nframes = []
for nframe in self._node_frames:
new_feats = {k : F.copy_to(feat, ctx) for k, feat in nframe.items()}
new_nframes.append(FrameRef(Frame(new_feats)))
new_eframes = []
for eframe in self._edge_frames:
new_feats = {k : F.copy_to(feat, ctx) for k, feat in eframe.items()}
new_eframes.append(FrameRef(Frame(new_feats)))
# TODO(minjie): replace the following line with the commented one to enable GPU graph.
new_gidx = self._graph
#new_gidx = self._graph.copy_to(utils.to_dgl_context(ctx))
return DGLHeteroGraph(new_gidx, self.ntypes, self.etypes,
new_nframes, new_eframes)
def local_var(self):
"""Return a heterograph object that can be used in a local function scope.
......
......@@ -161,6 +161,7 @@ class HeteroGraphIndex(ObjectBase):
"""
return _CAPI_DGLHeteroDataType(self)
@property
def ctx(self):
"""Return the context of this graph index.
......
......@@ -819,15 +819,15 @@ def compact_graphs(graphs, always_preserve=None):
# TODO(BarclayII): we ideally need to remove this constraint.
ntypes = graphs[0].ntypes
graph_dtype = graphs[0]._idtype_str
graph_ctx = graphs[0]._graph.ctx()
graph_ctx = graphs[0]._graph.ctx
for g in graphs:
assert ntypes == g.ntypes, \
("All graphs should have the same node types in the same order, got %s and %s" %
ntypes, g.ntypes)
assert graph_dtype == g._idtype_str, "Expect graph data type to be {}, but got {}".format(
graph_dtype, g._idtype_str)
assert graph_ctx == g._graph.ctx(), "Expect graph device to be {}, but got {}".format(
graph_ctx, g._graph.ctx())
assert graph_ctx == g._graph.ctx, "Expect graph device to be {}, but got {}".format(
graph_ctx, g._graph.ctx)
# Process the dictionary or tensor of "always preserve" nodes
if always_preserve is None:
......
......@@ -6,45 +6,95 @@
#ifndef DGL_ARRAY_ARITH_H_
#define DGL_ARRAY_ARITH_H_
#ifdef __CUDACC__
#define DGLDEVICE __device__
#define DGLINLINE __forceinline__
#else
#define DGLDEVICE
#define DGLINLINE inline
#endif // __CUDACC__
namespace dgl {
namespace aten {
namespace arith {
struct Add {
template <typename T>
inline static T Call(const T& t1, const T& t2) {
static DGLINLINE DGLDEVICE T Call(const T& t1, const T& t2) {
return t1 + t2;
}
};
struct Sub {
template <typename T>
inline static T Call(const T& t1, const T& t2) {
static DGLINLINE DGLDEVICE T Call(const T& t1, const T& t2) {
return t1 - t2;
}
};
struct Mul {
template <typename T>
inline static T Call(const T& t1, const T& t2) {
static DGLINLINE DGLDEVICE T Call(const T& t1, const T& t2) {
return t1 * t2;
}
};
struct Div {
template <typename T>
inline static T Call(const T& t1, const T& t2) {
static DGLINLINE DGLDEVICE T Call(const T& t1, const T& t2) {
return t1 / t2;
}
};
struct GT {
template <typename T>
static DGLINLINE DGLDEVICE bool Call(const T& t1, const T& t2) {
return t1 > t2;
}
};
struct LT {
template <typename T>
inline static bool Call(const T& t1, const T& t2) {
static DGLINLINE DGLDEVICE bool Call(const T& t1, const T& t2) {
return t1 < t2;
}
};
struct GE {
template <typename T>
static DGLINLINE DGLDEVICE bool Call(const T& t1, const T& t2) {
return t1 >= t2;
}
};
struct LE {
template <typename T>
static DGLINLINE DGLDEVICE bool Call(const T& t1, const T& t2) {
return t1 <= t2;
}
};
struct EQ {
template <typename T>
static DGLINLINE DGLDEVICE bool Call(const T& t1, const T& t2) {
return t1 == t2;
}
};
struct NE {
template <typename T>
static DGLINLINE DGLDEVICE bool Call(const T& t1, const T& t2) {
return t1 != t2;
}
};
struct Neg {
template <typename T>
static DGLINLINE DGLDEVICE T Call(const T& t1) {
return -t1;
}
};
} // namespace arith
} // namespace aten
} // namespace dgl
......
......@@ -43,7 +43,7 @@ IdArray Range(int64_t low, int64_t high, uint8_t nbits, DLContext ctx) {
IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx) {
IdArray ret;
ATEN_XPU_SWITCH(ctx.device_type, XPU, "Full", {
ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "Full", {
if (nbits == 32) {
ret = impl::Full<XPU, int32_t>(val, length, ctx);
} else if (nbits == 64) {
......@@ -70,136 +70,10 @@ IdArray AsNumBits(IdArray arr, uint8_t bits) {
return ret;
}
IdArray Add(IdArray lhs, IdArray rhs) {
IdArray ret;
CHECK_EQ(lhs->ctx, rhs->ctx) << "Both operands should have the same device context";
CHECK_EQ(lhs->dtype, rhs->dtype) << "Both operands should have the same dtype";
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, "Add", {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Add>(lhs, rhs);
});
});
return ret;
}
IdArray Sub(IdArray lhs, IdArray rhs) {
IdArray ret;
CHECK_EQ(lhs->ctx, rhs->ctx) << "Both operands should have the same device context";
CHECK_EQ(lhs->dtype, rhs->dtype) << "Both operands should have the same dtype";
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, "Sub", {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Sub>(lhs, rhs);
});
});
return ret;
}
IdArray Mul(IdArray lhs, IdArray rhs) {
IdArray ret;
CHECK_EQ(lhs->ctx, rhs->ctx) << "Both operands should have the same device context";
CHECK_EQ(lhs->dtype, rhs->dtype) << "Both operands should have the same dtype";
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, "Mul", {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Mul>(lhs, rhs);
});
});
return ret;
}
IdArray Div(IdArray lhs, IdArray rhs) {
IdArray ret;
CHECK_EQ(lhs->ctx, rhs->ctx) << "Both operands should have the same device context";
CHECK_EQ(lhs->dtype, rhs->dtype) << "Both operands should have the same dtype";
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, "Div", {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Div>(lhs, rhs);
});
});
return ret;
}
IdArray Add(IdArray lhs, dgl_id_t rhs) {
IdArray ret;
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, "Add", {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Add>(lhs, rhs);
});
});
return ret;
}
IdArray Sub(IdArray lhs, dgl_id_t rhs) {
IdArray ret;
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, "Sub", {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Sub>(lhs, rhs);
});
});
return ret;
}
IdArray Mul(IdArray lhs, dgl_id_t rhs) {
IdArray ret;
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, "Mul", {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Mul>(lhs, rhs);
});
});
return ret;
}
IdArray Div(IdArray lhs, dgl_id_t rhs) {
IdArray ret;
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, "Div", {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Div>(lhs, rhs);
});
});
return ret;
}
IdArray Add(dgl_id_t lhs, IdArray rhs) {
return Add(rhs, lhs);
}
IdArray Sub(dgl_id_t lhs, IdArray rhs) {
IdArray ret;
ATEN_XPU_SWITCH(rhs->ctx.device_type, XPU, "Sub", {
ATEN_ID_TYPE_SWITCH(rhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Sub>(lhs, rhs);
});
});
return ret;
}
IdArray Mul(dgl_id_t lhs, IdArray rhs) {
return Mul(rhs, lhs);
}
IdArray Div(dgl_id_t lhs, IdArray rhs) {
IdArray ret;
ATEN_XPU_SWITCH(rhs->ctx.device_type, XPU, "Div", {
ATEN_ID_TYPE_SWITCH(rhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Div>(lhs, rhs);
});
});
return ret;
}
BoolArray LT(IdArray lhs, dgl_id_t rhs) {
BoolArray ret;
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, "LT", {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::LT>(lhs, rhs);
});
});
return ret;
}
IdArray HStack(IdArray lhs, IdArray rhs) {
IdArray ret;
CHECK_EQ(lhs->ctx, rhs->ctx) << "Both operands should have the same device context";
CHECK_EQ(lhs->dtype, rhs->dtype) << "Both operands should have the same dtype";
CHECK_SAME_CONTEXT(lhs, rhs);
CHECK_SAME_DTYPE(lhs, rhs);
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, "HStack", {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::HStack<XPU, IdType>(lhs, rhs);
......@@ -210,8 +84,10 @@ IdArray HStack(IdArray lhs, IdArray rhs) {
NDArray IndexSelect(NDArray array, IdArray index) {
NDArray ret;
// TODO(BarclayII): check if array and index match in context
ATEN_XPU_SWITCH(array->ctx.device_type, XPU, "IndexSelect", {
CHECK_SAME_CONTEXT(array, index);
CHECK_EQ(array->ndim, 1) << "Only support select values from 1D array.";
CHECK_EQ(index->ndim, 1) << "Index array must be an 1D array.";
ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "IndexSelect", {
ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
ATEN_ID_TYPE_SWITCH(index->dtype, IdType, {
ret = impl::IndexSelect<XPU, DType, IdType>(array, index);
......@@ -223,8 +99,9 @@ NDArray IndexSelect(NDArray array, IdArray index) {
template<typename ValueType>
ValueType IndexSelect(NDArray array, uint64_t index) {
CHECK_EQ(array->ndim, 1) << "Only support select values from 1D array.";
ValueType ret = 0;
ATEN_XPU_SWITCH(array->ctx.device_type, XPU, "IndexSelect", {
ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "IndexSelect", {
ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
ret = impl::IndexSelect<XPU, DType>(array, index);
});
......@@ -305,8 +182,10 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths) {
///////////////////////// CSR routines //////////////////////////
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
CHECK(col >= 0 && col < csr.num_cols) << "Invalid col index: " << col;
bool ret = false;
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRIsNonZero", {
ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRIsNonZero", {
ret = impl::CSRIsNonZero<XPU, IdType>(csr, row, col);
});
return ret;
......@@ -314,7 +193,11 @@ bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
NDArray ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRIsNonZero", {
CHECK_SAME_DTYPE(csr.indices, row);
CHECK_SAME_DTYPE(csr.indices, col);
CHECK_SAME_CONTEXT(csr.indices, row);
CHECK_SAME_CONTEXT(csr.indices, col);
ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRIsNonZero", {
ret = impl::CSRIsNonZero<XPU, IdType>(csr, row, col);
});
return ret;
......@@ -329,8 +212,9 @@ bool CSRHasDuplicate(CSRMatrix csr) {
}
int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {
CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
int64_t ret = 0;
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRGetRowNNZ", {
ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRGetRowNNZ", {
ret = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
});
return ret;
......@@ -338,29 +222,35 @@ int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {
NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray row) {
NDArray ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRGetRowNNZ", {
CHECK_SAME_DTYPE(csr.indices, row);
CHECK_SAME_CONTEXT(csr.indices, row);
ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRGetRowNNZ", {
ret = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
});
return ret;
}
NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
NDArray ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRGetRowColumnIndices", {
ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRGetRowColumnIndices", {
ret = impl::CSRGetRowColumnIndices<XPU, IdType>(csr, row);
});
return ret;
}
NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
NDArray ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRGetRowData", {
ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRGetRowData", {
ret = impl::CSRGetRowData<XPU, IdType>(csr, row);
});
return ret;
}
NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col) {
CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
CHECK(col >= 0 && col < csr.num_cols) << "Invalid col index: " << col;
NDArray ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRGetData", {
ret = impl::CSRGetData<XPU, IdType>(csr, row, col);
......@@ -370,6 +260,10 @@ NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col) {
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
NDArray ret;
CHECK_SAME_DTYPE(csr.indices, rows);
CHECK_SAME_DTYPE(csr.indices, cols);
CHECK_SAME_CONTEXT(csr.indices, rows);
CHECK_SAME_CONTEXT(csr.indices, cols);
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRGetData", {
ret = impl::CSRGetData<XPU, IdType>(csr, rows, cols);
});
......@@ -378,6 +272,10 @@ NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
std::vector<NDArray> CSRGetDataAndIndices(
CSRMatrix csr, NDArray rows, NDArray cols) {
CHECK_SAME_DTYPE(csr.indices, rows);
CHECK_SAME_DTYPE(csr.indices, cols);
CHECK_SAME_CONTEXT(csr.indices, rows);
CHECK_SAME_CONTEXT(csr.indices, cols);
std::vector<NDArray> ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRGetDataAndIndices", {
ret = impl::CSRGetDataAndIndices<XPU, IdType>(csr, rows, cols);
......@@ -414,6 +312,9 @@ COOMatrix CSRToCOO(CSRMatrix csr, bool data_as_order) {
}
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
CHECK(start >= 0 && start < csr.num_rows) << "Invalid start index: " << start;
CHECK(end >= 0 && end <= csr.num_rows) << "Invalid end index: " << end;
CHECK_GE(end, start);
CSRMatrix ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRSliceRows", {
ret = impl::CSRSliceRows<XPU, IdType>(csr, start, end);
......@@ -422,6 +323,8 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
}
CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
CHECK_SAME_DTYPE(csr.indices, rows);
CHECK_SAME_CONTEXT(csr.indices, rows);
CSRMatrix ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRSliceRows", {
ret = impl::CSRSliceRows<XPU, IdType>(csr, rows);
......@@ -430,6 +333,10 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
}
CSRMatrix CSRSliceMatrix(CSRMatrix csr, NDArray rows, NDArray cols) {
CHECK_SAME_DTYPE(csr.indices, rows);
CHECK_SAME_DTYPE(csr.indices, cols);
CHECK_SAME_CONTEXT(csr.indices, rows);
CHECK_SAME_CONTEXT(csr.indices, cols);
CSRMatrix ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRSliceMatrix", {
ret = impl::CSRSliceMatrix<XPU, IdType>(csr, rows, cols);
......
/*!
* Copyright (c) 2019 by Contributors
* \file array/array_aritch.cc
* \brief DGL array arithmetic operations
*/
#include <dgl/array.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
#include "../c_api_common.h"
#include "./array_op.h"
#include "./arith.h"
using namespace dgl::runtime;
namespace dgl {
namespace aten {
// Generate operators with both operations being NDArrays.
#define BINARY_ELEMENT_OP(name, op) \
IdArray name(IdArray lhs, IdArray rhs) { \
IdArray ret; \
CHECK_SAME_DTYPE(lhs, rhs); \
CHECK_SAME_CONTEXT(lhs, rhs); \
ATEN_XPU_SWITCH_CUDA(lhs->ctx.device_type, XPU, #name, { \
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, { \
ret = impl::BinaryElewise<XPU, IdType, arith::op>(lhs, rhs); \
}); \
}); \
return ret; \
}
// Generate operators with only lhs being NDArray.
#define BINARY_ELEMENT_OP_L(name, op) \
IdArray name(IdArray lhs, int64_t rhs) { \
IdArray ret; \
ATEN_XPU_SWITCH_CUDA(lhs->ctx.device_type, XPU, #name, { \
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, { \
ret = impl::BinaryElewise<XPU, IdType, arith::op>(lhs, rhs); \
}); \
}); \
return ret; \
}
// Generate operators with only lhs being NDArray.
#define BINARY_ELEMENT_OP_R(name, op) \
IdArray name(int64_t lhs, IdArray rhs) { \
IdArray ret; \
ATEN_XPU_SWITCH_CUDA(rhs->ctx.device_type, XPU, #name, { \
ATEN_ID_TYPE_SWITCH(rhs->dtype, IdType, { \
ret = impl::BinaryElewise<XPU, IdType, arith::op>(lhs, rhs); \
}); \
}); \
return ret; \
}
// Generate operators with only lhs being NDArray.
#define UNARY_ELEMENT_OP(name, op) \
IdArray name(IdArray lhs) { \
IdArray ret; \
ATEN_XPU_SWITCH_CUDA(lhs->ctx.device_type, XPU, #name, { \
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, { \
ret = impl::UnaryElewise<XPU, IdType, arith::op>(lhs); \
}); \
}); \
return ret; \
}
BINARY_ELEMENT_OP(Add, Add)
BINARY_ELEMENT_OP(Sub, Sub)
BINARY_ELEMENT_OP(Mul, Mul)
BINARY_ELEMENT_OP(Div, Div)
BINARY_ELEMENT_OP(GT, GT)
BINARY_ELEMENT_OP(LT, LT)
BINARY_ELEMENT_OP(GE, GE)
BINARY_ELEMENT_OP(LE, LE)
BINARY_ELEMENT_OP(EQ, EQ)
BINARY_ELEMENT_OP(NE, NE)
BINARY_ELEMENT_OP_L(Add, Add)
BINARY_ELEMENT_OP_L(Sub, Sub)
BINARY_ELEMENT_OP_L(Mul, Mul)
BINARY_ELEMENT_OP_L(Div, Div)
BINARY_ELEMENT_OP_L(GT, GT)
BINARY_ELEMENT_OP_L(LT, LT)
BINARY_ELEMENT_OP_L(GE, GE)
BINARY_ELEMENT_OP_L(LE, LE)
BINARY_ELEMENT_OP_L(EQ, EQ)
BINARY_ELEMENT_OP_L(NE, NE)
BINARY_ELEMENT_OP_R(Add, Add)
BINARY_ELEMENT_OP_R(Sub, Sub)
BINARY_ELEMENT_OP_R(Mul, Mul)
BINARY_ELEMENT_OP_R(Div, Div)
BINARY_ELEMENT_OP_R(GT, GT)
BINARY_ELEMENT_OP_R(LT, LT)
BINARY_ELEMENT_OP_R(GE, GE)
BINARY_ELEMENT_OP_R(LE, LE)
BINARY_ELEMENT_OP_R(EQ, EQ)
BINARY_ELEMENT_OP_R(NE, NE)
UNARY_ELEMENT_OP(Neg, Neg)
} // namespace aten
} // namespace dgl
///////////////// Operator overloading for NDArray /////////////////
NDArray operator + (const NDArray& lhs, const NDArray& rhs) {
return dgl::aten::Add(lhs, rhs);
}
NDArray operator - (const NDArray& lhs, const NDArray& rhs) {
return dgl::aten::Sub(lhs, rhs);
}
NDArray operator * (const NDArray& lhs, const NDArray& rhs) {
return dgl::aten::Mul(lhs, rhs);
}
NDArray operator / (const NDArray& lhs, const NDArray& rhs) {
return dgl::aten::Div(lhs, rhs);
}
NDArray operator + (const NDArray& lhs, int64_t rhs) {
return dgl::aten::Add(lhs, rhs);
}
NDArray operator - (const NDArray& lhs, int64_t rhs) {
return dgl::aten::Sub(lhs, rhs);
}
NDArray operator * (const NDArray& lhs, int64_t rhs) {
return dgl::aten::Mul(lhs, rhs);
}
NDArray operator / (const NDArray& lhs, int64_t rhs) {
return dgl::aten::Div(lhs, rhs);
}
NDArray operator + (int64_t lhs, const NDArray& rhs) {
return dgl::aten::Add(lhs, rhs);
}
NDArray operator - (int64_t lhs, const NDArray& rhs) {
return dgl::aten::Sub(lhs, rhs);
}
NDArray operator * (int64_t lhs, const NDArray& rhs) {
return dgl::aten::Mul(lhs, rhs);
}
NDArray operator / (int64_t lhs, const NDArray& rhs) {
return dgl::aten::Div(lhs, rhs);
}
NDArray operator - (const NDArray& array) {
return dgl::aten::Neg(array);
}
NDArray operator > (const NDArray& lhs, const NDArray& rhs) {
return dgl::aten::GT(lhs, rhs);
}
NDArray operator < (const NDArray& lhs, const NDArray& rhs) {
return dgl::aten::LT(lhs, rhs);
}
NDArray operator >= (const NDArray& lhs, const NDArray& rhs) {
return dgl::aten::GE(lhs, rhs);
}
NDArray operator <= (const NDArray& lhs, const NDArray& rhs) {
return dgl::aten::LE(lhs, rhs);
}
NDArray operator == (const NDArray& lhs, const NDArray& rhs) {
return dgl::aten::EQ(lhs, rhs);
}
NDArray operator != (const NDArray& lhs, const NDArray& rhs) {
return dgl::aten::NE(lhs, rhs);
}
NDArray operator > (const NDArray& lhs, int64_t rhs) {
return dgl::aten::GT(lhs, rhs);
}
NDArray operator < (const NDArray& lhs, int64_t rhs) {
return dgl::aten::LT(lhs, rhs);
}
NDArray operator >= (const NDArray& lhs, int64_t rhs) {
return dgl::aten::GE(lhs, rhs);
}
NDArray operator <= (const NDArray& lhs, int64_t rhs) {
return dgl::aten::LE(lhs, rhs);
}
NDArray operator == (const NDArray& lhs, int64_t rhs) {
return dgl::aten::EQ(lhs, rhs);
}
NDArray operator != (const NDArray& lhs, int64_t rhs) {
return dgl::aten::NE(lhs, rhs);
}
NDArray operator > (int64_t lhs, const NDArray& rhs) {
return dgl::aten::GT(lhs, rhs);
}
NDArray operator < (int64_t lhs, const NDArray& rhs) {
return dgl::aten::LT(lhs, rhs);
}
NDArray operator >= (int64_t lhs, const NDArray& rhs) {
return dgl::aten::GE(lhs, rhs);
}
NDArray operator <= (int64_t lhs, const NDArray& rhs) {
return dgl::aten::LE(lhs, rhs);
}
NDArray operator == (int64_t lhs, const NDArray& rhs) {
return dgl::aten::EQ(lhs, rhs);
}
NDArray operator != (int64_t lhs, const NDArray& rhs) {
return dgl::aten::NE(lhs, rhs);
}
......@@ -34,6 +34,9 @@ IdArray BinaryElewise(IdArray lhs, IdType rhs);
template <DLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdType lhs, IdArray rhs);
template <DLDeviceType XPU, typename IdType, typename Op>
IdArray UnaryElewise(IdArray array);
template <DLDeviceType XPU, typename IdType>
IdArray HStack(IdArray arr1, IdArray arr2);
......
......@@ -48,6 +48,8 @@ IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
const IdType* lhs_data = static_cast<IdType*>(lhs->data);
const IdType* rhs_data = static_cast<IdType*>(rhs->data);
IdType* ret_data = static_cast<IdType*>(ret->data);
// TODO(minjie): we should split the loop into segments for better cache locality.
#pragma omp parallel for
for (int64_t i = 0; i < lhs->shape[0]; ++i) {
ret_data[i] = Op::Call(lhs_data[i], rhs_data[i]);
}
......@@ -58,18 +60,30 @@ template IdArray BinaryElewise<kDLCPU, int32_t, arith::Add>(IdArray lhs, IdArray
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::NE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::NE>(IdArray lhs, IdArray rhs);
template <DLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdType rhs) {
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
const IdType* lhs_data = static_cast<IdType*>(lhs->data);
IdType* ret_data = static_cast<IdType*>(ret->data);
// TODO(minjie): we should split the loop into segments for better cache locality.
#pragma omp parallel for
for (int64_t i = 0; i < lhs->shape[0]; ++i) {
ret_data[i] = Op::Call(lhs_data[i], rhs);
}
......@@ -80,18 +94,30 @@ template IdArray BinaryElewise<kDLCPU, int32_t, arith::Add>(IdArray lhs, int32_t
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Sub>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mul>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Div>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::EQ>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::NE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Add>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Sub>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mul>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Div>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::EQ>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::NE>(IdArray lhs, int64_t rhs);
template <DLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdType lhs, IdArray rhs) {
IdArray ret = NewIdArray(rhs->shape[0], rhs->ctx, rhs->dtype.bits);
const IdType* rhs_data = static_cast<IdType*>(rhs->data);
IdType* ret_data = static_cast<IdType*>(ret->data);
// TODO(minjie): we should split the loop into segments for better cache locality.
#pragma omp parallel for
for (int64_t i = 0; i < rhs->shape[0]; ++i) {
ret_data[i] = Op::Call(lhs, rhs_data[i]);
}
......@@ -102,12 +128,38 @@ template IdArray BinaryElewise<kDLCPU, int32_t, arith::Add>(int32_t lhs, IdArray
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Sub>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mul>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Div>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::EQ>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::NE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Add>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Sub>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mul>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Div>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::EQ>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::NE>(int64_t lhs, IdArray rhs);
template <DLDeviceType XPU, typename IdType, typename Op>
IdArray UnaryElewise(IdArray lhs) {
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
const IdType* lhs_data = static_cast<IdType*>(lhs->data);
IdType* ret_data = static_cast<IdType*>(ret->data);
// TODO(minjie): we should split the loop into segments for better cache locality.
#pragma omp parallel for
for (int64_t i = 0; i < lhs->shape[0]; ++i) {
ret_data[i] = Op::Call(lhs_data[i]);
}
return ret;
}
template IdArray UnaryElewise<kDLCPU, int32_t, arith::Neg>(IdArray lhs);
template IdArray UnaryElewise<kDLCPU, int64_t, arith::Neg>(IdArray lhs);
///////////////////////////// HStack /////////////////////////////
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment