Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
06d9038f
Commit
06d9038f
authored
Mar 12, 2020
by
rusty1s
Browse files
fps kernel
parent
7587ce48
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
96 additions
and
204 deletions
+96
-204
csrc/cuda/fps_cuda.cu
csrc/cuda/fps_cuda.cu
+96
-4
csrc/cuda/fps_kernel.cu
csrc/cuda/fps_kernel.cu
+0
-200
No files found.
csrc/cuda/fps_cuda.cu
View file @
06d9038f
#include "fps_cuda.h"
#include "fps_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#include "utils.cuh"
#define THREADS 1024
inline
torch
::
Tensor
get_dist
(
torch
::
Tensor
x
,
int64_t
idx
)
{
inline
torch
::
Tensor
get_dist
(
torch
::
Tensor
x
,
int64_t
idx
)
{
return
(
x
-
x
[
idx
]).
norm
(
2
,
1
);
return
(
x
-
x
[
idx
]).
norm
(
2
,
1
);
}
}
template
<
typename
scalar_t
>
struct
Dist
<
scalar_t
>
{
static
inline
__device__
void
compute
(
int64_t
idx
,
int64_t
start_idx
,
int64_t
end_idx
,
int64_t
old
,
scalar_t
*
best
,
int64_t
*
best_idx
,
const
scalar_t
*
x
,
scalar_t
*
dist
,
scalar_t
*
tmp_dist
,
int64_t
dim
)
{
for
(
int64_t
n
=
start_idx
+
idx
;
n
<
end_idx
;
n
+=
THREADS
)
{
tmp_dist
[
n
]
=
0
;
}
__syncthreads
();
for
(
int64_t
i
=
start_idx
*
dim
+
idx
;
i
<
end_idx
*
dim
;
i
+=
THREADS
)
{
scalar_t
d
=
x
[(
old
*
dim
)
+
(
i
%
dim
)]
-
x
[
i
];
atomicAdd
(
&
tmp_dist
[
i
/
dim
],
d
*
d
);
}
__syncthreads
();
for
(
int64_t
n
=
start_idx
+
idx
;
n
<
end_idx
;
n
+=
THREADS
)
{
dist
[
n
]
=
min
(
dist
[
n
],
tmp_dist
[
n
]);
if
(
dist
[
n
]
>
*
best
)
{
*
best
=
dist
[
n
];
*
best_idx
=
n
;
}
}
}
};
template
<
typename
scalar_t
>
__global__
void
fps_kernel
(
const
scalar_t
*
x
,
const
int64_t
*
ptr
,
const
int64_t
*
out_ptr
,
const
int64_t
*
start
,
scalar_t
*
dist
,
scalar_t
*
tmp_dist
,
int64_t
*
out
,
int64_t
dim
)
{
const
int64_t
batch_idx
=
blockIdx
.
x
;
const
int64_t
thread_idx
=
threadIdx
.
x
;
const
int64_t
start_idx
=
ptr
[
batch_idx
];
const
int64_t
end_idx
=
ptr
[
batch_idx
+
1
];
__shared__
scalar_t
best_dist
[
THREADS
];
__shared__
int64_t
best_dist_idx
[
THREADS
];
if
(
threadIdx
.
x
==
0
)
{
out
[
out_ptr
[
batch_idx
]]
=
start_idx
+
start
[
batch_idx
];
}
for
(
int64_t
m
=
out_ptr
[
batch_idx
]
+
1
;
m
<
out_ptr
[
batch_idx
+
1
];
m
++
)
{
scalar_t
best
=
-
1
;
int64_t
best_idx
=
0
;
__syncthreads
();
Dist
<
scalar_t
,
Dim
>::
compute
(
thread_idx
,
start_idx
,
end_idx
,
out
[
m
-
1
],
&
best
,
&
best_idx
,
x
,
dist
,
tmp_dist
,
dim
);
best_dist
[
idx
]
=
best
;
best_dist_idx
[
idx
]
=
best_idx
;
for
(
int64_t
u
=
0
;
(
1
<<
u
)
<
THREADS
;
u
++
)
{
__syncthreads
();
if
(
thread_idx
<
(
THREADS
>>
(
u
+
1
)))
{
int64_t
idx1
=
(
thread_idx
*
2
)
<<
u
;
int64_t
idx2
=
(
thread_idx
*
2
+
1
)
<<
u
;
if
(
best_dist
[
idx1
]
<
best_dist
[
idx2
])
{
best_dist
[
idx1
]
=
best_dist
[
idx2
];
best_dist_idx
[
idx1
]
=
best_dist_idx
[
idx2
];
}
}
}
__syncthreads
();
if
(
idx
==
0
)
{
out
[
m
]
=
best_dist_idx
[
0
];
}
}
}
torch
::
Tensor
fps_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
ptr
,
double
ratio
,
torch
::
Tensor
fps_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
ptr
,
double
ratio
,
bool
random_start
)
{
bool
random_start
)
{
...
@@ -31,11 +112,22 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio,
...
@@ -31,11 +112,22 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio,
start
=
torch
::
zeros
(
batch_size
,
ptr
.
options
());
start
=
torch
::
zeros
(
batch_size
,
ptr
.
options
());
}
}
auto
out
=
torch
::
empty
(
out_ptr
[
-
1
].
data_ptr
<
int64_t
>
()[
0
],
ptr
.
options
());
auto
dist
=
torch
::
full
(
src
.
size
(
0
),
1e38
,
src
.
options
());
auto
tmp_dist
=
torch
::
empty
(
src
.
size
(
0
),
src
.
options
());
auto
out_size
=
(
int64_t
*
)
malloc
(
sizeof
(
int64_t
));
cudaMemcpy
(
out_size
,
out_ptr
[
-
1
].
data_ptr
<
int64_t
>
(),
sizeof
(
int64_t
),
cudaMemcpyDeviceToHost
);
auto
out
=
at
::
empty
(
out_size
[
0
],
out_ptr
.
options
());
auto
ptr_data
=
ptr
.
data_ptr
<
int64_t
>
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
out_ptr_data
=
out_ptr
.
data_ptr
<
int64_t
>
();
AT_DISPATCH_FLOATING_TYPES
(
src
.
scalar_type
(),
"fps_kernel"
,
[
&
]
{
auto
out_data
=
out
.
data_ptr
<
int64_t
>
();
fps_kernel
<
scalar_t
><<<
batch_size
,
THREADS
,
0
,
stream
>>>
(
src
.
data_ptr
<
scalar_t
>
(),
rowptr
.
data_ptr
<
int64_t
>
(),
out_ptr
.
data_ptr
<
int64_t
>
(),
start
.
data_ptr
<
int64_t
>
(),
dist
.
data_ptr
<
scalar_t
>
(),
tmp_dist
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int64_t
>
(),
src
.
size
(
1
));
});
return
out
;
return
out
;
}
}
csrc/cuda/fps_kernel.cu
deleted
100644 → 0
View file @
7587ce48
#include <ATen/ATen.h>
#include "compat.cuh"
#include "utils.cuh"
#define THREADS 1024
template
<
typename
scalar_t
,
int64_t
Dim
>
struct
Dist
;
template
<
typename
scalar_t
>
struct
Dist
<
scalar_t
,
1
>
{
static
__device__
void
compute
(
ptrdiff_t
idx
,
ptrdiff_t
start_idx
,
ptrdiff_t
end_idx
,
ptrdiff_t
old
,
scalar_t
*
__restrict__
best
,
ptrdiff_t
*
__restrict__
best_idx
,
const
scalar_t
*
__restrict__
x
,
scalar_t
*
__restrict__
dist
,
scalar_t
*
__restrict__
tmp_dist
,
size_t
dim
)
{
for
(
ptrdiff_t
n
=
start_idx
+
idx
;
n
<
end_idx
;
n
+=
THREADS
)
{
scalar_t
d
=
x
[
old
]
-
x
[
n
];
dist
[
n
]
=
min
(
dist
[
n
],
d
*
d
);
if
(
dist
[
n
]
>
*
best
)
{
*
best
=
dist
[
n
];
*
best_idx
=
n
;
}
}
}
};
template
<
typename
scalar_t
>
struct
Dist
<
scalar_t
,
2
>
{
static
__device__
void
compute
(
ptrdiff_t
idx
,
ptrdiff_t
start_idx
,
ptrdiff_t
end_idx
,
ptrdiff_t
old
,
scalar_t
*
__restrict__
best
,
ptrdiff_t
*
__restrict__
best_idx
,
const
scalar_t
*
__restrict__
x
,
scalar_t
*
__restrict__
dist
,
scalar_t
*
__restrict__
tmp_dist
,
size_t
dim
)
{
for
(
ptrdiff_t
n
=
start_idx
+
idx
;
n
<
end_idx
;
n
+=
THREADS
)
{
scalar_t
a
=
x
[
2
*
old
+
0
]
-
x
[
2
*
n
+
0
];
scalar_t
b
=
x
[
2
*
old
+
1
]
-
x
[
2
*
n
+
1
];
scalar_t
d
=
a
*
a
+
b
*
b
;
dist
[
n
]
=
min
(
dist
[
n
],
d
);
if
(
dist
[
n
]
>
*
best
)
{
*
best
=
dist
[
n
];
*
best_idx
=
n
;
}
}
}
};
template
<
typename
scalar_t
>
struct
Dist
<
scalar_t
,
3
>
{
static
__device__
void
compute
(
ptrdiff_t
idx
,
ptrdiff_t
start_idx
,
ptrdiff_t
end_idx
,
ptrdiff_t
old
,
scalar_t
*
__restrict__
best
,
ptrdiff_t
*
__restrict__
best_idx
,
const
scalar_t
*
__restrict__
x
,
scalar_t
*
__restrict__
dist
,
scalar_t
*
__restrict__
tmp_dist
,
size_t
dim
)
{
for
(
ptrdiff_t
n
=
start_idx
+
idx
;
n
<
end_idx
;
n
+=
THREADS
)
{
scalar_t
a
=
x
[
3
*
old
+
0
]
-
x
[
3
*
n
+
0
];
scalar_t
b
=
x
[
3
*
old
+
1
]
-
x
[
3
*
n
+
1
];
scalar_t
c
=
x
[
3
*
old
+
2
]
-
x
[
3
*
n
+
2
];
scalar_t
d
=
a
*
a
+
b
*
b
+
c
*
c
;
dist
[
n
]
=
min
(
dist
[
n
],
d
);
if
(
dist
[
n
]
>
*
best
)
{
*
best
=
dist
[
n
];
*
best_idx
=
n
;
}
}
}
};
template
<
typename
scalar_t
>
struct
Dist
<
scalar_t
,
-
1
>
{
static
__device__
void
compute
(
ptrdiff_t
idx
,
ptrdiff_t
start_idx
,
ptrdiff_t
end_idx
,
ptrdiff_t
old
,
scalar_t
*
__restrict__
best
,
ptrdiff_t
*
__restrict__
best_idx
,
const
scalar_t
*
__restrict__
x
,
scalar_t
*
__restrict__
dist
,
scalar_t
*
__restrict__
tmp_dist
,
size_t
dim
)
{
for
(
ptrdiff_t
n
=
start_idx
+
idx
;
n
<
end_idx
;
n
+=
THREADS
)
{
tmp_dist
[
n
]
=
0
;
}
__syncthreads
();
for
(
ptrdiff_t
i
=
start_idx
*
dim
+
idx
;
i
<
end_idx
*
dim
;
i
+=
THREADS
)
{
scalar_t
d
=
x
[(
old
*
dim
)
+
(
i
%
dim
)]
-
x
[
i
];
atomicAdd
(
&
tmp_dist
[
i
/
dim
],
d
*
d
);
}
__syncthreads
();
for
(
ptrdiff_t
n
=
start_idx
+
idx
;
n
<
end_idx
;
n
+=
THREADS
)
{
dist
[
n
]
=
min
(
dist
[
n
],
tmp_dist
[
n
]);
if
(
dist
[
n
]
>
*
best
)
{
*
best
=
dist
[
n
];
*
best_idx
=
n
;
}
}
}
};
template
<
typename
scalar_t
,
int64_t
Dim
>
__global__
void
fps_kernel
(
const
scalar_t
*
__restrict__
x
,
const
int64_t
*
__restrict__
cum_deg
,
const
int64_t
*
__restrict__
cum_k
,
const
int64_t
*
__restrict__
start
,
scalar_t
*
__restrict__
dist
,
scalar_t
*
__restrict__
tmp_dist
,
int64_t
*
__restrict__
out
,
size_t
dim
)
{
const
ptrdiff_t
batch_idx
=
blockIdx
.
x
;
const
ptrdiff_t
idx
=
threadIdx
.
x
;
const
ptrdiff_t
start_idx
=
cum_deg
[
batch_idx
];
const
ptrdiff_t
end_idx
=
cum_deg
[
batch_idx
+
1
];
__shared__
scalar_t
best_dist
[
THREADS
];
__shared__
int64_t
best_dist_idx
[
THREADS
];
if
(
idx
==
0
)
{
out
[
cum_k
[
batch_idx
]]
=
start_idx
+
start
[
batch_idx
];
}
for
(
ptrdiff_t
m
=
cum_k
[
batch_idx
]
+
1
;
m
<
cum_k
[
batch_idx
+
1
];
m
++
)
{
scalar_t
best
=
-
1
;
ptrdiff_t
best_idx
=
0
;
__syncthreads
();
Dist
<
scalar_t
,
Dim
>::
compute
(
idx
,
start_idx
,
end_idx
,
out
[
m
-
1
],
&
best
,
&
best_idx
,
x
,
dist
,
tmp_dist
,
dim
);
best_dist
[
idx
]
=
best
;
best_dist_idx
[
idx
]
=
best_idx
;
for
(
int64_t
u
=
0
;
(
1
<<
u
)
<
THREADS
;
u
++
)
{
__syncthreads
();
if
(
idx
<
(
THREADS
>>
(
u
+
1
)))
{
int64_t
idx_1
=
(
idx
*
2
)
<<
u
;
int64_t
idx_2
=
(
idx
*
2
+
1
)
<<
u
;
if
(
best_dist
[
idx_1
]
<
best_dist
[
idx_2
])
{
best_dist
[
idx_1
]
=
best_dist
[
idx_2
];
best_dist_idx
[
idx_1
]
=
best_dist_idx
[
idx_2
];
}
}
}
__syncthreads
();
if
(
idx
==
0
)
{
out
[
m
]
=
best_dist_idx
[
0
];
}
}
}
#define FPS_KERNEL(DIM, ...) \
[&] { \
switch (DIM) { \
case 1: \
fps_kernel<scalar_t, 1><<<batch_size, THREADS>>>(__VA_ARGS__, DIM); \
break; \
case 2: \
fps_kernel<scalar_t, 2><<<batch_size, THREADS>>>(__VA_ARGS__, DIM); \
break; \
case 3: \
fps_kernel<scalar_t, 3><<<batch_size, THREADS>>>(__VA_ARGS__, DIM); \
break; \
default: \
fps_kernel<scalar_t, -1><<<batch_size, THREADS>>>(__VA_ARGS__, DIM); \
} \
}()
at
::
Tensor
fps_cuda
(
at
::
Tensor
x
,
at
::
Tensor
batch
,
float
ratio
,
bool
random
)
{
cudaSetDevice
(
x
.
get_device
());
auto
batch_sizes
=
(
int64_t
*
)
malloc
(
sizeof
(
int64_t
));
cudaMemcpy
(
batch_sizes
,
batch
[
-
1
].
DATA_PTR
<
int64_t
>
(),
sizeof
(
int64_t
),
cudaMemcpyDeviceToHost
);
auto
batch_size
=
batch_sizes
[
0
]
+
1
;
auto
deg
=
degree
(
batch
,
batch_size
);
auto
cum_deg
=
at
::
cat
({
at
::
zeros
(
1
,
deg
.
options
()),
deg
.
cumsum
(
0
)},
0
);
auto
k
=
(
deg
.
toType
(
at
::
kFloat
)
*
ratio
).
ceil
().
toType
(
at
::
kLong
);
auto
cum_k
=
at
::
cat
({
at
::
zeros
(
1
,
k
.
options
()),
k
.
cumsum
(
0
)},
0
);
at
::
Tensor
start
;
if
(
random
)
{
start
=
at
::
rand
(
batch_size
,
x
.
options
());
start
=
(
start
*
deg
.
toType
(
at
::
kFloat
)).
toType
(
at
::
kLong
);
}
else
{
start
=
at
::
zeros
(
batch_size
,
k
.
options
());
}
auto
dist
=
at
::
full
(
x
.
size
(
0
),
1e38
,
x
.
options
());
auto
tmp_dist
=
at
::
empty
(
x
.
size
(
0
),
x
.
options
());
auto
k_sum
=
(
int64_t
*
)
malloc
(
sizeof
(
int64_t
));
cudaMemcpy
(
k_sum
,
cum_k
[
-
1
].
DATA_PTR
<
int64_t
>
(),
sizeof
(
int64_t
),
cudaMemcpyDeviceToHost
);
auto
out
=
at
::
empty
(
k_sum
[
0
],
k
.
options
());
AT_DISPATCH_FLOATING_TYPES
(
x
.
scalar_type
(),
"fps_kernel"
,
[
&
]
{
FPS_KERNEL
(
x
.
size
(
1
),
x
.
DATA_PTR
<
scalar_t
>
(),
cum_deg
.
DATA_PTR
<
int64_t
>
(),
cum_k
.
DATA_PTR
<
int64_t
>
(),
start
.
DATA_PTR
<
int64_t
>
(),
dist
.
DATA_PTR
<
scalar_t
>
(),
tmp_dist
.
DATA_PTR
<
scalar_t
>
(),
out
.
DATA_PTR
<
int64_t
>
());
});
return
out
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment