Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
e1125f13
Commit
e1125f13
authored
May 01, 2018
by
rusty1s
Browse files
typos
parent
6d8410a1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
5 deletions
+5
-5
aten/cuda/cluster_kernel.cu
aten/cuda/cluster_kernel.cu
+5
-5
No files found.
aten/cuda/cluster_kernel.cu
View file @
e1125f13
...
@@ -6,9 +6,9 @@
...
@@ -6,9 +6,9 @@
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
grid_cuda_kernel
(
__global__
void
grid_cuda_kernel
(
int64_t
*
cluster
,
const
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int
>
pos
,
int64_t
*
cluster
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int
>
pos
,
const
scalar_t
*
__restrict__
size
,
const
scalar_t
*
__restrict__
start
,
scalar_t
*
__restrict__
size
,
scalar_t
*
__restrict__
start
,
const
scalar_t
*
__restrict__
end
,
const
size_t
num_nodes
)
{
scalar_t
*
__restrict__
end
,
size_t
num_nodes
)
{
const
size_t
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
ptrdiff_t
i
=
index
;
i
<
num_nodes
;
i
+=
stride
)
{
for
(
ptrdiff_t
i
=
index
;
i
<
num_nodes
;
i
+=
stride
)
{
...
@@ -25,8 +25,8 @@ __global__ void grid_cuda_kernel(
...
@@ -25,8 +25,8 @@ __global__ void grid_cuda_kernel(
at
::
Tensor
grid_cuda
(
at
::
Tensor
pos
,
at
::
Tensor
size
,
at
::
Tensor
start
,
at
::
Tensor
grid_cuda
(
at
::
Tensor
pos
,
at
::
Tensor
size
,
at
::
Tensor
start
,
at
::
Tensor
end
)
{
at
::
Tensor
end
)
{
const
auto
num_nodes
=
pos
.
size
(
0
);
auto
num_nodes
=
pos
.
size
(
0
);
auto
cluster
=
at
::
empty
(
pos
.
type
().
to
Scalar
Type
(
at
::
kLong
),
{
num_nodes
});
auto
cluster
=
at
::
empty
(
pos
.
type
().
toType
(
at
::
kLong
),
{
num_nodes
});
AT_DISPATCH_ALL_TYPES
(
pos
.
type
(),
"grid_cuda_kernel"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES
(
pos
.
type
(),
"grid_cuda_kernel"
,
[
&
]
{
grid_cuda_kernel
<
scalar_t
><<<
BLOCKS
(
num_nodes
),
THREADS
>>>
(
grid_cuda_kernel
<
scalar_t
><<<
BLOCKS
(
num_nodes
),
THREADS
>>>
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment