Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
efbbce74
Commit
efbbce74
authored
Apr 03, 2020
by
rusty1s
Browse files
working example
parent
2bea1c3c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
147 additions
and
91 deletions
+147
-91
csrc/cuda/degree_padding_cuda.cu
csrc/cuda/degree_padding_cuda.cu
+57
-68
csrc/cuda/degree_padding_cuda.h
csrc/cuda/degree_padding_cuda.h
+8
-2
csrc/degree_padding.cpp
csrc/degree_padding.cpp
+24
-3
test/test_degree_padding2.py
test/test_degree_padding2.py
+58
-18
No files found.
csrc/cuda/degree_padding_cuda.cu
View file @
efbbce74
...
...
@@ -7,84 +7,72 @@
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__
void
bin_kernel
(
const
int64_t
*
rowcount
,
const
int64_t
*
bin_strategy
,
int64_t
*
bin
,
int64_t
*
one_hot
,
int64_t
num_bins
,
int64_t
numel
)
{
int64_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
if
(
thread_idx
<
numel
)
{
auto
count
=
rowcount
[
thread_idx
];
int64_t
b
=
-
1
;
for
(
int64_t
i
=
0
;
i
<
num_bins
;
i
++
)
{
if
(
count
>=
__ldg
(
bin_strategy
+
2
*
i
)
&&
count
<=
__ldg
(
bin_strategy
+
2
*
i
+
1
))
{
b
=
i
;
break
;
}
}
__global__
void
sizes_kernel
(
const
int64_t
*
__restrict__
sorted_rowcount
,
const
int64_t
*
__restrict__
binptr
,
int64_t
*
__restrict__
size
,
int64_t
*
__restrict__
length
,
const
int64_t
num_bins
,
const
int64_t
numel
)
{
for
(
int64_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
thread_idx
<
numel
-
1
;
thread_idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
bin
[
thread_idx
]
=
b
;
if
(
b
>=
0
)
{
one_hot
[
b
*
numel
+
thread_idx
]
=
1
;
}
}
}
int64_t
deg1
=
sorted_rowcount
[
thread_idx
];
int64_t
deg2
=
sorted_rowcount
[
thread_idx
+
1
];
__global__
void
index_kernel
(
const
int64_t
*
bin
,
const
int64_t
*
cumsum
,
const
int64_t
*
nodes_per_bin
,
int64_t
*
index
,
int64_t
num_bins
,
int64_t
numel
)
{
int64_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
if
(
thread_idx
<
numel
)
{
auto
b
=
bin
[
thread_idx
];
if
(
b
>=
0
)
{
auto
idx
=
cumsum
[
b
*
numel
+
thread_idx
]
-
1
;
for
(
int64_t
i
=
0
;
i
<
b
;
i
++
)
{
idx
+=
__ldg
(
nodes_per_bin
+
i
);
if
(
deg1
!=
deg2
)
{
for
(
int64_t
b
=
1
;
b
<=
num_bins
;
b
++
)
{
if
(
deg1
<
__ldg
(
binptr
+
b
)
&&
deg2
>=
__ldg
(
binptr
+
b
))
{
size
[
b
]
=
thread_idx
+
1
;
length
[
b
-
1
]
=
deg1
;
}
}
index
[
idx
]
=
thread_idx
;
}
if
(
thread_idx
+
1
==
numel
-
1
)
{
size
[
num_bins
]
=
numel
;
length
[
num_bins
-
1
]
=
deg2
;
}
}
}
std
::
vector
<
torch
::
Tensor
>
bin_assignment_cuda
(
torch
::
Tensor
rowcount
,
torch
::
Tensor
bin
_strategy
)
{
std
::
tuple
<
std
::
vector
<
torch
::
Tensor
>
,
std
::
vector
<
int64_t
>>
bin_assignment_cuda
(
torch
::
Tensor
rowcount
,
torch
::
Tensor
bin
ptr
)
{
CHECK_CUDA
(
rowcount
);
CHECK_CUDA
(
bin
_strategy
);
CHECK_CUDA
(
bin
ptr
);
CHECK_INPUT
(
rowcount
.
dim
()
==
1
);
CHECK_INPUT
(
bin_strategy
.
dim
()
==
2
&&
bin_strategy
.
size
(
1
)
==
2
);
CHECK_INPUT
(
binptr
.
dim
()
==
1
);
cudaSetDevice
(
rowcount
.
get_device
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
int64_t
mpc
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
;
int64_t
num_bins
=
bin_strategy
.
size
(
0
);
auto
bin
=
torch
::
empty
({
rowcount
.
numel
()},
rowcount
.
options
());
auto
one_hot
=
torch
::
zeros
({
num_bins
,
rowcount
.
numel
()},
rowcount
.
options
());
torch
::
Tensor
sorted_rowcount
,
perm
;
std
::
tie
(
sorted_rowcount
,
perm
)
=
rowcount
.
sort
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
bin_kernel
<<<
BLOCKS
(
rowcount
.
numel
()),
THREADS
,
0
,
stream
>>>
(
rowcount
.
data_ptr
<
int64_t
>
(),
bin_strategy
.
data_ptr
<
int64_t
>
(),
bin
.
data_ptr
<
int64_t
>
(),
one_hot
.
data_ptr
<
int64_t
>
(),
num_bins
,
rowcount
.
numel
());
auto
size
=
torch
::
zeros
({
binptr
.
numel
()},
binptr
.
options
());
auto
length
=
torch
::
zeros
({
binptr
.
numel
()
-
1
},
binptr
.
options
());
auto
cumsum
=
one_hot
.
cumsum
(
1
);
auto
d_nodes_per_bin
=
cumsum
.
select
(
1
,
rowcount
.
numel
()
-
1
).
contiguous
();
auto
h_nodes_per_bin
=
d_nodes_per_bin
.
cpu
();
sizes_kernel
<<<
std
::
min
(
BLOCKS
(
rowcount
.
numel
()
-
1
),
mpc
*
8
),
THREADS
,
0
,
stream
>>>
(
sorted_rowcount
.
data_ptr
<
int64_t
>
(),
binptr
.
data_ptr
<
int64_t
>
(),
size
.
data_ptr
<
int64_t
>
(),
length
.
data_ptr
<
int64_t
>
(),
length
.
numel
(),
rowcount
.
numel
());
auto
h_size
=
h_nodes_per_bin
.
sum
().
data_ptr
<
int64_t
>
()[
0
];
auto
index
=
torch
::
empty
({
h_size
},
rowcount
.
options
());
size
=
size
.
cpu
();
size
=
size
.
narrow
(
0
,
1
,
length
.
numel
())
-
size
.
narrow
(
0
,
0
,
length
.
numel
());
auto
sizes
=
at
::
IntArrayRef
(
size
.
data_ptr
<
int64_t
>
(),
size
.
numel
());
index_kernel
<<<
BLOCKS
(
bin
.
numel
()),
THREADS
,
0
,
stream
>>>
(
bin
.
data_ptr
<
int64_t
>
(),
cumsum
.
data_ptr
<
int64_t
>
()
,
d_nodes_per_bin
.
data_ptr
<
int64_t
>
(),
index
.
data_ptr
<
int64_t
>
(),
num_bins
,
rowcount
.
numel
());
length
=
length
.
cpu
();
int64_t
*
length_data
=
length
.
data_ptr
<
int64_t
>
()
;
std
::
vector
<
int64_t
>
lengths
(
length
.
numel
());
std
::
copy
(
length_data
,
length_data
+
length
.
numel
(),
lengths
.
begin
());
auto
sizes
=
at
::
IntArrayRef
(
h_nodes_per_bin
.
data_ptr
<
int64_t
>
(),
num_bins
);
return
index
.
split_with_sizes
(
sizes
);
return
std
::
make_tuple
(
perm
.
split_with_sizes
(
sizes
),
lengths
);
}
__global__
void
padded_mask_select_kernel
(
const
int64_t
*
rowptr
,
const
int64_t
*
col
,
const
int64_t
*
index
,
int64_t
*
out_idx
,
bool
*
mask
,
int64_t
length
,
int64_t
numel
)
{
__global__
void
padded_mask_select_kernel
(
const
int64_t
*
__restrict__
rowptr
,
const
int64_t
*
__restrict__
col
,
const
int64_t
*
__restrict__
index
,
int64_t
*
__restrict__
out_idx
,
bool
*
__restrict__
mask
,
const
int64_t
length
,
const
int64_t
numel
)
{
int64_t
lane_idx
,
row_idx
,
row_start
,
row_end
,
col_idx
;
for
(
int64_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
...
...
@@ -104,10 +92,11 @@ __global__ void padded_mask_select_kernel(const int64_t *rowptr,
}
template
<
typename
scalar_t
>
__global__
void
padded_index_select_kernel
(
const
scalar_t
*
src
,
const
int64_t
*
index
,
scalar_t
*
out
,
scalar_t
fill_value
,
int64_t
dim
,
int64_t
numel
)
{
__global__
void
padded_index_select_kernel
(
const
scalar_t
*
__restrict__
src
,
const
int64_t
*
__restrict__
index
,
scalar_t
*
__restrict__
out
,
scalar_t
fill_value
,
const
int64_t
dim
,
const
int64_t
numel
)
{
int64_t
index_idx
,
dim_idx
,
col
;
for
(
int64_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
...
...
@@ -136,22 +125,22 @@ padded_index_select_cuda(torch::Tensor src, torch::Tensor rowptr,
CHECK_INPUT
(
col
.
dim
()
==
1
);
CHECK_INPUT
(
index
.
dim
()
==
1
);
CHECK_INPUT
(
fill_value
.
numel
()
==
1
);
cudaSetDevice
(
src
.
get_device
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
int64_t
mpc
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
;
auto
out_idx
=
torch
::
empty
({
index
.
size
(
0
),
length
},
index
.
options
());
auto
out
=
torch
::
empty
({
index
.
size
(
0
),
length
,
src
.
size
(
-
1
)},
src
.
options
());
auto
mask
=
torch
::
empty
({
index
.
size
(
0
),
length
,
1
},
src
.
options
().
dtype
(
torch
::
kBool
));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
int64_t
mpc
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
;
padded_mask_select_kernel
<<<
std
::
min
((
out_idx
.
numel
()
+
THREADS
-
1
)
/
THREADS
,
mpc
*
8
),
THREADS
,
0
,
stream
>>>
(
rowptr
.
data_ptr
<
int64_t
>
(),
col
.
data_ptr
<
int64_t
>
(),
index
.
data_ptr
<
int64_t
>
(),
out_idx
.
data_ptr
<
int64_t
>
(),
mask
.
data_ptr
<
bool
>
(),
length
,
out_idx
.
numel
());
auto
out
=
torch
::
empty
({
index
.
size
(
0
),
length
,
src
.
size
(
-
1
)},
src
.
options
());
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"padded_index_select_kernel"
,
[
&
]
{
scalar_t
*
fill
;
if
(
fill_value
.
is_cuda
())
{
...
...
csrc/cuda/degree_padding_cuda.h
View file @
efbbce74
...
...
@@ -2,9 +2,15 @@
#include <torch/extension.h>
std
::
vector
<
torch
::
Tensor
>
bin_assignment_cuda
(
torch
::
Tensor
rowcount
,
torch
::
Tensor
bin_strategy
);
std
::
tuple
<
std
::
vector
<
torch
::
Tensor
>
,
std
::
vector
<
int64_t
>>
bin_assignment_cuda
(
torch
::
Tensor
rowcount
,
torch
::
Tensor
binptr
);
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
padded_index_select_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
index
,
int64_t
length
,
torch
::
Tensor
fill_value
);
// std::tuple<torch::Tensor, torch::Tensor> padded_index_select_cuda2(
// torch::Tensor src, torch::Tensor rowptr, torch::Tensor col,
// torch::Tensor bin, torch::Tensor index, std::vector<int64_t> node_counts,
// std::vector<int64_t> lengths, torch::Tensor fill_value);
csrc/degree_padding.cpp
View file @
efbbce74
...
...
@@ -9,11 +9,11 @@
PyMODINIT_FUNC
PyInit__degree_padding
(
void
)
{
return
NULL
;
}
#endif
std
::
vector
<
torch
::
Tensor
>
bin_assignment
(
torch
::
Tensor
rowcount
,
torch
::
Tensor
bin
_strategy
)
{
std
::
tuple
<
std
::
vector
<
torch
::
Tensor
>
,
std
::
vector
<
int64_t
>>
bin_assignment
(
torch
::
Tensor
rowcount
,
torch
::
Tensor
bin
ptr
)
{
if
(
rowcount
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
return
bin_assignment_cuda
(
rowcount
,
bin
_strategy
);
return
bin_assignment_cuda
(
rowcount
,
bin
ptr
);
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
...
...
@@ -38,7 +38,28 @@ padded_index_select(torch::Tensor src, torch::Tensor rowptr, torch::Tensor col,
}
}
// std::tuple<torch::Tensor, torch::Tensor>
// padded_index_select2(torch::Tensor src, torch::Tensor rowptr, torch::Tensor
// col,
// torch::Tensor bin, torch::Tensor index,
// std::vector<int64_t> node_counts,
// std::vector<int64_t> lengths, torch::Tensor fill_value)
// {
// if (src.device().is_cuda()) {
// #ifdef WITH_CUDA
// return padded_index_select_cuda2(src, rowptr, col, bin, index,
// node_counts,
// lengths, fill_value);
// #else
// AT_ERROR("Not compiled with CUDA support");
// #endif
// } else {
// AT_ERROR("Not implemented yet");
// }
// }
static
auto
registry
=
torch
::
RegisterOperators
()
.
op
(
"torch_sparse::bin_assignment"
,
&
bin_assignment
)
.
op
(
"torch_sparse::padded_index_select"
,
&
padded_index_select
);
// .op("torch_sparse::padded_index_select2", &padded_index_select2);
test/test_degree_padding2.py
View file @
efbbce74
...
...
@@ -9,34 +9,71 @@ devices = [torch.device('cuda')]
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_padded_index_select
(
device
):
start
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
dataset
=
Planetoid
(
'/tmp/Planetoid'
,
name
=
'PubMed'
)
data
=
dataset
[
0
]
row
,
col
=
data
.
edge_index
.
to
(
device
)
row
=
torch
.
arange
(
data
.
num_nodes
).
view
(
-
1
,
1
).
repeat
(
1
,
4
).
view
(
-
1
)
col
=
torch
.
randint
(
0
,
data
.
num_nodes
,
(
row
.
size
(
0
),
))
row
,
col
=
row
.
to
(
device
),
col
.
to
(
device
)
adj
=
SparseTensor
(
row
=
row
,
col
=
col
)
rowcount
=
adj
.
storage
.
rowcount
().
to
(
device
)
rowptr
=
adj
.
storage
.
rowptr
().
to
(
device
)
bin_strategy
=
torch
.
tensor
([[
1
,
4
]]).
to
(
device
)
# bin_strategy = torch.tensor([[1, 5], [6, 12], [13, 19], [20, 30]],
# device=device)
perms
=
torch
.
ops
.
torch_sparse
.
bin_assignment
(
rowcount
,
bin_strategy
)
lengths
=
bin_strategy
[:,
1
].
view
(
-
1
).
tolist
()
print
(
lengths
)
bin_strategy
=
torch
.
tensor
([[
1
,
4
],
[
4
,
11
],
[
11
,
30
]]).
to
(
device
)
binptr
=
torch
.
tensor
([
0
,
4
,
11
,
30
,
50
,
80
,
120
,
140
,
2000
]).
to
(
device
)
deg
=
degree
(
row
,
dtype
=
torch
.
long
)
print
(
deg
.
size
(),
deg
.
min
(),
deg
.
float
().
mean
(),
deg
.
max
())
bins
=
torch
.
bincount
(
deg
)
print
(
bins
)
nonzero
=
bins
.
nonzero
().
flatten
()
print
(
nonzero
)
print
(
bins
[
nonzero
])
print
(
bins
.
size
())
print
(
bins
[:
200
])
for
i
in
range
(
110
):
if
i
==
10
:
start
.
record
()
perms
,
lengths
=
torch
.
ops
.
torch_sparse
.
bin_assignment
(
rowcount
,
binptr
)
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
start
.
elapsed_time
(
end
))
return
start
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
for
i
in
range
(
110
):
if
i
==
10
:
start
.
record
()
rowcount
.
sort
()
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
start
.
elapsed_time
(
end
))
x
=
torch
.
randn
(
data
.
num_nodes
,
128
).
to
(
device
)
for
i
in
range
(
110
):
if
i
==
10
:
start
.
record
()
x
.
index_select
(
0
,
col
)
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
start
.
elapsed_time
(
end
))
for
i
in
range
(
110
):
if
i
==
10
:
start
.
record
()
for
perm
,
length
in
zip
(
perms
,
lengths
):
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
rowptr
,
col
,
perm
,
length
,
torch
.
tensor
(
0.
))
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
start
.
elapsed_time
(
end
))
for
perm
,
length
in
zip
(
perms
,
lengths
):
out
,
mask
=
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
rowptr
,
col
,
perm
,
length
,
torch
.
tensor
(
0.
))
print
(
out
.
size
(),
mask
.
size
(),
out
.
numel
(),
(
out
!=
0
).
sum
().
item
())
return
lengths
=
bin_strategy
[:,
1
].
view
(
-
1
).
tolist
()
for
dim
in
[
32
,
64
,
128
,
256
,
512
,
1024
]:
print
(
f
'--- Dim:
{
dim
}
---'
)
...
...
@@ -45,6 +82,10 @@ def test_padded_index_select(device):
for
i
in
range
(
110
):
if
i
==
10
:
start
.
record
()
perms
=
torch
.
ops
.
torch_sparse
.
bin_assignment
(
rowcount
,
bin_strategy
)
print
(
perms
)
return
for
perm
,
length
in
zip
(
perms
,
lengths
):
out1
,
_
=
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
rowptr
,
col
,
perm
,
length
,
torch
.
tensor
(
0.
))
...
...
@@ -67,4 +108,3 @@ def test_padded_index_select(device):
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
start
.
elapsed_time
(
end
))
print
(
torch
.
allclose
(
out1
.
view
(
-
1
,
dim
),
out3
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment