Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
f9ce729f
"vscode:/vscode.git/clone" did not exist on "b4a1ed85440d4d9c1cafbe118ca6c034000c85f9"
Commit
f9ce729f
authored
Feb 10, 2020
by
rusty1s
Browse files
fixed a crucial bug in ptr2ind
parent
1f175220
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
5 additions
and
4 deletions
+5
-4
csrc/cuda/convert_cuda.cu
csrc/cuda/convert_cuda.cu
+2
-2
test/test_cat.py
test/test_cat.py
+2
-1
test/test_matmul.py
test/test_matmul.py
+1
-1
No files found.
csrc/cuda/convert_cuda.cu
View file @
f9ce729f
...
@@ -57,7 +57,7 @@ torch::Tensor ptr2ind_cuda(torch::Tensor ptr, int64_t E) {
...
@@ -57,7 +57,7 @@ torch::Tensor ptr2ind_cuda(torch::Tensor ptr, int64_t E) {
auto
ptr_data
=
ptr
.
data_ptr
<
int64_t
>
();
auto
ptr_data
=
ptr
.
data_ptr
<
int64_t
>
();
auto
out_data
=
out
.
data_ptr
<
int64_t
>
();
auto
out_data
=
out
.
data_ptr
<
int64_t
>
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
ptr2ind_kernel
<<<
(
ptr
.
numel
()
+
THREADS
-
1
)
/
THREADS
,
THREADS
,
0
,
stream
>>>
(
ptr2ind_kernel
<<<
(
ptr
.
numel
()
-
1
+
THREADS
-
1
)
/
THREADS
,
THREADS
,
0
,
ptr_data
,
out_data
,
E
,
ptr
.
numel
());
stream
>>>
(
ptr_data
,
out_data
,
E
,
ptr
.
numel
()
-
1
);
return
out
;
return
out
;
}
}
test/test_cat.py
View file @
f9ce729f
...
@@ -39,7 +39,8 @@ def test_cat(device):
...
@@ -39,7 +39,8 @@ def test_cat(device):
assert
out
.
storage
.
has_rowptr
()
assert
out
.
storage
.
has_rowptr
()
assert
out
.
storage
.
num_cached_keys
()
==
5
assert
out
.
storage
.
num_cached_keys
()
==
5
mat1
=
mat1
.
set_value_
(
torch
.
randn
((
mat1
.
nnz
(),
4
),
device
=
device
))
value
=
torch
.
randn
((
mat1
.
nnz
(),
4
),
device
=
device
)
mat1
=
mat1
.
set_value_
(
value
,
layout
=
'coo'
)
out
=
cat
([
mat1
,
mat1
],
dim
=-
1
)
out
=
cat
([
mat1
,
mat1
],
dim
=-
1
)
assert
out
.
storage
.
value
().
size
()
==
(
mat1
.
nnz
(),
8
)
assert
out
.
storage
.
value
().
size
()
==
(
mat1
.
nnz
(),
8
)
assert
out
.
storage
.
has_row
()
assert
out
.
storage
.
has_row
()
...
...
test/test_matmul.py
View file @
f9ce729f
...
@@ -40,7 +40,7 @@ def test_spmm(dtype, device, reduce):
...
@@ -40,7 +40,7 @@ def test_spmm(dtype, device, reduce):
out
=
matmul
(
src
,
other
,
reduce
)
out
=
matmul
(
src
,
other
,
reduce
)
out
.
backward
(
grad_out
)
out
.
backward
(
grad_out
)
assert
torch
.
allclose
(
expected
,
out
)
assert
torch
.
allclose
(
expected
,
out
,
atol
=
1e-6
)
assert
torch
.
allclose
(
expected_grad_value
,
value
.
grad
,
atol
=
1e-6
)
assert
torch
.
allclose
(
expected_grad_value
,
value
.
grad
,
atol
=
1e-6
)
assert
torch
.
allclose
(
expected_grad_other
,
other
.
grad
,
atol
=
1e-6
)
assert
torch
.
allclose
(
expected_grad_other
,
other
.
grad
,
atol
=
1e-6
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment