Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
5f98eee4
Commit
5f98eee4
authored
May 01, 2018
by
rusty1s
Browse files
cuda grid
parent
6b18f2d1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
76 additions
and
0 deletions
+76
-0
aten/cuda/cluster.cpp
aten/cuda/cluster.cpp
+20
-0
aten/cuda/cluster_kernel.cu
aten/cuda/cluster_kernel.cu
+46
-0
aten/cuda/setup.py
aten/cuda/setup.py
+10
-0
No files found.
aten/cuda/cluster.cpp
0 → 100644
View file @
5f98eee4
#include <torch/torch.h>
at
::
Tensor
grid_cuda
(
at
::
Tensor
pos
,
at
::
Tensor
size
,
at
::
Tensor
start
,
at
::
Tensor
end
);
#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDA tensor")
at
::
Tensor
grid
(
at
::
Tensor
pos
,
at
::
Tensor
size
,
at
::
Tensor
start
,
at
::
Tensor
end
)
{
CHECK_CUDA
(
pos
);
CHECK_CUDA
(
size
);
CHECK_CUDA
(
start
);
CHECK_CUDA
(
end
);
return
grid_cuda
(
pos
,
size
,
start
,
end
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"grid"
,
&
grid
,
"Grid (CUDA)"
);
}
aten/cuda/cluster_kernel.cu
0 → 100644
View file @
5f98eee4
#include <ATen/ATen.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
template
<
typename
scalar_t
>
__global__
void
grid_cuda_kernel
(
int64_t
*
cluster
,
const
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int
>
pos
,
const
scalar_t
*
__restrict__
size
,
const
scalar_t
*
__restrict__
start
,
const
scalar_t
*
__restrict__
end
,
const
size_t
n
)
{
const
size_t
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
ptrdiff_t
i
=
index
;
i
<
n
;
i
+=
stride
)
{
int64_t
c
=
0
,
k
=
1
;
scalar_t
tmp
;
for
(
ptrdiff_t
d
=
0
;
d
<
pos
.
sizes
[
1
];
d
++
)
{
tmp
=
(
pos
.
data
[
i
*
pos
.
strides
[
0
]
+
d
*
pos
.
strides
[
1
]])
-
start
[
d
];
c
+=
(
int64_t
)(
tmp
/
size
[
d
])
*
k
;
k
+=
(
int64_t
)((
end
[
d
]
-
start
[
d
])
/
size
[
d
]);
}
cluster
[
i
]
=
c
;
}
}
at
::
Tensor
grid_cuda
(
at
::
Tensor
pos
,
at
::
Tensor
size
,
at
::
Tensor
start
,
at
::
Tensor
end
)
{
size
=
size
.
toType
(
pos
.
type
());
start
=
start
.
toType
(
pos
.
type
());
end
=
end
.
toType
(
pos
.
type
());
const
auto
num_nodes
=
pos
.
size
(
0
);
auto
cluster
=
at
::
empty
(
pos
.
type
().
toScalarType
(
at
::
kLong
),
{
num_nodes
});
const
int
threads
=
1024
;
const
dim3
blocks
((
num_nodes
+
threads
-
1
)
/
threads
);
AT_DISPATCH_ALL_TYPES
(
pos
.
type
(),
"unique"
,
[
&
]
{
auto
cluster_data
=
cluster
.
data
<
int64_t
>
();
auto
pos_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int
>
(
pos
);
auto
size_data
=
size
.
data
<
scalar_t
>
();
auto
start_data
=
start
.
data
<
scalar_t
>
();
auto
end_data
=
end
.
data
<
scalar_t
>
();
grid_cuda_kernel
<
scalar_t
><<<
blocks
,
threads
>>>
(
cluster_data
,
pos_info
,
size_data
,
start_data
,
end_data
,
num_nodes
);
});
return
cluster
;
}
aten/cuda/setup.py
0 → 100644
View file @
5f98eee4
from
setuptools
import
setup
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
setup
(
name
=
'cluster_cuda'
,
ext_modules
=
[
CUDAExtension
(
'cluster_cuda'
,
[
'cluster.cpp'
,
'cluster_kernel.cu'
])
],
cmdclass
=
{
'build_ext'
:
BuildExtension
},
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment