/*! * Copyright (c) 2019 by Contributors * \file array/array_op.h * \brief Array operator templates */ #ifndef DGL_ARRAY_ARRAY_OP_H_ #define DGL_ARRAY_ARRAY_OP_H_ #include #include namespace dgl { namespace aten { namespace impl { template IdArray Full(IdType val, int64_t length, DLContext ctx); template IdArray Range(IdType low, IdType high, DLContext ctx); template IdArray AsNumBits(IdArray arr, uint8_t bits); template IdArray BinaryElewise(IdArray lhs, IdArray rhs); template IdArray BinaryElewise(IdArray lhs, IdType rhs); template IdArray BinaryElewise(IdType lhs, IdArray rhs); template IdArray HStack(IdArray arr1, IdArray arr2); template IdArray IndexSelect(IdArray array, IdArray index); template int64_t IndexSelect(IdArray array, int64_t index); template IdArray Relabel_(const std::vector& arrays); // sparse arrays template bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col); template runtime::NDArray CSRIsNonZero(CSRMatrix csr, runtime::NDArray row, runtime::NDArray col); template bool CSRHasDuplicate(CSRMatrix csr); template int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row); template runtime::NDArray CSRGetRowNNZ(CSRMatrix csr, runtime::NDArray row); template runtime::NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row); template runtime::NDArray CSRGetRowData(CSRMatrix csr, int64_t row); template runtime::NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col); template runtime::NDArray CSRGetData(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols); template std::vector CSRGetDataAndIndices( CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols); template CSRMatrix CSRTranspose(CSRMatrix csr); // Convert CSR to COO template COOMatrix CSRToCOO(CSRMatrix csr); // Convert CSR to COO using data array as order template COOMatrix CSRToCOODataAsOrder(CSRMatrix csr); template CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end); template CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows); template CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols); template void CSRSort(CSRMatrix csr); template bool COOHasDuplicate(COOMatrix coo); template CSRMatrix COOToCSR(COOMatrix coo); } // namespace impl } // namespace aten } // namespace dgl #endif // DGL_ARRAY_ARRAY_OP_H_