radius_cpu.cpp 3.32 KB
Newer Older
Alexander Liao's avatar
Alexander Liao committed
1
#include "radius_cpu.h"
2
#include <algorithm>
Alexander Liao's avatar
Alexander Liao committed
3
4
#include "utils.h"

5
torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support, 
Alexander Liao's avatar
Alexander Liao committed
6
7
			 float radius, int max_num){

8
9
	CHECK_CPU(query);
	CHECK_CPU(support);
Alexander Liao's avatar
Alexander Liao committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

	torch::Tensor out;
	std::vector<long> neighbors_indices;
	auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
	int max_count = 0;

	AT_DISPATCH_ALL_TYPES(query.scalar_type(), "radius_cpu", [&] {

	auto data_q = query.DATA_PTR<scalar_t>();
	auto data_s = support.DATA_PTR<scalar_t>();
	std::vector<scalar_t> queries_stl = std::vector<scalar_t>(data_q,
								   data_q + query.size(0)*query.size(1));
	std::vector<scalar_t> supports_stl = std::vector<scalar_t>(data_s,
								   data_s + support.size(0)*support.size(1));

	int dim = torch::size(query, 1);

	max_count = nanoflann_neighbors<scalar_t>(queries_stl, supports_stl ,neighbors_indices, radius, dim, max_num);

	});

	long* neighbors_indices_ptr = neighbors_indices.data();

33
34
35
36
37
38
39
40
41
42
43
	const long long tsize = static_cast<long long>(neighbors_indices.size()/2);
	out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
	out = out.t();

	auto result = torch::zeros_like(out);

	auto index = torch::tensor({0,1});

	result.index_copy_(0, index, out);

	return result;
Alexander Liao's avatar
Alexander Liao committed
44
45
}

46

Alexander Liao's avatar
Alexander Liao committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
void get_size_batch(const vector<long>& batch, vector<long>& res){

	res.resize(batch[batch.size()-1]-batch[0]+1, 0);
	long ind = batch[0];
	long incr = 1;
	for(int i=1; i < batch.size(); i++){

		if(batch[i] == ind)
			incr++;
		else{
			res[ind-batch[0]] = incr;
			incr =1;
			ind = batch[i];
		}
	}
	res[ind-batch[0]] = incr;
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
}

torch::Tensor batch_radius_cpu(torch::Tensor query,
			       torch::Tensor support,
			       torch::Tensor query_batch,
			       torch::Tensor support_batch,
			       float radius, int max_num) {

	torch::Tensor out;
	auto data_qb = query_batch.data_ptr<long>();
	auto data_sb = support_batch.data_ptr<long>();
	std::vector<long> query_batch_stl = std::vector<long>(data_qb, data_qb+query_batch.size(0));
	std::vector<long> size_query_batch_stl;
	get_size_batch(query_batch_stl, size_query_batch_stl);
	std::vector<long> support_batch_stl = std::vector<long>(data_sb, data_sb+support_batch.size(0));
	std::vector<long> size_support_batch_stl;
	get_size_batch(support_batch_stl, size_support_batch_stl);
	std::vector<long> neighbors_indices;
	auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
	int max_count = 0;

	
	AT_DISPATCH_ALL_TYPES(query.scalar_type(), "batch_radius_search", [&] {
	auto data_q = query.data_ptr<scalar_t>();
	auto data_s = support.data_ptr<scalar_t>();
	std::vector<scalar_t> queries_stl = std::vector<scalar_t>(data_q,
								  data_q + query.size(0)*query.size(1));
	std::vector<scalar_t> supports_stl = std::vector<scalar_t>(data_s,
								   data_s + support.size(0)*support.size(1));

	int dim = torch::size(query, 1);
	max_count = batch_nanoflann_neighbors<scalar_t>(queries_stl,
							    supports_stl,
							    size_query_batch_stl,
							    size_support_batch_stl,
							    neighbors_indices,
							    radius,
								dim,
							    max_num
							    );
	});

	long* neighbors_indices_ptr = neighbors_indices.data();


	const long long tsize = static_cast<long long>(neighbors_indices.size()/2);
	out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
	out = out.t();

	return out.clone();
Alexander Liao's avatar
Alexander Liao committed
113
}