Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
73a508e1
Unverified
Commit
73a508e1
authored
Feb 23, 2023
by
czkkkkkk
Committed by
GitHub
Feb 23, 2023
Browse files
[Sparse] Stack SparseMatrix COO row and column coordinates into one tensor. (#5314)
parent
5ea04713
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
22 additions
and
24 deletions
+22
-24
dgl_sparse/include/sparse/sparse_format.h
dgl_sparse/include/sparse/sparse_format.h
+4
-4
dgl_sparse/src/reduction.cc
dgl_sparse/src/reduction.cc
+2
-2
dgl_sparse/src/sparse_format.cc
dgl_sparse/src/sparse_format.cc
+7
-7
dgl_sparse/src/sparse_matrix.cc
dgl_sparse/src/sparse_matrix.cc
+7
-9
dgl_sparse/src/spspmm.cc
dgl_sparse/src/spspmm.cc
+2
-2
No files found.
dgl_sparse/include/sparse/sparse_format.h
View file @
73a508e1
...
@@ -25,10 +25,10 @@ enum SparseFormat { kCOO, kCSR, kCSC };
...
@@ -25,10 +25,10 @@ enum SparseFormat { kCOO, kCSR, kCSC };
struct
COO
{
struct
COO
{
/** @brief The shape of the matrix. */
/** @brief The shape of the matrix. */
int64_t
num_rows
=
0
,
num_cols
=
0
;
int64_t
num_rows
=
0
,
num_cols
=
0
;
/**
@brief COO format row indices array of the matrix. */
/**
torch
::
Tensor
row
;
* @brief COO tensor of shape (2, nnz), stacking the row and column indices.
/** @brief COO format column indices array of the matrix.
*/
*/
torch
::
Tensor
col
;
torch
::
Tensor
indices
;
/** @brief Whether the row indices are sorted. */
/** @brief Whether the row indices are sorted. */
bool
row_sorted
=
false
;
bool
row_sorted
=
false
;
/** @brief Whether the column indices per row are sorted. */
/** @brief Whether the column indices per row are sorted. */
...
...
dgl_sparse/src/reduction.cc
View file @
73a508e1
...
@@ -51,10 +51,10 @@ torch::Tensor ReduceAlong(
...
@@ -51,10 +51,10 @@ torch::Tensor ReduceAlong(
torch
::
Tensor
idx
;
torch
::
Tensor
idx
;
if
(
dim
==
0
)
{
if
(
dim
==
0
)
{
output_shape
[
0
]
=
coo
->
num_cols
;
output_shape
[
0
]
=
coo
->
num_cols
;
idx
=
coo
->
col
.
view
(
view_dims
).
expand_as
(
value
);
idx
=
coo
->
indices
.
index
({
1
})
.
view
(
view_dims
).
expand_as
(
value
);
}
else
if
(
dim
==
1
)
{
}
else
if
(
dim
==
1
)
{
output_shape
[
0
]
=
coo
->
num_rows
;
output_shape
[
0
]
=
coo
->
num_rows
;
idx
=
coo
->
row
.
view
(
view_dims
).
expand_as
(
value
);
idx
=
coo
->
indices
.
index
({
0
})
.
view
(
view_dims
).
expand_as
(
value
);
}
}
torch
::
Tensor
out
=
torch
::
zeros
(
output_shape
,
value
.
options
());
torch
::
Tensor
out
=
torch
::
zeros
(
output_shape
,
value
.
options
());
...
...
dgl_sparse/src/sparse_format.cc
View file @
73a508e1
...
@@ -18,14 +18,15 @@ std::shared_ptr<COO> COOFromOldDGLCOO(const aten::COOMatrix& dgl_coo) {
...
@@ -18,14 +18,15 @@ std::shared_ptr<COO> COOFromOldDGLCOO(const aten::COOMatrix& dgl_coo) {
auto
row
=
DGLArrayToTorchTensor
(
dgl_coo
.
row
);
auto
row
=
DGLArrayToTorchTensor
(
dgl_coo
.
row
);
auto
col
=
DGLArrayToTorchTensor
(
dgl_coo
.
col
);
auto
col
=
DGLArrayToTorchTensor
(
dgl_coo
.
col
);
TORCH_CHECK
(
aten
::
IsNullArray
(
dgl_coo
.
data
));
TORCH_CHECK
(
aten
::
IsNullArray
(
dgl_coo
.
data
));
auto
indices
=
torch
::
stack
({
row
,
col
});
return
std
::
make_shared
<
COO
>
(
return
std
::
make_shared
<
COO
>
(
COO
{
dgl_coo
.
num_rows
,
dgl_coo
.
num_cols
,
row
,
col
,
dgl_coo
.
row_sorted
,
COO
{
dgl_coo
.
num_rows
,
dgl_coo
.
num_cols
,
indices
,
dgl_coo
.
row_sorted
,
dgl_coo
.
col_sorted
});
dgl_coo
.
col_sorted
});
}
}
aten
::
COOMatrix
COOToOldDGLCOO
(
const
std
::
shared_ptr
<
COO
>&
coo
)
{
aten
::
COOMatrix
COOToOldDGLCOO
(
const
std
::
shared_ptr
<
COO
>&
coo
)
{
auto
row
=
TorchTensorToDGLArray
(
coo
->
row
);
auto
row
=
TorchTensorToDGLArray
(
coo
->
indices
.
index
({
0
})
);
auto
col
=
TorchTensorToDGLArray
(
coo
->
col
);
auto
col
=
TorchTensorToDGLArray
(
coo
->
indices
.
index
({
1
})
);
return
aten
::
COOMatrix
(
return
aten
::
COOMatrix
(
coo
->
num_rows
,
coo
->
num_cols
,
row
,
col
,
aten
::
NullArray
(),
coo
->
num_rows
,
coo
->
num_cols
,
row
,
col
,
aten
::
NullArray
(),
coo
->
row_sorted
,
coo
->
col_sorted
);
coo
->
row_sorted
,
coo
->
col_sorted
);
...
@@ -50,14 +51,13 @@ aten::CSRMatrix CSRToOldDGLCSR(const std::shared_ptr<CSR>& csr) {
...
@@ -50,14 +51,13 @@ aten::CSRMatrix CSRToOldDGLCSR(const std::shared_ptr<CSR>& csr) {
torch
::
Tensor
COOToTorchCOO
(
torch
::
Tensor
COOToTorchCOO
(
const
std
::
shared_ptr
<
COO
>&
coo
,
torch
::
Tensor
value
)
{
const
std
::
shared_ptr
<
COO
>&
coo
,
torch
::
Tensor
value
)
{
std
::
vector
<
torch
::
Tensor
>
indices
=
{
coo
->
row
,
coo
->
col
}
;
torch
::
Tensor
indices
=
coo
->
indices
;
if
(
value
.
ndimension
()
==
2
)
{
if
(
value
.
ndimension
()
==
2
)
{
return
torch
::
sparse_coo_tensor
(
return
torch
::
sparse_coo_tensor
(
torch
::
stack
(
indices
),
value
,
indices
,
value
,
{
coo
->
num_rows
,
coo
->
num_cols
,
value
.
size
(
1
)});
{
coo
->
num_rows
,
coo
->
num_cols
,
value
.
size
(
1
)});
}
else
{
}
else
{
return
torch
::
sparse_coo_tensor
(
return
torch
::
sparse_coo_tensor
(
torch
::
stack
(
indices
)
,
value
,
{
coo
->
num_rows
,
coo
->
num_cols
});
indices
,
value
,
{
coo
->
num_rows
,
coo
->
num_cols
});
}
}
}
}
...
...
dgl_sparse/src/sparse_matrix.cc
View file @
73a508e1
...
@@ -30,12 +30,10 @@ SparseMatrix::SparseMatrix(
...
@@ -30,12 +30,10 @@ SparseMatrix::SparseMatrix(
// device. Do we allow the graph structure and values are on different
// device. Do we allow the graph structure and values are on different
// devices?
// devices?
if
(
coo
!=
nullptr
)
{
if
(
coo
!=
nullptr
)
{
TORCH_CHECK
(
coo
->
row
.
dim
()
==
1
);
TORCH_CHECK
(
coo
->
indices
.
dim
()
==
2
);
TORCH_CHECK
(
coo
->
col
.
dim
()
==
1
);
TORCH_CHECK
(
coo
->
indices
.
size
(
0
)
==
2
);
TORCH_CHECK
(
coo
->
row
.
size
(
0
)
==
coo
->
col
.
size
(
0
));
TORCH_CHECK
(
coo
->
indices
.
size
(
1
)
==
value
.
size
(
0
));
TORCH_CHECK
(
coo
->
row
.
size
(
0
)
==
value
.
size
(
0
));
TORCH_CHECK
(
coo
->
indices
.
device
()
==
value
.
device
());
TORCH_CHECK
(
coo
->
row
.
device
()
==
value
.
device
());
TORCH_CHECK
(
coo
->
col
.
device
()
==
value
.
device
());
}
}
if
(
csr
!=
nullptr
)
{
if
(
csr
!=
nullptr
)
{
TORCH_CHECK
(
csr
->
indptr
.
dim
()
==
1
);
TORCH_CHECK
(
csr
->
indptr
.
dim
()
==
1
);
...
@@ -76,8 +74,8 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSCPointer(
...
@@ -76,8 +74,8 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSCPointer(
c10
::
intrusive_ptr
<
SparseMatrix
>
SparseMatrix
::
FromCOO
(
c10
::
intrusive_ptr
<
SparseMatrix
>
SparseMatrix
::
FromCOO
(
torch
::
Tensor
row
,
torch
::
Tensor
col
,
torch
::
Tensor
value
,
torch
::
Tensor
row
,
torch
::
Tensor
col
,
torch
::
Tensor
value
,
const
std
::
vector
<
int64_t
>&
shape
)
{
const
std
::
vector
<
int64_t
>&
shape
)
{
auto
coo
=
auto
coo
=
std
::
make_shared
<
COO
>
(
std
::
make_shared
<
COO
>
(
COO
{
shape
[
0
],
shape
[
1
],
row
,
col
,
false
,
false
});
COO
{
shape
[
0
],
shape
[
1
],
torch
::
stack
({
row
,
col
})
,
false
,
false
});
return
SparseMatrix
::
FromCOOPointer
(
coo
,
value
,
shape
);
return
SparseMatrix
::
FromCOOPointer
(
coo
,
value
,
shape
);
}
}
...
@@ -141,7 +139,7 @@ std::shared_ptr<CSR> SparseMatrix::CSCPtr() {
...
@@ -141,7 +139,7 @@ std::shared_ptr<CSR> SparseMatrix::CSCPtr() {
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
SparseMatrix
::
COOTensors
()
{
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
SparseMatrix
::
COOTensors
()
{
auto
coo
=
COOPtr
();
auto
coo
=
COOPtr
();
auto
val
=
value
();
auto
val
=
value
();
return
std
::
make_tuple
(
coo
->
row
,
coo
->
col
);
return
std
::
make_tuple
(
coo
->
indices
.
index
({
0
}),
coo
->
indices
.
index
({
1
})
);
}
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>>
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>>
...
...
dgl_sparse/src/spspmm.cc
View file @
73a508e1
...
@@ -64,8 +64,8 @@ torch::Tensor _CSRMask(
...
@@ -64,8 +64,8 @@ torch::Tensor _CSRMask(
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
sub_mat
)
{
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
sub_mat
)
{
auto
csr
=
CSRToOldDGLCSR
(
mat
->
CSRPtr
());
auto
csr
=
CSRToOldDGLCSR
(
mat
->
CSRPtr
());
auto
val
=
TorchTensorToDGLArray
(
value
);
auto
val
=
TorchTensorToDGLArray
(
value
);
auto
row
=
TorchTensorToDGLArray
(
sub_mat
->
COOPtr
()
->
row
);
auto
row
=
TorchTensorToDGLArray
(
sub_mat
->
COOPtr
()
->
indices
.
index
({
0
})
);
auto
col
=
TorchTensorToDGLArray
(
sub_mat
->
COOPtr
()
->
col
);
auto
col
=
TorchTensorToDGLArray
(
sub_mat
->
COOPtr
()
->
indices
.
index
({
1
})
);
runtime
::
NDArray
ret
=
aten
::
CSRGetFloatingData
(
csr
,
row
,
col
,
val
,
0.
);
runtime
::
NDArray
ret
=
aten
::
CSRGetFloatingData
(
csr
,
row
,
col
,
val
,
0.
);
return
DGLArrayToTorchTensor
(
ret
);
return
DGLArrayToTorchTensor
(
ret
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment