Unverified Commit cbc34705 authored by xiangyuzhi's avatar xiangyuzhi Committed by GitHub
Browse files

[Sparse] Compact C++ API (#6334)

parent d566ff8e
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#define SPARSE_MATRIX_OPS_H_ #define SPARSE_MATRIX_OPS_H_
#include <sparse/sparse_format.h> #include <sparse/sparse_format.h>
#include <sparse/sparse_matrix.h>
#include <tuple> #include <tuple>
...@@ -26,6 +27,28 @@ namespace sparse { ...@@ -26,6 +27,28 @@ namespace sparse {
std::tuple<std::shared_ptr<COO>, torch::Tensor, torch::Tensor> COOIntersection( std::tuple<std::shared_ptr<COO>, torch::Tensor, torch::Tensor> COOIntersection(
const std::shared_ptr<COO>& lhs, const std::shared_ptr<COO>& rhs); const std::shared_ptr<COO>& lhs, const std::shared_ptr<COO>& rhs);
/**
* @brief Compact sparse matrix by removing rows or columns without non-zero
* elements in the sparse matrix and relabeling indices of the dimension.
*
* This function serves a dual purpose: it allows you to reorganize the
* indices within a specific dimension (rows or columns) of the sparse matrix
* and, if needed, place certain 'leading_indices' at the beginning of the
* compact dimension.
*
* @param mat The sparse matrix to be compacted.
* @param dim The dimension to compact. Should be 0 or 1. Use 0 for row-wise
* compaction and 1 for column-wise compaction.
* @param leading_indices An optional tensor containing row or column ids that
* should be placed at the beginning of the compact dimension.
*
* @return A tuple containing the compacted sparse matrix and the index mapping
* of the compact dimension from the new index to the original index.
*/
std::tuple<c10::intrusive_ptr<SparseMatrix>, torch::Tensor> Compact(
const c10::intrusive_ptr<SparseMatrix>& mat, int64_t dim,
torch::Tensor leading_indices);
} // namespace sparse } // namespace sparse
} // namespace dgl } // namespace dgl
......
/**
* Copyright (c) 2023 by Contributors
* @file macro.h
* @brief DGL C++ sparse API macros.
*/
#ifndef DGL_SPARSE_MACRO_H_
#define DGL_SPARSE_MACRO_H_
namespace dgl {
namespace sparse {
/**
* Dispatch an operator to a templated implementation function
* according to its device:
*
* DGL_SPARSE_XPU_SWITCH(tensor.device().type(), XPU, {
* // Now XPU is a placeholder for tensor.device().type()
* DeviceSpecificImplementation<XPU>(...);
* });
*/
#define DGL_SPARSE_XPU_SWITCH(device, XPU, op, ...) \
do { \
if ((device) == c10::DeviceType::CPU) { \
constexpr auto XPU = c10::DeviceType::CPU; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Operator " << (op) << " does not support " \
<< c10::DeviceTypeName(device) << " device."; \
} \
} while (0)
/**
* Dispatch according to ID type (either int32 or int64):
*
* DGL_SPARSE_ID_TYPE_SWITCH(tensor.dtype(), IdType, {
* // Now IdType is the type corresponding to data type of the tensor.
* // For instance, one can do this for a CPU array:
* IdType *data = static_cast<IdType *>(array.data_ptr());
* });
*/
#define DGL_SPARSE_ID_TYPE_SWITCH(dtype, IdType, op, ...) \
do { \
if ((dtype) == torch::kInt32) { \
typedef int32_t IdType; \
{ __VA_ARGS__ } \
} else if ((dtype) == torch::kInt64) { \
typedef int64_t IdType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Operator " << (op) << " does not support " \
<< (dtype).name() << " as ID dtype."; \
} \
} while (0)
// Macro to dispatch according to device and index type.
#define DGL_SPARSE_COO_SWITCH(coo, XPU, IdType, op, ...) \
DGL_SPARSE_XPU_SWITCH(coo->indices.device().type(), XPU, op, { \
DGL_SPARSE_ID_TYPE_SWITCH( \
(coo)->indices.dtype(), IdType, op, {{__VA_ARGS__}}); \
});
} // namespace sparse
} // namespace dgl
#endif // DGL_SPARSE_MACRO_H_
...@@ -6,6 +6,9 @@ ...@@ -6,6 +6,9 @@
#include <sparse/matrix_ops.h> #include <sparse/matrix_ops.h>
#include <torch/script.h> #include <torch/script.h>
#include "./macro.h"
#include "./matrix_ops_impl.h"
namespace dgl { namespace dgl {
namespace sparse { namespace sparse {
...@@ -55,5 +58,13 @@ std::tuple<std::shared_ptr<COO>, torch::Tensor, torch::Tensor> COOIntersection( ...@@ -55,5 +58,13 @@ std::tuple<std::shared_ptr<COO>, torch::Tensor, torch::Tensor> COOIntersection(
return {ret_coo, lhs_indices, rhs_indices}; return {ret_coo, lhs_indices, rhs_indices};
} }
std::tuple<c10::intrusive_ptr<SparseMatrix>, torch::Tensor> Compact(
const c10::intrusive_ptr<SparseMatrix>& mat, uint64_t dim,
torch::Tensor leading_indices) {
DGL_SPARSE_COO_SWITCH(mat->COOPtr(), XPU, IdType, "Compact", {
return CompactImpl<XPU, IdType>(mat, dim, leading_indices);
});
}
} // namespace sparse } // namespace sparse
} // namespace dgl } // namespace dgl
...@@ -6,8 +6,22 @@ ...@@ -6,8 +6,22 @@
#ifndef DGL_SPARSE_MATRIX_OPS_IMPL_H_ #ifndef DGL_SPARSE_MATRIX_OPS_IMPL_H_
#define DGL_SPARSE_MATRIX_OPS_IMPL_H_ #define DGL_SPARSE_MATRIX_OPS_IMPL_H_
#include <sparse/sparse_format.h>
#include <tuple>
namespace dgl { namespace dgl {
namespace sparse {} namespace sparse {
template <c10::DeviceType XPU, typename IdType>
std::tuple<c10::intrusive_ptr<SparseMatrix>, torch::Tensor> CompactImpl(
const c10::intrusive_ptr<SparseMatrix>& mat, int64_t dim,
torch::Tensor leading_indices) {
// Place holder only.
return {mat, leading_indices};
}
} // namespace sparse
} // namespace dgl } // namespace dgl
#endif // DGL_SPARSE_MATRIX_OPS_IMPL_H_ #endif // DGL_SPARSE_MATRIX_OPS_IMPL_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment