Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
11c866ab
Unverified
Commit
11c866ab
authored
Nov 08, 2022
by
Hongzhi (Steve), Chen
Committed by
GitHub
Nov 08, 2022
Browse files
fix (#4841)
Co-authored-by:
Steve
<
ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal
>
parent
0d687968
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
41 additions
and
41 deletions
+41
-41
dgl_sparse/include/sparse/elementwise_op.h
dgl_sparse/include/sparse/elementwise_op.h
+6
-6
dgl_sparse/include/sparse/sparse_matrix.h
dgl_sparse/include/sparse/sparse_matrix.h
+29
-29
dgl_sparse/src/elemenwise_op.cc
dgl_sparse/src/elemenwise_op.cc
+1
-1
dgl_sparse/src/python_binding.cc
dgl_sparse/src/python_binding.cc
+1
-1
dgl_sparse/src/sparse_matrix.cc
dgl_sparse/src/sparse_matrix.cc
+1
-1
dgl_sparse/src/utils.h
dgl_sparse/src/utils.h
+3
-3
No files found.
dgl_sparse/include/sparse/elementwise_op.h
View file @
11c866ab
/*
!
/*
*
* Copyright (c) 2022 by Contributors
* Copyright (c) 2022 by Contributors
*
\
file sparse/elementwise_op.h
*
@
file sparse/elementwise_op.h
*
\
brief DGL C++ sparse elementwise operators
*
@
brief DGL C++ sparse elementwise operators
*/
*/
#ifndef SPARSE_ELEMENTWISE_OP_H_
#ifndef SPARSE_ELEMENTWISE_OP_H_
#define SPARSE_ELEMENTWISE_OP_H_
#define SPARSE_ELEMENTWISE_OP_H_
...
@@ -13,13 +13,13 @@ namespace dgl {
...
@@ -13,13 +13,13 @@ namespace dgl {
namespace
sparse
{
namespace
sparse
{
// TODO(zhenkun): support addition of matrices with different sparsity.
// TODO(zhenkun): support addition of matrices with different sparsity.
/*
!
/*
*
* @brief Adds two sparse matrices. Currently does not support two matrices with
* @brief Adds two sparse matrices. Currently does not support two matrices with
* different sparsity.
* different sparsity.
*
*
* @param A SparseMatrix
* @param A SparseMatrix
* @param B SparseMatrix
* @param B SparseMatrix
*
*
* @return SparseMatrix
* @return SparseMatrix
*/
*/
c10
::
intrusive_ptr
<
SparseMatrix
>
SpSpAdd
(
c10
::
intrusive_ptr
<
SparseMatrix
>
SpSpAdd
(
...
...
dgl_sparse/include/sparse/sparse_matrix.h
View file @
11c866ab
/*
!
/*
*
* Copyright (c) 2022 by Contributors
* Copyright (c) 2022 by Contributors
* @file sparse/sparse_matrix.h
* @file sparse/sparse_matrix.h
* @brief DGL C++ sparse matrix header
* @brief DGL C++ sparse matrix header
...
@@ -15,10 +15,10 @@
...
@@ -15,10 +15,10 @@
namespace
dgl
{
namespace
dgl
{
namespace
sparse
{
namespace
sparse
{
/*
!
@brief SparseFormat enumeration */
/*
*
@brief SparseFormat enumeration */
enum
SparseFormat
{
kCOO
,
kCSR
,
kCSC
};
enum
SparseFormat
{
kCOO
,
kCSR
,
kCSC
};
/*
!
@brief CSR sparse structure */
/*
*
@brief CSR sparse structure */
struct
CSR
{
struct
CSR
{
// CSR format index pointer array of the matrix
// CSR format index pointer array of the matrix
torch
::
Tensor
indptr
;
torch
::
Tensor
indptr
;
...
@@ -34,7 +34,7 @@ struct CSR {
...
@@ -34,7 +34,7 @@ struct CSR {
torch
::
optional
<
torch
::
Tensor
>
value_indices
;
torch
::
optional
<
torch
::
Tensor
>
value_indices
;
};
};
/*
!
@brief COO sparse structure */
/*
*
@brief COO sparse structure */
struct
COO
{
struct
COO
{
// COO format row array of the matrix
// COO format row array of the matrix
torch
::
Tensor
row
;
torch
::
Tensor
row
;
...
@@ -42,10 +42,10 @@ struct COO {
...
@@ -42,10 +42,10 @@ struct COO {
torch
::
Tensor
col
;
torch
::
Tensor
col
;
};
};
/*
!
@brief SparseMatrix bound to Python */
/*
*
@brief SparseMatrix bound to Python */
class
SparseMatrix
:
public
torch
::
CustomClassHolder
{
class
SparseMatrix
:
public
torch
::
CustomClassHolder
{
public:
public:
/*
!
/*
*
* @brief General constructor to construct a sparse matrix for different
* @brief General constructor to construct a sparse matrix for different
* sparse formats. At least one of the sparse formats should be provided,
* sparse formats. At least one of the sparse formats should be provided,
* while others could be nullptrs.
* while others could be nullptrs.
...
@@ -61,7 +61,7 @@ class SparseMatrix : public torch::CustomClassHolder {
...
@@ -61,7 +61,7 @@ class SparseMatrix : public torch::CustomClassHolder {
const
std
::
shared_ptr
<
CSR
>&
csc
,
torch
::
Tensor
value
,
const
std
::
shared_ptr
<
CSR
>&
csc
,
torch
::
Tensor
value
,
const
std
::
vector
<
int64_t
>&
shape
);
const
std
::
vector
<
int64_t
>&
shape
);
/*
!
/*
*
* @brief Construct a SparseMatrix from a COO format.
* @brief Construct a SparseMatrix from a COO format.
* @param coo The COO format
* @param coo The COO format
* @param value Values of the sparse matrix
* @param value Values of the sparse matrix
...
@@ -73,7 +73,7 @@ class SparseMatrix : public torch::CustomClassHolder {
...
@@ -73,7 +73,7 @@ class SparseMatrix : public torch::CustomClassHolder {
const
std
::
shared_ptr
<
COO
>&
coo
,
torch
::
Tensor
value
,
const
std
::
shared_ptr
<
COO
>&
coo
,
torch
::
Tensor
value
,
const
std
::
vector
<
int64_t
>&
shape
);
const
std
::
vector
<
int64_t
>&
shape
);
/*
!
/*
*
* @brief Construct a SparseMatrix from a CSR format.
* @brief Construct a SparseMatrix from a CSR format.
* @param csr The CSR format
* @param csr The CSR format
* @param value Values of the sparse matrix
* @param value Values of the sparse matrix
...
@@ -85,7 +85,7 @@ class SparseMatrix : public torch::CustomClassHolder {
...
@@ -85,7 +85,7 @@ class SparseMatrix : public torch::CustomClassHolder {
const
std
::
shared_ptr
<
CSR
>&
csr
,
torch
::
Tensor
value
,
const
std
::
shared_ptr
<
CSR
>&
csr
,
torch
::
Tensor
value
,
const
std
::
vector
<
int64_t
>&
shape
);
const
std
::
vector
<
int64_t
>&
shape
);
/*
!
/*
*
* @brief Construct a SparseMatrix from a CSC format.
* @brief Construct a SparseMatrix from a CSC format.
* @param csc The CSC format
* @param csc The CSC format
* @param value Values of the sparse matrix
* @param value Values of the sparse matrix
...
@@ -97,44 +97,44 @@ class SparseMatrix : public torch::CustomClassHolder {
...
@@ -97,44 +97,44 @@ class SparseMatrix : public torch::CustomClassHolder {
const
std
::
shared_ptr
<
CSR
>&
csc
,
torch
::
Tensor
value
,
const
std
::
shared_ptr
<
CSR
>&
csc
,
torch
::
Tensor
value
,
const
std
::
vector
<
int64_t
>&
shape
);
const
std
::
vector
<
int64_t
>&
shape
);
/*
!
@return Value of the sparse matrix. */
/*
*
@return Value of the sparse matrix. */
inline
torch
::
Tensor
value
()
const
{
return
value_
;
}
inline
torch
::
Tensor
value
()
const
{
return
value_
;
}
/*
!
@return Shape of the sparse matrix. */
/*
*
@return Shape of the sparse matrix. */
inline
const
std
::
vector
<
int64_t
>&
shape
()
const
{
return
shape_
;
}
inline
const
std
::
vector
<
int64_t
>&
shape
()
const
{
return
shape_
;
}
/*
!
@return Number of non-zero values */
/*
*
@return Number of non-zero values */
inline
int64_t
nnz
()
const
{
return
value_
.
size
(
0
);
}
inline
int64_t
nnz
()
const
{
return
value_
.
size
(
0
);
}
/*
!
@return Non-zero value data type */
/*
*
@return Non-zero value data type */
inline
caffe2
::
TypeMeta
dtype
()
const
{
return
value_
.
dtype
();
}
inline
caffe2
::
TypeMeta
dtype
()
const
{
return
value_
.
dtype
();
}
/*
!
@return Device of the sparse matrix */
/*
*
@return Device of the sparse matrix */
inline
torch
::
Device
device
()
const
{
return
value_
.
device
();
}
inline
torch
::
Device
device
()
const
{
return
value_
.
device
();
}
/*
!
@return COO of the sparse matrix. The COO is created if not exists. */
/*
*
@return COO of the sparse matrix. The COO is created if not exists. */
std
::
shared_ptr
<
COO
>
COOPtr
();
std
::
shared_ptr
<
COO
>
COOPtr
();
/*
!
@return CSR of the sparse matrix. The CSR is created if not exists. */
/*
*
@return CSR of the sparse matrix. The CSR is created if not exists. */
std
::
shared_ptr
<
CSR
>
CSRPtr
();
std
::
shared_ptr
<
CSR
>
CSRPtr
();
/*
!
@return CSC of the sparse matrix. The CSC is created if not exists. */
/*
*
@return CSC of the sparse matrix. The CSC is created if not exists. */
std
::
shared_ptr
<
CSR
>
CSCPtr
();
std
::
shared_ptr
<
CSR
>
CSCPtr
();
/*
!
@brief Check whether this sparse matrix has COO format. */
/*
*
@brief Check whether this sparse matrix has COO format. */
inline
bool
HasCOO
()
const
{
return
coo_
!=
nullptr
;
}
inline
bool
HasCOO
()
const
{
return
coo_
!=
nullptr
;
}
/*
!
@brief Check whether this sparse matrix has CSR format. */
/*
*
@brief Check whether this sparse matrix has CSR format. */
inline
bool
HasCSR
()
const
{
return
csr_
!=
nullptr
;
}
inline
bool
HasCSR
()
const
{
return
csr_
!=
nullptr
;
}
/*
!
@brief Check whether this sparse matrix has CSC format. */
/*
*
@brief Check whether this sparse matrix has CSC format. */
inline
bool
HasCSC
()
const
{
return
csc_
!=
nullptr
;
}
inline
bool
HasCSC
()
const
{
return
csc_
!=
nullptr
;
}
/*
!
@return {row, col, value} tensors in the COO format. */
/*
*
@return {row, col, value} tensors in the COO format. */
std
::
vector
<
torch
::
Tensor
>
COOTensors
();
std
::
vector
<
torch
::
Tensor
>
COOTensors
();
/*
!
@return {row, col, value} tensors in the CSR format. */
/*
*
@return {row, col, value} tensors in the CSR format. */
std
::
vector
<
torch
::
Tensor
>
CSRTensors
();
std
::
vector
<
torch
::
Tensor
>
CSRTensors
();
/*
!
@return {row, col, value} tensors in the CSC format. */
/*
*
@return {row, col, value} tensors in the CSC format. */
std
::
vector
<
torch
::
Tensor
>
CSCTensors
();
std
::
vector
<
torch
::
Tensor
>
CSCTensors
();
private:
private:
/*
!
@brief Create the COO format for the sparse matrix internally */
/*
*
@brief Create the COO format for the sparse matrix internally */
void
_CreateCOO
();
void
_CreateCOO
();
/*
!
@brief Create the CSR format for the sparse matrix internally */
/*
*
@brief Create the CSR format for the sparse matrix internally */
void
_CreateCSR
();
void
_CreateCSR
();
/*
!
@brief Create the CSC format for the sparse matrix internally */
/*
*
@brief Create the CSC format for the sparse matrix internally */
void
_CreateCSC
();
void
_CreateCSC
();
// COO/CSC/CSR pointers. Nullptr indicates non-existence.
// COO/CSC/CSR pointers. Nullptr indicates non-existence.
...
@@ -146,7 +146,7 @@ class SparseMatrix : public torch::CustomClassHolder {
...
@@ -146,7 +146,7 @@ class SparseMatrix : public torch::CustomClassHolder {
const
std
::
vector
<
int64_t
>
shape_
;
const
std
::
vector
<
int64_t
>
shape_
;
};
};
/*
!
/*
*
* @brief Create a SparseMatrix from tensors in COO format.
* @brief Create a SparseMatrix from tensors in COO format.
* @param row Row indices of the COO.
* @param row Row indices of the COO.
* @param col Column indices of the COO.
* @param col Column indices of the COO.
...
@@ -159,7 +159,7 @@ c10::intrusive_ptr<SparseMatrix> CreateFromCOO(
...
@@ -159,7 +159,7 @@ c10::intrusive_ptr<SparseMatrix> CreateFromCOO(
torch
::
Tensor
row
,
torch
::
Tensor
col
,
torch
::
Tensor
value
,
torch
::
Tensor
row
,
torch
::
Tensor
col
,
torch
::
Tensor
value
,
const
std
::
vector
<
int64_t
>&
shape
);
const
std
::
vector
<
int64_t
>&
shape
);
/*
!
/*
*
* @brief Create a SparseMatrix from tensors in CSR format.
* @brief Create a SparseMatrix from tensors in CSR format.
* @param indptr Index pointer array of the CSR
* @param indptr Index pointer array of the CSR
* @param indices Indices array of the CSR
* @param indices Indices array of the CSR
...
@@ -172,7 +172,7 @@ c10::intrusive_ptr<SparseMatrix> CreateFromCSR(
...
@@ -172,7 +172,7 @@ c10::intrusive_ptr<SparseMatrix> CreateFromCSR(
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
value
,
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
value
,
const
std
::
vector
<
int64_t
>&
shape
);
const
std
::
vector
<
int64_t
>&
shape
);
/*
!
/*
*
* @brief Create a SparseMatrix from tensors in CSC format.
* @brief Create a SparseMatrix from tensors in CSC format.
* @param indptr Index pointer array of the CSC
* @param indptr Index pointer array of the CSC
* @param indices Indices array of the CSC
* @param indices Indices array of the CSC
...
...
dgl_sparse/src/elemenwise_op.cc
View file @
11c866ab
/*
!
/*
*
* Copyright (c) 2022 by Contributors
* Copyright (c) 2022 by Contributors
* @file elementwise_op.cc
* @file elementwise_op.cc
* @brief DGL C++ sparse elementwise operator implementation
* @brief DGL C++ sparse elementwise operator implementation
...
...
dgl_sparse/src/python_binding.cc
View file @
11c866ab
/*
!
/*
*
* Copyright (c) 2022 by Contributors
* Copyright (c) 2022 by Contributors
* @file python_binding.cc
* @file python_binding.cc
* @brief DGL sparse library Python binding
* @brief DGL sparse library Python binding
...
...
dgl_sparse/src/sparse_matrix.cc
View file @
11c866ab
/*
!
/*
*
* Copyright (c) 2022 by Contributors
* Copyright (c) 2022 by Contributors
* @file sparse_matrix.cc
* @file sparse_matrix.cc
* @brief DGL C++ sparse matrix implementations
* @brief DGL C++ sparse matrix implementations
...
...
dgl_sparse/src/utils.h
View file @
11c866ab
/*
!
/*
*
* Copyright (c) 2022 by Contributors
* Copyright (c) 2022 by Contributors
* @file utils.h
* @file utils.h
* @brief DGL C++ sparse API utilities
* @brief DGL C++ sparse API utilities
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
namespace
dgl
{
namespace
dgl
{
namespace
sparse
{
namespace
sparse
{
/*
!
@brief Find a proper sparse format for two sparse matrices. It chooses
/*
*
@brief Find a proper sparse format for two sparse matrices. It chooses
* COO if anyone of the sparse matrices has COO format. If none of them has
* COO if anyone of the sparse matrices has COO format. If none of them has
* COO, it tries CSR and CSC in the same manner. */
* COO, it tries CSR and CSC in the same manner. */
inline
static
SparseFormat
FindAnyExistingFormat
(
inline
static
SparseFormat
FindAnyExistingFormat
(
...
@@ -29,7 +29,7 @@ inline static SparseFormat FindAnyExistingFormat(
...
@@ -29,7 +29,7 @@ inline static SparseFormat FindAnyExistingFormat(
return
fmt
;
return
fmt
;
}
}
/*
!
@brief Check whether two matrices has the same dtype and shape for
/*
*
@brief Check whether two matrices has the same dtype and shape for
* elementwise operators. */
* elementwise operators. */
inline
static
void
ElementwiseOpSanityCheck
(
inline
static
void
ElementwiseOpSanityCheck
(
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
A
,
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
A
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment