"vscode:/vscode.git/clone" did not exist on "80e6aff6e75195b96dfbf16093049eed89793b63"
radius_cpu.cpp 3.36 KB
Newer Older
Alexander Liao's avatar
Alexander Liao committed
1
#include "radius_cpu.h"
2
#include <algorithm>
Alexander Liao's avatar
Alexander Liao committed
3
#include "utils.h"
4
5
#include <cstdint>

Alexander Liao's avatar
Alexander Liao committed
6

7
torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support, 
8
			 double radius, int64_t max_num, int64_t n_threads){
Alexander Liao's avatar
Alexander Liao committed
9

10
11
	CHECK_CPU(query);
	CHECK_CPU(support);
Alexander Liao's avatar
Alexander Liao committed
12
13

	torch::Tensor out;
14
	std::vector<size_t>* neighbors_indices = new std::vector<size_t>(); 
Alexander Liao's avatar
Alexander Liao committed
15
16
17
18
19
	auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
	int max_count = 0;

	AT_DISPATCH_ALL_TYPES(query.scalar_type(), "radius_cpu", [&] {

20
21
	auto data_q = query.data_ptr<scalar_t>();
	auto data_s = support.data_ptr<scalar_t>();
Alexander Liao's avatar
Alexander Liao committed
22
23
24
25
26
27
28
	std::vector<scalar_t> queries_stl = std::vector<scalar_t>(data_q,
								   data_q + query.size(0)*query.size(1));
	std::vector<scalar_t> supports_stl = std::vector<scalar_t>(data_s,
								   data_s + support.size(0)*support.size(1));

	int dim = torch::size(query, 1);

29
	max_count = nanoflann_neighbors<scalar_t>(queries_stl, supports_stl ,neighbors_indices, radius, dim, max_num, n_threads);
Alexander Liao's avatar
Alexander Liao committed
30
31
32

	});

33
	size_t* neighbors_indices_ptr = neighbors_indices->data();
Alexander Liao's avatar
Alexander Liao committed
34

35
	const long long tsize = static_cast<long long>(neighbors_indices->size()/2);
36
37
38
	out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
	out = out.t();

39
	return out.clone();
Alexander Liao's avatar
Alexander Liao committed
40
41
}

42

43
void get_size_batch(const std::vector<long>& batch, std::vector<long>& res){
Alexander Liao's avatar
Alexander Liao committed
44
45
46
47

	res.resize(batch[batch.size()-1]-batch[0]+1, 0);
	long ind = batch[0];
	long incr = 1;
48
	for(unsigned long i=1; i < batch.size(); i++){
Alexander Liao's avatar
Alexander Liao committed
49
50
51
52
53
54
55
56
57
58

		if(batch[i] == ind)
			incr++;
		else{
			res[ind-batch[0]] = incr;
			incr =1;
			ind = batch[i];
		}
	}
	res[ind-batch[0]] = incr;
59
60
61
62
63
64
}

torch::Tensor batch_radius_cpu(torch::Tensor query,
			       torch::Tensor support,
			       torch::Tensor query_batch,
			       torch::Tensor support_batch,
65
			       double radius, int64_t max_num) {
66
67

	torch::Tensor out;
68
69
	auto data_qb = query_batch.data_ptr<int64_t>();
	auto data_sb = support_batch.data_ptr<int64_t>();
70
71
72
73
74
75
	std::vector<long> query_batch_stl = std::vector<long>(data_qb, data_qb+query_batch.size(0));
	std::vector<long> size_query_batch_stl;
	get_size_batch(query_batch_stl, size_query_batch_stl);
	std::vector<long> support_batch_stl = std::vector<long>(data_sb, data_sb+support_batch.size(0));
	std::vector<long> size_support_batch_stl;
	get_size_batch(support_batch_stl, size_support_batch_stl);
76
	std::vector<size_t>* neighbors_indices = new std::vector<size_t>(); 
77
78
79
	auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
	int max_count = 0;

80
	AT_DISPATCH_ALL_TYPES(query.scalar_type(), "batch_radius_cpu", [&] {
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
	auto data_q = query.data_ptr<scalar_t>();
	auto data_s = support.data_ptr<scalar_t>();
	std::vector<scalar_t> queries_stl = std::vector<scalar_t>(data_q,
								  data_q + query.size(0)*query.size(1));
	std::vector<scalar_t> supports_stl = std::vector<scalar_t>(data_s,
								   data_s + support.size(0)*support.size(1));

	int dim = torch::size(query, 1);
	max_count = batch_nanoflann_neighbors<scalar_t>(queries_stl,
							    supports_stl,
							    size_query_batch_stl,
							    size_support_batch_stl,
							    neighbors_indices,
							    radius,
								dim,
							    max_num
							    );
	});

100
	size_t* neighbors_indices_ptr = neighbors_indices->data();
101
102


103
	const long long tsize = static_cast<long long>(neighbors_indices->size()/2);
104
105
106
107
	out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
	out = out.t();

	return out.clone();
Alexander Liao's avatar
Alexander Liao committed
108
}