Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
685b3770
Commit
685b3770
authored
Nov 13, 2018
by
rusty1s
Browse files
template for dim
parent
5ab10c54
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
119 additions
and
26 deletions
+119
-26
cuda/fps_kernel.cu
cuda/fps_kernel.cu
+119
-26
No files found.
cuda/fps_kernel.cu
View file @
685b3770
...
@@ -5,29 +5,74 @@
...
@@ -5,29 +5,74 @@
#define THREADS 1024
#define THREADS 1024
template
<
typename
scalar_t
>
template
<
typename
scalar_t
,
int64_t
Dim
>
struct
Dist
{};
__global__
void
fps_kernel
(
scalar_t
*
__restrict__
x
,
int64_t
*
__restrict__
cum_deg
,
int64_t
*
__restrict__
cum_k
,
int64_t
*
__restrict__
start
,
scalar_t
*
__restrict__
dist
,
scalar_t
*
__restrict__
tmp_dist
,
int64_t
*
__restrict__
out
,
size_t
dim
)
{
const
size_t
batch_idx
=
blockIdx
.
x
;
template
<
typename
scalar_t
>
struct
Dist
<
scalar_t
,
1
>
{
const
size_t
idx
=
threadIdx
.
x
;
static
__device__
void
compute
(
ptrdiff_t
idx
,
ptrdiff_t
start_idx
,
ptrdiff_t
end_idx
,
ptrdiff_t
old
,
scalar_t
*
__restrict__
best
,
ptrdiff_t
*
__restrict__
best_idx
,
scalar_t
*
__restrict__
x
,
scalar_t
*
__restrict__
dist
,
scalar_t
*
__restrict__
tmp_dist
,
size_t
dim
)
{
const
size_t
start_idx
=
cum_deg
[
batch_idx
];
for
(
ptrdiff_t
n
=
start_idx
+
idx
;
n
<
end_idx
;
n
+=
THREADS
)
{
const
size_t
end_idx
=
cum_deg
[
batch_idx
+
1
];
scalar_t
a
=
x
[
old
*
3
+
0
]
-
x
[
n
*
3
+
0
];
scalar_t
d
=
a
*
a
;
dist
[
n
]
=
min
(
dist
[
n
],
d
);
if
(
dist
[
n
]
>
*
best
)
{
*
best
=
dist
[
n
];
*
best_idx
=
n
;
}
}
}
};
__shared__
scalar_t
best_dist
[
THREADS
];
template
<
typename
scalar_t
>
struct
Dist
<
scalar_t
,
2
>
{
__shared__
int64_t
best_dist_idx
[
THREADS
];
static
__device__
void
compute
(
ptrdiff_t
idx
,
ptrdiff_t
start_idx
,
ptrdiff_t
end_idx
,
ptrdiff_t
old
,
scalar_t
*
__restrict__
best
,
ptrdiff_t
*
__restrict__
best_idx
,
scalar_t
*
__restrict__
x
,
scalar_t
*
__restrict__
dist
,
scalar_t
*
__restrict__
tmp_dist
,
size_t
dim
)
{
if
(
idx
==
0
)
{
for
(
ptrdiff_t
n
=
start_idx
+
idx
;
n
<
end_idx
;
n
+=
THREADS
)
{
out
[
cum_k
[
batch_idx
]]
=
start_idx
+
start
[
batch_idx
];
scalar_t
a
=
x
[
old
*
3
+
0
]
-
x
[
n
*
3
+
0
];
scalar_t
b
=
x
[
old
*
3
+
1
]
-
x
[
n
*
3
+
1
];
scalar_t
d
=
a
*
a
+
b
*
b
;
dist
[
n
]
=
min
(
dist
[
n
],
d
);
if
(
dist
[
n
]
>
*
best
)
{
*
best
=
dist
[
n
];
*
best_idx
=
n
;
}
}
}
}
};
for
(
ptrdiff_t
m
=
cum_k
[
batch_idx
]
+
1
;
m
<
cum_k
[
batch_idx
+
1
];
m
++
)
{
template
<
typename
scalar_t
>
struct
Dist
<
scalar_t
,
3
>
{
ptrdiff_t
best_idx
=
0
;
static
__device__
void
scalar_t
best
=
-
1
;
compute
(
ptrdiff_t
idx
,
ptrdiff_t
start_idx
,
ptrdiff_t
end_idx
,
ptrdiff_t
old
,
scalar_t
*
__restrict__
best
,
ptrdiff_t
*
__restrict__
best_idx
,
scalar_t
*
__restrict__
x
,
scalar_t
*
__restrict__
dist
,
scalar_t
*
__restrict__
tmp_dist
,
size_t
dim
)
{
for
(
ptrdiff_t
n
=
start_idx
+
idx
;
n
<
end_idx
;
n
+=
THREADS
)
{
scalar_t
a
=
x
[
old
*
3
+
0
]
-
x
[
n
*
3
+
0
];
scalar_t
b
=
x
[
old
*
3
+
1
]
-
x
[
n
*
3
+
1
];
scalar_t
c
=
x
[
old
*
3
+
2
]
-
x
[
n
*
3
+
2
];
scalar_t
d
=
a
*
a
+
b
*
b
+
c
*
c
;
dist
[
n
]
=
min
(
dist
[
n
],
d
);
if
(
dist
[
n
]
>
*
best
)
{
*
best
=
dist
[
n
];
*
best_idx
=
n
;
}
}
}
};
template
<
typename
scalar_t
>
struct
Dist
<
scalar_t
,
-
1
>
{
static
__device__
void
compute
(
ptrdiff_t
idx
,
ptrdiff_t
start_idx
,
ptrdiff_t
end_idx
,
ptrdiff_t
old
,
scalar_t
*
__restrict__
best
,
ptrdiff_t
*
__restrict__
best_idx
,
scalar_t
*
__restrict__
x
,
scalar_t
*
__restrict__
dist
,
scalar_t
*
__restrict__
tmp_dist
,
size_t
dim
)
{
for
(
ptrdiff_t
n
=
start_idx
+
idx
;
n
<
end_idx
;
n
+=
THREADS
)
{
for
(
ptrdiff_t
n
=
start_idx
+
idx
;
n
<
end_idx
;
n
+=
THREADS
)
{
tmp_dist
[
n
]
=
0
;
tmp_dist
[
n
]
=
0
;
...
@@ -35,18 +80,47 @@ fps_kernel(scalar_t *__restrict__ x, int64_t *__restrict__ cum_deg,
...
@@ -35,18 +80,47 @@ fps_kernel(scalar_t *__restrict__ x, int64_t *__restrict__ cum_deg,
__syncthreads
();
__syncthreads
();
for
(
ptrdiff_t
i
=
start_idx
*
dim
+
idx
;
i
<
end_idx
*
dim
;
i
+=
THREADS
)
{
for
(
ptrdiff_t
i
=
start_idx
*
dim
+
idx
;
i
<
end_idx
*
dim
;
i
+=
THREADS
)
{
scalar_t
d
=
x
[(
o
ut
[
m
-
1
]
*
dim
)
+
(
i
%
dim
)]
-
x
[
i
];
scalar_t
d
=
x
[(
o
ld
*
dim
)
+
(
i
%
dim
)]
-
x
[
i
];
atomicAdd
(
&
tmp_dist
[
i
/
dim
],
d
*
d
);
atomicAdd
(
&
tmp_dist
[
i
/
dim
],
d
*
d
);
}
}
__syncthreads
();
__syncthreads
();
for
(
ptrdiff_t
n
=
start_idx
+
idx
;
n
<
end_idx
;
n
+=
THREADS
)
{
for
(
ptrdiff_t
n
=
start_idx
+
idx
;
n
<
end_idx
;
n
+=
THREADS
)
{
dist
[
n
]
=
min
(
dist
[
n
],
tmp_dist
[
n
]);
dist
[
n
]
=
min
(
dist
[
n
],
tmp_dist
[
n
]);
if
(
dist
[
n
]
>
best
)
{
if
(
dist
[
n
]
>
*
best
)
{
best
=
dist
[
n
];
*
best
=
dist
[
n
];
best_idx
=
n
;
*
best_idx
=
n
;
}
}
}
}
}
};
template
<
typename
scalar_t
,
int64_t
Dim
>
__global__
void
fps_kernel
(
scalar_t
*
__restrict__
x
,
int64_t
*
__restrict__
cum_deg
,
int64_t
*
__restrict__
cum_k
,
int64_t
*
__restrict__
start
,
scalar_t
*
__restrict__
dist
,
scalar_t
*
__restrict__
tmp_dist
,
int64_t
*
__restrict__
out
,
size_t
dim
)
{
const
ptrdiff_t
batch_idx
=
blockIdx
.
x
;
const
ptrdiff_t
idx
=
threadIdx
.
x
;
const
ptrdiff_t
start_idx
=
cum_deg
[
batch_idx
];
const
ptrdiff_t
end_idx
=
cum_deg
[
batch_idx
+
1
];
__shared__
scalar_t
best_dist
[
THREADS
];
__shared__
int64_t
best_dist_idx
[
THREADS
];
if
(
idx
==
0
)
{
out
[
cum_k
[
batch_idx
]]
=
start_idx
+
start
[
batch_idx
];
}
for
(
ptrdiff_t
m
=
cum_k
[
batch_idx
]
+
1
;
m
<
cum_k
[
batch_idx
+
1
];
m
++
)
{
scalar_t
best
=
-
1
;
ptrdiff_t
best_idx
=
0
;
Dist
<
scalar_t
,
Dim
>::
compute
(
idx
,
start_idx
,
end_idx
,
out
[
m
-
1
],
&
best
,
&
best_idx
,
x
,
dist
,
tmp_dist
,
dim
);
best_dist
[
idx
]
=
best
;
best_dist
[
idx
]
=
best
;
best_dist_idx
[
idx
]
=
best_idx
;
best_dist_idx
[
idx
]
=
best_idx
;
...
@@ -64,10 +138,29 @@ fps_kernel(scalar_t *__restrict__ x, int64_t *__restrict__ cum_deg,
...
@@ -64,10 +138,29 @@ fps_kernel(scalar_t *__restrict__ x, int64_t *__restrict__ cum_deg,
}
}
__syncthreads
();
__syncthreads
();
out
[
m
]
=
best_dist_idx
[
0
];
if
(
idx
==
0
)
{
out
[
m
]
=
best_dist_idx
[
0
];
}
}
}
}
}
#define FPS_KERNEL(DIM, ...) \
[&] { \
switch (DIM) { \
case 1: \
fps_kernel<scalar_t, 1><<<batch_size, THREADS>>>(__VA_ARGS__, DIM); \
break; \
case 2: \
fps_kernel<scalar_t, 2><<<batch_size, THREADS>>>(__VA_ARGS__, DIM); \
break; \
case 3: \
fps_kernel<scalar_t, 3><<<batch_size, THREADS>>>(__VA_ARGS__, DIM); \
break; \
default: \
fps_kernel<scalar_t, -1><<<batch_size, THREADS>>>(__VA_ARGS__, DIM); \
} \
}()
at
::
Tensor
fps_cuda
(
at
::
Tensor
x
,
at
::
Tensor
batch
,
float
ratio
,
bool
random
)
{
at
::
Tensor
fps_cuda
(
at
::
Tensor
x
,
at
::
Tensor
batch
,
float
ratio
,
bool
random
)
{
auto
batch_sizes
=
(
int64_t
*
)
malloc
(
sizeof
(
int64_t
));
auto
batch_sizes
=
(
int64_t
*
)
malloc
(
sizeof
(
int64_t
));
cudaMemcpy
(
batch_sizes
,
batch
[
-
1
].
data
<
int64_t
>
(),
sizeof
(
int64_t
),
cudaMemcpy
(
batch_sizes
,
batch
[
-
1
].
data
<
int64_t
>
(),
sizeof
(
int64_t
),
...
@@ -96,10 +189,10 @@ at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) {
...
@@ -96,10 +189,10 @@ at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) {
auto
out
=
at
::
empty
(
k_sum
[
0
],
k
.
options
());
auto
out
=
at
::
empty
(
k_sum
[
0
],
k
.
options
());
AT_DISPATCH_FLOATING_TYPES
(
x
.
type
(),
"fps_kernel"
,
[
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
x
.
type
(),
"fps_kernel"
,
[
&
]
{
fps_kernel
<
scalar_t
><<<
batch_size
,
THREADS
>>>
(
FPS_KERNEL
(
x
.
size
(
1
),
x
.
data
<
scalar_t
>
(),
cum_deg
.
data
<
int64_t
>
(),
x
.
data
<
scalar_t
>
(),
cum_
deg
.
data
<
int64_t
>
(),
cum_k
.
data
<
int64_t
>
(),
cum_
k
.
data
<
int64_t
>
(),
start
.
data
<
int64_t
>
(),
start
.
data
<
int64_t
>
(),
dist
.
data
<
scalar_t
>
(),
tmp_dist
.
data
<
scalar_t
>
(),
dist
.
data
<
scalar_t
>
(),
tmp_dist
.
data
<
scalar_t
>
(),
out
.
data
<
int64_t
>
()
,
x
.
size
(
1
)
);
out
.
data
<
int64_t
>
());
});
});
return
out
;
return
out
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment