Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
3639bfab
Commit
3639bfab
authored
Mar 30, 2020
by
rusty1s
Browse files
initial commit
parent
ae928282
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
286 additions
and
1 deletion
+286
-1
csrc/cuda/degree_padding_cuda.cu
csrc/cuda/degree_padding_cuda.cu
+139
-0
csrc/cuda/degree_padding_cuda.h
csrc/cuda/degree_padding_cuda.h
+10
-0
csrc/degree_padding.cpp
csrc/degree_padding.cpp
+42
-0
test/test_degree_padding.py
test/test_degree_padding.py
+94
-0
torch_sparse/__init__.py
torch_sparse/__init__.py
+1
-1
No files found.
csrc/cuda/degree_padding_cuda.cu
0 → 100644
View file @
3639bfab
#include "degree_padding_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#define THREADS 256
#define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__
void
bin_kernel
(
const
int64_t
*
rowcount
,
const
int64_t
*
bin_strategy
,
int64_t
*
bin
,
int64_t
*
one_hot
,
int64_t
num_bins
,
int64_t
numel
)
{
int64_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
if
(
thread_idx
<
numel
)
{
auto
count
=
rowcount
[
thread_idx
];
int64_t
b
=
-
1
;
for
(
int64_t
i
=
0
;
i
<
num_bins
;
i
++
)
{
if
(
count
>=
__ldg
(
bin_strategy
+
2
*
i
)
&&
count
<=
__ldg
(
bin_strategy
+
2
*
i
+
1
))
{
b
=
i
;
break
;
}
}
bin
[
thread_idx
]
=
b
;
if
(
b
>=
0
)
{
one_hot
[
b
*
numel
+
thread_idx
]
=
1
;
}
}
}
__global__
void
index_kernel
(
const
int64_t
*
bin
,
const
int64_t
*
cumsum
,
const
int64_t
*
nodes_per_bin
,
int64_t
*
index
,
int64_t
num_bins
,
int64_t
numel
)
{
int64_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
if
(
thread_idx
<
numel
)
{
auto
b
=
bin
[
thread_idx
];
if
(
b
>=
0
)
{
auto
idx
=
cumsum
[
b
*
numel
+
thread_idx
]
-
1
;
for
(
int64_t
i
=
0
;
i
<
b
;
i
++
)
{
idx
+=
__ldg
(
nodes_per_bin
+
i
);
}
index
[
idx
]
=
thread_idx
;
}
}
}
std
::
vector
<
torch
::
Tensor
>
bin_assignment_cuda
(
torch
::
Tensor
rowcount
,
torch
::
Tensor
bin_strategy
)
{
CHECK_CUDA
(
rowcount
);
CHECK_CUDA
(
bin_strategy
);
CHECK_INPUT
(
rowcount
.
dim
()
==
1
);
CHECK_INPUT
(
bin_strategy
.
dim
()
==
2
&&
bin_strategy
.
size
(
1
)
==
2
);
cudaSetDevice
(
rowcount
.
get_device
());
int64_t
num_bins
=
bin_strategy
.
size
(
0
);
auto
bin
=
torch
::
empty
({
rowcount
.
numel
()},
rowcount
.
options
());
auto
one_hot
=
torch
::
zeros
({
num_bins
,
rowcount
.
numel
()},
rowcount
.
options
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
bin_kernel
<<<
BLOCKS
(
rowcount
.
numel
()),
THREADS
,
0
,
stream
>>>
(
rowcount
.
data_ptr
<
int64_t
>
(),
bin_strategy
.
data_ptr
<
int64_t
>
(),
bin
.
data_ptr
<
int64_t
>
(),
one_hot
.
data_ptr
<
int64_t
>
(),
num_bins
,
rowcount
.
numel
());
auto
cumsum
=
one_hot
.
cumsum
(
1
);
auto
d_nodes_per_bin
=
cumsum
.
select
(
1
,
rowcount
.
numel
()
-
1
).
contiguous
();
auto
h_nodes_per_bin
=
d_nodes_per_bin
.
cpu
();
auto
h_size
=
h_nodes_per_bin
.
sum
().
data_ptr
<
int64_t
>
()[
0
];
auto
index
=
torch
::
empty
({
h_size
},
rowcount
.
options
());
index_kernel
<<<
BLOCKS
(
bin
.
numel
()),
THREADS
,
0
,
stream
>>>
(
bin
.
data_ptr
<
int64_t
>
(),
cumsum
.
data_ptr
<
int64_t
>
(),
d_nodes_per_bin
.
data_ptr
<
int64_t
>
(),
index
.
data_ptr
<
int64_t
>
(),
num_bins
,
rowcount
.
numel
());
auto
sizes
=
at
::
IntArrayRef
(
h_nodes_per_bin
.
data_ptr
<
int64_t
>
(),
num_bins
);
return
index
.
split_with_sizes
(
sizes
);
}
template
<
typename
scalar_t
,
int64_t
TB
>
__global__
void
padded_index_select_kernel
(
const
scalar_t
*
src
,
const
int64_t
*
rowptr
,
const
int64_t
*
col
,
const
int64_t
*
index
,
scalar_t
*
out
,
bool
*
mask
,
int64_t
length
,
int64_t
dim
,
int64_t
numel
)
{
int64_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
auto
dim_idx
=
thread_idx
%
dim
;
auto
lane_idx
=
(
thread_idx
/
dim
)
%
TB
;
auto
index_idx
=
thread_idx
/
(
TB
*
dim
);
if
(
thread_idx
<
numel
)
{
auto
row_idx
=
__ldg
(
index
+
index_idx
);
auto
row_start
=
__ldg
(
rowptr
+
row_idx
);
auto
row_end
=
__ldg
(
rowptr
+
row_idx
+
1
);
for
(
int64_t
c
=
lane_idx
;
c
<
row_end
-
row_start
;
c
+=
TB
)
{
auto
x
=
src
[
__ldg
(
col
+
row_start
+
c
)
*
dim
+
dim_idx
];
out
[
index_idx
*
dim
*
length
+
c
*
dim
+
dim_idx
]
=
x
;
// mask[index_idx * dim * length + c * dim + dim_idx] = true;
}
}
}
#define TB 4
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
padded_index_select_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
index
,
int64_t
length
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
rowptr
);
CHECK_CUDA
(
col
);
CHECK_CUDA
(
index
);
CHECK_INPUT
(
src
.
dim
()
==
2
);
CHECK_INPUT
(
rowptr
.
dim
()
==
1
);
CHECK_INPUT
(
col
.
dim
()
==
1
);
CHECK_INPUT
(
index
.
dim
()
==
1
);
cudaSetDevice
(
src
.
get_device
());
auto
out
=
torch
::
zeros
({
index
.
size
(
0
),
length
,
src
.
size
(
-
1
)},
src
.
options
());
auto
mask
=
torch
::
zeros
({
index
.
size
(
0
),
length
},
src
.
options
().
dtype
(
torch
::
kBool
));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"padded_index_select_kernel"
,
[
&
]
{
padded_index_select_kernel
<
scalar_t
,
TB
>
<<<
BLOCKS
(
index
.
numel
()
*
src
.
size
(
-
1
)
*
TB
),
THREADS
,
0
,
stream
>>>
(
src
.
data_ptr
<
scalar_t
>
(),
rowptr
.
data_ptr
<
int64_t
>
(),
col
.
data_ptr
<
int64_t
>
(),
index
.
data_ptr
<
int64_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
mask
.
data_ptr
<
bool
>
(),
length
,
src
.
size
(
-
1
),
index
.
numel
()
*
src
.
size
(
-
1
)
*
TB
);
});
return
std
::
make_tuple
(
out
,
mask
);
}
csrc/cuda/degree_padding_cuda.h
0 → 100644
View file @
3639bfab
#pragma once
#include <torch/extension.h>
std
::
vector
<
torch
::
Tensor
>
bin_assignment_cuda
(
torch
::
Tensor
rowcount
,
torch
::
Tensor
bin_strategy
);
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
padded_index_select_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
index
,
int64_t
length
);
csrc/degree_padding.cpp
0 → 100644
View file @
3639bfab
#include <Python.h>
#include <torch/script.h>
#ifdef WITH_CUDA
#include "cuda/degree_padding_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC
PyInit__degree_padding
(
void
)
{
return
NULL
;
}
#endif
std
::
vector
<
torch
::
Tensor
>
bin_assignment
(
torch
::
Tensor
rowcount
,
torch
::
Tensor
bin_strategy
)
{
if
(
rowcount
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
return
bin_assignment_cuda
(
rowcount
,
bin_strategy
);
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
AT_ERROR
(
"Not implemented yet"
);
}
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
padded_index_select
(
torch
::
Tensor
src
,
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
index
,
int64_t
length
)
{
if
(
src
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
return
padded_index_select_cuda
(
src
,
rowptr
,
col
,
index
,
length
);
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
AT_ERROR
(
"Not implemented yet"
);
}
}
static
auto
registry
=
torch
::
RegisterOperators
()
.
op
(
"torch_sparse::bin_assignment"
,
&
bin_assignment
)
.
op
(
"torch_sparse::padded_index_select"
,
&
padded_index_select
);
test/test_degree_padding.py
0 → 100644
View file @
3639bfab
import
pytest
import
torch
from
torch_sparse
import
SparseTensor
from
torch_geometric.datasets
import
Planetoid
devices
=
[
torch
.
device
(
'cuda'
)]
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_bin_assignment
(
device
):
rowcount
=
torch
.
tensor
([
2
,
3
,
6
,
4
,
5
,
7
,
8
,
1
],
device
=
device
)
bin_strategy
=
torch
.
tensor
([[
1
,
4
],
[
5
,
8
]],
device
=
device
)
perms
=
torch
.
ops
.
torch_sparse
.
bin_assignment
(
rowcount
,
bin_strategy
)
print
()
print
(
perms
)
dataset
=
Planetoid
(
'/tmp/Planetoid'
,
name
=
'PubMed'
)
row
,
col
=
dataset
[
0
].
edge_index
adj
=
SparseTensor
(
row
=
row
,
col
=
col
)
rowcount
=
adj
.
storage
.
rowcount
().
to
(
device
)
# bin_strategy = torch.tensor([[1, 7], [8, 12]], device=device)
bin_strategy
=
torch
.
tensor
([[
1
,
4
],
[
5
,
13
],
[
14
,
22
]],
device
=
device
)
bin_count
=
[
4
,
13
,
22
]
# src = torch.tensor([
# [1, 1],
# [2, 2],
# [3, 3],
# [4, 4],
# [5, 5],
# [6, 6],
# [7, 7],
# [8, 8],
# ], dtype=torch.float, device=device)
# rowptr = torch.tensor([0, 2, 5, 8, 10], device=device)
# col = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 1], device=device)
# index = torch.tensor([1, 2, 3], device=device)
# out, mask = torch.ops.torch_sparse.padded_index_select(
# src, rowptr, col, index, 4)
# print(out)
start
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
for
i
in
range
(
102
):
if
i
==
2
:
start
.
record
()
perms
=
torch
.
ops
.
torch_sparse
.
bin_assignment
(
rowcount
,
bin_strategy
)
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
start
.
elapsed_time
(
end
))
print
(
'-------------'
)
x
=
torch
.
randn
(
dataset
[
0
].
num_nodes
,
512
).
to
(
device
)
col
=
col
.
to
(
device
)
for
i
in
range
(
102
):
if
i
==
2
:
start
.
record
()
x
=
x
.
index_select
(
0
,
col
)
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
start
.
elapsed_time
(
end
))
x
=
torch
.
randn
(
dataset
[
0
].
num_nodes
,
512
).
to
(
device
)
rowptr
=
adj
.
storage
.
rowptr
().
to
(
device
)
col
=
col
.
to
(
device
)
for
i
in
range
(
102
):
if
i
==
2
:
start
.
record
()
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
rowptr
,
col
,
perms
[
0
],
bin_count
[
0
])
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
start
.
elapsed_time
(
end
))
for
i
in
range
(
102
):
if
i
==
2
:
start
.
record
()
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
rowptr
,
col
,
perms
[
1
],
bin_count
[
1
])
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
start
.
elapsed_time
(
end
))
for
i
in
range
(
102
):
if
i
==
2
:
start
.
record
()
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
rowptr
,
col
,
perms
[
2
],
bin_count
[
2
])
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
start
.
elapsed_time
(
end
))
torch_sparse/__init__.py
View file @
3639bfab
...
@@ -9,7 +9,7 @@ expected_torch_version = (1, 4)
...
@@ -9,7 +9,7 @@ expected_torch_version = (1, 4)
try
:
try
:
for
library
in
[
for
library
in
[
'_version'
,
'_convert'
,
'_diag'
,
'_spmm'
,
'_spspmm'
,
'_metis'
,
'_version'
,
'_convert'
,
'_diag'
,
'_spmm'
,
'_spspmm'
,
'_metis'
,
'_rw'
,
'_saint'
'_rw'
,
'_saint'
,
'_degree_padding'
]:
]:
torch
.
ops
.
load_library
(
importlib
.
machinery
.
PathFinder
().
find_spec
(
torch
.
ops
.
load_library
(
importlib
.
machinery
.
PathFinder
().
find_spec
(
library
,
[
osp
.
dirname
(
__file__
)]).
origin
)
library
,
[
osp
.
dirname
(
__file__
)]).
origin
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment