Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
a03dec05
Unverified
Commit
a03dec05
authored
Mar 09, 2023
by
czkkkkkk
Committed by
GitHub
Mar 09, 2023
Browse files
[Sparse] Support Diag sparse format in C++ (#5432)
* [Sparse] Support Diag sparse format in C++ * update * Update
parent
b7ce4b6a
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
192 additions
and
15 deletions
+192
-15
dgl_sparse/include/sparse/sparse_format.h
dgl_sparse/include/sparse/sparse_format.h
+21
-1
dgl_sparse/include/sparse/sparse_matrix.h
dgl_sparse/include/sparse/sparse_matrix.h
+33
-3
dgl_sparse/src/elemenwise_op.cc
dgl_sparse/src/elemenwise_op.cc
+4
-0
dgl_sparse/src/python_binding.cc
dgl_sparse/src/python_binding.cc
+1
-0
dgl_sparse/src/sparse_format.cc
dgl_sparse/src/sparse_format.cc
+33
-0
dgl_sparse/src/sparse_matrix.cc
dgl_sparse/src/sparse_matrix.cc
+61
-11
dgl_sparse/src/spspmm.cc
dgl_sparse/src/spspmm.cc
+39
-0
No files found.
dgl_sparse/include/sparse/sparse_format.h
View file @
a03dec05
...
@@ -19,7 +19,7 @@ namespace dgl {
...
@@ -19,7 +19,7 @@ namespace dgl {
namespace
sparse
{
namespace
sparse
{
/** @brief SparseFormat enumeration. */
/** @brief SparseFormat enumeration. */
enum
SparseFormat
{
kCOO
,
kCSR
,
kCSC
};
enum
SparseFormat
{
kCOO
,
kCSR
,
kCSC
,
kDiag
};
/** @brief COO sparse structure. */
/** @brief COO sparse structure. */
struct
COO
{
struct
COO
{
...
@@ -50,6 +50,11 @@ struct CSR {
...
@@ -50,6 +50,11 @@ struct CSR {
bool
sorted
=
false
;
bool
sorted
=
false
;
};
};
struct
Diag
{
/** @brief The dense shape of the matrix. */
int64_t
num_rows
=
0
,
num_cols
=
0
;
};
/** @brief Convert an old DGL COO format to a COO in the sparse library. */
/** @brief Convert an old DGL COO format to a COO in the sparse library. */
std
::
shared_ptr
<
COO
>
COOFromOldDGLCOO
(
const
aten
::
COOMatrix
&
dgl_coo
);
std
::
shared_ptr
<
COO
>
COOFromOldDGLCOO
(
const
aten
::
COOMatrix
&
dgl_coo
);
...
@@ -90,6 +95,21 @@ std::shared_ptr<CSR> COOToCSC(const std::shared_ptr<COO>& coo);
...
@@ -90,6 +95,21 @@ std::shared_ptr<CSR> COOToCSC(const std::shared_ptr<COO>& coo);
/** @brief Convert a CSR format to CSC format. */
/** @brief Convert a CSR format to CSC format. */
std
::
shared_ptr
<
CSR
>
CSRToCSC
(
const
std
::
shared_ptr
<
CSR
>&
csr
);
std
::
shared_ptr
<
CSR
>
CSRToCSC
(
const
std
::
shared_ptr
<
CSR
>&
csr
);
/** @brief Convert a Diag format to COO format. */
std
::
shared_ptr
<
COO
>
DiagToCOO
(
const
std
::
shared_ptr
<
Diag
>&
diag
,
const
c10
::
TensorOptions
&
indices_options
);
/** @brief Convert a Diag format to CSR format. */
std
::
shared_ptr
<
CSR
>
DiagToCSR
(
const
std
::
shared_ptr
<
Diag
>&
diag
,
const
c10
::
TensorOptions
&
indices_options
);
/** @brief Convert a Diag format to CSC format. */
std
::
shared_ptr
<
CSR
>
DiagToCSC
(
const
std
::
shared_ptr
<
Diag
>&
diag
,
const
c10
::
TensorOptions
&
indices_options
);
/** @brief COO transposition. */
/** @brief COO transposition. */
std
::
shared_ptr
<
COO
>
COOTranspose
(
const
std
::
shared_ptr
<
COO
>&
coo
);
std
::
shared_ptr
<
COO
>
COOTranspose
(
const
std
::
shared_ptr
<
COO
>&
coo
);
...
...
dgl_sparse/include/sparse/sparse_matrix.h
View file @
a03dec05
...
@@ -38,8 +38,8 @@ class SparseMatrix : public torch::CustomClassHolder {
...
@@ -38,8 +38,8 @@ class SparseMatrix : public torch::CustomClassHolder {
*/
*/
SparseMatrix
(
SparseMatrix
(
const
std
::
shared_ptr
<
COO
>&
coo
,
const
std
::
shared_ptr
<
CSR
>&
csr
,
const
std
::
shared_ptr
<
COO
>&
coo
,
const
std
::
shared_ptr
<
CSR
>&
csr
,
const
std
::
shared_ptr
<
CSR
>&
csc
,
torch
::
Tensor
value
,
const
std
::
shared_ptr
<
CSR
>&
csc
,
const
std
::
shared_ptr
<
Diag
>&
diag
,
const
std
::
vector
<
int64_t
>&
shape
);
torch
::
Tensor
value
,
const
std
::
vector
<
int64_t
>&
shape
);
/**
/**
* @brief Construct a SparseMatrix from a COO format.
* @brief Construct a SparseMatrix from a COO format.
...
@@ -77,6 +77,18 @@ class SparseMatrix : public torch::CustomClassHolder {
...
@@ -77,6 +77,18 @@ class SparseMatrix : public torch::CustomClassHolder {
const
std
::
shared_ptr
<
CSR
>&
csc
,
torch
::
Tensor
value
,
const
std
::
shared_ptr
<
CSR
>&
csc
,
torch
::
Tensor
value
,
const
std
::
vector
<
int64_t
>&
shape
);
const
std
::
vector
<
int64_t
>&
shape
);
/**
* @brief Construct a SparseMatrix from a Diag format.
* @param diag The Diag format
* @param value Values of the sparse matrix
* @param shape Shape of the sparse matrix
*
* @return SparseMatrix
*/
static
c10
::
intrusive_ptr
<
SparseMatrix
>
FromDiagPointer
(
const
std
::
shared_ptr
<
Diag
>&
diag
,
torch
::
Tensor
value
,
const
std
::
vector
<
int64_t
>&
shape
);
/**
/**
* @brief Create a SparseMatrix from tensors in COO format.
* @brief Create a SparseMatrix from tensors in COO format.
* @param indices COO coordinates with shape (2, nnz).
* @param indices COO coordinates with shape (2, nnz).
...
@@ -115,6 +127,16 @@ class SparseMatrix : public torch::CustomClassHolder {
...
@@ -115,6 +127,16 @@ class SparseMatrix : public torch::CustomClassHolder {
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
value
,
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
value
,
const
std
::
vector
<
int64_t
>&
shape
);
const
std
::
vector
<
int64_t
>&
shape
);
/**
* @brief Create a SparseMatrix with Diag format.
* @param value Values of the sparse matrix
* @param shape Shape of the sparse matrix
*
* @return SparseMatrix
*/
static
c10
::
intrusive_ptr
<
SparseMatrix
>
FromDiag
(
torch
::
Tensor
value
,
const
std
::
vector
<
int64_t
>&
shape
);
/**
/**
* @brief Create a SparseMatrix from a SparseMatrix using new values.
* @brief Create a SparseMatrix from a SparseMatrix using new values.
* @param mat An existing sparse matrix
* @param mat An existing sparse matrix
...
@@ -142,6 +164,11 @@ class SparseMatrix : public torch::CustomClassHolder {
...
@@ -142,6 +164,11 @@ class SparseMatrix : public torch::CustomClassHolder {
std
::
shared_ptr
<
CSR
>
CSRPtr
();
std
::
shared_ptr
<
CSR
>
CSRPtr
();
/** @return CSC of the sparse matrix. The CSC is created if not exists. */
/** @return CSC of the sparse matrix. The CSC is created if not exists. */
std
::
shared_ptr
<
CSR
>
CSCPtr
();
std
::
shared_ptr
<
CSR
>
CSCPtr
();
/**
* @return Diagonal format of the sparse matrix. An error will be raised if
* it does not have a diagonal format.
*/
std
::
shared_ptr
<
Diag
>
DiagPtr
();
/** @brief Check whether this sparse matrix has COO format. */
/** @brief Check whether this sparse matrix has COO format. */
inline
bool
HasCOO
()
const
{
return
coo_
!=
nullptr
;
}
inline
bool
HasCOO
()
const
{
return
coo_
!=
nullptr
;
}
...
@@ -149,6 +176,8 @@ class SparseMatrix : public torch::CustomClassHolder {
...
@@ -149,6 +176,8 @@ class SparseMatrix : public torch::CustomClassHolder {
inline
bool
HasCSR
()
const
{
return
csr_
!=
nullptr
;
}
inline
bool
HasCSR
()
const
{
return
csr_
!=
nullptr
;
}
/** @brief Check whether this sparse matrix has CSC format. */
/** @brief Check whether this sparse matrix has CSC format. */
inline
bool
HasCSC
()
const
{
return
csc_
!=
nullptr
;
}
inline
bool
HasCSC
()
const
{
return
csc_
!=
nullptr
;
}
/** @brief Check whether this sparse matrix has Diag format. */
inline
bool
HasDiag
()
const
{
return
diag_
!=
nullptr
;
}
/** @return {row, col} tensors in the COO format. */
/** @return {row, col} tensors in the COO format. */
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
COOTensors
();
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
COOTensors
();
...
@@ -191,9 +220,10 @@ class SparseMatrix : public torch::CustomClassHolder {
...
@@ -191,9 +220,10 @@ class SparseMatrix : public torch::CustomClassHolder {
/** @brief Create the CSC format for the sparse matrix internally */
/** @brief Create the CSC format for the sparse matrix internally */
void
_CreateCSC
();
void
_CreateCSC
();
// COO/CSC/CSR pointers. Nullptr indicates non-existence.
// COO/CSC/CSR
/Diag
pointers. Nullptr indicates non-existence.
std
::
shared_ptr
<
COO
>
coo_
;
std
::
shared_ptr
<
COO
>
coo_
;
std
::
shared_ptr
<
CSR
>
csr_
,
csc_
;
std
::
shared_ptr
<
CSR
>
csr_
,
csc_
;
std
::
shared_ptr
<
Diag
>
diag_
;
// Value of the SparseMatrix
// Value of the SparseMatrix
torch
::
Tensor
value_
;
torch
::
Tensor
value_
;
// Shape of the SparseMatrix
// Shape of the SparseMatrix
...
...
dgl_sparse/src/elemenwise_op.cc
View file @
a03dec05
...
@@ -22,6 +22,10 @@ c10::intrusive_ptr<SparseMatrix> SpSpAdd(
...
@@ -22,6 +22,10 @@ c10::intrusive_ptr<SparseMatrix> SpSpAdd(
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
A
,
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
A
,
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
B
)
{
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
B
)
{
ElementwiseOpSanityCheck
(
A
,
B
);
ElementwiseOpSanityCheck
(
A
,
B
);
if
(
A
->
HasDiag
()
&&
B
->
HasDiag
())
{
return
SparseMatrix
::
FromDiagPointer
(
A
->
DiagPtr
(),
A
->
value
()
+
B
->
value
(),
A
->
shape
());
}
auto
torch_A
=
COOToTorchCOO
(
A
->
COOPtr
(),
A
->
value
());
auto
torch_A
=
COOToTorchCOO
(
A
->
COOPtr
(),
A
->
value
());
auto
torch_B
=
COOToTorchCOO
(
B
->
COOPtr
(),
B
->
value
());
auto
torch_B
=
COOToTorchCOO
(
B
->
COOPtr
(),
B
->
value
());
auto
sum
=
(
torch_A
+
torch_B
).
coalesce
();
auto
sum
=
(
torch_A
+
torch_B
).
coalesce
();
...
...
dgl_sparse/src/python_binding.cc
View file @
a03dec05
...
@@ -36,6 +36,7 @@ TORCH_LIBRARY(dgl_sparse, m) {
...
@@ -36,6 +36,7 @@ TORCH_LIBRARY(dgl_sparse, m) {
m
.
def
(
"from_coo"
,
&
SparseMatrix
::
FromCOO
)
m
.
def
(
"from_coo"
,
&
SparseMatrix
::
FromCOO
)
.
def
(
"from_csr"
,
&
SparseMatrix
::
FromCSR
)
.
def
(
"from_csr"
,
&
SparseMatrix
::
FromCSR
)
.
def
(
"from_csc"
,
&
SparseMatrix
::
FromCSC
)
.
def
(
"from_csc"
,
&
SparseMatrix
::
FromCSC
)
.
def
(
"from_diag"
,
&
SparseMatrix
::
FromDiag
)
.
def
(
"spsp_add"
,
&
SpSpAdd
)
.
def
(
"spsp_add"
,
&
SpSpAdd
)
.
def
(
"reduce"
,
&
Reduce
)
.
def
(
"reduce"
,
&
Reduce
)
.
def
(
"sum"
,
&
ReduceSum
)
.
def
(
"sum"
,
&
ReduceSum
)
...
...
dgl_sparse/src/sparse_format.cc
View file @
a03dec05
...
@@ -99,6 +99,39 @@ std::shared_ptr<CSR> CSRToCSC(const std::shared_ptr<CSR>& csr) {
...
@@ -99,6 +99,39 @@ std::shared_ptr<CSR> CSRToCSC(const std::shared_ptr<CSR>& csr) {
return
CSRFromOldDGLCSR
(
dgl_csc
);
return
CSRFromOldDGLCSR
(
dgl_csc
);
}
}
std
::
shared_ptr
<
COO
>
DiagToCOO
(
const
std
::
shared_ptr
<
Diag
>&
diag
,
const
c10
::
TensorOptions
&
indices_options
)
{
int64_t
nnz
=
std
::
min
(
diag
->
num_rows
,
diag
->
num_cols
);
auto
indices
=
torch
::
arange
(
nnz
,
indices_options
).
repeat
({
2
,
1
});
return
std
::
make_shared
<
COO
>
(
COO
{
diag
->
num_rows
,
diag
->
num_cols
,
indices
,
true
,
true
});
}
std
::
shared_ptr
<
CSR
>
DiagToCSR
(
const
std
::
shared_ptr
<
Diag
>&
diag
,
const
c10
::
TensorOptions
&
indices_options
)
{
int64_t
nnz
=
std
::
min
(
diag
->
num_rows
,
diag
->
num_cols
);
auto
indptr
=
torch
::
full
(
diag
->
num_rows
+
1
,
nnz
,
indices_options
);
torch
::
arange_out
(
indptr
,
nnz
+
1
);
auto
indices
=
torch
::
arange
(
nnz
,
indices_options
);
return
std
::
make_shared
<
CSR
>
(
CSR
{
diag
->
num_rows
,
diag
->
num_cols
,
indptr
,
indices
,
torch
::
optional
<
torch
::
Tensor
>
(),
true
});
}
std
::
shared_ptr
<
CSR
>
DiagToCSC
(
const
std
::
shared_ptr
<
Diag
>&
diag
,
const
c10
::
TensorOptions
&
indices_options
)
{
int64_t
nnz
=
std
::
min
(
diag
->
num_rows
,
diag
->
num_cols
);
auto
indptr
=
torch
::
full
(
diag
->
num_cols
+
1
,
nnz
,
indices_options
);
torch
::
arange_out
(
indptr
,
nnz
+
1
);
auto
indices
=
torch
::
arange
(
nnz
,
indices_options
);
return
std
::
make_shared
<
CSR
>
(
CSR
{
diag
->
num_cols
,
diag
->
num_rows
,
indptr
,
indices
,
torch
::
optional
<
torch
::
Tensor
>
(),
true
});
}
std
::
shared_ptr
<
COO
>
COOTranspose
(
const
std
::
shared_ptr
<
COO
>&
coo
)
{
std
::
shared_ptr
<
COO
>
COOTranspose
(
const
std
::
shared_ptr
<
COO
>&
coo
)
{
auto
dgl_coo
=
COOToOldDGLCOO
(
coo
);
auto
dgl_coo
=
COOToOldDGLCOO
(
coo
);
auto
dgl_coo_tr
=
aten
::
COOTranspose
(
dgl_coo
);
auto
dgl_coo_tr
=
aten
::
COOTranspose
(
dgl_coo
);
...
...
dgl_sparse/src/sparse_matrix.cc
View file @
a03dec05
...
@@ -17,12 +17,18 @@ namespace sparse {
...
@@ -17,12 +17,18 @@ namespace sparse {
SparseMatrix
::
SparseMatrix
(
SparseMatrix
::
SparseMatrix
(
const
std
::
shared_ptr
<
COO
>&
coo
,
const
std
::
shared_ptr
<
CSR
>&
csr
,
const
std
::
shared_ptr
<
COO
>&
coo
,
const
std
::
shared_ptr
<
CSR
>&
csr
,
const
std
::
shared_ptr
<
CSR
>&
csc
,
torch
::
Tensor
value
,
const
std
::
shared_ptr
<
CSR
>&
csc
,
const
std
::
shared_ptr
<
Diag
>&
diag
,
const
std
::
vector
<
int64_t
>&
shape
)
torch
::
Tensor
value
,
const
std
::
vector
<
int64_t
>&
shape
)
:
coo_
(
coo
),
csr_
(
csr
),
csc_
(
csc
),
value_
(
value
),
shape_
(
shape
)
{
:
coo_
(
coo
),
csr_
(
csr
),
csc_
(
csc
),
diag_
(
diag
),
value_
(
value
),
shape_
(
shape
)
{
TORCH_CHECK
(
TORCH_CHECK
(
coo
!=
nullptr
||
csr
!=
nullptr
||
csc
!=
nullptr
,
"At least "
,
coo
!=
nullptr
||
csr
!=
nullptr
||
csc
!=
nullptr
||
diag
!=
nullptr
,
"one of CSR/COO/CSC is required to construct a SparseMatrix."
)
"At least one of CSR/COO/CSC/Diag is required to construct a "
"SparseMatrix."
)
TORCH_CHECK
(
TORCH_CHECK
(
shape
.
size
()
==
2
,
"The shape of a sparse matrix should be "
,
shape
.
size
()
==
2
,
"The shape of a sparse matrix should be "
,
"2-dimensional."
);
"2-dimensional."
);
...
@@ -51,24 +57,37 @@ SparseMatrix::SparseMatrix(
...
@@ -51,24 +57,37 @@ SparseMatrix::SparseMatrix(
TORCH_CHECK
(
csc
->
indptr
.
device
()
==
value
.
device
());
TORCH_CHECK
(
csc
->
indptr
.
device
()
==
value
.
device
());
TORCH_CHECK
(
csc
->
indices
.
device
()
==
value
.
device
());
TORCH_CHECK
(
csc
->
indices
.
device
()
==
value
.
device
());
}
}
if
(
diag
!=
nullptr
)
{
TORCH_CHECK
(
value
.
size
(
0
)
==
std
::
min
(
diag
->
num_rows
,
diag
->
num_cols
));
}
}
}
c10
::
intrusive_ptr
<
SparseMatrix
>
SparseMatrix
::
FromCOOPointer
(
c10
::
intrusive_ptr
<
SparseMatrix
>
SparseMatrix
::
FromCOOPointer
(
const
std
::
shared_ptr
<
COO
>&
coo
,
torch
::
Tensor
value
,
const
std
::
shared_ptr
<
COO
>&
coo
,
torch
::
Tensor
value
,
const
std
::
vector
<
int64_t
>&
shape
)
{
const
std
::
vector
<
int64_t
>&
shape
)
{
return
c10
::
make_intrusive
<
SparseMatrix
>
(
coo
,
nullptr
,
nullptr
,
value
,
shape
);
return
c10
::
make_intrusive
<
SparseMatrix
>
(
coo
,
nullptr
,
nullptr
,
nullptr
,
value
,
shape
);
}
}
c10
::
intrusive_ptr
<
SparseMatrix
>
SparseMatrix
::
FromCSRPointer
(
c10
::
intrusive_ptr
<
SparseMatrix
>
SparseMatrix
::
FromCSRPointer
(
const
std
::
shared_ptr
<
CSR
>&
csr
,
torch
::
Tensor
value
,
const
std
::
shared_ptr
<
CSR
>&
csr
,
torch
::
Tensor
value
,
const
std
::
vector
<
int64_t
>&
shape
)
{
const
std
::
vector
<
int64_t
>&
shape
)
{
return
c10
::
make_intrusive
<
SparseMatrix
>
(
nullptr
,
csr
,
nullptr
,
value
,
shape
);
return
c10
::
make_intrusive
<
SparseMatrix
>
(
nullptr
,
csr
,
nullptr
,
nullptr
,
value
,
shape
);
}
}
c10
::
intrusive_ptr
<
SparseMatrix
>
SparseMatrix
::
FromCSCPointer
(
c10
::
intrusive_ptr
<
SparseMatrix
>
SparseMatrix
::
FromCSCPointer
(
const
std
::
shared_ptr
<
CSR
>&
csc
,
torch
::
Tensor
value
,
const
std
::
shared_ptr
<
CSR
>&
csc
,
torch
::
Tensor
value
,
const
std
::
vector
<
int64_t
>&
shape
)
{
const
std
::
vector
<
int64_t
>&
shape
)
{
return
c10
::
make_intrusive
<
SparseMatrix
>
(
nullptr
,
nullptr
,
csc
,
value
,
shape
);
return
c10
::
make_intrusive
<
SparseMatrix
>
(
nullptr
,
nullptr
,
csc
,
nullptr
,
value
,
shape
);
}
c10
::
intrusive_ptr
<
SparseMatrix
>
SparseMatrix
::
FromDiagPointer
(
const
std
::
shared_ptr
<
Diag
>&
diag
,
torch
::
Tensor
value
,
const
std
::
vector
<
int64_t
>&
shape
)
{
return
c10
::
make_intrusive
<
SparseMatrix
>
(
nullptr
,
nullptr
,
nullptr
,
diag
,
value
,
shape
);
}
}
c10
::
intrusive_ptr
<
SparseMatrix
>
SparseMatrix
::
FromCOO
(
c10
::
intrusive_ptr
<
SparseMatrix
>
SparseMatrix
::
FromCOO
(
...
@@ -97,6 +116,12 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSC(
...
@@ -97,6 +116,12 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSC(
return
SparseMatrix
::
FromCSCPointer
(
csc
,
value
,
shape
);
return
SparseMatrix
::
FromCSCPointer
(
csc
,
value
,
shape
);
}
}
c10
::
intrusive_ptr
<
SparseMatrix
>
SparseMatrix
::
FromDiag
(
torch
::
Tensor
value
,
const
std
::
vector
<
int64_t
>&
shape
)
{
auto
diag
=
std
::
make_shared
<
Diag
>
(
Diag
{
shape
[
0
],
shape
[
1
]});
return
SparseMatrix
::
FromDiagPointer
(
diag
,
value
,
shape
);
}
c10
::
intrusive_ptr
<
SparseMatrix
>
SparseMatrix
::
ValLike
(
c10
::
intrusive_ptr
<
SparseMatrix
>
SparseMatrix
::
ValLike
(
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
mat
,
torch
::
Tensor
value
)
{
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
mat
,
torch
::
Tensor
value
)
{
TORCH_CHECK
(
TORCH_CHECK
(
...
@@ -136,6 +161,13 @@ std::shared_ptr<CSR> SparseMatrix::CSCPtr() {
...
@@ -136,6 +161,13 @@ std::shared_ptr<CSR> SparseMatrix::CSCPtr() {
return
csc_
;
return
csc_
;
}
}
std
::
shared_ptr
<
Diag
>
SparseMatrix
::
DiagPtr
()
{
TORCH_CHECK
(
diag_
!=
nullptr
,
"Cannot get Diag sparse format from a non-diagonal sparse matrix"
);
return
diag_
;
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
SparseMatrix
::
COOTensors
()
{
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
SparseMatrix
::
COOTensors
()
{
auto
coo
=
COOPtr
();
auto
coo
=
COOPtr
();
return
std
::
make_tuple
(
coo
->
indices
.
index
({
0
}),
coo
->
indices
.
index
({
1
}));
return
std
::
make_tuple
(
coo
->
indices
.
index
({
0
}),
coo
->
indices
.
index
({
1
}));
...
@@ -175,7 +207,13 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::Transpose() const {
...
@@ -175,7 +207,13 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::Transpose() const {
void
SparseMatrix
::
_CreateCOO
()
{
void
SparseMatrix
::
_CreateCOO
()
{
if
(
HasCOO
())
return
;
if
(
HasCOO
())
return
;
if
(
HasCSR
())
{
if
(
HasDiag
())
{
auto
indices_options
=
torch
::
TensorOptions
()
.
dtype
(
torch
::
kInt64
)
.
layout
(
torch
::
kStrided
)
.
device
(
this
->
device
());
coo_
=
DiagToCOO
(
diag_
,
indices_options
);
}
else
if
(
HasCSR
())
{
coo_
=
CSRToCOO
(
csr_
);
coo_
=
CSRToCOO
(
csr_
);
}
else
if
(
HasCSC
())
{
}
else
if
(
HasCSC
())
{
coo_
=
CSCToCOO
(
csc_
);
coo_
=
CSCToCOO
(
csc_
);
...
@@ -186,7 +224,13 @@ void SparseMatrix::_CreateCOO() {
...
@@ -186,7 +224,13 @@ void SparseMatrix::_CreateCOO() {
void
SparseMatrix
::
_CreateCSR
()
{
void
SparseMatrix
::
_CreateCSR
()
{
if
(
HasCSR
())
return
;
if
(
HasCSR
())
return
;
if
(
HasCOO
())
{
if
(
HasDiag
())
{
auto
indices_options
=
torch
::
TensorOptions
()
.
dtype
(
torch
::
kInt64
)
.
layout
(
torch
::
kStrided
)
.
device
(
this
->
device
());
csr_
=
DiagToCSR
(
diag_
,
indices_options
);
}
else
if
(
HasCOO
())
{
csr_
=
COOToCSR
(
coo_
);
csr_
=
COOToCSR
(
coo_
);
}
else
if
(
HasCSC
())
{
}
else
if
(
HasCSC
())
{
csr_
=
CSCToCSR
(
csc_
);
csr_
=
CSCToCSR
(
csc_
);
...
@@ -197,7 +241,13 @@ void SparseMatrix::_CreateCSR() {
...
@@ -197,7 +241,13 @@ void SparseMatrix::_CreateCSR() {
void
SparseMatrix
::
_CreateCSC
()
{
void
SparseMatrix
::
_CreateCSC
()
{
if
(
HasCSC
())
return
;
if
(
HasCSC
())
return
;
if
(
HasCOO
())
{
if
(
HasDiag
())
{
auto
indices_options
=
torch
::
TensorOptions
()
.
dtype
(
torch
::
kInt64
)
.
layout
(
torch
::
kStrided
)
.
device
(
this
->
device
());
csc_
=
DiagToCSC
(
diag_
,
indices_options
);
}
else
if
(
HasCOO
())
{
csc_
=
COOToCSC
(
coo_
);
csc_
=
COOToCSC
(
coo_
);
}
else
if
(
HasCSR
())
{
}
else
if
(
HasCSR
())
{
csc_
=
CSRToCSC
(
csr_
);
csc_
=
CSRToCSC
(
csr_
);
...
...
dgl_sparse/src/spspmm.cc
View file @
a03dec05
...
@@ -116,10 +116,49 @@ tensor_list SpSpMMAutoGrad::backward(
...
@@ -116,10 +116,49 @@ tensor_list SpSpMMAutoGrad::backward(
return
{
torch
::
Tensor
(),
lhs_val_grad
,
torch
::
Tensor
(),
rhs_val_grad
};
return
{
torch
::
Tensor
(),
lhs_val_grad
,
torch
::
Tensor
(),
rhs_val_grad
};
}
}
c10
::
intrusive_ptr
<
SparseMatrix
>
DiagSpSpMM
(
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
lhs_mat
,
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
rhs_mat
)
{
if
(
lhs_mat
->
HasDiag
()
&&
rhs_mat
->
HasDiag
())
{
// Diag @ Diag
const
int64_t
m
=
lhs_mat
->
shape
()[
0
];
const
int64_t
n
=
lhs_mat
->
shape
()[
1
];
const
int64_t
p
=
rhs_mat
->
shape
()[
1
];
const
int64_t
common_diag_len
=
std
::
min
({
m
,
n
,
p
});
const
int64_t
new_diag_len
=
std
::
min
(
m
,
p
);
auto
slice
=
torch
::
indexing
::
Slice
(
0
,
common_diag_len
);
auto
new_val
=
lhs_mat
->
value
().
index
({
slice
})
*
rhs_mat
->
value
().
index
({
slice
});
new_val
=
torch
::
constant_pad_nd
(
new_val
,
{
0
,
new_diag_len
-
common_diag_len
},
0
);
return
SparseMatrix
::
FromDiag
(
new_val
,
{
m
,
p
});
}
if
(
lhs_mat
->
HasDiag
()
&&
!
rhs_mat
->
HasDiag
())
{
// Diag @ Sparse
auto
row
=
rhs_mat
->
Indices
().
index
({
0
});
auto
val
=
lhs_mat
->
value
().
index_select
(
0
,
row
)
*
rhs_mat
->
value
();
return
SparseMatrix
::
ValLike
(
rhs_mat
,
val
);
}
if
(
!
lhs_mat
->
HasDiag
()
&&
rhs_mat
->
HasDiag
())
{
// Sparse @ Diag
auto
col
=
lhs_mat
->
Indices
().
index
({
1
});
auto
val
=
rhs_mat
->
value
().
index_select
(
0
,
col
)
*
lhs_mat
->
value
();
return
SparseMatrix
::
ValLike
(
lhs_mat
,
val
);
}
TORCH_CHECK
(
false
,
"For DiagSpSpMM, at least one of the sparse matries need to have kDiag "
"format"
);
return
c10
::
intrusive_ptr
<
SparseMatrix
>
();
}
c10
::
intrusive_ptr
<
SparseMatrix
>
SpSpMM
(
c10
::
intrusive_ptr
<
SparseMatrix
>
SpSpMM
(
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
lhs_mat
,
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
lhs_mat
,
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
rhs_mat
)
{
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
rhs_mat
)
{
_SpSpMMSanityCheck
(
lhs_mat
,
rhs_mat
);
_SpSpMMSanityCheck
(
lhs_mat
,
rhs_mat
);
if
(
lhs_mat
->
HasDiag
()
||
rhs_mat
->
HasDiag
())
{
return
DiagSpSpMM
(
lhs_mat
,
rhs_mat
);
}
auto
results
=
SpSpMMAutoGrad
::
apply
(
auto
results
=
SpSpMMAutoGrad
::
apply
(
lhs_mat
,
lhs_mat
->
value
(),
rhs_mat
,
rhs_mat
->
value
());
lhs_mat
,
lhs_mat
->
value
(),
rhs_mat
,
rhs_mat
->
value
());
std
::
vector
<
int64_t
>
ret_shape
({
lhs_mat
->
shape
()[
0
],
rhs_mat
->
shape
()[
1
]});
std
::
vector
<
int64_t
>
ret_shape
({
lhs_mat
->
shape
()[
0
],
rhs_mat
->
shape
()[
1
]});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment