Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
98c4d2c6
Commit
98c4d2c6
authored
Apr 05, 2020
by
rusty1s
Browse files
backward implementation
parent
354ef5e5
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
109 additions
and
12 deletions
+109
-12
csrc/cuda/atomics.cuh
csrc/cuda/atomics.cuh
+21
-0
csrc/cuda/degree_padding_cuda.h
csrc/cuda/degree_padding_cuda.h
+0
-5
csrc/cuda/padding_cuda.cu
csrc/cuda/padding_cuda.cu
+46
-1
csrc/cuda/padding_cuda.h
csrc/cuda/padding_cuda.h
+3
-0
csrc/padding.cpp
csrc/padding.cpp
+25
-1
test/test_degree_padding2.py
test/test_degree_padding2.py
+14
-5
No files found.
csrc/cuda/atomics.cuh
0 → 100644
View file @
98c4d2c6
#pragma once
static
inline
__device__
void
atomAdd
(
float
*
address
,
float
val
)
{
atomicAdd
(
address
,
val
);
}
static
inline
__device__
void
atomAdd
(
double
*
address
,
double
val
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
unsigned
long
long
int
*
address_as_ull
=
(
unsigned
long
long
int
*
)
address
;
unsigned
long
long
int
old
=
*
address_as_ull
;
unsigned
long
long
int
assumed
;
do
{
assumed
=
old
;
old
=
atomicCAS
(
address_as_ull
,
assumed
,
__double_as_longlong
(
val
+
__longlong_as_double
(
assumed
)));
}
while
(
assumed
!=
old
);
#else
atomicAdd
(
address
,
val
);
#endif
}
csrc/cuda/degree_padding_cuda.h
View file @
98c4d2c6
...
...
@@ -9,8 +9,3 @@ std::tuple<torch::Tensor, torch::Tensor>
padded_index_select_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
index
,
int64_t
length
,
torch
::
Tensor
fill_value
);
// std::tuple<torch::Tensor, torch::Tensor> padded_index_select_cuda2(
// torch::Tensor src, torch::Tensor rowptr, torch::Tensor col,
// torch::Tensor bin, torch::Tensor index, std::vector<int64_t> node_counts,
// std::vector<int64_t> lengths, torch::Tensor fill_value);
csrc/cuda/padding_cuda.cu
View file @
98c4d2c6
...
...
@@ -2,6 +2,7 @@
#include <ATen/cuda/CUDAContext.h>
#include "atomics.cuh"
#include "utils.cuh"
#define THREADS 1024
...
...
@@ -225,7 +226,7 @@ torch::Tensor padded_index_select_cuda(torch::Tensor src, torch::Tensor index,
size_t
E
=
index
.
numel
();
size_t
F
=
src
.
size
(
-
1
);
auto
out
=
torch
::
empty
(
E
*
F
,
src
.
options
());
auto
out
=
torch
::
empty
(
{(
int
)
E
,
(
int
)
F
}
,
src
.
options
());
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"padded_index_select_kernel"
,
[
&
]
{
scalar_t
*
fill
;
...
...
@@ -245,3 +246,47 @@ torch::Tensor padded_index_select_cuda(torch::Tensor src, torch::Tensor index,
return
out
;
}
template
<
typename
scalar_t
>
__global__
void
padded_index_scatter_kernel
(
const
scalar_t
*
__restrict__
src
,
const
int64_t
*
__restrict__
index
,
scalar_t
*
__restrict__
out
,
const
size_t
E
,
const
size_t
F
)
{
for
(
ptrdiff_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
thread_idx
<
E
*
F
;
thread_idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
int64_t
index_idx
=
__ldg
(
index
+
thread_idx
/
F
);
if
(
index_idx
!=
-
1
)
{
atomAdd
(
out
+
index_idx
*
F
+
thread_idx
%
F
,
src
[
thread_idx
]);
}
}
}
torch
::
Tensor
padded_index_scatter_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
int64_t
N
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
index
);
CHECK_INPUT
(
src
.
dim
()
==
2
);
CHECK_INPUT
(
index
.
dim
()
==
1
);
CHECK_INPUT
(
src
.
size
(
0
)
==
index
.
size
(
0
));
cudaSetDevice
(
src
.
get_device
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
size_t
mpc
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
;
size_t
E
=
index
.
numel
();
size_t
F
=
src
.
size
(
-
1
);
auto
out
=
torch
::
zeros
({
N
,
(
int
)
F
},
src
.
options
());
AT_DISPATCH_FLOATING_TYPES
(
src
.
scalar_type
(),
"padded_index_scatter_kernel"
,
[
&
]
{
padded_index_scatter_kernel
<
scalar_t
>
<<<
std
::
min
(
BLOCKS
(
E
*
F
),
mpc
*
8
),
THREADS
,
0
,
stream
>>>
(
src
.
data_ptr
<
scalar_t
>
(),
index
.
data_ptr
<
int64_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
E
,
F
);
});
return
out
;
}
csrc/cuda/padding_cuda.h
View file @
98c4d2c6
...
...
@@ -9,3 +9,6 @@ padded_index_cuda(torch::Tensor rowptr, torch::Tensor col,
torch
::
Tensor
padded_index_select_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
torch
::
Tensor
fill_value
);
torch
::
Tensor
padded_index_scatter_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
int64_t
N
);
csrc/padding.cpp
View file @
98c4d2c6
...
...
@@ -16,9 +16,33 @@ padded_index(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount,
return
padded_index_cuda
(
rowptr
,
col
,
rowcount
,
binptr
);
}
using
torch
::
autograd
::
AutogradContext
;
using
torch
::
autograd
::
Variable
;
using
torch
::
autograd
::
variable_list
;
class
PaddedIndexSelect
:
public
torch
::
autograd
::
Function
<
PaddedIndexSelect
>
{
public:
static
variable_list
forward
(
AutogradContext
*
ctx
,
Variable
src
,
Variable
index
,
Variable
fill_value
)
{
ctx
->
saved_data
[
"N"
]
=
src
.
size
(
0
);
auto
out
=
padded_index_select_cuda
(
src
,
index
,
fill_value
);
ctx
->
save_for_backward
({
index
});
return
{
out
};
}
static
variable_list
backward
(
AutogradContext
*
ctx
,
variable_list
grad_outs
)
{
auto
grad_out
=
grad_outs
[
0
];
auto
saved
=
ctx
->
get_saved_variables
();
auto
index
=
saved
[
0
];
auto
N
=
ctx
->
saved_data
[
"N"
].
toInt
();
auto
grad_in
=
padded_index_scatter_cuda
(
grad_out
,
index
,
N
);
return
{
grad_in
,
Variable
(),
Variable
()};
}
};
torch
::
Tensor
padded_index_select
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
torch
::
Tensor
fill_value
)
{
return
p
added
_i
ndex
_s
elect
_cuda
(
src
,
index
,
fill_value
);
return
P
added
I
ndex
S
elect
::
apply
(
src
,
index
,
fill_value
)
[
0
]
;
}
static
auto
registry
=
...
...
test/test_degree_padding2.py
View file @
98c4d2c6
...
...
@@ -64,6 +64,17 @@ def test_padded_index_select(device):
# print(mask[:10])
# print(idx[:10])
x
=
torch
.
randn
(
adj
.
size
(
0
),
512
).
to
(
device
)
data
=
torch
.
ops
.
torch_sparse
.
padded_index
(
rowptr
,
col
,
rowcount
,
binptr
)
node_perm
,
row_perm
,
col_perm
,
mask
,
node_sizes
,
edge_sizes
=
data
out
=
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
col_perm
,
torch
.
tensor
(
0.
))
outs
=
out
.
split
(
edge_sizes
)
for
out
,
size
in
zip
(
outs
,
node_sizes
):
print
(
out
.
view
(
size
,
-
1
,
x
.
size
(
-
1
)).
shape
)
for
i
in
range
(
110
):
if
i
==
10
:
start
.
record
()
...
...
@@ -71,15 +82,13 @@ def test_padded_index_select(device):
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
'padded index'
,
start
.
elapsed_time
(
end
))
return
x
=
torch
.
randn
(
data
.
num_nodes
,
512
).
to
(
device
)
for
i
in
range
(
110
):
if
i
==
10
:
start
.
record
()
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
col
,
idx
,
torch
.
tensor
(
0.
))
out
=
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
col_perm
,
torch
.
tensor
(
0.
))
out
.
split
(
edge_sizes
)
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
'padded index select'
,
start
.
elapsed_time
(
end
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment