Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
a23b490d
"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "9e4a75b1420769f890adb5a9a112d16031fb3530"
Unverified
Commit
a23b490d
authored
Sep 12, 2023
by
xiangyuzhi
Committed by
GitHub
Sep 12, 2023
Browse files
[Sparse] Sparse sample implementation (#6303)
parent
0440806a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
155 additions
and
2 deletions
+155
-2
dgl_sparse/include/sparse/sparse_matrix.h
dgl_sparse/include/sparse/sparse_matrix.h
+33
-0
dgl_sparse/src/python_binding.cc
dgl_sparse/src/python_binding.cc
+2
-1
dgl_sparse/src/sparse_matrix.cc
dgl_sparse/src/sparse_matrix.cc
+28
-0
python/dgl/sparse/sparse_matrix.py
python/dgl/sparse/sparse_matrix.py
+8
-1
tests/python/pytorch/sparse/test_sparse_matrix.py
tests/python/pytorch/sparse/test_sparse_matrix.py
+84
-0
No files found.
dgl_sparse/include/sparse/sparse_matrix.h
View file @
a23b490d
...
...
@@ -182,6 +182,39 @@ class SparseMatrix : public torch::CustomClassHolder {
c10
::
intrusive_ptr
<
SparseMatrix
>
RangeSelect
(
int64_t
dim
,
int64_t
start
,
int64_t
end
);
/**
* @brief Create a SparseMatrix by sampling elements based on the specified
* dimension and sample count.
*
* If `ids` is provided, this function samples elements from the specified
* set of row or column IDs, resulting in a sparse matrix containing only
* the sampled rows or columns.
*
* @param dim Select rows (dim=0) or columns (dim=1) for sampling.
* @param fanout The number of elements to randomly sample from each row or
* column.
* @param ids An optional tensor containing row or column IDs from which to
* sample elements.
* @param replace Indicates whether repeated sampling of the same element
* is allowed. If True, repeated sampling is allowed; otherwise, it is not
* allowed.
* @param bias An optional boolean flag indicating whether to enable biasing
* during sampling. If True, the values of the sparse matrix will be used as
* bias weights, meaning that elements with higher values will be more likely
* to be sampled. Otherwise, all elements will be sampled uniformly,
* regardless of their value.
*
* @return A new SparseMatrix with the same shape as the original matrix
* containing the sampled elements.
*
* @note If 'replace = false' and there are fewer elements than 'fanout',
* all non-zero elements will be sampled.
* @note If 'ids' is not provided, the function will sample from
* all rows or columns.
*/
c10
::
intrusive_ptr
<
SparseMatrix
>
Sample
(
int64_t
dim
,
int64_t
fanout
,
torch
::
Tensor
ids
,
bool
replace
,
bool
bias
);
/**
* @brief Create a SparseMatrix from a SparseMatrix using new values.
* @param mat An existing sparse matrix
...
...
dgl_sparse/src/python_binding.cc
View file @
a23b490d
...
...
@@ -35,7 +35,8 @@ TORCH_LIBRARY(dgl_sparse, m) {
.
def
(
"has_duplicate"
,
&
SparseMatrix
::
HasDuplicate
)
.
def
(
"is_diag"
,
&
SparseMatrix
::
HasDiag
)
.
def
(
"index_select"
,
&
SparseMatrix
::
IndexSelect
)
.
def
(
"range_select"
,
&
SparseMatrix
::
RangeSelect
);
.
def
(
"range_select"
,
&
SparseMatrix
::
RangeSelect
)
.
def
(
"sample"
,
&
SparseMatrix
::
Sample
);
m
.
def
(
"from_coo"
,
&
SparseMatrix
::
FromCOO
)
.
def
(
"from_csr"
,
&
SparseMatrix
::
FromCSR
)
.
def
(
"from_csc"
,
&
SparseMatrix
::
FromCSC
)
...
...
dgl_sparse/src/sparse_matrix.cc
View file @
a23b490d
...
...
@@ -167,6 +167,34 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::RangeSelect(
}
}
c10
::
intrusive_ptr
<
SparseMatrix
>
SparseMatrix
::
Sample
(
int64_t
dim
,
int64_t
fanout
,
torch
::
Tensor
ids
,
bool
replace
,
bool
bias
)
{
bool
rowwise
=
dim
==
0
;
auto
id_array
=
TorchTensorToDGLArray
(
ids
);
auto
csr
=
rowwise
?
this
->
CSRPtr
()
:
this
->
CSCPtr
();
// Slicing matrix.
auto
slice_csr
=
dgl
::
aten
::
CSRSliceRows
(
CSRToOldDGLCSR
(
csr
),
id_array
);
auto
slice_value
=
this
->
value
().
index_select
(
0
,
DGLArrayToTorchTensor
(
slice_csr
.
data
));
// Reset value indices.
slice_csr
.
data
=
dgl
::
aten
::
NullArray
();
auto
prob
=
bias
?
TorchTensorToDGLArray
(
slice_value
)
:
dgl
::
aten
::
NullArray
();
auto
slice_id
=
dgl
::
aten
::
Range
(
0
,
id_array
.
NumElements
(),
64
,
id_array
->
ctx
);
// Sampling all rows on sliced matrix.
auto
sample_coo
=
dgl
::
aten
::
CSRRowWiseSampling
(
slice_csr
,
slice_id
,
fanout
,
prob
,
replace
);
auto
sample_value
=
slice_value
.
index_select
(
0
,
DGLArrayToTorchTensor
(
sample_coo
.
data
));
sample_coo
.
data
=
dgl
::
aten
::
NullArray
();
auto
ret
=
COOFromOldDGLCOO
(
sample_coo
);
if
(
!
rowwise
)
ret
=
COOTranspose
(
ret
);
return
SparseMatrix
::
FromCOOPointer
(
ret
,
sample_value
,
{
ret
->
num_rows
,
ret
->
num_cols
});
}
c10
::
intrusive_ptr
<
SparseMatrix
>
SparseMatrix
::
ValLike
(
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
mat
,
torch
::
Tensor
value
)
{
TORCH_CHECK
(
...
...
python/dgl/sparse/sparse_matrix.py
View file @
a23b490d
...
...
@@ -671,7 +671,14 @@ class SparseMatrix:
values=tensor([0, 2, 3, 3]),
shape=(3, 2), nnz=3)
"""
raise
NotImplementedError
if
ids
is
None
:
dim_size
=
self
.
shape
[
0
]
if
dim
==
0
else
self
.
shape
[
1
]
ids
=
torch
.
range
(
0
,
dim_size
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
return
SparseMatrix
(
self
.
c_sparse_matrix
.
sample
(
dim
,
fanout
,
ids
,
replace
,
bias
)
)
def
spmatrix
(
...
...
tests/python/pytorch/sparse/test_sparse_matrix.py
View file @
a23b490d
...
...
@@ -504,6 +504,90 @@ def test_range_select(create_func, shape, dense_dim, select_dim, rang):
assert
torch
.
allclose
(
A_select_to_dense
,
dense_select
)
@
pytest
.
mark
.
parametrize
(
"create_func"
,
[
rand_diag
,
rand_csr
,
rand_csc
,
rand_coo
]
)
@
pytest
.
mark
.
parametrize
(
"index"
,
[(
0
,
1
,
2
,
3
,
4
),
(
0
,
1
,
3
),
(
1
,
1
,
2
)])
@
pytest
.
mark
.
parametrize
(
"replace"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"bias"
,
[
False
,
True
])
def
test_sample_rowwise
(
create_func
,
index
,
replace
,
bias
):
ctx
=
F
.
ctx
()
shape
=
(
5
,
5
)
sample_dim
=
0
sample_num
=
3
A
=
create_func
(
shape
,
10
,
ctx
)
A
=
val_like
(
A
,
torch
.
abs
(
A
.
val
))
index
=
torch
.
tensor
(
index
).
to
(
ctx
)
A_sample
=
A
.
sample
(
sample_dim
,
sample_num
,
index
,
replace
,
bias
)
A_dense
=
sparse_matrix_to_dense
(
A
)
A_sample_to_dense
=
sparse_matrix_to_dense
(
A_sample
)
ans_shape
=
(
index
.
size
(
0
),
shape
[
1
])
# Verify sample elements in origin rows
for
i
,
row
in
enumerate
(
list
(
index
)):
ans_ele
=
list
(
A_dense
[
row
,
:].
nonzero
().
reshape
(
-
1
))
ret_ele
=
list
(
A_sample_to_dense
[
i
,
:].
nonzero
().
reshape
(
-
1
))
for
e
in
ret_ele
:
assert
e
in
ans_ele
if
replace
:
# The number of sample elements in one row should be equal to
# 'sample_num' if the row is not empty otherwise should be
# equal to 0.
assert
list
(
A_sample
.
row
).
count
(
torch
.
tensor
(
i
))
==
(
sample_num
if
len
(
ans_ele
)
!=
0
else
0
)
else
:
assert
len
(
ret_ele
)
==
min
(
sample_num
,
len
(
ans_ele
))
assert
A_sample
.
shape
==
ans_shape
if
not
replace
:
assert
not
A_sample
.
has_duplicate
()
@
pytest
.
mark
.
parametrize
(
"create_func"
,
[
rand_diag
,
rand_csr
,
rand_csc
,
rand_coo
]
)
@
pytest
.
mark
.
parametrize
(
"index"
,
[(
0
,
1
,
2
,
3
,
4
),
(
0
,
1
,
3
),
(
1
,
1
,
2
)])
@
pytest
.
mark
.
parametrize
(
"replace"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"bias"
,
[
False
,
True
])
def
test_sample_columnwise
(
create_func
,
index
,
replace
,
bias
):
ctx
=
F
.
ctx
()
shape
=
(
5
,
5
)
sample_dim
=
1
sample_num
=
3
A
=
create_func
(
shape
,
10
,
ctx
)
A
=
val_like
(
A
,
torch
.
abs
(
A
.
val
))
index
=
torch
.
tensor
(
index
).
to
(
ctx
)
A_sample
=
A
.
sample
(
sample_dim
,
sample_num
,
index
,
replace
,
bias
)
A_dense
=
sparse_matrix_to_dense
(
A
)
A_sample_to_dense
=
sparse_matrix_to_dense
(
A_sample
)
ans_shape
=
(
shape
[
0
],
index
.
size
(
0
))
# Verify sample elements in origin columns
for
i
,
col
in
enumerate
(
list
(
index
)):
ans_ele
=
list
(
A_dense
[:,
col
].
nonzero
().
reshape
(
-
1
))
ret_ele
=
list
(
A_sample_to_dense
[:,
i
].
nonzero
().
reshape
(
-
1
))
for
e
in
ret_ele
:
assert
e
in
ans_ele
if
replace
:
# The number of sample elements in one column should be equal to
# 'sample_num' if the column is not empty otherwise should be
# equal to 0.
assert
list
(
A_sample
.
col
).
count
(
torch
.
tensor
(
i
))
==
(
sample_num
if
len
(
ans_ele
)
!=
0
else
0
)
else
:
assert
len
(
ret_ele
)
==
min
(
sample_num
,
len
(
ans_ele
))
assert
A_sample
.
shape
==
ans_shape
if
not
replace
:
assert
not
A_sample
.
has_duplicate
()
def
test_print
():
ctx
=
F
.
ctx
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment