"launcher/vscode:/vscode.git/clone" did not exist on "2335459556d50d5dd5e39dc41e0ee16366d655dc"
knn_cpu.cpp 3.77 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#include "radius_cpu.h"
#include <algorithm>
#include "utils.h"
#include <cstdint>


torch::Tensor knn_cpu(torch::Tensor support, torch::Tensor query, 
			 int64_t k, int64_t n_threads){

	CHECK_CPU(query);
	CHECK_CPU(support);

	torch::Tensor out;
	std::vector<size_t>* neighbors_indices = new std::vector<size_t>(); 
	auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
	int max_count = 0;

	AT_DISPATCH_ALL_TYPES(query.scalar_type(), "radius_cpu", [&] {

	auto data_q = query.data_ptr<scalar_t>();
	auto data_s = support.data_ptr<scalar_t>();
	std::vector<scalar_t> queries_stl = std::vector<scalar_t>(data_q,
								   data_q + query.size(0)*query.size(1));
	std::vector<scalar_t> supports_stl = std::vector<scalar_t>(data_s,
								   data_s + support.size(0)*support.size(1));

	int dim = torch::size(query, 1);

	max_count = nanoflann_neighbors<scalar_t>(queries_stl, supports_stl ,neighbors_indices, 0, dim, 0, n_threads, k, 0);

	});

	size_t* neighbors_indices_ptr = neighbors_indices->data();

	const long long tsize = static_cast<long long>(neighbors_indices->size()/2);
	out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
	out = out.t();

	auto result = torch::zeros_like(out);

	auto index = torch::tensor({1,0});

	result.index_copy_(0, index, out);

	return result;
}


void get_size_batch(const std::vector<long>& batch, std::vector<long>& res){

	res.resize(batch[batch.size()-1]-batch[0]+1, 0);
	long ind = batch[0];
	long incr = 1;
	for(unsigned long i=1; i < batch.size(); i++){

		if(batch[i] == ind)
			incr++;
		else{
			res[ind-batch[0]] = incr;
			incr =1;
			ind = batch[i];
		}
	}
	res[ind-batch[0]] = incr;
}

torch::Tensor batch_knn_cpu(torch::Tensor support,
			       torch::Tensor query,
			       torch::Tensor support_batch,
			       torch::Tensor query_batch,
			       int64_t k) {

	CHECK_CPU(query);
	CHECK_CPU(support);
	CHECK_CPU(query_batch);
	CHECK_CPU(support_batch);

	torch::Tensor out;
	auto data_qb = query_batch.data_ptr<int64_t>();
	auto data_sb = support_batch.data_ptr<int64_t>();
	
	std::vector<long> query_batch_stl = std::vector<long>(data_qb, data_qb+query_batch.size(0));
	std::vector<long> size_query_batch_stl;
	CHECK_INPUT(std::is_sorted(query_batch_stl.begin(),query_batch_stl.end()));
	get_size_batch(query_batch_stl, size_query_batch_stl);
	
	std::vector<long> support_batch_stl = std::vector<long>(data_sb, data_sb+support_batch.size(0));
	std::vector<long> size_support_batch_stl;
	CHECK_INPUT(std::is_sorted(support_batch_stl.begin(),support_batch_stl.end()));
	get_size_batch(support_batch_stl, size_support_batch_stl);
	
	std::vector<size_t>* neighbors_indices = new std::vector<size_t>(); 
	auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
	int max_count = 0;

	AT_DISPATCH_ALL_TYPES(query.scalar_type(), "batch_radius_cpu", [&] {
	auto data_q = query.data_ptr<scalar_t>();
	auto data_s = support.data_ptr<scalar_t>();
	std::vector<scalar_t> queries_stl = std::vector<scalar_t>(data_q,
								  data_q + query.size(0)*query.size(1));
	std::vector<scalar_t> supports_stl = std::vector<scalar_t>(data_s,
								   data_s + support.size(0)*support.size(1));

	int dim = torch::size(query, 1);
	max_count = batch_nanoflann_neighbors<scalar_t>(queries_stl,
							    supports_stl,
							    size_query_batch_stl,
							    size_support_batch_stl,
							    neighbors_indices,
							    0,
								dim,
							    0,
							    k, 0);
	});

	size_t* neighbors_indices_ptr = neighbors_indices->data();


	const long long tsize = static_cast<long long>(neighbors_indices->size()/2);
	out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
	out = out.t();

	auto result = torch::zeros_like(out);

	auto index = torch::tensor({1,0});

	result.index_copy_(0, index, out);

	return result;
}