Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
2515ce6d
Commit
2515ce6d
authored
Jan 29, 2020
by
rusty1s
Browse files
set device
parent
f9b00093
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
13 additions
and
0 deletions
+13
-0
cuda/convert_kernel.cu
cuda/convert_kernel.cu
+4
-0
cuda/diag_kernel.cu
cuda/diag_kernel.cu
+2
-0
cuda/spmm_kernel.cu
cuda/spmm_kernel.cu
+3
-0
cuda/spspmm_kernel.cu
cuda/spspmm_kernel.cu
+3
-0
cuda/unique_kernel.cu
cuda/unique_kernel.cu
+1
-0
No files found.
cuda/convert_kernel.cu
View file @
2515ce6d
...
...
@@ -23,6 +23,8 @@ __global__ void ind2ptr_kernel(const int64_t *ind_data, int64_t *out_data,
}
torch
::
Tensor
ind2ptr_cuda
(
torch
::
Tensor
ind
,
int64_t
M
)
{
cudaSetDevice
(
ind
.
get_device
());
auto
out
=
torch
::
empty
(
M
+
1
,
ind
.
options
());
auto
ind_data
=
ind
.
DATA_PTR
<
int64_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
int64_t
>
();
...
...
@@ -46,6 +48,8 @@ __global__ void ptr2ind_kernel(const int64_t *ptr_data, int64_t *out_data,
}
torch
::
Tensor
ptr2ind_cuda
(
torch
::
Tensor
ptr
,
int64_t
E
)
{
cudaSetDevice
(
ptr
.
get_device
());
auto
out
=
torch
::
empty
(
E
,
ptr
.
options
());
auto
ptr_data
=
ptr
.
DATA_PTR
<
int64_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
int64_t
>
();
...
...
cuda/diag_kernel.cu
View file @
2515ce6d
...
...
@@ -40,6 +40,8 @@ __global__ void non_diag_mask_kernel(const int64_t *row_data,
torch
::
Tensor
non_diag_mask_cuda
(
torch
::
Tensor
row
,
torch
::
Tensor
col
,
int64_t
M
,
int64_t
N
,
int64_t
k
)
{
cudaSetDevice
(
row
.
get_device
());
int64_t
E
=
row
.
size
(
0
);
int64_t
num_diag
=
k
<
0
?
std
::
min
(
M
+
k
,
N
)
:
std
::
min
(
M
,
N
-
k
);
...
...
cuda/spmm_kernel.cu
View file @
2515ce6d
...
...
@@ -160,6 +160,7 @@ spmm_cuda(torch::Tensor rowptr, torch::Tensor col,
torch
::
optional
<
torch
::
Tensor
>
value_opt
,
torch
::
Tensor
mat
,
std
::
string
reduce
)
{
cudaSetDevice
(
rowptr
.
get_device
());
AT_ASSERTM
(
rowptr
.
dim
()
==
1
,
"Input mismatch"
);
AT_ASSERTM
(
col
.
dim
()
==
1
,
"Input mismatch"
);
if
(
value_opt
.
has_value
())
...
...
@@ -252,6 +253,8 @@ torch::Tensor spmm_val_bw_cuda(torch::Tensor row, torch::Tensor rowptr,
torch
::
Tensor
col
,
torch
::
Tensor
mat
,
torch
::
Tensor
grad
,
std
::
string
reduce
)
{
cudaSetDevice
(
row
.
get_device
());
mat
=
mat
.
contiguous
();
grad
=
grad
.
contiguous
();
...
...
cuda/spspmm_kernel.cu
View file @
2515ce6d
...
...
@@ -48,6 +48,9 @@ spspmm_cuda(torch::Tensor rowptrA, torch::Tensor colA,
torch
::
optional
<
torch
::
Tensor
>
valueA
,
torch
::
Tensor
rowptrB
,
torch
::
Tensor
colB
,
torch
::
optional
<
torch
::
Tensor
>
valueB
,
int64_t
M
,
int64_t
N
,
int64_t
K
)
{
cudaSetDevice
(
rowptrA
.
get_device
());
cusparseMatDescr_t
descr
=
0
;
cusparseCreateMatDescr
(
&
descr
);
auto
handle
=
at
::
cuda
::
getCurrentCUDASparseHandle
();
...
...
cuda/unique_kernel.cu
View file @
2515ce6d
...
...
@@ -19,6 +19,7 @@ __global__ void unique_cuda_kernel(scalar_t *__restrict__ src, bool *mask,
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
unique_cuda
(
at
::
Tensor
src
)
{
cudaSetDevice
(
src
.
get_device
());
at
::
Tensor
perm
;
std
::
tie
(
src
,
perm
)
=
src
.
sort
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment