Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
cb0e5f63
Commit
cb0e5f63
authored
May 01, 2018
by
rusty1s
Browse files
cleaner
parent
b992389e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
10 deletions
+5
-10
aten/cuda/cluster_kernel.cu
aten/cuda/cluster_kernel.cu
+5
-10
No files found.
aten/cuda/cluster_kernel.cu
View file @
cb0e5f63
...
...
@@ -25,21 +25,16 @@ __global__ void grid_cuda_kernel(
at
::
Tensor
grid_cuda
(
at
::
Tensor
pos
,
at
::
Tensor
size
,
at
::
Tensor
start
,
at
::
Tensor
end
)
{
size
=
size
.
toType
(
pos
.
type
());
start
=
start
.
toType
(
pos
.
type
());
end
=
end
.
toType
(
pos
.
type
());
const
auto
num_nodes
=
pos
.
size
(
0
);
auto
cluster
=
at
::
empty
(
pos
.
type
().
toScalarType
(
at
::
kLong
),
{
num_nodes
});
AT_DISPATCH_ALL_TYPES
(
pos
.
type
(),
"grid_cuda_kernel"
,
[
&
]
{
auto
cluster_data
=
cluster
.
data
<
int64_t
>
();
auto
pos_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int
>
(
pos
);
auto
size_data
=
size
.
data
<
scalar_t
>
();
auto
start_data
=
start
.
data
<
scalar_t
>
();
auto
end_data
=
end
.
data
<
scalar_t
>
();
grid_cuda_kernel
<
scalar_t
><<<
BLOCKS
(
num_nodes
),
THREADS
>>>
(
cluster_data
,
pos_info
,
size_data
,
start_data
,
end_data
,
num_nodes
);
cluster
.
data
<
int64_t
>
(),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int
>
(
pos
),
size
.
toType
(
pos
.
type
()).
data
<
scalar_t
>
(),
start
..
toType
(
pos
.
type
()).
data
<
scalar_t
>
(),
end
.
toType
(
pos
.
type
()).
data
<
scalar_t
>
(),
num_nodes
);
});
return
cluster
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment