/*! * Copyright (c) 2019 by Contributors * \file array/array_op.h * \brief Array operator templates */ #ifndef DGL_ARRAY_ARRAY_OP_H_ #define DGL_ARRAY_ARRAY_OP_H_ #include #include #include #include #include namespace dgl { namespace aten { namespace impl { template IdArray Full(IdType val, int64_t length, DGLContext ctx); template IdArray Range(IdType low, IdType high, DGLContext ctx); template IdArray AsNumBits(IdArray arr, uint8_t bits); template IdArray BinaryElewise(IdArray lhs, IdArray rhs); template IdArray BinaryElewise(IdArray lhs, IdType rhs); template IdArray BinaryElewise(IdType lhs, IdArray rhs); template IdArray UnaryElewise(IdArray array); template NDArray IndexSelect(NDArray array, IdArray index); template DType IndexSelect(NDArray array, int64_t index); template IdArray NonZero(BoolArray bool_arr); template std::pair Sort(IdArray array, int num_bits); template NDArray Scatter(NDArray array, IdArray indices); template void Scatter_(IdArray index, NDArray value, NDArray out); template NDArray Repeat(NDArray array, IdArray repeats); template IdArray Relabel_(const std::vector& arrays); template NDArray Concat(const std::vector& arrays); template std::tuple Pack(NDArray array, DType pad_value); template std::pair ConcatSlices(NDArray array, IdArray lengths); template IdArray CumSum(IdArray array, bool prepend_zero); template IdArray NonZero(NDArray array); // sparse arrays template bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col); template runtime::NDArray CSRIsNonZero(CSRMatrix csr, runtime::NDArray row, runtime::NDArray col); template bool CSRHasDuplicate(CSRMatrix csr); template int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row); template runtime::NDArray CSRGetRowNNZ(CSRMatrix csr, runtime::NDArray row); template runtime::NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row); template runtime::NDArray CSRGetRowData(CSRMatrix csr, int64_t row); template bool CSRIsSorted(CSRMatrix csr); template runtime::NDArray CSRGetData( CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols, bool return_eids, runtime::NDArray weights, DType filler); template runtime::NDArray CSRGetData( CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols, runtime::NDArray weights, DType filler) { return CSRGetData(csr, rows, cols, false, weights, filler); } template NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) { return CSRGetData(csr, rows, cols, true, NullArray(rows->dtype), -1); } template std::vector CSRGetDataAndIndices( CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols); template CSRMatrix CSRTranspose(CSRMatrix csr); // Convert CSR to COO template COOMatrix CSRToCOO(CSRMatrix csr); // Convert CSR to COO using data array as order template COOMatrix CSRToCOODataAsOrder(CSRMatrix csr); template CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end); template CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows); template CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols); template void CSRSort_(CSRMatrix* csr); template std::pair CSRSortByTag( const CSRMatrix &csr, IdArray tag_array, int64_t num_tags); template CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids); template COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids); template CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries); // FloatType is the type of probability data. template COOMatrix CSRRowWiseSampling( CSRMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace); // FloatType is the type of probability data. template COOMatrix CSRRowWisePerEtypeSampling( CSRMatrix mat, IdArray rows, IdArray etypes, const std::vector& num_samples, FloatArray prob, bool replace, bool etype_sorted); template COOMatrix CSRRowWiseSamplingUniform( CSRMatrix mat, IdArray rows, int64_t num_samples, bool replace); template COOMatrix CSRRowWisePerEtypeSamplingUniform( CSRMatrix mat, IdArray rows, IdArray etypes, const std::vector& num_samples, bool replace, bool etype_sorted); // FloatType is the type of weight data. template COOMatrix CSRRowWiseTopk( CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending); template COOMatrix CSRRowWiseSamplingBiased( CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray tag_offset, FloatArray bias, bool replace); template std::pair CSRGlobalUniformNegativeSampling( const CSRMatrix& csr, int64_t num_samples, int num_trials, bool exclude_self_loops, bool replace, double redundancy); // Union CSRMatrixes template CSRMatrix UnionCsr(const std::vector& csrs); template std::tuple CSRToSimple(CSRMatrix csr); /////////////////////////////////////////////////////////////////////////////////////////// template bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col); template runtime::NDArray COOIsNonZero(COOMatrix coo, runtime::NDArray row, runtime::NDArray col); template bool COOHasDuplicate(COOMatrix coo); template int64_t COOGetRowNNZ(COOMatrix coo, int64_t row); template runtime::NDArray COOGetRowNNZ(COOMatrix coo, runtime::NDArray row); template std::pair COOGetRowDataAndIndices(COOMatrix coo, int64_t row); template std::vector COOGetDataAndIndices( COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols); template runtime::NDArray COOGetData(COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols); template COOMatrix COOTranspose(COOMatrix coo); template CSRMatrix COOToCSR(COOMatrix coo); template COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end); template COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows); template COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols); template std::pair COOCoalesce(COOMatrix coo); template COOMatrix DisjointUnionCoo(const std::vector& coos); template void COOSort_(COOMatrix* mat, bool sort_column); template std::pair COOIsSorted(COOMatrix coo); template COOMatrix COORemove(COOMatrix coo, IdArray entries); // FloatType is the type of probability data. template COOMatrix COORowWiseSampling( COOMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace); // FloatType is the type of probability data. template COOMatrix COORowWisePerEtypeSampling( COOMatrix mat, IdArray rows, IdArray etypes, const std::vector& num_samples, FloatArray prob, bool replace, bool etype_sorted); template COOMatrix COORowWiseSamplingUniform( COOMatrix mat, IdArray rows, int64_t num_samples, bool replace); template COOMatrix COORowWisePerEtypeSamplingUniform( COOMatrix mat, IdArray rows, IdArray etypes, const std::vector& num_samples, bool replace, bool etype_sorted); // FloatType is the type of weight data. template COOMatrix COORowWiseTopk( COOMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending); ///////////////////////// Graph Traverse routines ////////////////////////// template Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source); template Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source); template Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr); template Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source); template Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr, IdArray source, const bool has_reverse_edge, const bool has_nontree_edge, const bool return_labels); template COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking); } // namespace impl } // namespace aten } // namespace dgl #endif // DGL_ARRAY_ARRAY_OP_H_