Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
4463b3d6
"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "40aa47b998b0203f6aca0759df9f5eeefd64fcc7"
Unverified
Commit
4463b3d6
authored
Mar 31, 2023
by
czkkkkkk
Committed by
GitHub
Mar 31, 2023
Browse files
[Sparse] Support strided tensor when calling old DGL APIs (#5506)
parent
d45eafd4
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
52 additions
and
11 deletions
+52
-11
dgl_sparse/src/matmul.cc
dgl_sparse/src/matmul.cc
+2
-2
dgl_sparse/src/sddmm.cc
dgl_sparse/src/sddmm.cc
+3
-4
dgl_sparse/src/utils.h
dgl_sparse/src/utils.h
+1
-1
tests/python/pytorch/sparse/test_matmul.py
tests/python/pytorch/sparse/test_matmul.py
+3
-0
tests/python/pytorch/sparse/test_sddmm.py
tests/python/pytorch/sparse/test_sddmm.py
+13
-1
tests/python/pytorch/sparse/utils.py
tests/python/pytorch/sparse/utils.py
+30
-3
No files found.
dgl_sparse/src/matmul.cc
View file @
4463b3d6
...
@@ -83,9 +83,9 @@ torch::Tensor SDDMMNoAutoGrad(
...
@@ -83,9 +83,9 @@ torch::Tensor SDDMMNoAutoGrad(
if
(
mat1
.
dim
()
>=
3
)
{
if
(
mat1
.
dim
()
>=
3
)
{
shape
.
push_back
(
mat1
.
size
(
2
));
shape
.
push_back
(
mat1
.
size
(
2
));
// (N, K, B) -> (N, B, K)
// (N, K, B) -> (N, B, K)
mat1
=
mat1
.
transpose
(
1
,
2
)
.
contiguous
()
;
mat1
=
mat1
.
transpose
(
1
,
2
);
// (M, K, B) -> (M, B, K)
// (M, K, B) -> (M, B, K)
mat2_tr
=
mat2_tr
.
transpose
(
1
,
2
)
.
contiguous
()
;
mat2_tr
=
mat2_tr
.
transpose
(
1
,
2
);
}
}
auto
ret
=
torch
::
zeros
(
shape
,
mat1
.
options
());
auto
ret
=
torch
::
zeros
(
shape
,
mat1
.
options
());
const
std
::
string
op
=
"dot"
;
const
std
::
string
op
=
"dot"
;
...
...
dgl_sparse/src/sddmm.cc
View file @
4463b3d6
...
@@ -68,7 +68,7 @@ void _SDDMMSanityCheck(
...
@@ -68,7 +68,7 @@ void _SDDMMSanityCheck(
torch
::
Tensor
SDDMMAutoGrad
::
forward
(
torch
::
Tensor
SDDMMAutoGrad
::
forward
(
AutogradContext
*
ctx
,
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
sparse_mat
,
AutogradContext
*
ctx
,
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
sparse_mat
,
torch
::
Tensor
mat1
,
torch
::
Tensor
mat2
)
{
torch
::
Tensor
mat1
,
torch
::
Tensor
mat2
)
{
auto
mat2_tr
=
mat2
.
transpose
(
0
,
1
)
.
contiguous
()
;
auto
mat2_tr
=
mat2
.
transpose
(
0
,
1
);
auto
ret
=
SDDMMNoAutoGrad
(
sparse_mat
,
mat1
,
mat2_tr
);
auto
ret
=
SDDMMNoAutoGrad
(
sparse_mat
,
mat1
,
mat2_tr
);
torch
::
Tensor
cache_mat1
,
cache_mat2
;
torch
::
Tensor
cache_mat1
,
cache_mat2
;
if
(
mat1
.
requires_grad
())
{
if
(
mat1
.
requires_grad
())
{
...
@@ -94,13 +94,12 @@ tensor_list SDDMMAutoGrad::backward(
...
@@ -94,13 +94,12 @@ tensor_list SDDMMAutoGrad::backward(
torch
::
Tensor
mat1_grad
,
mat2_grad
;
torch
::
Tensor
mat1_grad
,
mat2_grad
;
if
(
ctx
->
saved_data
[
"mat1_requires_grad"
].
toBool
())
{
if
(
ctx
->
saved_data
[
"mat1_requires_grad"
].
toBool
())
{
// SDDMM(M, A, B) = C. dA = SpMM(dC, B^T)
// SDDMM(M, A, B) = C. dA = SpMM(dC, B^T)
mat1_grad
=
SpMMNoAutoGrad
(
mat1_grad
=
SpMMNoAutoGrad
(
sparse_mat
,
grad
,
mat2
.
transpose
(
0
,
1
),
false
);
sparse_mat
,
grad
,
mat2
.
transpose
(
0
,
1
).
contiguous
(),
false
);
}
}
if
(
ctx
->
saved_data
[
"mat2_requires_grad"
].
toBool
())
{
if
(
ctx
->
saved_data
[
"mat2_requires_grad"
].
toBool
())
{
// SDDMM(M, A, B) = C. dB = SpMM(dC^T, A)^T
// SDDMM(M, A, B) = C. dB = SpMM(dC^T, A)^T
auto
mat2_tr_grad
=
SpMMNoAutoGrad
(
sparse_mat
,
grad
,
mat1
,
true
);
auto
mat2_tr_grad
=
SpMMNoAutoGrad
(
sparse_mat
,
grad
,
mat1
,
true
);
mat2_grad
=
mat2_tr_grad
.
transpose
(
0
,
1
)
.
contiguous
()
;
mat2_grad
=
mat2_tr_grad
.
transpose
(
0
,
1
);
}
}
return
{
torch
::
Tensor
(),
mat1_grad
,
mat2_grad
};
return
{
torch
::
Tensor
(),
mat1_grad
,
mat2_grad
};
}
}
...
...
dgl_sparse/src/utils.h
View file @
4463b3d6
...
@@ -52,7 +52,7 @@ inline static void ElementwiseOpSanityCheck(
...
@@ -52,7 +52,7 @@ inline static void ElementwiseOpSanityCheck(
/** @brief Convert a Torch tensor to a DGL array. */
/** @brief Convert a Torch tensor to a DGL array. */
inline
static
runtime
::
NDArray
TorchTensorToDGLArray
(
torch
::
Tensor
tensor
)
{
inline
static
runtime
::
NDArray
TorchTensorToDGLArray
(
torch
::
Tensor
tensor
)
{
return
runtime
::
DLPackConvert
::
FromDLPack
(
at
::
toDLPack
(
tensor
));
return
runtime
::
DLPackConvert
::
FromDLPack
(
at
::
toDLPack
(
tensor
.
contiguous
()
));
}
}
/** @brief Convert a DGL array to a Torch tensor. */
/** @brief Convert a DGL array to a Torch tensor. */
...
...
tests/python/pytorch/sparse/test_matmul.py
View file @
4463b3d6
...
@@ -13,6 +13,7 @@ from .utils import (
...
@@ -13,6 +13,7 @@ from .utils import (
rand_coo
,
rand_coo
,
rand_csc
,
rand_csc
,
rand_csr
,
rand_csr
,
rand_stride
,
sparse_matrix_to_dense
,
sparse_matrix_to_dense
,
sparse_matrix_to_torch_sparse
,
sparse_matrix_to_torch_sparse
,
)
)
...
@@ -30,6 +31,7 @@ def test_spmm(create_func, shape, nnz, out_dim):
...
@@ -30,6 +31,7 @@ def test_spmm(create_func, shape, nnz, out_dim):
else
:
else
:
X
=
torch
.
randn
(
shape
[
1
],
requires_grad
=
True
,
device
=
dev
)
X
=
torch
.
randn
(
shape
[
1
],
requires_grad
=
True
,
device
=
dev
)
X
=
rand_stride
(
X
)
sparse_result
=
matmul
(
A
,
X
)
sparse_result
=
matmul
(
A
,
X
)
grad
=
torch
.
randn_like
(
sparse_result
)
grad
=
torch
.
randn_like
(
sparse_result
)
sparse_result
.
backward
(
grad
)
sparse_result
.
backward
(
grad
)
...
@@ -56,6 +58,7 @@ def test_bspmm(create_func, shape, nnz):
...
@@ -56,6 +58,7 @@ def test_bspmm(create_func, shape, nnz):
dev
=
F
.
ctx
()
dev
=
F
.
ctx
()
A
=
create_func
(
shape
,
nnz
,
dev
,
2
)
A
=
create_func
(
shape
,
nnz
,
dev
,
2
)
X
=
torch
.
randn
(
shape
[
1
],
10
,
2
,
requires_grad
=
True
,
device
=
dev
)
X
=
torch
.
randn
(
shape
[
1
],
10
,
2
,
requires_grad
=
True
,
device
=
dev
)
X
=
rand_stride
(
X
)
sparse_result
=
matmul
(
A
,
X
)
sparse_result
=
matmul
(
A
,
X
)
grad
=
torch
.
randn_like
(
sparse_result
)
grad
=
torch
.
randn_like
(
sparse_result
)
...
...
tests/python/pytorch/sparse/test_sddmm.py
View file @
4463b3d6
...
@@ -6,7 +6,13 @@ import torch
...
@@ -6,7 +6,13 @@ import torch
from
dgl.sparse
import
bsddmm
,
sddmm
from
dgl.sparse
import
bsddmm
,
sddmm
from
.utils
import
clone_detach_and_grad
,
rand_coo
,
rand_csc
,
rand_csr
from
.utils
import
(
clone_detach_and_grad
,
rand_coo
,
rand_csc
,
rand_csr
,
rand_stride
,
)
@
pytest
.
mark
.
parametrize
(
"create_func"
,
[
rand_coo
,
rand_csr
,
rand_csc
])
@
pytest
.
mark
.
parametrize
(
"create_func"
,
[
rand_coo
,
rand_csr
,
rand_csc
])
...
@@ -23,6 +29,9 @@ def test_sddmm(create_func, shape, nnz, hidden):
...
@@ -23,6 +29,9 @@ def test_sddmm(create_func, shape, nnz, hidden):
B
=
torch
.
rand
(
shape
[
0
],
requires_grad
=
True
,
device
=
dev
)
B
=
torch
.
rand
(
shape
[
0
],
requires_grad
=
True
,
device
=
dev
)
C
=
torch
.
rand
(
shape
[
1
],
requires_grad
=
True
,
device
=
dev
)
C
=
torch
.
rand
(
shape
[
1
],
requires_grad
=
True
,
device
=
dev
)
B
=
rand_stride
(
B
)
C
=
rand_stride
(
C
)
A_val_clone
=
clone_detach_and_grad
(
A
.
val
)
A_val_clone
=
clone_detach_and_grad
(
A
.
val
)
dense_B
=
clone_detach_and_grad
(
B
)
dense_B
=
clone_detach_and_grad
(
B
)
dense_C
=
clone_detach_and_grad
(
C
)
dense_C
=
clone_detach_and_grad
(
C
)
...
@@ -58,6 +67,9 @@ def test_bsddmm(create_func, shape, nnz, nz_dim):
...
@@ -58,6 +67,9 @@ def test_bsddmm(create_func, shape, nnz, nz_dim):
B
=
torch
.
rand
(
shape
[
0
],
hidden
,
nz_dim
,
requires_grad
=
True
,
device
=
dev
)
B
=
torch
.
rand
(
shape
[
0
],
hidden
,
nz_dim
,
requires_grad
=
True
,
device
=
dev
)
C
=
torch
.
rand
(
hidden
,
shape
[
1
],
nz_dim
,
requires_grad
=
True
,
device
=
dev
)
C
=
torch
.
rand
(
hidden
,
shape
[
1
],
nz_dim
,
requires_grad
=
True
,
device
=
dev
)
B
=
rand_stride
(
B
)
C
=
rand_stride
(
C
)
A_val_clone
=
clone_detach_and_grad
(
A
.
val
)
A_val_clone
=
clone_detach_and_grad
(
A
.
val
)
dense_B
=
clone_detach_and_grad
(
B
)
dense_B
=
clone_detach_and_grad
(
B
)
dense_C
=
clone_detach_and_grad
(
C
)
dense_C
=
clone_detach_and_grad
(
C
)
...
...
tests/python/pytorch/sparse/utils.py
View file @
4463b3d6
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
dgl.sparse
import
from_coo
,
from_csc
,
from_csr
,
SparseMatrix
from
dgl.sparse
import
from_csc
,
from_csr
,
SparseMatrix
,
spmatrix
np
.
random
.
seed
(
42
)
np
.
random
.
seed
(
42
)
torch
.
random
.
manual_seed
(
42
)
torch
.
random
.
manual_seed
(
42
)
...
@@ -13,6 +13,16 @@ def clone_detach_and_grad(t):
...
@@ -13,6 +13,16 @@ def clone_detach_and_grad(t):
return
t
return
t
def
rand_stride
(
t
):
"""Add stride to the last dimension of a tensor."""
stride
=
np
.
random
.
randint
(
2
,
4
)
ret
=
torch
.
stack
([
t
]
*
stride
,
dim
=-
1
)[...,
0
]
ret
=
ret
.
detach
()
if
torch
.
is_floating_point
(
t
):
ret
.
requires_grad_
()
return
ret
def
rand_coo
(
shape
,
nnz
,
dev
,
nz_dim
=
None
):
def
rand_coo
(
shape
,
nnz
,
dev
,
nz_dim
=
None
):
# Create a sparse matrix without duplicate entries.
# Create a sparse matrix without duplicate entries.
nnzid
=
np
.
random
.
choice
(
shape
[
0
]
*
shape
[
1
],
nnz
,
replace
=
False
)
nnzid
=
np
.
random
.
choice
(
shape
[
0
]
*
shape
[
1
],
nnz
,
replace
=
False
)
...
@@ -23,7 +33,10 @@ def rand_coo(shape, nnz, dev, nz_dim=None):
...
@@ -23,7 +33,10 @@ def rand_coo(shape, nnz, dev, nz_dim=None):
val
=
torch
.
randn
(
nnz
,
device
=
dev
,
requires_grad
=
True
)
val
=
torch
.
randn
(
nnz
,
device
=
dev
,
requires_grad
=
True
)
else
:
else
:
val
=
torch
.
randn
(
nnz
,
nz_dim
,
device
=
dev
,
requires_grad
=
True
)
val
=
torch
.
randn
(
nnz
,
nz_dim
,
device
=
dev
,
requires_grad
=
True
)
return
from_coo
(
row
,
col
,
val
,
shape
)
indices
=
torch
.
stack
([
row
,
col
])
indices
=
rand_stride
(
indices
)
val
=
rand_stride
(
val
)
return
spmatrix
(
indices
,
val
,
shape
)
def
rand_csr
(
shape
,
nnz
,
dev
,
nz_dim
=
None
):
def
rand_csr
(
shape
,
nnz
,
dev
,
nz_dim
=
None
):
...
@@ -42,6 +55,9 @@ def rand_csr(shape, nnz, dev, nz_dim=None):
...
@@ -42,6 +55,9 @@ def rand_csr(shape, nnz, dev, nz_dim=None):
indptr
=
torch
.
cumsum
(
indptr
,
0
)
indptr
=
torch
.
cumsum
(
indptr
,
0
)
row_sorted
,
row_sorted_idx
=
torch
.
sort
(
row
)
row_sorted
,
row_sorted_idx
=
torch
.
sort
(
row
)
indices
=
col
[
row_sorted_idx
]
indices
=
col
[
row_sorted_idx
]
indptr
=
rand_stride
(
indptr
)
indices
=
rand_stride
(
indices
)
val
=
rand_stride
(
val
)
return
from_csr
(
indptr
,
indices
,
val
,
shape
=
shape
)
return
from_csr
(
indptr
,
indices
,
val
,
shape
=
shape
)
...
@@ -61,6 +77,9 @@ def rand_csc(shape, nnz, dev, nz_dim=None):
...
@@ -61,6 +77,9 @@ def rand_csc(shape, nnz, dev, nz_dim=None):
indptr
=
torch
.
cumsum
(
indptr
,
0
)
indptr
=
torch
.
cumsum
(
indptr
,
0
)
col_sorted
,
col_sorted_idx
=
torch
.
sort
(
col
)
col_sorted
,
col_sorted_idx
=
torch
.
sort
(
col
)
indices
=
row
[
col_sorted_idx
]
indices
=
row
[
col_sorted_idx
]
indptr
=
rand_stride
(
indptr
)
indices
=
rand_stride
(
indices
)
val
=
rand_stride
(
val
)
return
from_csc
(
indptr
,
indices
,
val
,
shape
=
shape
)
return
from_csc
(
indptr
,
indices
,
val
,
shape
=
shape
)
...
@@ -69,7 +88,9 @@ def rand_coo_uncoalesced(shape, nnz, dev):
...
@@ -69,7 +88,9 @@ def rand_coo_uncoalesced(shape, nnz, dev):
row
=
torch
.
randint
(
shape
[
0
],
(
nnz
,),
device
=
dev
)
row
=
torch
.
randint
(
shape
[
0
],
(
nnz
,),
device
=
dev
)
col
=
torch
.
randint
(
shape
[
1
],
(
nnz
,),
device
=
dev
)
col
=
torch
.
randint
(
shape
[
1
],
(
nnz
,),
device
=
dev
)
val
=
torch
.
randn
(
nnz
,
device
=
dev
,
requires_grad
=
True
)
val
=
torch
.
randn
(
nnz
,
device
=
dev
,
requires_grad
=
True
)
return
from_coo
(
row
,
col
,
val
,
shape
)
indices
=
torch
.
stack
([
row
,
col
])
indices
=
rand_stride
(
indices
)
return
spmatrix
(
indices
,
val
,
shape
)
def
rand_csr_uncoalesced
(
shape
,
nnz
,
dev
):
def
rand_csr_uncoalesced
(
shape
,
nnz
,
dev
):
...
@@ -83,6 +104,9 @@ def rand_csr_uncoalesced(shape, nnz, dev):
...
@@ -83,6 +104,9 @@ def rand_csr_uncoalesced(shape, nnz, dev):
indptr
=
torch
.
cumsum
(
indptr
,
0
)
indptr
=
torch
.
cumsum
(
indptr
,
0
)
row_sorted
,
row_sorted_idx
=
torch
.
sort
(
row
)
row_sorted
,
row_sorted_idx
=
torch
.
sort
(
row
)
indices
=
col
[
row_sorted_idx
]
indices
=
col
[
row_sorted_idx
]
indptr
=
rand_stride
(
indptr
)
indices
=
rand_stride
(
indices
)
val
=
rand_stride
(
val
)
return
from_csr
(
indptr
,
indices
,
val
,
shape
=
shape
)
return
from_csr
(
indptr
,
indices
,
val
,
shape
=
shape
)
...
@@ -97,6 +121,9 @@ def rand_csc_uncoalesced(shape, nnz, dev):
...
@@ -97,6 +121,9 @@ def rand_csc_uncoalesced(shape, nnz, dev):
indptr
=
torch
.
cumsum
(
indptr
,
0
)
indptr
=
torch
.
cumsum
(
indptr
,
0
)
col_sorted
,
col_sorted_idx
=
torch
.
sort
(
col
)
col_sorted
,
col_sorted_idx
=
torch
.
sort
(
col
)
indices
=
row
[
col_sorted_idx
]
indices
=
row
[
col_sorted_idx
]
indptr
=
rand_stride
(
indptr
)
indices
=
rand_stride
(
indices
)
val
=
rand_stride
(
val
)
return
from_csc
(
indptr
,
indices
,
val
,
shape
=
shape
)
return
from_csc
(
indptr
,
indices
,
val
,
shape
=
shape
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment