Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
cbc34705
Unverified
Commit
cbc34705
authored
Sep 18, 2023
by
xiangyuzhi
Committed by
GitHub
Sep 18, 2023
Browse files
[Sparse] Compact C++ API (#6334)
parent
d566ff8e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
114 additions
and
1 deletion
+114
-1
dgl_sparse/include/sparse/matrix_ops.h
dgl_sparse/include/sparse/matrix_ops.h
+23
-0
dgl_sparse/src/macro.h
dgl_sparse/src/macro.h
+65
-0
dgl_sparse/src/matrix_ops.cc
dgl_sparse/src/matrix_ops.cc
+11
-0
dgl_sparse/src/matrix_ops_impl.h
dgl_sparse/src/matrix_ops_impl.h
+15
-1
No files found.
dgl_sparse/include/sparse/matrix_ops.h
View file @
cbc34705
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#define SPARSE_MATRIX_OPS_H_
#define SPARSE_MATRIX_OPS_H_
#include <sparse/sparse_format.h>
#include <sparse/sparse_format.h>
#include <sparse/sparse_matrix.h>
#include <tuple>
#include <tuple>
...
@@ -26,6 +27,28 @@ namespace sparse {
...
@@ -26,6 +27,28 @@ namespace sparse {
std
::
tuple
<
std
::
shared_ptr
<
COO
>
,
torch
::
Tensor
,
torch
::
Tensor
>
COOIntersection
(
std
::
tuple
<
std
::
shared_ptr
<
COO
>
,
torch
::
Tensor
,
torch
::
Tensor
>
COOIntersection
(
const
std
::
shared_ptr
<
COO
>&
lhs
,
const
std
::
shared_ptr
<
COO
>&
rhs
);
const
std
::
shared_ptr
<
COO
>&
lhs
,
const
std
::
shared_ptr
<
COO
>&
rhs
);
/**
* @brief Compact sparse matrix by removing rows or columns without non-zero
* elements in the sparse matrix and relabeling indices of the dimension.
*
* This function serves a dual purpose: it allows you to reorganize the
* indices within a specific dimension (rows or columns) of the sparse matrix
* and, if needed, place certain 'leading_indices' at the beginning of the
* compact dimension.
*
* @param mat The sparse matrix to be compacted.
* @param dim The dimension to compact. Should be 0 or 1. Use 0 for row-wise
* compaction and 1 for column-wise compaction.
* @param leading_indices An optional tensor containing row or column ids that
* should be placed at the beginning of the compact dimension.
*
* @return A tuple containing the compacted sparse matrix and the index mapping
* of the compact dimension from the new index to the original index.
*/
std
::
tuple
<
c10
::
intrusive_ptr
<
SparseMatrix
>
,
torch
::
Tensor
>
Compact
(
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
mat
,
int64_t
dim
,
torch
::
Tensor
leading_indices
);
}
// namespace sparse
}
// namespace sparse
}
// namespace dgl
}
// namespace dgl
...
...
dgl_sparse/src/macro.h
0 → 100644
View file @
cbc34705
/**
* Copyright (c) 2023 by Contributors
* @file macro.h
* @brief DGL C++ sparse API macros.
*/
#ifndef DGL_SPARSE_MACRO_H_
#define DGL_SPARSE_MACRO_H_
namespace
dgl
{
namespace
sparse
{
/**
* Dispatch an operator to a templated implementation function
* according to its device:
*
* DGL_SPARSE_XPU_SWITCH(tensor.device().type(), XPU, {
* // Now XPU is a placeholder for tensor.device().type()
* DeviceSpecificImplementation<XPU>(...);
* });
*/
#define DGL_SPARSE_XPU_SWITCH(device, XPU, op, ...) \
do { \
if ((device) == c10::DeviceType::CPU) { \
constexpr auto XPU = c10::DeviceType::CPU; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Operator " << (op) << " does not support " \
<< c10::DeviceTypeName(device) << " device."; \
} \
} while (0)
/**
* Dispatch according to ID type (either int32 or int64):
*
* DGL_SPARSE_ID_TYPE_SWITCH(tensor.dtype(), IdType, {
* // Now IdType is the type corresponding to data type of the tensor.
* // For instance, one can do this for a CPU array:
* IdType *data = static_cast<IdType *>(array.data_ptr());
* });
*/
#define DGL_SPARSE_ID_TYPE_SWITCH(dtype, IdType, op, ...) \
do { \
if ((dtype) == torch::kInt32) { \
typedef int32_t IdType; \
{ __VA_ARGS__ } \
} else if ((dtype) == torch::kInt64) { \
typedef int64_t IdType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Operator " << (op) << " does not support " \
<< (dtype).name() << " as ID dtype."; \
} \
} while (0)
// Macro to dispatch according to device and index type.
#define DGL_SPARSE_COO_SWITCH(coo, XPU, IdType, op, ...) \
DGL_SPARSE_XPU_SWITCH(coo->indices.device().type(), XPU, op, { \
DGL_SPARSE_ID_TYPE_SWITCH( \
(coo)->indices.dtype(), IdType, op, {{__VA_ARGS__}}); \
});
}
// namespace sparse
}
// namespace dgl
#endif // DGL_SPARSE_MACRO_H_
dgl_sparse/src/matrix_ops.cc
View file @
cbc34705
...
@@ -6,6 +6,9 @@
...
@@ -6,6 +6,9 @@
#include <sparse/matrix_ops.h>
#include <sparse/matrix_ops.h>
#include <torch/script.h>
#include <torch/script.h>
#include "./macro.h"
#include "./matrix_ops_impl.h"
namespace
dgl
{
namespace
dgl
{
namespace
sparse
{
namespace
sparse
{
...
@@ -55,5 +58,13 @@ std::tuple<std::shared_ptr<COO>, torch::Tensor, torch::Tensor> COOIntersection(
...
@@ -55,5 +58,13 @@ std::tuple<std::shared_ptr<COO>, torch::Tensor, torch::Tensor> COOIntersection(
return
{
ret_coo
,
lhs_indices
,
rhs_indices
};
return
{
ret_coo
,
lhs_indices
,
rhs_indices
};
}
}
std
::
tuple
<
c10
::
intrusive_ptr
<
SparseMatrix
>
,
torch
::
Tensor
>
Compact
(
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
mat
,
uint64_t
dim
,
torch
::
Tensor
leading_indices
)
{
DGL_SPARSE_COO_SWITCH
(
mat
->
COOPtr
(),
XPU
,
IdType
,
"Compact"
,
{
return
CompactImpl
<
XPU
,
IdType
>
(
mat
,
dim
,
leading_indices
);
});
}
}
// namespace sparse
}
// namespace sparse
}
// namespace dgl
}
// namespace dgl
dgl_sparse/src/matrix_ops_impl.h
View file @
cbc34705
...
@@ -6,8 +6,22 @@
...
@@ -6,8 +6,22 @@
#ifndef DGL_SPARSE_MATRIX_OPS_IMPL_H_
#ifndef DGL_SPARSE_MATRIX_OPS_IMPL_H_
#define DGL_SPARSE_MATRIX_OPS_IMPL_H_
#define DGL_SPARSE_MATRIX_OPS_IMPL_H_
#include <sparse/sparse_format.h>
#include <tuple>
namespace
dgl
{
namespace
dgl
{
namespace
sparse
{}
namespace
sparse
{
template
<
c10
::
DeviceType
XPU
,
typename
IdType
>
std
::
tuple
<
c10
::
intrusive_ptr
<
SparseMatrix
>
,
torch
::
Tensor
>
CompactImpl
(
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
mat
,
int64_t
dim
,
torch
::
Tensor
leading_indices
)
{
// Place holder only.
return
{
mat
,
leading_indices
};
}
}
// namespace sparse
}
// namespace dgl
}
// namespace dgl
#endif // DGL_SPARSE_MATRIX_OPS_IMPL_H_
#endif // DGL_SPARSE_MATRIX_OPS_IMPL_H_
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment