Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
0b3069fe
"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "9ebb75dad2bd0f1d1633b7af50b9cd03db379987"
Commit
0b3069fe
authored
Dec 30, 2019
by
rusty1s
Browse files
shared memory version
parent
4d7b32c5
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
2 deletions
+44
-2
cuda/segment_kernel.cu
cuda/segment_kernel.cu
+42
-0
test/test_segment.py
test/test_segment.py
+2
-2
No files found.
cuda/segment_kernel.cu
View file @
0b3069fe
...
@@ -89,6 +89,48 @@ __global__ void segment_add_csr_broadcast_kernel(
...
@@ -89,6 +89,48 @@ __global__ void segment_add_csr_broadcast_kernel(
}
}
}
}
template
<
typename
scalar_t
,
int
TB
>
__global__
void
segment_add_csr_broadcast_kernel2
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
indptr_info
,
scalar_t
*
out_data
,
size_t
N
,
size_t
K
,
size_t
E
)
{
int
thread_idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
row_idx
=
thread_idx
/
TB
;
int
lane_idx
=
thread_idx
&
(
TB
-
1
);
int
col_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
__shared__
scalar_t
vals
[
32
][
32
];
if
(
row_idx
<
N
)
{
auto
offset
=
IndexPtrToOffset
<
int64_t
,
int
>::
get
(
row_idx
,
indptr_info
);
int
row_start
=
__ldg
(
indptr_info
.
data
+
offset
);
int
row_end
=
__ldg
(
indptr_info
.
data
+
offset
+
indptr_info
.
strides
[
indptr_info
.
dims
-
1
]);
scalar_t
val
=
(
scalar_t
)
0
;
offset
=
(
row_idx
/
(
indptr_info
.
sizes
[
indptr_info
.
dims
-
1
]
-
1
))
*
E
*
K
;
if
(
col_idx
<
K
)
{
for
(
int
i
=
row_start
+
lane_idx
;
i
<
row_end
;
i
+=
TB
)
{
val
+=
src_data
[
offset
+
K
*
i
+
col_idx
];
}
}
vals
[
threadIdx
.
x
][
threadIdx
.
y
]
=
val
;
__syncthreads
();
#pragma unroll
for
(
int
i
=
1
;
i
<
TB
;
i
*=
2
)
{
vals
[
threadIdx
.
x
][
threadIdx
.
y
]
+=
vals
[
threadIdx
.
x
][
threadIdx
.
y
+
i
];
__syncthreads
();
}
if
(
col_idx
<
K
&&
lane_idx
==
0
)
{
out_data
[
row_idx
*
K
+
col_idx
]
=
vals
[
threadIdx
.
x
][
threadIdx
.
y
];
}
}
}
at
::
Tensor
segment_add_csr_cuda
(
at
::
Tensor
src
,
at
::
Tensor
indptr
)
{
at
::
Tensor
segment_add_csr_cuda
(
at
::
Tensor
src
,
at
::
Tensor
indptr
)
{
AT_ASSERTM
(
src
.
dim
()
>=
indptr
.
dim
());
AT_ASSERTM
(
src
.
dim
()
>=
indptr
.
dim
());
for
(
int
i
=
0
;
i
<
indptr
.
dim
()
-
1
;
i
++
)
for
(
int
i
=
0
;
i
<
indptr
.
dim
()
-
1
;
i
++
)
...
...
test/test_segment.py
View file @
0b3069fe
...
@@ -28,7 +28,7 @@ def test_forward2(dtype, device):
...
@@ -28,7 +28,7 @@ def test_forward2(dtype, device):
# indptr = indptr.view(1, -1).expand(2, -1).t().contiguous().t()
# indptr = indptr.view(1, -1).expand(2, -1).t().contiguous().t()
out
=
segment_add_csr
(
src
,
indptr
)
out
=
segment_add_csr
(
src
,
indptr
)
#
print('CSR', out)
print
(
'CSR'
,
out
)
# index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
# index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
# out = segment_add_coo(src, index)
# out = segment_add_coo(src, index)
...
@@ -95,7 +95,7 @@ def test_benchmark(dtype, device):
...
@@ -95,7 +95,7 @@ def test_benchmark(dtype, device):
assert
torch
.
allclose
(
out1
,
out4
,
atol
=
1e-2
)
assert
torch
.
allclose
(
out1
,
out4
,
atol
=
1e-2
)
x
=
torch
.
randn
((
data
.
num_edges
,
1024
),
device
=
device
)
x
=
torch
.
randn
((
data
.
num_edges
,
32
),
device
=
device
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
t
=
time
.
perf_counter
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment