Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
787eaef6
Commit
787eaef6
authored
Mar 09, 2020
by
rusty1s
Browse files
new radius implementation
parent
1960e391
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
83 additions
and
101 deletions
+83
-101
csrc/cuda/nearest_cuda.cu
csrc/cuda/nearest_cuda.cu
+2
-1
csrc/cuda/radius.cpp
csrc/cuda/radius.cpp
+0
-24
csrc/cuda/radius_cuda.cu
csrc/cuda/radius_cuda.cu
+74
-0
csrc/cuda/radius_cuda.h
csrc/cuda/radius_cuda.h
+7
-0
csrc/cuda/radius_kernel.cu
csrc/cuda/radius_kernel.cu
+0
-76
No files found.
csrc/cuda/nearest_cuda.cu
View file @
787eaef6
...
...
@@ -25,7 +25,8 @@ __global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
scalar_t
best
=
1e38
;
int64_t
best_idx
=
0
;
for
(
int64_t
n_y
=
y_start_idx
+
threadIdx
.
x
;
n_y
<
end_idx
;
n_y
+=
THREADS
)
{
for
(
int64_t
n_y
=
y_start_idx
+
threadIdx
.
x
;
n_y
<
y_end_idx
;
n_y
+=
THREADS
)
{
scalar_t
dist
=
0
;
for
(
int64_t
d
=
0
;
d
<
dim
;
d
++
)
{
dist
+=
(
x
[
n_x
*
dim
+
d
]
-
y
[
n_y
*
dim
+
d
])
*
...
...
csrc/cuda/radius.cpp
deleted
100644 → 0
View file @
1960e391
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define IS_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " is not contiguous");
at
::
Tensor
radius_cuda
(
at
::
Tensor
x
,
at
::
Tensor
y
,
float
radius
,
at
::
Tensor
batch_x
,
at
::
Tensor
batch_y
,
size_t
max_num_neighbors
);
at
::
Tensor
radius
(
at
::
Tensor
x
,
at
::
Tensor
y
,
float
radius
,
at
::
Tensor
batch_x
,
at
::
Tensor
batch_y
,
size_t
max_num_neighbors
)
{
CHECK_CUDA
(
x
);
IS_CONTIGUOUS
(
x
);
CHECK_CUDA
(
y
);
IS_CONTIGUOUS
(
y
);
CHECK_CUDA
(
batch_x
);
CHECK_CUDA
(
batch_y
);
return
radius_cuda
(
x
,
y
,
radius
,
batch_x
,
batch_y
,
max_num_neighbors
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"radius"
,
&
radius
,
"Radius (CUDA)"
);
}
csrc/cuda/radius_cuda.cu
0 → 100644
View file @
787eaef6
#include "radius_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#define THREADS 1024
template
<
typename
scalar_t
>
__global__
void
radius_kernel
(
const
scalar_t
*
x
,
const
scalar_t
*
y
,
const
int64_t
*
ptr_x
,
const
int64_t
*
ptr_y
,
int64_t
*
row
,
int64_t
*
col
,
scalar_t
radius
,
int64_t
max_num_neighbors
,
int64_t
dim
)
{
const
int64_t
batch_idx
=
blockIdx
.
x
;
// const ptrdiff_t idx = threadIdx.x;
const
ptrdiff_t
x_start_idx
=
ptr_x
[
batch_idx
];
const
ptrdiff_t
x_end_idx
=
ptr_x
[
batch_idx
+
1
];
const
ptrdiff_t
y_start_idx
=
ptr_y
[
batch_idx
];
const
ptrdiff_t
y_end_idx
=
ptr_y
[
batch_idx
+
1
];
for
(
int64_t
n_y
=
y_start_idx
+
threadIdx
.
x
;
n_y
<
y_end_idx
;
n_y
+=
THREADS
)
{
int64_t
count
=
0
;
for
(
int64_t
n_x
=
x_start_idx
;
n_x
<
x_end_idx
;
n_x
++
)
{
scalar_t
dist
=
0
;
for
(
int64_t
d
=
0
;
d
<
dim
;
d
++
)
{
dist
+=
(
x
[
n_x
*
dim
+
d
]
-
y
[
n_y
*
dim
+
d
])
*
(
x
[
n_x
*
dim
+
d
]
-
y
[
n_y
*
dim
+
d
]);
}
dist
=
sqrt
(
dist
);
if
(
dist
<=
radius
)
{
row
[
n_y
*
max_num_neighbors
+
count
]
=
n_y
;
col
[
n_y
*
max_num_neighbors
+
count
]
=
n_x
;
count
++
;
}
if
(
count
>=
max_num_neighbors
)
{
break
;
}
}
}
}
torch
::
Tensor
radius_cuda
(
torch
::
Tensor
x
,
torch
::
Tensor
y
,
torch
::
Tensor
ptr_x
,
torch
::
Tensor
ptr_y
,
double
r
,
int64_t
max_num_neighbors
)
{
CHECK_CUDA
(
x
);
CHECK_CUDA
(
y
);
CHECK_CUDA
(
ptr_x
);
CHECK_CUDA
(
ptr_y
);
cudaSetDevice
(
x
.
get_device
());
x
=
x
.
view
({
x
.
size
(
0
),
-
1
}).
contiguous
();
y
=
y
.
view
({
y
.
size
(
0
),
-
1
}).
contiguous
();
auto
row
=
torch
::
full
(
y
.
size
(
0
)
*
max_num_neighbors
,
-
1
,
ptr_y
.
options
());
auto
col
=
torch
::
full
(
y
.
size
(
0
)
*
max_num_neighbors
,
-
1
,
ptr_y
.
options
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES
(
x
.
scalar_type
(),
"radius_kernel"
,
[
&
]
{
radius_kernel
<
scalar_t
><<<
ptr_x
.
size
(
0
)
-
1
,
THREADS
,
0
,
stream
>>>
(
x
.
data_ptr
<
scalar_t
>
(),
y
.
data_ptr
<
scalar_t
>
(),
ptr_x
.
data_ptr
<
int64_t
>
(),
ptr_y
.
data_ptr
<
int64_t
>
(),
row
.
data_ptr
<
int64_t
>
(),
col
.
data_ptr
<
int64_t
>
(),
radius
,
max_num_neighbors
,
x
.
size
(
1
));
});
auto
mask
=
row
!=
-
1
;
return
torch
::
stack
({
row
.
masked_select
(
mask
),
col
.
masked_select
(
mask
)},
0
);
}
csrc/cuda/radius_cuda.h
0 → 100644
View file @
787eaef6
#pragma once
#include <torch/extension.h>
torch
::
Tensor
radius_cuda
(
torch
::
Tensor
x
,
torch
::
Tensor
y
,
torch
::
Tensor
ptr_x
,
torch
::
Tensor
ptr_y
,
double
r
,
int64_t
max_num_neighbors
);
csrc/cuda/radius_kernel.cu
deleted
100644 → 0
View file @
1960e391
#include <ATen/ATen.h>
#include "compat.cuh"
#include "utils.cuh"
#define THREADS 1024
template
<
typename
scalar_t
>
__global__
void
radius_kernel
(
const
scalar_t
*
__restrict__
x
,
const
scalar_t
*
__restrict__
y
,
const
int64_t
*
__restrict__
batch_x
,
const
int64_t
*
__restrict__
batch_y
,
int64_t
*
__restrict__
row
,
int64_t
*
__restrict__
col
,
scalar_t
radius
,
size_t
max_num_neighbors
,
size_t
dim
)
{
const
ptrdiff_t
batch_idx
=
blockIdx
.
x
;
const
ptrdiff_t
idx
=
threadIdx
.
x
;
const
ptrdiff_t
start_idx_x
=
batch_x
[
batch_idx
];
const
ptrdiff_t
end_idx_x
=
batch_x
[
batch_idx
+
1
];
const
ptrdiff_t
start_idx_y
=
batch_y
[
batch_idx
];
const
ptrdiff_t
end_idx_y
=
batch_y
[
batch_idx
+
1
];
for
(
ptrdiff_t
n_y
=
start_idx_y
+
idx
;
n_y
<
end_idx_y
;
n_y
+=
THREADS
)
{
size_t
count
=
0
;
for
(
ptrdiff_t
n_x
=
start_idx_x
;
n_x
<
end_idx_x
;
n_x
++
)
{
scalar_t
dist
=
0
;
for
(
ptrdiff_t
d
=
0
;
d
<
dim
;
d
++
)
{
dist
+=
(
x
[
n_x
*
dim
+
d
]
-
y
[
n_y
*
dim
+
d
])
*
(
x
[
n_x
*
dim
+
d
]
-
y
[
n_y
*
dim
+
d
]);
}
dist
=
sqrt
(
dist
);
if
(
dist
<=
radius
)
{
row
[
n_y
*
max_num_neighbors
+
count
]
=
n_y
;
col
[
n_y
*
max_num_neighbors
+
count
]
=
n_x
;
count
++
;
}
if
(
count
>=
max_num_neighbors
)
{
break
;
}
}
}
}
at
::
Tensor
radius_cuda
(
at
::
Tensor
x
,
at
::
Tensor
y
,
float
radius
,
at
::
Tensor
batch_x
,
at
::
Tensor
batch_y
,
size_t
max_num_neighbors
)
{
cudaSetDevice
(
x
.
get_device
());
auto
batch_sizes
=
(
int64_t
*
)
malloc
(
sizeof
(
int64_t
));
cudaMemcpy
(
batch_sizes
,
batch_x
[
-
1
].
DATA_PTR
<
int64_t
>
(),
sizeof
(
int64_t
),
cudaMemcpyDeviceToHost
);
auto
batch_size
=
batch_sizes
[
0
]
+
1
;
batch_x
=
degree
(
batch_x
,
batch_size
);
batch_x
=
at
::
cat
({
at
::
zeros
(
1
,
batch_x
.
options
()),
batch_x
.
cumsum
(
0
)},
0
);
batch_y
=
degree
(
batch_y
,
batch_size
);
batch_y
=
at
::
cat
({
at
::
zeros
(
1
,
batch_y
.
options
()),
batch_y
.
cumsum
(
0
)},
0
);
auto
row
=
at
::
full
(
y
.
size
(
0
)
*
max_num_neighbors
,
-
1
,
batch_y
.
options
());
auto
col
=
at
::
full
(
y
.
size
(
0
)
*
max_num_neighbors
,
-
1
,
batch_y
.
options
());
AT_DISPATCH_FLOATING_TYPES
(
x
.
scalar_type
(),
"radius_kernel"
,
[
&
]
{
radius_kernel
<
scalar_t
><<<
batch_size
,
THREADS
>>>
(
x
.
DATA_PTR
<
scalar_t
>
(),
y
.
DATA_PTR
<
scalar_t
>
(),
batch_x
.
DATA_PTR
<
int64_t
>
(),
batch_y
.
DATA_PTR
<
int64_t
>
(),
row
.
DATA_PTR
<
int64_t
>
(),
col
.
DATA_PTR
<
int64_t
>
(),
radius
,
max_num_neighbors
,
x
.
size
(
1
));
});
auto
mask
=
row
!=
-
1
;
return
at
::
stack
({
row
.
masked_select
(
mask
),
col
.
masked_select
(
mask
)},
0
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment