Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
ba9f2ed2
Commit
ba9f2ed2
authored
Mar 09, 2020
by
rusty1s
Browse files
added grid cuda implementation
parent
26f5fa37
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
88 additions
and
63 deletions
+88
-63
csrc/cuda/grid.cpp
csrc/cuda/grid.cpp
+0
-20
csrc/cuda/grid_cuda.cu
csrc/cuda/grid_cuda.cu
+73
-0
csrc/cuda/grid_cuda.h
csrc/cuda/grid_cuda.h
+7
-0
csrc/cuda/grid_kernel.cu
csrc/cuda/grid_kernel.cu
+0
-43
csrc/cuda/utils.cuh
csrc/cuda/utils.cuh
+8
-0
No files found.
csrc/cuda/grid.cpp
deleted
100644 → 0
View file @
26f5fa37
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
at
::
Tensor
grid_cuda
(
at
::
Tensor
pos
,
at
::
Tensor
size
,
at
::
Tensor
start
,
at
::
Tensor
end
);
at
::
Tensor
grid
(
at
::
Tensor
pos
,
at
::
Tensor
size
,
at
::
Tensor
start
,
at
::
Tensor
end
)
{
CHECK_CUDA
(
pos
);
CHECK_CUDA
(
size
);
CHECK_CUDA
(
start
);
CHECK_CUDA
(
end
);
return
grid_cuda
(
pos
,
size
,
start
,
end
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"grid"
,
&
grid
,
"Grid (CUDA)"
);
}
csrc/cuda/grid_cuda.cu
0 → 100644
View file @
ba9f2ed2
#include "grid_cpu.h"
#include <ATen/ATen.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include "compat.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template
<
typename
scalar_t
>
__global__
void
grid_kernel
(
const
scalar_t
*
pos
,
const
scalar_t
*
size
,
const
scalar_t
*
start
,
const
scalar_t
*
end
,
int64_t
*
out
,
int64_t
N
,
int64_t
D
,
int64_t
numel
)
{
const
size_t
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
thread_idx
<
numel
)
{
int64_t
c
=
0
,
k
=
1
;
for
(
int64_t
d
=
0
;
d
<
D
;
d
++
)
{
scalar_t
p
=
pos
.
data
[
thread_idx
*
D
+
d
]
-
start
[
d
];
c
+=
(
int64_t
)(
p
/
size
[
d
])
*
k
;
k
*=
(
int64_t
)((
end
[
d
]
-
start
[
d
])
/
size
[
d
])
+
1
;
}
out
[
thread_idx
]
=
c
;
}
}
torch
::
Tensor
grid_cpu
(
torch
::
Tensor
pos
,
torch
::
Tensor
size
,
torch
::
optional
<
torch
::
Tensor
>
optional_start
,
torch
::
optional
<
torch
::
Tensor
>
optional_end
)
{
CHECK_CUDA
(
pos
);
CHECK_CUDA
(
size
);
cudaSetDevice
(
pos
.
get_device
());
if
(
optional_start
.
has_value
())
CHECK_CPU
(
optional_start
.
value
());
if
(
optional_start
.
has_value
())
CHECK_CPU
(
optional_start
.
value
());
pos
=
pos
.
view
({
pos
.
size
(
0
),
-
1
}).
contiguous
();
size
=
size
.
contiguous
();
CHECK_INPUT
(
size
.
numel
()
==
pos
.
size
(
1
));
if
(
!
optional_start
.
has_value
())
optional_start
=
std
::
get
<
0
>
(
pos
.
min
(
0
));
else
{
optional_start
=
optional_start
.
value
().
contiguous
();
CHECK_INPUT
(
optional_start
.
value
().
numel
()
==
pos
.
size
(
1
));
}
if
(
!
optional_end
.
has_value
())
optional_end
=
std
::
get
<
0
>
(
pos
.
max
(
0
));
else
{
optional_start
=
optional_start
.
value
().
contiguous
();
CHECK_INPUT
(
optional_start
.
value
().
numel
()
==
pos
.
size
(
1
));
}
auto
start
=
optional_start
.
value
();
auto
end
=
optional_end
.
value
();
auto
out
=
torch
::
empty
(
pos
.
size
(
0
),
pos
.
options
().
dtype
(
torch
::
kLong
));
AT_DISPATCH_ALL_TYPES
(
pos
.
scalar_type
(),
"grid_kernel"
,
[
&
]
{
grid_kernel
<
scalar_t
><<<
BLOCKS
(
out
.
numel
()),
THREADS
>>>
(
pos
.
data_ptr
<
scalar_t
>
(),
size
.
data_ptr
<
scalar_t
>
(),
start
.
data_ptr
<
scalar_t
>
(),
end
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int64_t
>
(),
pos
.
size
(
0
),
pos
.
size
(
1
),
out
.
numel
());
});
return
out
;
}
csrc/cuda/grid_cuda.h
0 → 100644
View file @
ba9f2ed2
#pragma once
#include <torch/extension.h>
torch
::
Tensor
grid_cuda
(
torch
::
Tensor
pos
,
torch
::
Tensor
size
,
torch
::
optional
<
torch
::
Tensor
>
optional_start
,
torch
::
optional
<
torch
::
Tensor
>
optional_end
);
csrc/cuda/grid_kernel.cu
deleted
100644 → 0
View file @
26f5fa37
#include <ATen/ATen.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include "compat.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template
<
typename
scalar_t
>
__global__
void
grid_kernel
(
int64_t
*
cluster
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
pos
,
scalar_t
*
__restrict__
size
,
scalar_t
*
__restrict__
start
,
scalar_t
*
__restrict__
end
,
size_t
numel
)
{
const
size_t
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
ptrdiff_t
i
=
index
;
i
<
numel
;
i
+=
stride
)
{
int64_t
c
=
0
,
k
=
1
;
for
(
ptrdiff_t
d
=
0
;
d
<
pos
.
sizes
[
1
];
d
++
)
{
scalar_t
p
=
pos
.
data
[
i
*
pos
.
strides
[
0
]
+
d
*
pos
.
strides
[
1
]]
-
start
[
d
];
c
+=
(
int64_t
)(
p
/
size
[
d
])
*
k
;
k
*=
(
int64_t
)((
end
[
d
]
-
start
[
d
])
/
size
[
d
])
+
1
;
}
cluster
[
i
]
=
c
;
}
}
at
::
Tensor
grid_cuda
(
at
::
Tensor
pos
,
at
::
Tensor
size
,
at
::
Tensor
start
,
at
::
Tensor
end
)
{
cudaSetDevice
(
pos
.
get_device
());
auto
cluster
=
at
::
empty
(
pos
.
size
(
0
),
pos
.
options
().
dtype
(
at
::
kLong
));
AT_DISPATCH_ALL_TYPES
(
pos
.
scalar_type
(),
"grid_kernel"
,
[
&
]
{
grid_kernel
<
scalar_t
><<<
BLOCKS
(
cluster
.
numel
()),
THREADS
>>>
(
cluster
.
DATA_PTR
<
int64_t
>
(),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
pos
),
size
.
DATA_PTR
<
scalar_t
>
(),
start
.
DATA_PTR
<
scalar_t
>
(),
end
.
DATA_PTR
<
scalar_t
>
(),
cluster
.
numel
());
});
return
cluster
;
}
csrc/cuda/utils.cuh
View file @
ba9f2ed2
#pragma once
#pragma once
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
////////////////////////////////////////////////////////////////////////
#include <ATen/ATen.h>
#include <ATen/ATen.h>
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
remove_self_loops
(
at
::
Tensor
row
,
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
remove_self_loops
(
at
::
Tensor
row
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment