Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
cca0044c
"model/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "9ed8bf14cb885509281d63731cda16637a7e0bd2"
Commit
cca0044c
authored
Dec 27, 2019
by
rusty1s
Browse files
added tbs
parent
58d0025d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
7 deletions
+9
-7
cuda/segment_kernel.cu
cuda/segment_kernel.cu
+8
-6
test/test_segment.py
test/test_segment.py
+1
-1
No files found.
cuda/segment_kernel.cu
View file @
cca0044c
...
@@ -74,20 +74,20 @@ at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) {
...
@@ -74,20 +74,20 @@ at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) {
return
out
;
return
out
;
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
,
int
TB
>
__global__
void
segment_add_coo_kernel
(
const
scalar_t
*
src_data
,
__global__
void
segment_add_coo_kernel
(
const
scalar_t
*
src_data
,
const
int64_t
*
index_data
,
const
int64_t
*
index_data
,
scalar_t
*
out_data
,
size_t
numel
)
{
scalar_t
*
out_data
,
size_t
numel
)
{
int
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
lane_idx
=
thread_idx
&
(
32
-
1
);
int
lane_idx
=
thread_idx
&
(
TB
-
1
);
if
(
thread_idx
<
numel
)
{
if
(
thread_idx
<
numel
)
{
auto
idx
=
__ldg
(
index_data
+
thread_idx
);
auto
idx
=
__ldg
(
index_data
+
thread_idx
);
scalar_t
val
=
src_data
[
thread_idx
],
tmp
;
scalar_t
val
=
src_data
[
thread_idx
],
tmp
;
#pragma unroll
#pragma unroll
for
(
int
offset
=
1
;
offset
<
32
;
offset
*=
2
)
{
for
(
int
offset
=
1
;
offset
<
TB
;
offset
*=
2
)
{
tmp
=
__shfl_up_sync
(
FULL_MASK
,
val
,
offset
);
tmp
=
__shfl_up_sync
(
FULL_MASK
,
val
,
offset
);
if
(
lane_idx
>=
offset
&&
if
(
lane_idx
>=
offset
&&
idx
==
__ldg
(
index_data
+
thread_idx
-
offset
))
{
idx
==
__ldg
(
index_data
+
thread_idx
-
offset
))
{
...
@@ -95,7 +95,7 @@ __global__ void segment_add_coo_kernel(const scalar_t *src_data,
...
@@ -95,7 +95,7 @@ __global__ void segment_add_coo_kernel(const scalar_t *src_data,
}
}
}
}
if
(
lane_idx
==
3
1
||
idx
!=
__ldg
(
index_data
+
thread_idx
+
1
))
{
if
(
lane_idx
==
TB
-
1
||
idx
!=
__ldg
(
index_data
+
thread_idx
+
1
))
{
atomAdd
(
out_data
+
idx
,
val
);
atomAdd
(
out_data
+
idx
,
val
);
}
}
}
}
...
@@ -103,6 +103,7 @@ __global__ void segment_add_coo_kernel(const scalar_t *src_data,
...
@@ -103,6 +103,7 @@ __global__ void segment_add_coo_kernel(const scalar_t *src_data,
void
segment_add_coo_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
)
{
void
segment_add_coo_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
)
{
auto
numel
=
src
.
numel
();
auto
numel
=
src
.
numel
();
auto
avg_length
=
(
float
)
numel
/
(
float
)
out
.
numel
();
auto
index_data
=
index
.
DATA_PTR
<
int64_t
>
();
auto
index_data
=
index
.
DATA_PTR
<
int64_t
>
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
...
@@ -110,8 +111,9 @@ void segment_add_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out) {
...
@@ -110,8 +111,9 @@ void segment_add_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out) {
auto
src_data
=
src
.
DATA_PTR
<
scalar_t
>
();
auto
src_data
=
src
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
segment_add_coo_kernel
<
scalar_t
><<<
BLOCKS
(
1
,
numel
),
THREADS
,
0
,
stream
>>>
(
segment_add_coo_kernel
<
scalar_t
,
32
>
src_data
,
index_data
,
out_data
,
numel
);
<<<
BLOCKS
(
1
,
numel
),
THREADS
,
0
,
stream
>>>
(
src_data
,
index_data
,
out_data
,
numel
);
});
});
}
}
...
...
test/test_segment.py
View file @
cca0044c
...
@@ -37,7 +37,7 @@ def test_forward2(dtype, device):
...
@@ -37,7 +37,7 @@ def test_forward2(dtype, device):
def
test_benchmark
(
dtype
,
device
):
def
test_benchmark
(
dtype
,
device
):
from
torch_geometric.datasets
import
Planetoid
,
Reddit
# noqa
from
torch_geometric.datasets
import
Planetoid
,
Reddit
# noqa
data
=
Planetoid
(
'/tmp/Cora'
,
'Cora'
)[
0
].
to
(
device
)
data
=
Planetoid
(
'/tmp/Cora'
,
'Cora'
)[
0
].
to
(
device
)
#
data = Planetoid('/tmp/PubMed', 'PubMed')[0].to(device)
data
=
Planetoid
(
'/tmp/PubMed'
,
'PubMed'
)[
0
].
to
(
device
)
# data = Reddit('/tmp/Reddit')[0].to(device)
# data = Reddit('/tmp/Reddit')[0].to(device)
row
,
col
=
data
.
edge_index
row
,
col
=
data
.
edge_index
x
=
torch
.
randn
(
data
.
num_edges
,
device
=
device
)
x
=
torch
.
randn
(
data
.
num_edges
,
device
=
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment