Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
af2325bb
Commit
af2325bb
authored
Apr 03, 2020
by
rusty1s
Browse files
bugfixes
parent
b5aa7bc0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
30 additions
and
4 deletions
+30
-4
csrc/cuda/padding_cuda.cu
csrc/cuda/padding_cuda.cu
+6
-2
csrc/padding.cpp
csrc/padding.cpp
+1
-1
test/test_degree_padding2.py
test/test_degree_padding2.py
+23
-1
No files found.
csrc/cuda/padding_cuda.cu
View file @
af2325bb
...
@@ -91,6 +91,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
...
@@ -91,6 +91,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
torch
::
Tensor
>
torch
::
Tensor
>
padded_index_cuda
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
rowcount
,
padded_index_cuda
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
rowcount
,
torch
::
Tensor
binptr
)
{
torch
::
Tensor
binptr
)
{
// TODO: Add checks
cudaSetDevice
(
rowcount
.
get_device
());
cudaSetDevice
(
rowcount
.
get_device
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
size_t
mpc
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
;
size_t
mpc
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
;
...
@@ -148,9 +150,9 @@ __global__ void padded_index_select_kernel(const scalar_t *__restrict__ src,
...
@@ -148,9 +150,9 @@ __global__ void padded_index_select_kernel(const scalar_t *__restrict__ src,
int64_t
lane_idx
=
thread_idx
%
F
;
int64_t
lane_idx
=
thread_idx
%
F
;
int64_t
index_idx
=
__ldg
(
index
+
row_idx
);
int64_t
index_idx
=
__ldg
(
index
+
row_idx
);
scalar_tmp
=
fill_value
;
scalar_
t
tmp
=
fill_value
;
if
(
index_idx
!=
-
1
)
{
if
(
index_idx
!=
-
1
)
{
tmp
=
src
[
__ldg
(
col
+
index_idx
)
+
lane_idx
];
tmp
=
src
[
__ldg
(
col
+
index_idx
)
*
F
+
lane_idx
];
}
}
out
[
thread_idx
]
=
tmp
;
out
[
thread_idx
]
=
tmp
;
...
@@ -160,6 +162,8 @@ __global__ void padded_index_select_kernel(const scalar_t *__restrict__ src,
...
@@ -160,6 +162,8 @@ __global__ void padded_index_select_kernel(const scalar_t *__restrict__ src,
torch
::
Tensor
padded_index_select_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
col
,
torch
::
Tensor
padded_index_select_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
col
,
torch
::
Tensor
index
,
torch
::
Tensor
index
,
torch
::
Tensor
fill_value
)
{
torch
::
Tensor
fill_value
)
{
// TODO: Add checks
cudaSetDevice
(
src
.
get_device
());
cudaSetDevice
(
src
.
get_device
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
size_t
mpc
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
;
size_t
mpc
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
;
...
...
csrc/padding.cpp
View file @
af2325bb
...
@@ -19,7 +19,7 @@ padded_index(torch::Tensor rowptr, torch::Tensor rowcount,
...
@@ -19,7 +19,7 @@ padded_index(torch::Tensor rowptr, torch::Tensor rowcount,
torch
::
Tensor
padded_index_select
(
torch
::
Tensor
src
,
torch
::
Tensor
col
,
torch
::
Tensor
padded_index_select
(
torch
::
Tensor
src
,
torch
::
Tensor
col
,
torch
::
Tensor
index
,
torch
::
Tensor
index
,
torch
::
Tensor
fill_value
)
{
torch
::
Tensor
fill_value
)
{
return
padded_index_select
(
src
,
col
,
index
,
fill_value
);
return
padded_index_select
_cuda
(
src
,
col
,
index
,
fill_value
);
}
}
static
auto
registry
=
static
auto
registry
=
...
...
test/test_degree_padding2.py
View file @
af2325bb
...
@@ -12,6 +12,28 @@ def test_padded_index_select(device):
...
@@ -12,6 +12,28 @@ def test_padded_index_select(device):
start
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
start
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
row
=
torch
.
tensor
([
0
,
0
,
0
,
0
,
1
,
1
,
1
,
2
,
2
,
3
])
col
=
torch
.
tensor
([
0
,
1
,
2
,
3
,
0
,
2
,
3
,
1
,
3
,
2
])
idx
=
torch
.
tensor
([
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
])
adj
=
SparseTensor
(
row
=
row
,
col
=
col
).
to
(
device
)
binptr
=
torch
.
tensor
([
0
,
3
,
5
],
device
=
device
)
idx
,
mask
,
size
,
length
,
offset
=
torch
.
ops
.
torch_sparse
.
padded_index
(
adj
.
storage
.
rowptr
(),
adj
.
storage
.
rowcount
(),
binptr
)
print
(
size
)
print
(
length
)
print
(
offset
)
print
(
idx
)
print
(
mask
)
x
=
torch
.
tensor
([[
0
],
[
1
],
[
2
],
[
3
]],
dtype
=
torch
.
float
,
device
=
device
)
out
=
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
adj
.
storage
.
col
(),
idx
,
torch
.
tensor
(
0.
))
print
(
out
)
dataset
=
Planetoid
(
'/tmp/Planetoid'
,
name
=
'PubMed'
)
dataset
=
Planetoid
(
'/tmp/Planetoid'
,
name
=
'PubMed'
)
data
=
dataset
[
0
]
data
=
dataset
[
0
]
row
,
col
=
data
.
edge_index
.
to
(
device
)
row
,
col
=
data
.
edge_index
.
to
(
device
)
...
@@ -43,7 +65,7 @@ def test_padded_index_select(device):
...
@@ -43,7 +65,7 @@ def test_padded_index_select(device):
print
(
mask
[:
10
])
print
(
mask
[:
10
])
print
(
idx
[:
10
])
print
(
idx
[:
10
])
x
=
torch
.
randn
(
data
.
num_nodes
,
128
).
to
(
device
)
x
=
torch
.
randn
(
data
.
num_nodes
,
256
).
to
(
device
)
for
i
in
range
(
110
):
for
i
in
range
(
110
):
if
i
==
10
:
if
i
==
10
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment