Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
2bea1c3c
Commit
2bea1c3c
authored
Mar 31, 2020
by
rusty1s
Browse files
degree padding super fast
parent
3639bfab
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
161 additions
and
40 deletions
+161
-40
csrc/cuda/degree_padding_cuda.cu
csrc/cuda/degree_padding_cuda.cu
+67
-33
csrc/cuda/degree_padding_cuda.h
csrc/cuda/degree_padding_cuda.h
+2
-2
csrc/degree_padding.cpp
csrc/degree_padding.cpp
+4
-2
test/test_degree_padding.py
test/test_degree_padding.py
+18
-3
test/test_degree_padding2.py
test/test_degree_padding2.py
+70
-0
No files found.
csrc/cuda/degree_padding_cuda.cu
View file @
2bea1c3c
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#include "utils.cuh"
#include "utils.cuh"
#define THREADS
256
#define THREADS
1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
#define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__
void
bin_kernel
(
const
int64_t
*
rowcount
,
const
int64_t
*
bin_strategy
,
__global__
void
bin_kernel
(
const
int64_t
*
rowcount
,
const
int64_t
*
bin_strategy
,
...
@@ -80,37 +80,53 @@ std::vector<torch::Tensor> bin_assignment_cuda(torch::Tensor rowcount,
...
@@ -80,37 +80,53 @@ std::vector<torch::Tensor> bin_assignment_cuda(torch::Tensor rowcount,
return
index
.
split_with_sizes
(
sizes
);
return
index
.
split_with_sizes
(
sizes
);
}
}
template
<
typename
scalar_t
,
int64_t
TB
>
__global__
void
padded_mask_select_kernel
(
const
int64_t
*
rowptr
,
__global__
void
const
int64_t
*
col
,
padded_index_select_kernel
(
const
scalar_t
*
src
,
const
int64_t
*
rowptr
,
const
int64_t
*
index
,
const
int64_t
*
col
,
const
int64_t
*
index
,
int64_t
*
out_idx
,
bool
*
mask
,
scalar_t
*
out
,
bool
*
mask
,
int64_t
length
,
int64_t
length
,
int64_t
numel
)
{
int64_t
dim
,
int64_t
numel
)
{
int64_t
lane_idx
,
row_idx
,
row_start
,
row_end
,
col_idx
;
for
(
int64_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
thread_idx
<
numel
;
thread_idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
lane_idx
=
thread_idx
%
length
;
row_idx
=
index
[
thread_idx
/
length
];
row_start
=
rowptr
[
row_idx
];
row_end
=
rowptr
[
row_idx
+
1
];
col_idx
=
-
1
;
if
(
lane_idx
<
row_end
-
row_start
)
{
col_idx
=
col
[
row_start
+
lane_idx
];
}
int64_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
out_idx
[
thread_idx
]
=
col_id
x
;
auto
dim
_idx
=
thread_idx
%
dim
;
mask
[
thread
_idx
]
=
col_idx
==
-
1
;
auto
lane_idx
=
(
thread_idx
/
dim
)
%
TB
;
}
auto
index_idx
=
thread_idx
/
(
TB
*
dim
);
}
if
(
thread_idx
<
numel
)
{
template
<
typename
scalar_t
>
auto
row_idx
=
__ldg
(
index
+
index_idx
);
__global__
void
padded_index_select_kernel
(
const
scalar_t
*
src
,
auto
row_start
=
__ldg
(
rowptr
+
row_idx
);
const
int64_t
*
index
,
scalar_t
*
out
,
auto
row_end
=
__ldg
(
rowptr
+
row_idx
+
1
);
scalar_t
fill_value
,
int64_t
dim
,
int64_t
numel
)
{
for
(
int64_t
c
=
lane_idx
;
c
<
row_end
-
row_start
;
c
+=
TB
)
{
auto
x
=
src
[
__ldg
(
col
+
row_start
+
c
)
*
dim
+
dim_idx
];
int64_t
index_idx
,
dim_idx
,
col
;
out
[
index_idx
*
dim
*
length
+
c
*
dim
+
dim_idx
]
=
x
;
for
(
int64_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
// mask[index_idx * dim * length + c * dim + dim_idx] = true;
thread_idx
<
numel
;
thread_idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
index_idx
=
thread_idx
/
dim
;
dim_idx
=
thread_idx
%
dim
;
col
=
__ldg
(
index
+
index_idx
);
if
(
col
>=
0
)
{
fill_value
=
src
[
col
*
dim
+
dim_idx
];
}
}
out
[
thread_idx
]
=
fill_value
;
}
}
}
}
#define TB 4
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
padded_index_select_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
rowptr
,
padded_index_select_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
index
,
torch
::
Tensor
col
,
torch
::
Tensor
index
,
int64_t
length
,
int64_t
length
)
{
torch
::
Tensor
fill_value
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
src
);
CHECK_CUDA
(
rowptr
);
CHECK_CUDA
(
rowptr
);
CHECK_CUDA
(
col
);
CHECK_CUDA
(
col
);
...
@@ -119,20 +135,38 @@ padded_index_select_cuda(torch::Tensor src, torch::Tensor rowptr,
...
@@ -119,20 +135,38 @@ padded_index_select_cuda(torch::Tensor src, torch::Tensor rowptr,
CHECK_INPUT
(
rowptr
.
dim
()
==
1
);
CHECK_INPUT
(
rowptr
.
dim
()
==
1
);
CHECK_INPUT
(
col
.
dim
()
==
1
);
CHECK_INPUT
(
col
.
dim
()
==
1
);
CHECK_INPUT
(
index
.
dim
()
==
1
);
CHECK_INPUT
(
index
.
dim
()
==
1
);
CHECK_INPUT
(
fill_value
.
numel
()
==
1
);
cudaSetDevice
(
src
.
get_device
());
cudaSetDevice
(
src
.
get_device
());
auto
out
=
torch
::
zeros
({
index
.
size
(
0
),
length
,
src
.
size
(
-
1
)},
src
.
options
());
auto
out
_idx
=
torch
::
empty
({
index
.
size
(
0
),
length
}
,
index
.
options
());
auto
mask
=
auto
mask
=
torch
::
empty
({
index
.
size
(
0
),
length
,
1
},
torch
::
zeros
({
index
.
size
(
0
),
length
},
src
.
options
().
dtype
(
torch
::
kBool
));
src
.
options
().
dtype
(
torch
::
kBool
));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
int64_t
mpc
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
;
padded_mask_select_kernel
<<<
std
::
min
((
out_idx
.
numel
()
+
THREADS
-
1
)
/
THREADS
,
mpc
*
8
),
THREADS
,
0
,
stream
>>>
(
rowptr
.
data_ptr
<
int64_t
>
(),
col
.
data_ptr
<
int64_t
>
(),
index
.
data_ptr
<
int64_t
>
(),
out_idx
.
data_ptr
<
int64_t
>
(),
mask
.
data_ptr
<
bool
>
(),
length
,
out_idx
.
numel
());
auto
out
=
torch
::
empty
({
index
.
size
(
0
),
length
,
src
.
size
(
-
1
)},
src
.
options
());
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"padded_index_select_kernel"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"padded_index_select_kernel"
,
[
&
]
{
padded_index_select_kernel
<
scalar_t
,
TB
>
scalar_t
*
fill
;
<<<
BLOCKS
(
index
.
numel
()
*
src
.
size
(
-
1
)
*
TB
),
THREADS
,
0
,
stream
>>>
(
if
(
fill_value
.
is_cuda
())
{
src
.
data_ptr
<
scalar_t
>
(),
rowptr
.
data_ptr
<
int64_t
>
(),
fill
=
(
scalar_t
*
)
malloc
(
sizeof
(
scalar_t
));
col
.
data_ptr
<
int64_t
>
(),
index
.
data_ptr
<
int64_t
>
(),
cudaMemcpy
(
fill
,
fill_value
.
data_ptr
<
scalar_t
>
(),
sizeof
(
scalar_t
),
out
.
data_ptr
<
scalar_t
>
(),
mask
.
data_ptr
<
bool
>
(),
length
,
cudaMemcpyDeviceToHost
);
src
.
size
(
-
1
),
index
.
numel
()
*
src
.
size
(
-
1
)
*
TB
);
}
else
{
fill
=
fill_value
.
data_ptr
<
scalar_t
>
();
}
padded_index_select_kernel
<
scalar_t
>
<<<
std
::
min
((
out
.
numel
()
+
THREADS
-
1
)
/
THREADS
,
mpc
*
8
),
THREADS
,
0
,
stream
>>>
(
src
.
data_ptr
<
scalar_t
>
(),
out_idx
.
data_ptr
<
int64_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
fill
[
0
],
src
.
size
(
-
1
),
out
.
numel
());
});
});
return
std
::
make_tuple
(
out
,
mask
);
return
std
::
make_tuple
(
out
,
mask
);
...
...
csrc/cuda/degree_padding_cuda.h
View file @
2bea1c3c
...
@@ -6,5 +6,5 @@ std::vector<torch::Tensor> bin_assignment_cuda(torch::Tensor rowcount,
...
@@ -6,5 +6,5 @@ std::vector<torch::Tensor> bin_assignment_cuda(torch::Tensor rowcount,
torch
::
Tensor
bin_strategy
);
torch
::
Tensor
bin_strategy
);
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
padded_index_select_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
rowptr
,
padded_index_select_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
index
,
torch
::
Tensor
col
,
torch
::
Tensor
index
,
int64_t
length
,
int64_t
length
);
torch
::
Tensor
fill_value
);
csrc/degree_padding.cpp
View file @
2bea1c3c
...
@@ -24,10 +24,12 @@ std::vector<torch::Tensor> bin_assignment(torch::Tensor rowcount,
...
@@ -24,10 +24,12 @@ std::vector<torch::Tensor> bin_assignment(torch::Tensor rowcount,
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
padded_index_select
(
torch
::
Tensor
src
,
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
padded_index_select
(
torch
::
Tensor
src
,
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
index
,
int64_t
length
)
{
torch
::
Tensor
index
,
int64_t
length
,
torch
::
Tensor
fill_value
)
{
if
(
src
.
device
().
is_cuda
())
{
if
(
src
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
#ifdef WITH_CUDA
return
padded_index_select_cuda
(
src
,
rowptr
,
col
,
index
,
length
);
return
padded_index_select_cuda
(
src
,
rowptr
,
col
,
index
,
length
,
fill_value
);
#else
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
#endif
...
...
test/test_degree_padding.py
View file @
2bea1c3c
...
@@ -68,11 +68,24 @@ def test_bin_assignment(device):
...
@@ -68,11 +68,24 @@ def test_bin_assignment(device):
x
=
torch
.
randn
(
dataset
[
0
].
num_nodes
,
512
).
to
(
device
)
x
=
torch
.
randn
(
dataset
[
0
].
num_nodes
,
512
).
to
(
device
)
rowptr
=
adj
.
storage
.
rowptr
().
to
(
device
)
rowptr
=
adj
.
storage
.
rowptr
().
to
(
device
)
col
=
col
.
to
(
device
)
col
=
col
.
to
(
device
)
for
i
in
range
(
102
):
if
i
==
2
:
start
.
record
()
for
perm
,
count
in
zip
(
perms
,
bin_count
):
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
rowptr
,
col
,
perm
,
count
,
torch
.
tensor
(
0.
))
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
start
.
elapsed_time
(
end
))
print
(
'-----------'
)
for
i
in
range
(
102
):
for
i
in
range
(
102
):
if
i
==
2
:
if
i
==
2
:
start
.
record
()
start
.
record
()
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
rowptr
,
col
,
perms
[
0
],
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
rowptr
,
col
,
perms
[
0
],
bin_count
[
0
])
bin_count
[
0
],
torch
.
tensor
(
0.
))
end
.
record
()
end
.
record
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
start
.
elapsed_time
(
end
))
print
(
start
.
elapsed_time
(
end
))
...
@@ -80,7 +93,8 @@ def test_bin_assignment(device):
...
@@ -80,7 +93,8 @@ def test_bin_assignment(device):
if
i
==
2
:
if
i
==
2
:
start
.
record
()
start
.
record
()
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
rowptr
,
col
,
perms
[
1
],
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
rowptr
,
col
,
perms
[
1
],
bin_count
[
1
])
bin_count
[
1
],
torch
.
tensor
(
0.
))
end
.
record
()
end
.
record
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
start
.
elapsed_time
(
end
))
print
(
start
.
elapsed_time
(
end
))
...
@@ -88,7 +102,8 @@ def test_bin_assignment(device):
...
@@ -88,7 +102,8 @@ def test_bin_assignment(device):
if
i
==
2
:
if
i
==
2
:
start
.
record
()
start
.
record
()
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
rowptr
,
col
,
perms
[
2
],
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
rowptr
,
col
,
perms
[
2
],
bin_count
[
2
])
bin_count
[
2
],
torch
.
tensor
(
0.
))
end
.
record
()
end
.
record
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
start
.
elapsed_time
(
end
))
print
(
start
.
elapsed_time
(
end
))
test/test_degree_padding2.py
0 → 100644
View file @
2bea1c3c
import
pytest
import
torch
from
torch_sparse
import
SparseTensor
from
torch_geometric.datasets
import
Planetoid
from
torch_geometric.utils
import
degree
devices
=
[
torch
.
device
(
'cuda'
)]
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_padded_index_select
(
device
):
dataset
=
Planetoid
(
'/tmp/Planetoid'
,
name
=
'PubMed'
)
data
=
dataset
[
0
]
row
,
col
=
data
.
edge_index
.
to
(
device
)
row
=
torch
.
arange
(
data
.
num_nodes
).
view
(
-
1
,
1
).
repeat
(
1
,
4
).
view
(
-
1
)
col
=
torch
.
randint
(
0
,
data
.
num_nodes
,
(
row
.
size
(
0
),
))
row
,
col
=
row
.
to
(
device
),
col
.
to
(
device
)
adj
=
SparseTensor
(
row
=
row
,
col
=
col
)
rowcount
=
adj
.
storage
.
rowcount
().
to
(
device
)
rowptr
=
adj
.
storage
.
rowptr
().
to
(
device
)
bin_strategy
=
torch
.
tensor
([[
1
,
4
]]).
to
(
device
)
# bin_strategy = torch.tensor([[1, 5], [6, 12], [13, 19], [20, 30]],
# device=device)
perms
=
torch
.
ops
.
torch_sparse
.
bin_assignment
(
rowcount
,
bin_strategy
)
lengths
=
bin_strategy
[:,
1
].
view
(
-
1
).
tolist
()
print
(
lengths
)
deg
=
degree
(
row
,
dtype
=
torch
.
long
)
print
(
deg
.
size
(),
deg
.
min
(),
deg
.
float
().
mean
(),
deg
.
max
())
bins
=
torch
.
bincount
(
deg
)
print
(
bins
)
nonzero
=
bins
.
nonzero
().
flatten
()
print
(
nonzero
)
print
(
bins
[
nonzero
])
start
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
for
dim
in
[
32
,
64
,
128
,
256
,
512
,
1024
]:
print
(
f
'--- Dim:
{
dim
}
---'
)
x
=
torch
.
randn
(
adj
.
size
(
0
),
dim
).
to
(
device
)
for
i
in
range
(
110
):
if
i
==
10
:
start
.
record
()
for
perm
,
length
in
zip
(
perms
,
lengths
):
out1
,
_
=
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
rowptr
,
col
,
perm
,
length
,
torch
.
tensor
(
0.
))
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
start
.
elapsed_time
(
end
))
for
i
in
range
(
110
):
if
i
==
10
:
start
.
record
()
out2
=
x
.
index_select
(
0
,
row
)
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
start
.
elapsed_time
(
end
))
for
i
in
range
(
110
):
if
i
==
10
:
start
.
record
()
out3
=
x
.
index_select
(
0
,
col
)
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
start
.
elapsed_time
(
end
))
print
(
torch
.
allclose
(
out1
.
view
(
-
1
,
dim
),
out3
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment