/*! * Copyright (c) 2019 by Contributors * \file array/array.cc * \brief DGL array utilities implementation */ #include #include "../c_api_common.h" #include "./array_op.h" #include "./arith.h" namespace dgl { using runtime::NDArray; namespace aten { IdArray NewIdArray(int64_t length, DLContext ctx, uint8_t nbits) { return IdArray::Empty({length}, DLDataType{kDLInt, nbits, 1}, ctx); } IdArray Clone(IdArray arr) { IdArray ret = NewIdArray(arr->shape[0], arr->ctx, arr->dtype.bits); ret.CopyFrom(arr); return ret; } IdArray Range(int64_t low, int64_t high, uint8_t nbits, DLContext ctx) { IdArray ret; ATEN_XPU_SWITCH(ctx.device_type, XPU, { if (nbits == 32) { ret = impl::Range(low, high, ctx); } else if (nbits == 64) { ret = impl::Range(low, high, ctx); } else { LOG(FATAL) << "Only int32 or int64 is supported."; } }); return ret; } IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx) { IdArray ret; ATEN_XPU_SWITCH(ctx.device_type, XPU, { if (nbits == 32) { ret = impl::Full(val, length, ctx); } else if (nbits == 64) { ret = impl::Full(val, length, ctx); } else { LOG(FATAL) << "Only int32 or int64 is supported."; } }); return ret; } IdArray AsNumBits(IdArray arr, uint8_t bits) { IdArray ret; ATEN_XPU_SWITCH(arr->ctx.device_type, XPU, { ATEN_ID_TYPE_SWITCH(arr->dtype, IdType, { ret = impl::AsNumBits(arr, bits); }); }); return ret; } IdArray Add(IdArray lhs, IdArray rhs) { IdArray ret; CHECK_EQ(lhs->ctx, rhs->ctx) << "Both operands should have the same device context"; CHECK_EQ(lhs->dtype, rhs->dtype) << "Both operands should have the same dtype"; ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, { ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, { ret = impl::BinaryElewise(lhs, rhs); }); }); return ret; } IdArray Sub(IdArray lhs, IdArray rhs) { IdArray ret; CHECK_EQ(lhs->ctx, rhs->ctx) << "Both operands should have the same device context"; CHECK_EQ(lhs->dtype, rhs->dtype) << "Both operands should have the same dtype"; ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, { ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, { ret = impl::BinaryElewise(lhs, rhs); }); }); return ret; } IdArray Mul(IdArray lhs, IdArray rhs) { IdArray ret; CHECK_EQ(lhs->ctx, rhs->ctx) << "Both operands should have the same device context"; CHECK_EQ(lhs->dtype, rhs->dtype) << "Both operands should have the same dtype"; ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, { ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, { ret = impl::BinaryElewise(lhs, rhs); }); }); return ret; } IdArray Div(IdArray lhs, IdArray rhs) { IdArray ret; CHECK_EQ(lhs->ctx, rhs->ctx) << "Both operands should have the same device context"; CHECK_EQ(lhs->dtype, rhs->dtype) << "Both operands should have the same dtype"; ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, { ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, { ret = impl::BinaryElewise(lhs, rhs); }); }); return ret; } IdArray Add(IdArray lhs, dgl_id_t rhs) { IdArray ret; ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, { ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, { ret = impl::BinaryElewise(lhs, rhs); }); }); return ret; } IdArray Sub(IdArray lhs, dgl_id_t rhs) { IdArray ret; ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, { ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, { ret = impl::BinaryElewise(lhs, rhs); }); }); return ret; } IdArray Mul(IdArray lhs, dgl_id_t rhs) { IdArray ret; ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, { ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, { ret = impl::BinaryElewise(lhs, rhs); }); }); return ret; } IdArray Div(IdArray lhs, dgl_id_t rhs) { IdArray ret; ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, { ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, { ret = impl::BinaryElewise(lhs, rhs); }); }); return ret; } IdArray Add(dgl_id_t lhs, IdArray rhs) { return Add(rhs, lhs); } IdArray Sub(dgl_id_t lhs, IdArray rhs) { IdArray ret; ATEN_XPU_SWITCH(rhs->ctx.device_type, XPU, { ATEN_ID_TYPE_SWITCH(rhs->dtype, IdType, { ret = impl::BinaryElewise(lhs, rhs); }); }); return ret; } IdArray Mul(dgl_id_t lhs, IdArray rhs) { return Mul(rhs, lhs); } IdArray Div(dgl_id_t lhs, IdArray rhs) { IdArray ret; ATEN_XPU_SWITCH(rhs->ctx.device_type, XPU, { ATEN_ID_TYPE_SWITCH(rhs->dtype, IdType, { ret = impl::BinaryElewise(lhs, rhs); }); }); return ret; } BoolArray LT(IdArray lhs, dgl_id_t rhs) { BoolArray ret; ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, { ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, { ret = impl::BinaryElewise(lhs, rhs); }); }); return ret; } IdArray HStack(IdArray lhs, IdArray rhs) { IdArray ret; CHECK_EQ(lhs->ctx, rhs->ctx) << "Both operands should have the same device context"; CHECK_EQ(lhs->dtype, rhs->dtype) << "Both operands should have the same dtype"; ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, { ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, { ret = impl::HStack(lhs, rhs); }); }); return ret; } NDArray IndexSelect(NDArray array, IdArray index) { NDArray ret; // TODO(BarclayII): check if array and index match in context ATEN_XPU_SWITCH(array->ctx.device_type, XPU, { ATEN_DTYPE_SWITCH(array->dtype, DType, "values", { ATEN_ID_TYPE_SWITCH(index->dtype, IdType, { ret = impl::IndexSelect(array, index); }); }); }); return ret; } template ValueType IndexSelect(NDArray array, uint64_t index) { ValueType ret = 0; ATEN_XPU_SWITCH(array->ctx.device_type, XPU, { ATEN_DTYPE_SWITCH(array->dtype, DType, "values", { ret = impl::IndexSelect(array, index); }); }); return ret; } template int32_t IndexSelect(NDArray array, uint64_t index); template int64_t IndexSelect(NDArray array, uint64_t index); template uint32_t IndexSelect(NDArray array, uint64_t index); template uint64_t IndexSelect(NDArray array, uint64_t index); template float IndexSelect(NDArray array, uint64_t index); template double IndexSelect(NDArray array, uint64_t index); NDArray Scatter(NDArray array, IdArray indices) { NDArray ret; ATEN_XPU_SWITCH(array->ctx.device_type, XPU, { ATEN_DTYPE_SWITCH(array->dtype, DType, "values", { ATEN_ID_TYPE_SWITCH(indices->dtype, IdType, { ret = impl::Scatter(array, indices); }); }); }); return ret; } NDArray Repeat(NDArray array, IdArray repeats) { NDArray ret; ATEN_XPU_SWITCH(array->ctx.device_type, XPU, { ATEN_DTYPE_SWITCH(array->dtype, DType, "values", { ATEN_ID_TYPE_SWITCH(repeats->dtype, IdType, { ret = impl::Repeat(array, repeats); }); }); }); return ret; } IdArray Relabel_(const std::vector& arrays) { IdArray ret; ATEN_XPU_SWITCH(arrays[0]->ctx.device_type, XPU, { ATEN_ID_TYPE_SWITCH(arrays[0]->dtype, IdType, { ret = impl::Relabel_(arrays); }); }); return ret; } template std::tuple Pack(NDArray array, ValueType pad_value) { std::tuple ret; ATEN_XPU_SWITCH(array->ctx.device_type, XPU, { ATEN_DTYPE_SWITCH(array->dtype, DType, "array", { ret = impl::Pack(array, static_cast(pad_value)); }); }); return ret; } template std::tuple Pack(NDArray, int32_t); template std::tuple Pack(NDArray, int64_t); template std::tuple Pack(NDArray, uint32_t); template std::tuple Pack(NDArray, uint64_t); template std::tuple Pack(NDArray, float); template std::tuple Pack(NDArray, double); std::pair ConcatSlices(NDArray array, IdArray lengths) { std::pair ret; ATEN_XPU_SWITCH(array->ctx.device_type, XPU, { ATEN_DTYPE_SWITCH(array->dtype, DType, "array", { ATEN_ID_TYPE_SWITCH(lengths->dtype, IdType, { ret = impl::ConcatSlices(array, lengths); }); }); }); return ret; } ///////////////////////// CSR routines ////////////////////////// bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) { bool ret = false; ATEN_CSR_SWITCH(csr, XPU, IdType, { ret = impl::CSRIsNonZero(csr, row, col); }); return ret; } NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) { NDArray ret; ATEN_CSR_SWITCH(csr, XPU, IdType, { ret = impl::CSRIsNonZero(csr, row, col); }); return ret; } bool CSRHasDuplicate(CSRMatrix csr) { bool ret = false; ATEN_CSR_SWITCH(csr, XPU, IdType, { ret = impl::CSRHasDuplicate(csr); }); return ret; } int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) { int64_t ret = 0; ATEN_CSR_SWITCH(csr, XPU, IdType, { ret = impl::CSRGetRowNNZ(csr, row); }); return ret; } NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray row) { NDArray ret; ATEN_CSR_SWITCH(csr, XPU, IdType, { ret = impl::CSRGetRowNNZ(csr, row); }); return ret; } NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) { NDArray ret; ATEN_CSR_SWITCH(csr, XPU, IdType, { ret = impl::CSRGetRowColumnIndices(csr, row); }); return ret; } NDArray CSRGetRowData(CSRMatrix csr, int64_t row) { NDArray ret; ATEN_CSR_SWITCH(csr, XPU, IdType, { ret = impl::CSRGetRowData(csr, row); }); return ret; } NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col) { NDArray ret; ATEN_CSR_SWITCH(csr, XPU, IdType, { ret = impl::CSRGetData(csr, row, col); }); return ret; } NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) { NDArray ret; ATEN_CSR_SWITCH(csr, XPU, IdType, { ret = impl::CSRGetData(csr, rows, cols); }); return ret; } std::vector CSRGetDataAndIndices( CSRMatrix csr, NDArray rows, NDArray cols) { std::vector ret; ATEN_CSR_SWITCH(csr, XPU, IdType, { ret = impl::CSRGetDataAndIndices(csr, rows, cols); }); return ret; } CSRMatrix CSRTranspose(CSRMatrix csr) { CSRMatrix ret; ATEN_CSR_SWITCH(csr, XPU, IdType, { ret = impl::CSRTranspose(csr); }); return ret; } COOMatrix CSRToCOO(CSRMatrix csr, bool data_as_order) { COOMatrix ret; if (data_as_order) { ATEN_XPU_SWITCH(csr.indptr->ctx.device_type, XPU, { ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, { ret = impl::CSRToCOODataAsOrder(csr); }); }); } else { ATEN_XPU_SWITCH(csr.indptr->ctx.device_type, XPU, { ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, { ret = impl::CSRToCOO(csr); }); }); } return ret; } CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) { CSRMatrix ret; ATEN_CSR_SWITCH(csr, XPU, IdType, { ret = impl::CSRSliceRows(csr, start, end); }); return ret; } CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) { CSRMatrix ret; ATEN_CSR_SWITCH(csr, XPU, IdType, { ret = impl::CSRSliceRows(csr, rows); }); return ret; } CSRMatrix CSRSliceMatrix(CSRMatrix csr, NDArray rows, NDArray cols) { CSRMatrix ret; ATEN_CSR_SWITCH(csr, XPU, IdType, { ret = impl::CSRSliceMatrix(csr, rows, cols); }); return ret; } void CSRSort_(CSRMatrix* csr) { ATEN_CSR_SWITCH(*csr, XPU, IdType, { impl::CSRSort_(csr); }); } COOMatrix CSRRowWiseSampling( CSRMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace) { COOMatrix ret; ATEN_CSR_SWITCH(mat, XPU, IdType, { if (!prob.defined() || prob->shape[0] == 0) { ret = impl::CSRRowWiseSamplingUniform(mat, rows, num_samples, replace); } else { ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", { ret = impl::CSRRowWiseSampling( mat, rows, num_samples, prob, replace); }); } }); return ret; } COOMatrix CSRRowWiseTopk( CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) { COOMatrix ret; ATEN_CSR_SWITCH(mat, XPU, IdType, { ATEN_DTYPE_SWITCH(weight->dtype, DType, "weight", { ret = impl::CSRRowWiseTopk( mat, rows, k, weight, ascending); }); }); return ret; } ///////////////////////// COO routines ////////////////////////// bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col) { bool ret = false; ATEN_COO_SWITCH(coo, XPU, IdType, { ret = impl::COOIsNonZero(coo, row, col); }); return ret; } NDArray COOIsNonZero(COOMatrix coo, NDArray row, NDArray col) { NDArray ret; ATEN_COO_SWITCH(coo, XPU, IdType, { ret = impl::COOIsNonZero(coo, row, col); }); return ret; } bool COOHasDuplicate(COOMatrix coo) { bool ret = false; ATEN_COO_SWITCH(coo, XPU, IdType, { ret = impl::COOHasDuplicate(coo); }); return ret; } int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) { int64_t ret = 0; ATEN_COO_SWITCH(coo, XPU, IdType, { ret = impl::COOGetRowNNZ(coo, row); }); return ret; } NDArray COOGetRowNNZ(COOMatrix coo, NDArray row) { NDArray ret; ATEN_COO_SWITCH(coo, XPU, IdType, { ret = impl::COOGetRowNNZ(coo, row); }); return ret; } std::pair COOGetRowDataAndIndices(COOMatrix coo, int64_t row) { std::pair ret; ATEN_COO_SWITCH(coo, XPU, IdType, { ret = impl::COOGetRowDataAndIndices(coo, row); }); return ret; } NDArray COOGetData(COOMatrix coo, int64_t row, int64_t col) { NDArray ret; ATEN_COO_SWITCH(coo, XPU, IdType, { ret = impl::COOGetData(coo, row, col); }); return ret; } std::vector COOGetDataAndIndices( COOMatrix coo, NDArray rows, NDArray cols) { std::vector ret; ATEN_COO_SWITCH(coo, XPU, IdType, { ret = impl::COOGetDataAndIndices(coo, rows, cols); }); return ret; } COOMatrix COOTranspose(COOMatrix coo) { COOMatrix ret; ATEN_COO_SWITCH(coo, XPU, IdType, { ret = impl::COOTranspose(coo); }); return ret; } CSRMatrix COOToCSR(COOMatrix coo) { CSRMatrix ret; ATEN_COO_SWITCH(coo, XPU, IdType, { ret = impl::COOToCSR(coo); }); return ret; } COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) { COOMatrix ret; ATEN_COO_SWITCH(coo, XPU, IdType, { ret = impl::COOSliceRows(coo, start, end); }); return ret; } COOMatrix COOSliceRows(COOMatrix coo, NDArray rows) { COOMatrix ret; ATEN_COO_SWITCH(coo, XPU, IdType, { ret = impl::COOSliceRows(coo, rows); }); return ret; } COOMatrix COOSliceMatrix(COOMatrix coo, NDArray rows, NDArray cols) { COOMatrix ret; ATEN_COO_SWITCH(coo, XPU, IdType, { ret = impl::COOSliceMatrix(coo, rows, cols); }); return ret; } COOMatrix COOSort(COOMatrix mat, bool sort_column) { COOMatrix ret; ATEN_COO_SWITCH(mat, XPU, IdType, { ret = impl::COOSort(mat, sort_column); }); return ret; } COOMatrix COORowWiseSampling( COOMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace) { COOMatrix ret; ATEN_COO_SWITCH(mat, XPU, IdType, { if (!prob.defined() || prob->shape[0] == 0) { ret = impl::COORowWiseSamplingUniform(mat, rows, num_samples, replace); } else { ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", { ret = impl::COORowWiseSampling( mat, rows, num_samples, prob, replace); }); } }); return ret; } COOMatrix COORowWiseTopk( COOMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending) { COOMatrix ret; ATEN_COO_SWITCH(mat, XPU, IdType, { ATEN_FLOAT_TYPE_SWITCH(weight->dtype, FloatType, "weight", { ret = impl::COORowWiseTopk( mat, rows, k, weight, ascending); }); }); return ret; } std::pair COOCoalesce(COOMatrix coo) { std::pair ret; ATEN_COO_SWITCH(coo, XPU, IdType, { ret = impl::COOCoalesce(coo); }); return ret; } } // namespace aten } // namespace dgl