Commit c67425b0 authored by quyuanhao123's avatar quyuanhao123
Browse files

Initial commit

parents
Pipeline #190 failed with stages
in 0 seconds
/***********************************************************************
* Software License Agreement (BSD License)
*
* Copyright 2011-16 Jose Luis Blanco (joseluisblancoc@gmail.com).
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
*
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
* OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
* IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
* NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
* THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*************************************************************************/
#pragma once
#include "nanoflann.hpp"
#include <vector>
// ===== This example shows how to use nanoflann with these types of containers:
// =======
// typedef std::vector<std::vector<double> > my_vector_of_vectors_t;
// typedef std::vector<Eigen::VectorXd> my_vector_of_vectors_t; // This
// requires #include <Eigen/Dense>
// =====================================================================================
/** A simple vector-of-vectors adaptor for nanoflann, without duplicating the
* storage. The i'th vector represents a point in the state space.
*
* \tparam DIM If set to >0, it specifies a compile-time fixed dimensionality
* for the points in the data set, allowing more compiler optimizations. \tparam
* num_t The type of the point coordinates (typically, double or float). \tparam
* Distance The distance metric to use: nanoflann::metric_L1,
* nanoflann::metric_L2, nanoflann::metric_L2_Simple, etc. \tparam IndexType The
* type for indices in the KD-tree index (typically, size_t of int)
*/
template <class VectorOfVectorsType, typename num_t = double, int DIM = -1,
class Distance = nanoflann::metric_L2, typename IndexType = size_t>
struct KDTreeVectorOfVectorsAdaptor {
typedef KDTreeVectorOfVectorsAdaptor<VectorOfVectorsType, num_t, DIM,
Distance>
self_t;
typedef
typename Distance::template traits<num_t, self_t>::distance_t metric_t;
typedef nanoflann::KDTreeSingleIndexAdaptor<metric_t, self_t, DIM, IndexType>
index_t;
index_t *index; //! The kd-tree index for the user to call its methods as
//! usual with any other FLANN index.
/// Constructor: takes a const ref to the vector of vectors object with the
/// data points
KDTreeVectorOfVectorsAdaptor(const size_t /* dimensionality */,
const VectorOfVectorsType &mat,
const int leaf_max_size = 10)
: m_data(mat) {
assert(mat.size() != 0 && mat[0].size() != 0);
const size_t dims = mat[0].size();
if (DIM > 0 && static_cast<int>(dims) != DIM)
throw std::runtime_error(
"Data set dimensionality does not match the 'DIM' template argument");
index =
new index_t(static_cast<int>(dims), *this /* adaptor */,
nanoflann::KDTreeSingleIndexAdaptorParams(leaf_max_size));
index->buildIndex();
}
~KDTreeVectorOfVectorsAdaptor() { delete index; }
const VectorOfVectorsType &m_data;
/** Query for the \a num_closest closest points to a given point (entered as
* query_point[0:dim-1]). Note that this is a short-cut method for
* index->findNeighbors(). The user can also call index->... methods as
* desired. \note nChecks_IGNORED is ignored but kept for compatibility with
* the original FLANN interface.
*/
inline void query(const num_t *query_point, const size_t num_closest,
IndexType *out_indices, num_t *out_distances_sq,
const int nChecks_IGNORED = 10) const {
nanoflann::KNNResultSet<num_t, IndexType> resultSet(num_closest);
resultSet.init(out_indices, out_distances_sq);
index->findNeighbors(resultSet, query_point, nanoflann::SearchParams());
}
/** @name Interface expected by KDTreeSingleIndexAdaptor
* @{ */
const self_t &derived() const { return *this; }
self_t &derived() { return *this; }
// Must return the number of data points
inline size_t kdtree_get_point_count() const { return m_data.size(); }
// Returns the dim'th component of the idx'th point in the class:
inline num_t kdtree_get_pt(const size_t idx, const size_t dim) const {
return m_data[idx][dim];
}
// Optional bounding-box computation: return false to default to a standard
// bbox computation loop.
// Return true if the BBOX was already computed by the class and returned in
// "bb" so it can be avoided to redo it again. Look at bb.size() to find out
// the expected dimensionality (e.g. 2 or 3 for point clouds)
template <class BBOX> bool kdtree_get_bbox(BBOX & /*bb*/) const {
return false;
}
/** @} */
}; // end of KDTreeVectorOfVectorsAdaptor
This diff is collapsed.
#include <Python.h>
#include <torch/script.h>
#include "cpu/fps_cpu.h"
#ifdef WITH_HIP
#include "hip/fps_hip.h"
#endif
#ifdef _WIN32
#ifdef WITH_HIP
PyMODINIT_FUNC PyInit__fps_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__fps_cpu(void) { return NULL; }
#endif
#endif
torch::Tensor fps(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
bool random_start) {
if (src.device().is_cuda()) {
#ifdef WITH_HIP
return fps_cuda(src, ptr, ratio, random_start);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return fps_cpu(src, ptr, ratio, random_start);
}
}
static auto registry =
torch::RegisterOperators().op("torch_cluster::fps", &fps);
#include <Python.h>
#include <torch/script.h>
#include "cpu/graclus_cpu.h"
#ifdef WITH_HIP
#include "hip/graclus_hip.h"
#endif
#ifdef _WIN32
#ifdef WITH_HIP
PyMODINIT_FUNC PyInit__graclus_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__graclus_cpu(void) { return NULL; }
#endif
#endif
torch::Tensor graclus(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_HIP
return graclus_cuda(rowptr, col, optional_weight);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return graclus_cpu(rowptr, col, optional_weight);
}
}
static auto registry =
torch::RegisterOperators().op("torch_cluster::graclus", &graclus);
#include <Python.h>
#include <torch/script.h>
#include "cpu/grid_cpu.h"
#ifdef WITH_HIP
#include "hip/grid_hip.h"
#endif
#ifdef _WIN32
#ifdef WITH_HIP
PyMODINIT_FUNC PyInit__grid_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__grid_cpu(void) { return NULL; }
#endif
#endif
torch::Tensor grid(torch::Tensor pos, torch::Tensor size,
torch::optional<torch::Tensor> optional_start,
torch::optional<torch::Tensor> optional_end) {
if (pos.device().is_cuda()) {
#ifdef WITH_HIP
return grid_cuda(pos, size, optional_start, optional_end);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return grid_cpu(pos, size, optional_start, optional_end);
}
}
static auto registry =
torch::RegisterOperators().op("torch_cluster::grid", &grid);
#pragma once
#include <torch/extension.h>
torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
torch::Tensor ratio, bool random_start);
#include "hip/hip_runtime.h"
#include "fps_hip.h"
#include <ATen/hip/HIPContext.h>
#include "utils.cuh"
#define THREADS 256
template <typename scalar_t>
__global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
const int64_t *out_ptr, const int64_t *start,
scalar_t *dist, int64_t *out, int64_t dim) {
const int64_t thread_idx = threadIdx.x;
const int64_t batch_idx = blockIdx.x;
const int64_t start_idx = ptr[batch_idx];
const int64_t end_idx = ptr[batch_idx + 1];
__shared__ scalar_t best_dist[THREADS];
__shared__ int64_t best_dist_idx[THREADS];
if (thread_idx == 0) {
out[out_ptr[batch_idx]] = start_idx + start[batch_idx];
}
for (int64_t m = out_ptr[batch_idx] + 1; m < out_ptr[batch_idx + 1]; m++) {
__syncthreads();
int64_t old = out[m - 1];
scalar_t best = (scalar_t)-1.;
int64_t best_idx = 0;
for (int64_t n = start_idx + thread_idx; n < end_idx; n += THREADS) {
scalar_t tmp, dd = (scalar_t)0.;
for (int64_t d = 0; d < dim; d++) {
tmp = src[dim * old + d] - src[dim * n + d];
dd += tmp * tmp;
}
dd = min(dist[n], dd);
dist[n] = dd;
if (dd > best) {
best = dd;
best_idx = n;
}
}
best_dist[thread_idx] = best;
best_dist_idx[thread_idx] = best_idx;
for (int64_t i = 1; i < THREADS; i *= 2) {
__syncthreads();
if ((thread_idx + i) < THREADS &&
best_dist[thread_idx] < best_dist[thread_idx + i]) {
best_dist[thread_idx] = best_dist[thread_idx + i];
best_dist_idx[thread_idx] = best_dist_idx[thread_idx + i];
}
}
__syncthreads();
if (thread_idx == 0) {
out[m] = best_dist_idx[0];
}
}
}
torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
torch::Tensor ratio, bool random_start) {
CHECK_CUDA(src);
CHECK_CUDA(ptr);
CHECK_CUDA(ratio);
CHECK_INPUT(ptr.dim() == 1);
hipSetDevice(src.get_device());
src = src.view({src.size(0), -1}).contiguous();
ptr = ptr.contiguous();
auto batch_size = ptr.numel() - 1;
auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
auto out_ptr = deg.toType(ratio.scalar_type()) * ratio;
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
out_ptr = torch::cat({torch::zeros(1, ptr.options()), out_ptr}, 0);
torch::Tensor start;
if (random_start) {
start = torch::rand(batch_size, src.options());
start = (start * deg.toType(ratio.scalar_type())).toType(torch::kLong);
} else {
start = torch::zeros(batch_size, ptr.options());
}
auto dist = torch::full(src.size(0), 5e4, src.options());
auto out_size = (int64_t *)malloc(sizeof(int64_t));
hipMemcpy(out_size, out_ptr[-1].data_ptr<int64_t>(), sizeof(int64_t),
hipMemcpyDeviceToHost);
auto out = torch::empty(out_size[0], out_ptr.options());
auto stream = at::cuda::getCurrentCUDAStream();
auto scalar_type = src.scalar_type();
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
fps_kernel<scalar_t><<<batch_size, THREADS, 0, stream>>>(
src.data_ptr<scalar_t>(), ptr.data_ptr<int64_t>(),
out_ptr.data_ptr<int64_t>(), start.data_ptr<int64_t>(),
dist.data_ptr<scalar_t>(), out.data_ptr<int64_t>(), src.size(1));
});
return out;
}
#include "hip/hip_runtime.h"
#include "fps_hip.h"
#include <ATen/hip/HIPContext.h>
#include "utils.cuh"
#define THREADS 256
template <typename scalar_t>
__global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
const int64_t *out_ptr, const int64_t *start,
scalar_t *dist, int64_t *out, int64_t dim) {
const int64_t thread_idx = threadIdx.x;
const int64_t batch_idx = blockIdx.x;
const int64_t start_idx = ptr[batch_idx];
const int64_t end_idx = ptr[batch_idx + 1];
__shared__ scalar_t best_dist[THREADS];
__shared__ int64_t best_dist_idx[THREADS];
if (thread_idx == 0) {
out[out_ptr[batch_idx]] = start_idx + start[batch_idx];
}
for (int64_t m = out_ptr[batch_idx] + 1; m < out_ptr[batch_idx + 1]; m++) {
__syncthreads();
int64_t old = out[m - 1];
scalar_t best = (scalar_t)-1.;
int64_t best_idx = 0;
for (int64_t n = start_idx + thread_idx; n < end_idx; n += THREADS) {
scalar_t tmp, dd = (scalar_t)0.;
for (int64_t d = 0; d < dim; d++) {
tmp = src[dim * old + d] - src[dim * n + d];
dd += tmp * tmp;
}
dd = min(dist[n], dd);
dist[n] = dd;
if (dd > best) {
best = dd;
best_idx = n;
}
}
best_dist[thread_idx] = best;
best_dist_idx[thread_idx] = best_idx;
for (int64_t i = 1; i < THREADS; i *= 2) {
__syncthreads();
if ((thread_idx + i) < THREADS &&
best_dist[thread_idx] < best_dist[thread_idx + i]) {
best_dist[thread_idx] = best_dist[thread_idx + i];
best_dist_idx[thread_idx] = best_dist_idx[thread_idx + i];
}
}
__syncthreads();
if (thread_idx == 0) {
out[m] = best_dist_idx[0];
}
}
}
torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
torch::Tensor ratio, bool random_start) {
CHECK_CUDA(src);
CHECK_CUDA(ptr);
CHECK_CUDA(ratio);
CHECK_INPUT(ptr.dim() == 1);
hipSetDevice(src.get_device());
src = src.view({src.size(0), -1}).contiguous();
ptr = ptr.contiguous();
auto batch_size = ptr.numel() - 1;
auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
auto out_ptr = deg.toType(ratio.scalar_type()) * ratio;
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
out_ptr = torch::cat({torch::zeros(1, ptr.options()), out_ptr}, 0);
torch::Tensor start;
if (random_start) {
start = torch::rand(batch_size, src.options());
start = (start * deg.toType(ratio.scalar_type())).toType(torch::kLong);
} else {
start = torch::zeros(batch_size, ptr.options());
}
auto dist = torch::full(src.size(0), 5e4, src.options());
auto out_size = (int64_t *)malloc(sizeof(int64_t));
hipMemcpy(out_size, out_ptr[-1].data_ptr<int64_t>(), sizeof(int64_t),
hipMemcpyDeviceToHost);
auto out = torch::empty(out_size[0], out_ptr.options());
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
auto scalar_type = src.scalar_type();
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
hipLaunchKernelGGL(( fps_kernel<scalar_t>), dim3(batch_size), dim3(THREADS), 0, stream,
src.data_ptr<scalar_t>(), ptr.data_ptr<int64_t>(),
out_ptr.data_ptr<int64_t>(), start.data_ptr<int64_t>(),
dist.data_ptr<scalar_t>(), out.data_ptr<int64_t>(), src.size(1));
});
return out;
}
#pragma once
#include <torch/extension.h>
torch::Tensor graclus_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight);
#include "hip/hip_runtime.h"
#include "graclus_hip.h"
#include <ATen/hip/HIPContext.h>
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
#define BLUE_P 0.53406
__device__ bool done_d;
__global__ void init_done_kernel() { done_d = true; }
__global__ void colorize_kernel(int64_t *out, const float *bernoulli,
int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
if (out[thread_idx] < 0) {
out[thread_idx] = (int64_t)bernoulli[thread_idx] - 2;
done_d = false;
}
}
}
bool colorize(torch::Tensor out) {
auto stream = at::cuda::getCurrentCUDAStream();
init_done_kernel<<<1, 1, 0, stream>>>();
auto numel = out.size(0);
auto props = torch::full(numel, BLUE_P, out.options().dtype(torch::kFloat));
auto bernoulli = props.bernoulli();
colorize_kernel<<<BLOCKS(numel), THREADS, 0, stream>>>(
out.data_ptr<int64_t>(), bernoulli.data_ptr<float>(), numel);
bool done_h;
hipMemcpyFromSymbol(&done_h, HIP_SYMBOL(done_d), sizeof(done_h), 0,
hipMemcpyDeviceToHost);
return done_h;
}
__global__ void propose_kernel(int64_t *out, int64_t *proposal,
const int64_t *rowptr, const int64_t *col,
int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
if (out[thread_idx] != -1)
return; // Only vist blue nodes.
bool has_unmatched_neighbor = false;
for (int64_t i = rowptr[thread_idx]; i < rowptr[thread_idx + 1]; i++) {
auto v = col[i];
if (out[v] < 0)
has_unmatched_neighbor = true; // Unmatched neighbor found.
if (out[v] == -2) {
proposal[thread_idx] = v; // Propose to first red neighbor.
break;
}
}
if (!has_unmatched_neighbor)
out[thread_idx] = thread_idx;
}
}
template <typename scalar_t>
__global__ void weighted_propose_kernel(int64_t *out, int64_t *proposal,
const int64_t *rowptr,
const int64_t *col,
const scalar_t *weight, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
if (out[thread_idx] != -1)
return; // Only vist blue nodes.
bool has_unmatched_neighbor = false;
int64_t v_max = -1;
scalar_t w_max = 0;
for (int64_t i = rowptr[thread_idx]; i < rowptr[thread_idx + 1]; i++) {
auto v = col[i];
if (out[v] < 0)
has_unmatched_neighbor = true; // Unmatched neighbor found.
// Find maximum weighted red neighbor.
if (out[v] == -2 && weight[i] >= w_max) {
v_max = v;
w_max = weight[i];
}
}
proposal[thread_idx] = v_max; // Propose.
if (!has_unmatched_neighbor)
out[thread_idx] = thread_idx;
}
}
void propose(torch::Tensor out, torch::Tensor proposal, torch::Tensor rowptr,
torch::Tensor col,
torch::optional<torch::Tensor> optional_weight) {
auto stream = at::cuda::getCurrentCUDAStream();
if (!optional_weight.has_value()) {
propose_kernel<<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(), out.numel());
} else {
auto weight = optional_weight.value();
auto scalar_type = weight.scalar_type();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
weighted_propose_kernel<scalar_t>
<<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
weight.data_ptr<scalar_t>(), out.numel());
});
}
}
__global__ void respond_kernel(int64_t *out, const int64_t *proposal,
const int64_t *rowptr, const int64_t *col,
int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
if (out[thread_idx] != -2)
return; // Only vist red nodes.
bool has_unmatched_neighbor = false;
for (int64_t i = rowptr[thread_idx]; i < rowptr[thread_idx + 1]; i++) {
auto v = col[i];
if (out[v] < 0)
has_unmatched_neighbor = true; // Unmatched neighbor found.
if (out[v] == -1 && proposal[v] == thread_idx) {
// Match first blue neighbhor v which proposed to u.
out[thread_idx] = min(thread_idx, v);
out[v] = min(thread_idx, v);
break;
}
}
if (!has_unmatched_neighbor)
out[thread_idx] = thread_idx;
}
}
template <typename scalar_t>
__global__ void weighted_respond_kernel(int64_t *out, const int64_t *proposal,
const int64_t *rowptr,
const int64_t *col,
const scalar_t *weight, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
if (out[thread_idx] != -2)
return; // Only vist red nodes.
bool has_unmatched_neighbor = false;
int64_t v_max = -1;
scalar_t w_max = 0;
for (int64_t i = rowptr[thread_idx]; i < rowptr[thread_idx + 1]; i++) {
auto v = col[i];
if (out[v] < 0)
has_unmatched_neighbor = true; // Unmatched neighbor found.
if (out[v] == -1 && proposal[v] == thread_idx && weight[i] >= w_max) {
// Find maximum weighted blue neighbhor v which proposed to u.
v_max = v;
w_max = weight[i];
}
}
if (v_max >= 0) {
out[thread_idx] = min(thread_idx, v_max); // Match neighbors.
out[v_max] = min(thread_idx, v_max);
}
if (!has_unmatched_neighbor)
out[thread_idx] = thread_idx;
}
}
void respond(torch::Tensor out, torch::Tensor proposal, torch::Tensor rowptr,
torch::Tensor col,
torch::optional<torch::Tensor> optional_weight) {
auto stream = at::cuda::getCurrentCUDAStream();
if (!optional_weight.has_value()) {
respond_kernel<<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(), out.numel());
} else {
auto weight = optional_weight.value();
auto scalar_type = weight.scalar_type();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
weighted_respond_kernel<scalar_t>
<<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
weight.data_ptr<scalar_t>(), out.numel());
});
}
}
torch::Tensor graclus_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_INPUT(rowptr.dim() == 1 && col.dim() == 1);
if (optional_weight.has_value()) {
CHECK_CUDA(optional_weight.value());
CHECK_INPUT(optional_weight.value().dim() == 1);
CHECK_INPUT(optional_weight.value().numel() == col.numel());
}
hipSetDevice(rowptr.get_device());
int64_t num_nodes = rowptr.numel() - 1;
auto out = torch::full(num_nodes, -1, rowptr.options());
auto proposal = torch::full(num_nodes, -1, rowptr.options());
while (!colorize(out)) {
propose(out, proposal, rowptr, col, optional_weight);
respond(out, proposal, rowptr, col, optional_weight);
}
return out;
}
#include "hip/hip_runtime.h"
#include "graclus_hip.h"
#include <ATen/hip/HIPContext.h>
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
#define BLUE_P 0.53406
__device__ bool done_d;
__global__ void init_done_kernel() { done_d = true; }
__global__ void colorize_kernel(int64_t *out, const float *bernoulli,
int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
if (out[thread_idx] < 0) {
out[thread_idx] = (int64_t)bernoulli[thread_idx] - 2;
done_d = false;
}
}
}
bool colorize(torch::Tensor out) {
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
hipLaunchKernelGGL(( init_done_kernel), dim3(1), dim3(1), 0, stream, );
auto numel = out.size(0);
auto props = torch::full(numel, BLUE_P, out.options().dtype(torch::kFloat));
auto bernoulli = props.bernoulli();
hipLaunchKernelGGL(( colorize_kernel), dim3(BLOCKS(numel)), dim3(THREADS), 0, stream,
out.data_ptr<int64_t>(), bernoulli.data_ptr<float>(), numel);
bool done_h;
hipMemcpyFromSymbol(&done_h, HIP_SYMBOL(done_d), sizeof(done_h), 0,
hipMemcpyDeviceToHost);
return done_h;
}
__global__ void propose_kernel(int64_t *out, int64_t *proposal,
const int64_t *rowptr, const int64_t *col,
int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
if (out[thread_idx] != -1)
return; // Only vist blue nodes.
bool has_unmatched_neighbor = false;
for (int64_t i = rowptr[thread_idx]; i < rowptr[thread_idx + 1]; i++) {
auto v = col[i];
if (out[v] < 0)
has_unmatched_neighbor = true; // Unmatched neighbor found.
if (out[v] == -2) {
proposal[thread_idx] = v; // Propose to first red neighbor.
break;
}
}
if (!has_unmatched_neighbor)
out[thread_idx] = thread_idx;
}
}
template <typename scalar_t>
__global__ void weighted_propose_kernel(int64_t *out, int64_t *proposal,
const int64_t *rowptr,
const int64_t *col,
const scalar_t *weight, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
if (out[thread_idx] != -1)
return; // Only vist blue nodes.
bool has_unmatched_neighbor = false;
int64_t v_max = -1;
scalar_t w_max = 0;
for (int64_t i = rowptr[thread_idx]; i < rowptr[thread_idx + 1]; i++) {
auto v = col[i];
if (out[v] < 0)
has_unmatched_neighbor = true; // Unmatched neighbor found.
// Find maximum weighted red neighbor.
if (out[v] == -2 && weight[i] >= w_max) {
v_max = v;
w_max = weight[i];
}
}
proposal[thread_idx] = v_max; // Propose.
if (!has_unmatched_neighbor)
out[thread_idx] = thread_idx;
}
}
void propose(torch::Tensor out, torch::Tensor proposal, torch::Tensor rowptr,
torch::Tensor col,
torch::optional<torch::Tensor> optional_weight) {
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
if (!optional_weight.has_value()) {
hipLaunchKernelGGL(( propose_kernel), dim3(BLOCKS(out.numel())), dim3(THREADS), 0, stream,
out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(), out.numel());
} else {
auto weight = optional_weight.value();
auto scalar_type = weight.scalar_type();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
hipLaunchKernelGGL(( weighted_propose_kernel<scalar_t>)
, dim3(BLOCKS(out.numel())), dim3(THREADS), 0, stream,
out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
weight.data_ptr<scalar_t>(), out.numel());
});
}
}
__global__ void respond_kernel(int64_t *out, const int64_t *proposal,
const int64_t *rowptr, const int64_t *col,
int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
if (out[thread_idx] != -2)
return; // Only vist red nodes.
bool has_unmatched_neighbor = false;
for (int64_t i = rowptr[thread_idx]; i < rowptr[thread_idx + 1]; i++) {
auto v = col[i];
if (out[v] < 0)
has_unmatched_neighbor = true; // Unmatched neighbor found.
if (out[v] == -1 && proposal[v] == thread_idx) {
// Match first blue neighbhor v which proposed to u.
out[thread_idx] = min(thread_idx, v);
out[v] = min(thread_idx, v);
break;
}
}
if (!has_unmatched_neighbor)
out[thread_idx] = thread_idx;
}
}
template <typename scalar_t>
__global__ void weighted_respond_kernel(int64_t *out, const int64_t *proposal,
const int64_t *rowptr,
const int64_t *col,
const scalar_t *weight, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
if (out[thread_idx] != -2)
return; // Only vist red nodes.
bool has_unmatched_neighbor = false;
int64_t v_max = -1;
scalar_t w_max = 0;
for (int64_t i = rowptr[thread_idx]; i < rowptr[thread_idx + 1]; i++) {
auto v = col[i];
if (out[v] < 0)
has_unmatched_neighbor = true; // Unmatched neighbor found.
if (out[v] == -1 && proposal[v] == thread_idx && weight[i] >= w_max) {
// Find maximum weighted blue neighbhor v which proposed to u.
v_max = v;
w_max = weight[i];
}
}
if (v_max >= 0) {
out[thread_idx] = min(thread_idx, v_max); // Match neighbors.
out[v_max] = min(thread_idx, v_max);
}
if (!has_unmatched_neighbor)
out[thread_idx] = thread_idx;
}
}
void respond(torch::Tensor out, torch::Tensor proposal, torch::Tensor rowptr,
torch::Tensor col,
torch::optional<torch::Tensor> optional_weight) {
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
if (!optional_weight.has_value()) {
hipLaunchKernelGGL(( respond_kernel), dim3(BLOCKS(out.numel())), dim3(THREADS), 0, stream,
out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(), out.numel());
} else {
auto weight = optional_weight.value();
auto scalar_type = weight.scalar_type();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
hipLaunchKernelGGL(( weighted_respond_kernel<scalar_t>)
, dim3(BLOCKS(out.numel())), dim3(THREADS), 0, stream,
out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
weight.data_ptr<scalar_t>(), out.numel());
});
}
}
torch::Tensor graclus_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_INPUT(rowptr.dim() == 1 && col.dim() == 1);
if (optional_weight.has_value()) {
CHECK_CUDA(optional_weight.value());
CHECK_INPUT(optional_weight.value().dim() == 1);
CHECK_INPUT(optional_weight.value().numel() == col.numel());
}
hipSetDevice(rowptr.get_device());
int64_t num_nodes = rowptr.numel() - 1;
auto out = torch::full(num_nodes, -1, rowptr.options());
auto proposal = torch::full(num_nodes, -1, rowptr.options());
while (!colorize(out)) {
propose(out, proposal, rowptr, col, optional_weight);
respond(out, proposal, rowptr, col, optional_weight);
}
return out;
}
#pragma once
#include <torch/extension.h>
torch::Tensor grid_cuda(torch::Tensor pos, torch::Tensor size,
torch::optional<torch::Tensor> optional_start,
torch::optional<torch::Tensor> optional_end);
#include "hip/hip_runtime.h"
#include "grid_hip.h"
#include <ATen/hip/HIPContext.h>
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template <typename scalar_t>
__global__ void grid_kernel(const scalar_t *pos, const scalar_t *size,
const scalar_t *start, const scalar_t *end,
int64_t *out, int64_t D, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
int64_t c = 0, k = 1;
for (int64_t d = 0; d < D; d++) {
scalar_t p = pos[thread_idx * D + d] - start[d];
c += (int64_t)(p / size[d]) * k;
k *= (int64_t)((end[d] - start[d]) / size[d]) + 1;
}
out[thread_idx] = c;
}
}
torch::Tensor grid_cuda(torch::Tensor pos, torch::Tensor size,
torch::optional<torch::Tensor> optional_start,
torch::optional<torch::Tensor> optional_end) {
CHECK_CUDA(pos);
CHECK_CUDA(size);
hipSetDevice(pos.get_device());
if (optional_start.has_value())
CHECK_CUDA(optional_start.value());
if (optional_start.has_value())
CHECK_CUDA(optional_start.value());
pos = pos.view({pos.size(0), -1}).contiguous();
size = size.contiguous();
CHECK_INPUT(size.numel() == pos.size(1));
if (!optional_start.has_value())
optional_start = std::get<0>(pos.min(0));
else {
optional_start = optional_start.value().contiguous();
CHECK_INPUT(optional_start.value().numel() == pos.size(1));
}
if (!optional_end.has_value())
optional_end = std::get<0>(pos.max(0));
else {
optional_start = optional_start.value().contiguous();
CHECK_INPUT(optional_start.value().numel() == pos.size(1));
}
auto start = optional_start.value();
auto end = optional_end.value();
auto out = torch::empty(pos.size(0), pos.options().dtype(torch::kLong));
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, pos.scalar_type(), "_", [&] {
grid_kernel<scalar_t><<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
pos.data_ptr<scalar_t>(), size.data_ptr<scalar_t>(),
start.data_ptr<scalar_t>(), end.data_ptr<scalar_t>(),
out.data_ptr<int64_t>(), pos.size(1), out.numel());
});
return out;
}
#include "hip/hip_runtime.h"
#include "grid_hip.h"
#include <ATen/hip/HIPContext.h>
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template <typename scalar_t>
__global__ void grid_kernel(const scalar_t *pos, const scalar_t *size,
const scalar_t *start, const scalar_t *end,
int64_t *out, int64_t D, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
int64_t c = 0, k = 1;
for (int64_t d = 0; d < D; d++) {
scalar_t p = pos[thread_idx * D + d] - start[d];
c += (int64_t)(p / size[d]) * k;
k *= (int64_t)((end[d] - start[d]) / size[d]) + 1;
}
out[thread_idx] = c;
}
}
torch::Tensor grid_cuda(torch::Tensor pos, torch::Tensor size,
torch::optional<torch::Tensor> optional_start,
torch::optional<torch::Tensor> optional_end) {
CHECK_CUDA(pos);
CHECK_CUDA(size);
hipSetDevice(pos.get_device());
if (optional_start.has_value())
CHECK_CUDA(optional_start.value());
if (optional_start.has_value())
CHECK_CUDA(optional_start.value());
pos = pos.view({pos.size(0), -1}).contiguous();
size = size.contiguous();
CHECK_INPUT(size.numel() == pos.size(1));
if (!optional_start.has_value())
optional_start = std::get<0>(pos.min(0));
else {
optional_start = optional_start.value().contiguous();
CHECK_INPUT(optional_start.value().numel() == pos.size(1));
}
if (!optional_end.has_value())
optional_end = std::get<0>(pos.max(0));
else {
optional_start = optional_start.value().contiguous();
CHECK_INPUT(optional_start.value().numel() == pos.size(1));
}
auto start = optional_start.value();
auto end = optional_end.value();
auto out = torch::empty(pos.size(0), pos.options().dtype(torch::kLong));
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, pos.scalar_type(), "_", [&] {
hipLaunchKernelGGL(( grid_kernel<scalar_t>), dim3(BLOCKS(out.numel())), dim3(THREADS), 0, stream,
pos.data_ptr<scalar_t>(), size.data_ptr<scalar_t>(),
start.data_ptr<scalar_t>(), end.data_ptr<scalar_t>(),
out.data_ptr<int64_t>(), pos.size(1), out.numel());
});
return out;
}
#pragma once
#include <torch/extension.h>
torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, int64_t k,
bool cosine);
#include "hip/hip_runtime.h"
#include "radius_hip.h"
#include <ATen/hip/HIPContext.h>
#include "utils.cuh"
#define THREADS 256
template <typename scalar_t> struct Cosine {
static inline __device__ scalar_t dot(const scalar_t *a, const scalar_t *b,
int64_t n_a, int64_t n_b,
int64_t size) {
scalar_t result = 0;
for (int64_t i = 0; i < size; i++) {
result += a[n_a * size + i] * b[n_b * size + i];
}
return result;
}
static inline __device__ scalar_t norm(const scalar_t *a, int64_t n_a,
int64_t size) {
scalar_t result = 0;
for (int64_t i = 0; i < size; i++) {
result += a[n_a * size + i] * a[n_a * size + i];
}
return sqrt(result);
}
};
template <typename scalar_t>
__global__ void
knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
const int64_t *__restrict__ ptr_x, const int64_t *__restrict__ ptr_y,
int64_t *__restrict__ row, int64_t *__restrict__ col,
const int64_t k, const int64_t n, const int64_t m, const int64_t dim,
const int64_t num_examples, const bool cosine) {
const int64_t n_y = blockIdx.x * blockDim.x + threadIdx.x;
if (n_y >= m)
return;
const int64_t example_idx = get_example_idx(n_y, ptr_y, num_examples);
scalar_t best_dist[100];
int64_t best_idx[100];
for (int e = 0; e < k; e++) {
best_dist[e] = 5e4;
best_idx[e] = -1;
}
for (int64_t n_x = ptr_x[example_idx]; n_x < ptr_x[example_idx + 1]; n_x++) {
scalar_t tmp_dist = 0;
if (cosine) {
tmp_dist = Cosine<scalar_t>::dot(x, y, n_x, n_y, dim) /
(Cosine<scalar_t>::norm(x, n_x, dim) *
Cosine<scalar_t>::norm(y, n_y, dim));
tmp_dist = 1. - tmp_dist;
} else {
for (int64_t d = 0; d < dim; d++) {
tmp_dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
(x[n_x * dim + d] - y[n_y * dim + d]);
}
}
for (int64_t e1 = 0; e1 < k; e1++) {
if (best_dist[e1] > tmp_dist) {
for (int64_t e2 = k - 1; e2 > e1; e2--) {
best_dist[e2] = best_dist[e2 - 1];
best_idx[e2] = best_idx[e2 - 1];
}
best_dist[e1] = tmp_dist;
best_idx[e1] = n_x;
break;
}
}
}
for (int64_t e = 0; e < k; e++) {
row[n_y * k + e] = n_y;
col[n_y * k + e] = best_idx[e];
}
}
torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, const int64_t k,
const bool cosine) {
CHECK_CUDA(x);
CHECK_CONTIGUOUS(x);
CHECK_INPUT(x.dim() == 2);
CHECK_CUDA(y);
CHECK_CONTIGUOUS(y);
CHECK_INPUT(y.dim() == 2);
CHECK_INPUT(x.size(1) == y.size(1));
AT_ASSERTM(k <= 100, "`k` needs to smaller than or equal to 100");
if (ptr_x.has_value()) {
CHECK_CUDA(ptr_x.value());
CHECK_INPUT(ptr_x.value().dim() == 1);
} else
ptr_x = torch::arange(0, x.size(0) + 1, x.size(0),
x.options().dtype(torch::kLong));
if (ptr_y.has_value()) {
CHECK_CUDA(ptr_y.value());
CHECK_INPUT(ptr_y.value().dim() == 1);
} else
ptr_y = torch::arange(0, y.size(0) + 1, y.size(0),
y.options().dtype(torch::kLong));
CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel());
hipSetDevice(x.get_device());
auto row = torch::empty(y.size(0) * k, ptr_y.value().options());
auto col = torch::full(y.size(0) * k, -1, ptr_y.value().options());
dim3 BLOCKS((y.size(0) + THREADS - 1) / THREADS);
auto stream = at::cuda::getCurrentCUDAStream();
auto scalar_type = x.scalar_type();
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
knn_kernel<scalar_t><<<BLOCKS, THREADS, 0, stream>>>(
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),
row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), k, x.size(0),
y.size(0), x.size(1), ptr_x.value().numel() - 1, cosine);
});
auto mask = col != -1;
return torch::stack({row.masked_select(mask), col.masked_select(mask)}, 0);
}
#include "hip/hip_runtime.h"
#include "radius_hip.h"
#include <ATen/hip/HIPContext.h>
#include "utils.cuh"
#define THREADS 256
template <typename scalar_t> struct Cosine {
static inline __device__ scalar_t dot(const scalar_t *a, const scalar_t *b,
int64_t n_a, int64_t n_b,
int64_t size) {
scalar_t result = 0;
for (int64_t i = 0; i < size; i++) {
result += a[n_a * size + i] * b[n_b * size + i];
}
return result;
}
static inline __device__ scalar_t norm(const scalar_t *a, int64_t n_a,
int64_t size) {
scalar_t result = 0;
for (int64_t i = 0; i < size; i++) {
result += a[n_a * size + i] * a[n_a * size + i];
}
return sqrt(result);
}
};
template <typename scalar_t>
__global__ void
knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
const int64_t *__restrict__ ptr_x, const int64_t *__restrict__ ptr_y,
int64_t *__restrict__ row, int64_t *__restrict__ col,
const int64_t k, const int64_t n, const int64_t m, const int64_t dim,
const int64_t num_examples, const bool cosine) {
const int64_t n_y = blockIdx.x * blockDim.x + threadIdx.x;
if (n_y >= m)
return;
const int64_t example_idx = get_example_idx(n_y, ptr_y, num_examples);
scalar_t best_dist[100];
int64_t best_idx[100];
for (int e = 0; e < k; e++) {
best_dist[e] = 5e4;
best_idx[e] = -1;
}
for (int64_t n_x = ptr_x[example_idx]; n_x < ptr_x[example_idx + 1]; n_x++) {
scalar_t tmp_dist = 0;
if (cosine) {
tmp_dist = Cosine<scalar_t>::dot(x, y, n_x, n_y, dim) /
(Cosine<scalar_t>::norm(x, n_x, dim) *
Cosine<scalar_t>::norm(y, n_y, dim));
tmp_dist = 1. - tmp_dist;
} else {
for (int64_t d = 0; d < dim; d++) {
tmp_dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
(x[n_x * dim + d] - y[n_y * dim + d]);
}
}
for (int64_t e1 = 0; e1 < k; e1++) {
if (best_dist[e1] > tmp_dist) {
for (int64_t e2 = k - 1; e2 > e1; e2--) {
best_dist[e2] = best_dist[e2 - 1];
best_idx[e2] = best_idx[e2 - 1];
}
best_dist[e1] = tmp_dist;
best_idx[e1] = n_x;
break;
}
}
}
for (int64_t e = 0; e < k; e++) {
row[n_y * k + e] = n_y;
col[n_y * k + e] = best_idx[e];
}
}
torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, const int64_t k,
const bool cosine) {
CHECK_CUDA(x);
CHECK_CONTIGUOUS(x);
CHECK_INPUT(x.dim() == 2);
CHECK_CUDA(y);
CHECK_CONTIGUOUS(y);
CHECK_INPUT(y.dim() == 2);
CHECK_INPUT(x.size(1) == y.size(1));
AT_ASSERTM(k <= 100, "`k` needs to smaller than or equal to 100");
if (ptr_x.has_value()) {
CHECK_CUDA(ptr_x.value());
CHECK_INPUT(ptr_x.value().dim() == 1);
} else
ptr_x = torch::arange(0, x.size(0) + 1, x.size(0),
x.options().dtype(torch::kLong));
if (ptr_y.has_value()) {
CHECK_CUDA(ptr_y.value());
CHECK_INPUT(ptr_y.value().dim() == 1);
} else
ptr_y = torch::arange(0, y.size(0) + 1, y.size(0),
y.options().dtype(torch::kLong));
CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel());
hipSetDevice(x.get_device());
auto row = torch::empty(y.size(0) * k, ptr_y.value().options());
auto col = torch::full(y.size(0) * k, -1, ptr_y.value().options());
dim3 BLOCKS((y.size(0) + THREADS - 1) / THREADS);
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
auto scalar_type = x.scalar_type();
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
hipLaunchKernelGGL(( knn_kernel<scalar_t>), dim3(BLOCKS), dim3(THREADS), 0, stream,
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),
row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), k, x.size(0),
y.size(0), x.size(1), ptr_x.value().numel() - 1, cosine);
});
auto mask = col != -1;
return torch::stack({row.masked_select(mask), col.masked_select(mask)}, 0);
}
#pragma once
#include <torch/extension.h>
torch::Tensor nearest_cuda(torch::Tensor x, torch::Tensor y,
torch::Tensor ptr_x, torch::Tensor ptr_y);
#include "hip/hip_runtime.h"
#include "nearest_hip.h"
#include <ATen/hip/HIPContext.h>
#include "utils.cuh"
#define THREADS 1024
template <typename scalar_t>
__global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
const int64_t *ptr_x, const int64_t *ptr_y,
int64_t *out, int64_t batch_size, int64_t dim) {
const int64_t thread_idx = threadIdx.x;
const int64_t n_x = blockIdx.x;
int64_t batch_idx;
for (int64_t b = 0; b < batch_size; b++) {
if (n_x >= ptr_x[b] && n_x < ptr_x[b + 1]) {
batch_idx = b;
break;
}
}
const int64_t y_start_idx = ptr_y[batch_idx];
const int64_t y_end_idx = ptr_y[batch_idx + 1];
__shared__ scalar_t best_dist[THREADS];
__shared__ int64_t best_dist_idx[THREADS];
scalar_t best = 1e38;
int64_t best_idx = 0;
for (int64_t n_y = y_start_idx + thread_idx; n_y < y_end_idx;
n_y += THREADS) {
scalar_t dist = 0;
for (int64_t d = 0; d < dim; d++) {
dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
(x[n_x * dim + d] - y[n_y * dim + d]);
}
if (dist < best) {
best = dist;
best_idx = n_y;
}
}
best_dist[thread_idx] = best;
best_dist_idx[thread_idx] = best_idx;
for (int64_t u = 0; (1 << u) < THREADS; u++) {
__syncthreads();
if (thread_idx < (THREADS >> (u + 1))) {
int64_t idx_1 = (thread_idx * 2) << u;
int64_t idx_2 = (thread_idx * 2 + 1) << u;
if (best_dist[idx_1] > best_dist[idx_2]) {
best_dist[idx_1] = best_dist[idx_2];
best_dist_idx[idx_1] = best_dist_idx[idx_2];
}
}
}
__syncthreads();
if (thread_idx == 0) {
out[n_x] = best_dist_idx[0];
}
}
torch::Tensor nearest_cuda(torch::Tensor x, torch::Tensor y,
torch::Tensor ptr_x, torch::Tensor ptr_y) {
CHECK_CUDA(x);
CHECK_CUDA(y);
CHECK_CUDA(ptr_x);
CHECK_CUDA(ptr_y);
hipSetDevice(x.get_device());
x = x.view({x.size(0), -1}).contiguous();
y = y.view({y.size(0), -1}).contiguous();
auto out = torch::empty({x.size(0)}, ptr_x.options());
auto stream = at::cuda::getCurrentCUDAStream();
auto scalar_type = x.scalar_type();
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
nearest_kernel<scalar_t><<<x.size(0), THREADS, 0, stream>>>(
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
ptr_x.data_ptr<int64_t>(), ptr_y.data_ptr<int64_t>(),
out.data_ptr<int64_t>(), ptr_x.size(0) - 1, x.size(1));
});
return out;
}
#include "hip/hip_runtime.h"
#include "nearest_hip.h"
#include <ATen/hip/HIPContext.h>
#include "utils.cuh"
#define THREADS 1024
template <typename scalar_t>
__global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
const int64_t *ptr_x, const int64_t *ptr_y,
int64_t *out, int64_t batch_size, int64_t dim) {
const int64_t thread_idx = threadIdx.x;
const int64_t n_x = blockIdx.x;
int64_t batch_idx;
for (int64_t b = 0; b < batch_size; b++) {
if (n_x >= ptr_x[b] && n_x < ptr_x[b + 1]) {
batch_idx = b;
break;
}
}
const int64_t y_start_idx = ptr_y[batch_idx];
const int64_t y_end_idx = ptr_y[batch_idx + 1];
__shared__ scalar_t best_dist[THREADS];
__shared__ int64_t best_dist_idx[THREADS];
scalar_t best = 1e38;
int64_t best_idx = 0;
for (int64_t n_y = y_start_idx + thread_idx; n_y < y_end_idx;
n_y += THREADS) {
scalar_t dist = 0;
for (int64_t d = 0; d < dim; d++) {
dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
(x[n_x * dim + d] - y[n_y * dim + d]);
}
if (dist < best) {
best = dist;
best_idx = n_y;
}
}
best_dist[thread_idx] = best;
best_dist_idx[thread_idx] = best_idx;
for (int64_t u = 0; (1 << u) < THREADS; u++) {
__syncthreads();
if (thread_idx < (THREADS >> (u + 1))) {
int64_t idx_1 = (thread_idx * 2) << u;
int64_t idx_2 = (thread_idx * 2 + 1) << u;
if (best_dist[idx_1] > best_dist[idx_2]) {
best_dist[idx_1] = best_dist[idx_2];
best_dist_idx[idx_1] = best_dist_idx[idx_2];
}
}
}
__syncthreads();
if (thread_idx == 0) {
out[n_x] = best_dist_idx[0];
}
}
torch::Tensor nearest_cuda(torch::Tensor x, torch::Tensor y,
torch::Tensor ptr_x, torch::Tensor ptr_y) {
CHECK_CUDA(x);
CHECK_CUDA(y);
CHECK_CUDA(ptr_x);
CHECK_CUDA(ptr_y);
hipSetDevice(x.get_device());
x = x.view({x.size(0), -1}).contiguous();
y = y.view({y.size(0), -1}).contiguous();
auto out = torch::empty({x.size(0)}, ptr_x.options());
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
auto scalar_type = x.scalar_type();
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
hipLaunchKernelGGL(( nearest_kernel<scalar_t>), dim3(x.size(0)), dim3(THREADS), 0, stream,
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
ptr_x.data_ptr<int64_t>(), ptr_y.data_ptr<int64_t>(),
out.data_ptr<int64_t>(), ptr_x.size(0) - 1, x.size(1));
});
return out;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment