Commit d7f704c5 authored by Alexander Liao's avatar Alexander Liao
Browse files

fixed C++ warning and python flake8 style

parent 1111319d
...@@ -15,8 +15,8 @@ torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support, ...@@ -15,8 +15,8 @@ torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
AT_DISPATCH_ALL_TYPES(query.scalar_type(), "radius_cpu", [&] { AT_DISPATCH_ALL_TYPES(query.scalar_type(), "radius_cpu", [&] {
auto data_q = query.DATA_PTR<scalar_t>(); auto data_q = query.data_ptr<scalar_t>();
auto data_s = support.DATA_PTR<scalar_t>(); auto data_s = support.data_ptr<scalar_t>();
std::vector<scalar_t> queries_stl = std::vector<scalar_t>(data_q, std::vector<scalar_t> queries_stl = std::vector<scalar_t>(data_q,
data_q + query.size(0)*query.size(1)); data_q + query.size(0)*query.size(1));
std::vector<scalar_t> supports_stl = std::vector<scalar_t>(data_s, std::vector<scalar_t> supports_stl = std::vector<scalar_t>(data_s,
...@@ -34,13 +34,7 @@ torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support, ...@@ -34,13 +34,7 @@ torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options); out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
out = out.t(); out = out.t();
auto result = torch::zeros_like(out); return out.clone();
auto index = torch::tensor({0,1});
result.index_copy_(0, index, out);
return result;
} }
...@@ -49,7 +43,7 @@ void get_size_batch(const vector<long>& batch, vector<long>& res){ ...@@ -49,7 +43,7 @@ void get_size_batch(const vector<long>& batch, vector<long>& res){
res.resize(batch[batch.size()-1]-batch[0]+1, 0); res.resize(batch[batch.size()-1]-batch[0]+1, 0);
long ind = batch[0]; long ind = batch[0];
long incr = 1; long incr = 1;
for(int i=1; i < batch.size(); i++){ for(unsigned long i=1; i < batch.size(); i++){
if(batch[i] == ind) if(batch[i] == ind)
incr++; incr++;
...@@ -81,8 +75,7 @@ torch::Tensor batch_radius_cpu(torch::Tensor query, ...@@ -81,8 +75,7 @@ torch::Tensor batch_radius_cpu(torch::Tensor query,
auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU); auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
int max_count = 0; int max_count = 0;
AT_DISPATCH_ALL_TYPES(query.scalar_type(), "batch_radius_cpu", [&] {
AT_DISPATCH_ALL_TYPES(query.scalar_type(), "batch_radius_search", [&] {
auto data_q = query.data_ptr<scalar_t>(); auto data_q = query.data_ptr<scalar_t>();
auto data_s = support.data_ptr<scalar_t>(); auto data_s = support.data_ptr<scalar_t>();
std::vector<scalar_t> queries_stl = std::vector<scalar_t>(data_q, std::vector<scalar_t> queries_stl = std::vector<scalar_t>(data_q,
......
...@@ -127,20 +127,18 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries, ...@@ -127,20 +127,18 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
// Initiate variables // Initiate variables
// ****************** // ******************
// indices // indices
int i0 = 0; size_t i0 = 0;
// Square radius // Square radius
const scalar_t r2 = static_cast<scalar_t>(radius*radius); const scalar_t r2 = static_cast<scalar_t>(radius*radius);
// Counting vector // Counting vector
int max_count = 0; size_t max_count = 0;
float d2;
// batch index // batch index
long b = 0; size_t b = 0;
long sum_qb = 0; size_t sum_qb = 0;
long sum_sb = 0; size_t sum_sb = 0;
float eps = 0.000001; float eps = 0.000001;
// Nanoflann related variables // Nanoflann related variables
...@@ -173,16 +171,9 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries, ...@@ -173,16 +171,9 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
for (auto& p0 : query_pcd.pts){ for (auto& p0 : query_pcd.pts){
// Check if we changed batch // Check if we changed batch
scalar_t query_pt[dim]; scalar_t* query_pt = new scalar_t[dim];
std::copy(p0.begin(), p0.end(), query_pt); std::copy(p0.begin(), p0.end(), query_pt);
/*
std::cout << "\n ========== \n";
for(int i=0; i < dim; i++)
std::cout << query_pt[i] << '\n';
std::cout << "\n ========== \n";
*/
if (i0 == sum_qb + q_batches[b]){ if (i0 == sum_qb + q_batches[b]){
sum_qb += q_batches[b]; sum_qb += q_batches[b];
sum_sb += s_batches[b]; sum_sb += s_batches[b];
...@@ -218,7 +209,7 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries, ...@@ -218,7 +209,7 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
} }
// Reserve the memory // Reserve the memory
int size = 0; // total number of edges size_t size = 0; // total number of edges
for (auto& inds_dists : all_inds_dists){ for (auto& inds_dists : all_inds_dists){
if(inds_dists.size() <= max_count) if(inds_dists.size() <= max_count)
size += inds_dists.size(); size += inds_dists.size();
...@@ -230,14 +221,14 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries, ...@@ -230,14 +221,14 @@ int batch_nanoflann_neighbors (vector<scalar_t>& queries,
sum_sb = 0; sum_sb = 0;
sum_qb = 0; sum_qb = 0;
b = 0; b = 0;
int u = 0; size_t u = 0;
for (auto& inds_dists : all_inds_dists){ for (auto& inds_dists : all_inds_dists){
if (i0 == sum_qb + q_batches[b]){ if (i0 == sum_qb + q_batches[b]){
sum_qb += q_batches[b]; sum_qb += q_batches[b];
sum_sb += s_batches[b]; sum_sb += s_batches[b];
b++; b++;
} }
for (int j = 0; j < max_count; j++){ for (size_t j = 0; j < max_count; j++){
if (j < inds_dists.size()){ if (j < inds_dists.size()){
neighbors_indices[u] = inds_dists[j].first + sum_sb; neighbors_indices[u] = inds_dists[j].first + sum_sb;
neighbors_indices[u + 1] = i0; neighbors_indices[u + 1] = i0;
......
...@@ -107,24 +107,29 @@ def test_radius_graph_pointnet_small(dtype, device): ...@@ -107,24 +107,29 @@ def test_radius_graph_pointnet_small(dtype, device):
[0.3566, -0.7789, -0.3244], [0.3566, -0.7789, -0.3244],
[-0.2904, -0.1869, -0.3244], [-0.2904, -0.1869, -0.3244],
[-0.1890, -0.8423, 0.0057], [-0.1890, -0.8423, 0.0057],
[0.3787, 0.5441, -0.1557]],dtype, device) [0.3787, 0.5441, -0.1557]], dtype, device)
batch = tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, batch = tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], torch.long, device) 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3], torch.long, device)
row, col = radius_graph(x, r=0.2, flow='source_to_target', batch=batch) row, col = radius_graph(x, r=0.2, flow='source_to_target', batch=batch)
edges = set([(i,j) for (i,j) in zip(row.cpu().numpy(),col.cpu().numpy())]) edges = set([(i, j) for (i, j) in zip(row.cpu().numpy(),
col.cpu().numpy())])
truth_row = [10, 11, 7, 9, 9, 1, 9, 1, 6, 7, 0, 11, 0, 10, 15, 12, 20, 16, 34, 31, 44, 43, 42, 41] truth_row = [10, 11, 7, 9, 9, 1, 9, 1, 6, 7, 0, 11, 0, 10, 15, 12, 20, 16,
truth_col = [0, 0, 1, 1, 6, 7, 7, 9, 9, 9, 10, 10, 11, 11, 12, 15, 16, 20, 31, 34, 41, 42, 43, 44] 34, 31, 44, 43, 42, 41]
truth_col = [0, 0, 1, 1, 6, 7, 7, 9, 9, 9, 10, 10, 11, 11, 12, 15, 16, 20,
31, 34, 41, 42, 43, 44]
truth = set([(i, j) for (i, j) in zip(truth_row, truth_col)])
assert(truth == edges)
truth = set([(i,j) for (i,j) in zip(truth_row, truth_col)])
assert(truth==edges)
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_radius_graph_pointnet_medium(dtype, device): def test_radius_graph_pointnet_medium(dtype, device):
#print('medium test {}'.format(device))
x = tensor([[-4.4043e-02, -5.7983e-01, -9.7623e-02], x = tensor([[-4.4043e-02, -5.7983e-01, -9.7623e-02],
[3.0804e-01, -1.8622e-01, 1.9274e-01], [3.0804e-01, -1.8622e-01, 1.9274e-01],
[1.9475e-02, 1.4221e-01, -2.7513e-02], [1.9475e-02, 1.4221e-01, -2.7513e-02],
...@@ -382,84 +387,206 @@ def test_radius_graph_pointnet_medium(dtype, device): ...@@ -382,84 +387,206 @@ def test_radius_graph_pointnet_medium(dtype, device):
[-1.2279e-01, 1.7300e-01, 1.4925e-01], [-1.2279e-01, 1.7300e-01, 1.4925e-01],
[-4.0297e-01, -1.2408e-01, 1.1571e-02]], dtype, device) [-4.0297e-01, -1.2408e-01, 1.1571e-02]], dtype, device)
batch = tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, batch = tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], torch.long, device) 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3], torch.long, device)
row_02, col_02 = radius_graph(x, r=0.2, flow='source_to_target', batch=batch) row, col = radius_graph(x, r=0.2, flow='source_to_target', batch=batch)
edges_02 = set([(i,j) for (i,j) in zip(row_02.cpu().numpy(),col_02.cpu().numpy())])
truth_row_02 = [6, 27, 17, 31, 3, 23, 62, 2, 14, 23, 36, 38, 62, 15, 0, 11, 27, 29, 50, 49, 54, 56, 12, 61, 16, 21, 24, 39, 6, 27, 29, 34, 50, 9, 33, 43, 41, 57, 3, 17, 22, 23, 31, 36, 38, 5, 58, 10, 1, 14, 31, 36, 38, 43, 52, 53, 57, 10, 24, 39, 60, 14, 26, 31, 2, 3, 14, 36, 38, 62, 10, 21, 39, 60, 45, 22, 31, 0, 6, 11, 29, 50, 55, 6, 11, 27, 34, 50, 55, 59, 1, 14, 17, 22, 26, 36, 38, 12, 53, 11, 29, 59, 54, 3, 14, 17, 23, 32, 38, 40, 3, 14, 17, 23, 32, 36, 62, 10, 21, 24, 37, 44, 13, 12, 19, 52, 53, 40, 52, 25, 62, 63, 7, 54, 6, 11, 27, 29, 19, 43, 44, 53, 19, 33, 43, 52, 7, 35, 49, 27, 29, 8, 13, 20, 15, 29, 34, 21, 24, 9, 2, 3, 23, 38, 46, 48, 111, 106, 120, 115, 117, 119, 75, 91, 84, 87, 112, 114, 121, 125, 92, 97, 78, 82, 110, 104, 127, 68, 73, 82, 110, 80, 79, 73, 78, 96, 71, 86, 118, 121, 84, 90, 98, 121, 71, 112, 114, 125, 86, 98, 68, 72, 115, 122, 83, 72, 108, 86, 90, 113, 122, 102, 103, 101, 103, 116, 101, 102, 74, 127, 126, 127, 65, 118, 120, 114, 97, 73, 78, 64, 71, 87, 114, 121, 125, 100, 71, 87, 107, 112, 118, 121, 66, 95, 102, 66, 119, 84, 106, 114, 120, 121, 66, 117, 65, 106, 118, 121, 71, 84, 86, 112, 114, 118, 120, 95, 100, 71, 87, 112, 105, 127, 74, 104, 105, 126, 138, 143, 151, 164, 167, 186, 133, 141, 166, 176, 177, 179, 131, 142, 146, 155, 189, 158, 167, 144, 152, 185, 129, 143, 151, 164, 167, 182, 131, 191, 134, 155, 163, 129, 138, 164, 137, 178, 185, 161, 134, 149, 183, 169, 171, 146, 183, 129, 138, 164, 137, 175, 182, 160, 134, 142, 184, 177, 136, 181, 161, 162, 188, 154, 176, 145, 159, 162, 159, 161, 188, 142, 129, 138, 143, 151, 168, 132, 176, 177, 179, 129, 136, 138, 164, 190, 148, 148, 152, 185, 132, 160, 166, 132, 157, 166, 179, 144, 185, 132, 166, 177, 183, 158, 139, 153, 146, 149, 179, 156, 137, 144, 175, 178, 130, 159, 162, 135, 168, 141, 207, 228, 217, 226, 248, 211, 241, 246, 253, 208, 209, 216, 218, 250, 254, 245, 199, 206, 218, 231, 250, 205, 197, 200, 206, 250, 199, 206, 214, 239, 210, 219, 233, 244, 210, 235, 198, 197, 199, 200, 192, 212, 228, 237, 195, 216, 218, 222, 242, 254, 195, 250, 254, 202, 204, 235, 194, 241, 246, 247, 207, 240, 241, 251, 201, 239, 255, 230, 195, 208, 222, 254, 193, 195, 197, 208, 231, 242, 250, 254, 202, 220, 224, 231, 219, 252, 208, 216, 242, 243, 219, 231, 242, 193, 248, 234, 236, 253, 192, 207, 237, 215, 197, 218, 219, 224, 237, 203, 227, 204, 210, 227, 253, 207, 228, 232, 244, 201, 214, 213, 241, 246, 251, 253, 194, 211, 213, 240, 246, 251, 253, 208, 218, 222, 224, 254, 223, 203, 238, 196, 194, 211, 240, 241, 247, 211, 246, 193, 226, 195, 197, 199, 209, 218, 254, 213, 240, 241, 221, 194, 227, 236, 240, 241, 195, 208, 209, 216, 218, 242, 250, 214] edges = set([(i, j) for (i, j) in zip(row.cpu().numpy(),
truth_col_02 = [0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 3, 5, 6, 6, 6, 6, 6, 7, 7, 8, 9, 9, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12, 12, 12, 13, 13, 14, 14, 14, 14, 14, 14, 14, 15, 15, 16, 17, 17, 17, 17, 17, 19, 19, 19, 20, 21, 21, 21, 21, 22, 22, 22, 23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 25, 26, 26, 27, 27, 27, 27, 27, 27, 29, 29, 29, 29, 29, 29, 29, 31, 31, 31, 31, 31, 32, 32, 33, 33, 34, 34, 34, 35, 36, 36, 36, 36, 36, 36, 37, 38, 38, 38, 38, 38, 38, 38, 39, 39, 39, 40, 40, 41, 43, 43, 43, 43, 44, 44, 45, 46, 48, 49, 49, 50, 50, 50, 50, 52, 52, 52, 52, 53, 53, 53, 53, 54, 54, 54, 55, 55, 56, 57, 57, 58, 59, 59, 60, 60, 61, 62, 62, 62, 62, 62, 63, 64, 65, 65, 66, 66, 66, 68, 68, 71, 71, 71, 71, 71, 71, 72, 72, 73, 73, 73, 74, 74, 75, 78, 78, 78, 79, 80, 82, 82, 83, 84, 84, 84, 84, 86, 86, 86, 86, 87, 87, 87, 87, 90, 90, 91, 92, 95, 95, 96, 97, 97, 98, 98, 100, 100, 101, 101, 102, 102, 102, 103, 103, 104, 104, 105, 105, 106, 106, 106, 107, 108, 110, 110, 111, 112, 112, 112, 112, 112, 113, 114, 114, 114, 114, 114, 114, 115, 115, 116, 117, 117, 118, 118, 118, 118, 118, 119, 119, 120, 120, 120, 120, 121, 121, 121, 121, 121, 121, 121, 122, 122, 125, 125, 125, 126, 126, 127, 127, 127, 127, 129, 129, 129, 129, 129, 130, 131, 131, 132, 132, 132, 132, 133, 134, 134, 134, 135, 136, 136, 137, 137, 137, 138, 138, 138, 138, 138, 139, 141, 141, 142, 142, 142, 143, 143, 143, 144, 144, 144, 145, 146, 146, 146, 148, 148, 149, 149, 151, 151, 151, 152, 152, 153, 154, 155, 155, 156, 157, 158, 158, 159, 159, 159, 160, 160, 161, 161, 161, 162, 162, 162, 163, 164, 164, 164, 164, 164, 166, 166, 166, 166, 167, 167, 167, 168, 168, 169, 171, 175, 175, 176, 176, 176, 177, 177, 177, 177, 178, 178, 179, 179, 179, 179, 181, 182, 182, 183, 183, 183, 184, 185, 185, 185, 185, 186, 188, 188, 189, 190, 191, 192, 192, 193, 193, 193, 194, 194, 194, 194, 195, 195, 195, 195, 195, 195, 196, 197, 197, 197, 197, 197, 198, 199, 199, 199, 199, 200, 200, 201, 201, 202, 202, 203, 203, 204, 204, 205, 206, 206, 206, 207, 207, 207, 207, 208, 208, 208, 208, 208, 208, 209, 209, 209, 210, 210, 210, 211, 211, 211, 211, 212, 213, 213, 213, 214, 214, 214, 215, 216, 216, 216, 216, 217, 218, 218, 218, 218, 218, 218, 218, 219, 219, 219, 219, 220, 221, 222, 222, 222, 223, 224, 224, 224, 226, 226, 227, 227, 227, 228, 228, 228, 230, 231, 231, 231, 231, 232, 233, 234, 235, 235, 236, 236, 237, 237, 237, 238, 239, 239, 240, 240, 240, 240, 240, 241, 241, 241, 241, 241, 241, 241, 242, 242, 242, 242, 242, 243, 244, 244, 245, 246, 246, 246, 246, 246, 247, 247, 248, 248, 250, 250, 250, 250, 250, 250, 251, 251, 251, 252, 253, 253, 253, 253, 253, 254, 254, 254, 254, 254, 254, 254, 255] col.cpu().numpy())])
truth_row = [6, 27, 17, 31, 3, 23, 62, 2, 14, 23, 36, 38, 62, 15, 0, 11,
27, 29, 50, 49, 54, 56, 12, 61, 16, 21, 24, 39, 6, 27, 29,
34, 50, 9, 33, 43, 41, 57, 3, 17, 22, 23, 31, 36, 38, 5, 58,
10, 1, 14, 31, 36, 38, 43, 52, 53, 57, 10, 24, 39, 60, 14,
26, 31, 2, 3, 14, 36, 38, 62, 10, 21, 39, 60, 45, 22, 31, 0,
6, 11, 29, 50, 55, 6, 11, 27, 34, 50, 55, 59, 1, 14, 17, 22,
26, 36, 38, 12, 53, 11, 29, 59, 54, 3, 14, 17, 23, 32, 38,
40, 3, 14, 17, 23, 32, 36, 62, 10, 21, 24, 37, 44, 13, 12,
19, 52, 53, 40, 52, 25, 62, 63, 7, 54, 6, 11, 27, 29, 19, 43,
44, 53, 19, 33, 43, 52, 7, 35, 49, 27, 29, 8, 13, 20, 15, 29,
34, 21, 24, 9, 2, 3, 23, 38, 46, 48, 111, 106, 120, 115, 117,
119, 75, 91, 84, 87, 112, 114, 121, 125, 92, 97, 78, 82, 110,
104, 127, 68, 73, 82, 110, 80, 79, 73, 78, 96, 71, 86, 118,
121, 84, 90, 98, 121, 71, 112, 114, 125, 86, 98, 68, 72, 115,
122, 83, 72, 108, 86, 90, 113, 122, 102, 103, 101, 103, 116,
101, 102, 74, 127, 126, 127, 65, 118, 120, 114, 97, 73, 78,
64, 71, 87, 114, 121, 125, 100, 71, 87, 107, 112, 118, 121,
66, 95, 102, 66, 119, 84, 106, 114, 120, 121, 66, 117, 65,
106, 118, 121, 71, 84, 86, 112, 114, 118, 120, 95, 100, 71,
87, 112, 105, 127, 74, 104, 105, 126, 138, 143, 151, 164, 167,
186, 133, 141, 166, 176, 177, 179, 131, 142, 146, 155, 189,
158, 167, 144, 152, 185, 129, 143, 151, 164, 167, 182, 131,
191, 134, 155, 163, 129, 138, 164, 137, 178, 185, 161, 134,
149, 183, 169, 171, 146, 183, 129, 138, 164, 137, 175, 182,
160, 134, 142, 184, 177, 136, 181, 161, 162, 188, 154, 176,
145, 159, 162, 159, 161, 188, 142, 129, 138, 143, 151, 168,
132, 176, 177, 179, 129, 136, 138, 164, 190, 148, 148, 152,
185, 132, 160, 166, 132, 157, 166, 179, 144, 185, 132, 166,
177, 183, 158, 139, 153, 146, 149, 179, 156, 137, 144, 175,
178, 130, 159, 162, 135, 168, 141, 207, 228, 217, 226, 248,
211, 241, 246, 253, 208, 209, 216, 218, 250, 254, 245, 199,
206, 218, 231, 250, 205, 197, 200, 206, 250, 199, 206, 214,
239, 210, 219, 233, 244, 210, 235, 198, 197, 199, 200, 192,
212, 228, 237, 195, 216, 218, 222, 242, 254, 195, 250, 254,
202, 204, 235, 194, 241, 246, 247, 207, 240, 241, 251, 201,
239, 255, 230, 195, 208, 222, 254, 193, 195, 197, 208, 231,
242, 250, 254, 202, 220, 224, 231, 219, 252, 208, 216, 242,
243, 219, 231, 242, 193, 248, 234, 236, 253, 192, 207, 237,
215, 197, 218, 219, 224, 237, 203, 227, 204, 210, 227, 253,
207, 228, 232, 244, 201, 214, 213, 241, 246, 251, 253, 194,
211, 213, 240, 246, 251, 253, 208, 218, 222, 224, 254, 223,
203, 238, 196, 194, 211, 240, 241, 247, 211, 246, 193, 226,
195, 197, 199, 209, 218, 254, 213, 240, 241, 221, 194, 227,
236, 240, 241, 195, 208, 209, 216, 218, 242, 250, 214]
truth_col = [0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 3, 5, 6, 6, 6, 6, 6, 7,
7, 8, 9, 9, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12, 12, 12,
13, 13, 14, 14, 14, 14, 14, 14, 14, 15, 15, 16, 17, 17, 17,
17, 17, 19, 19, 19, 20, 21, 21, 21, 21, 22, 22, 22, 23, 23,
23, 23, 23, 23, 24, 24, 24, 24, 25, 26, 26, 27, 27, 27,
27, 27, 27, 29, 29, 29, 29, 29, 29, 29, 31, 31, 31, 31,
31, 32, 32, 33, 33, 34, 34, 34, 35, 36, 36, 36, 36, 36, 36,
37, 38, 38, 38, 38, 38, 38, 38, 39, 39, 39, 40, 40, 41, 43,
43, 43, 43, 44, 44, 45, 46, 48, 49, 49, 50, 50, 50, 50, 52,
52, 52, 52, 53, 53, 53, 53, 54, 54, 54, 55, 55, 56, 57, 57,
58, 59, 59, 60, 60, 61, 62, 62, 62, 62, 62, 63, 64, 65, 65,
66, 66, 66, 68, 68, 71, 71, 71, 71, 71, 71, 72, 72, 73, 73,
73, 74, 74, 75, 78, 78, 78, 79, 80, 82, 82, 83, 84, 84, 84,
84, 86, 86, 86, 86, 87, 87, 87, 87, 90, 90, 91, 92, 95, 95,
96, 97, 97, 98, 98, 100, 100, 101, 101, 102, 102, 102, 103,
103, 104, 104, 105, 105, 106, 106, 106, 107, 108, 110, 110,
111, 112, 112, 112, 112, 112, 113, 114, 114, 114, 114, 114,
114, 115, 115, 116, 117, 117, 118, 118, 118, 118, 118, 119,
119, 120, 120, 120, 120, 121, 121, 121, 121, 121, 121, 121,
122, 122, 125, 125, 125, 126, 126, 127, 127, 127, 127, 129,
129, 129, 129, 129, 130, 131, 131, 132, 132, 132, 132, 133,
134, 134, 134, 135, 136, 136, 137, 137, 137, 138, 138, 138,
138, 138, 139, 141, 141, 142, 142, 142, 143, 143, 143, 144,
144, 144, 145, 146, 146, 146, 148, 148, 149, 149, 151, 151,
151, 152, 152, 153, 154, 155, 155, 156, 157, 158, 158, 159,
159, 159, 160, 160, 161, 161, 161, 162, 162, 162, 163, 164,
164, 164, 164, 164, 166, 166, 166, 166, 167, 167, 167, 168,
168, 169, 171, 175, 175, 176, 176, 176, 177, 177, 177, 177,
178, 178, 179, 179, 179, 179, 181, 182, 182, 183, 183, 183,
184, 185, 185, 185, 185, 186, 188, 188, 189, 190, 191, 192,
192, 193, 193, 193, 194, 194, 194, 194, 195, 195, 195, 195,
195, 195, 196, 197, 197, 197, 197, 197, 198, 199, 199, 199,
199, 200, 200, 201, 201, 202, 202, 203, 203, 204, 204, 205,
206, 206, 206, 207, 207, 207, 207, 208, 208, 208, 208, 208,
208, 209, 209, 209, 210, 210, 210, 211, 211, 211, 211, 212,
213, 213, 213, 214, 214, 214, 215, 216, 216, 216, 216, 217,
218, 218, 218, 218, 218, 218, 218, 219, 219, 219, 219, 220,
221, 222, 222, 222, 223, 224, 224, 224, 226, 226, 227, 227,
227, 228, 228, 228, 230, 231, 231, 231, 231, 232, 233, 234,
235, 235, 236, 236, 237, 237, 237, 238, 239, 239, 240, 240,
240, 240, 240, 241, 241, 241, 241, 241, 241, 241, 242, 242,
242, 242, 242, 243, 244, 244, 245, 246, 246, 246, 246, 246,
247, 247, 248, 248, 250, 250, 250, 250, 250, 250, 251, 251,
251, 252, 253, 253, 253, 253, 253, 254, 254, 254, 254, 254,
254, 254, 255]
truth = set([(i, j) for (i, j) in zip(truth_row, truth_col)])
assert(truth == edges)
truth_02 = set([(i,j) for (i,j) in zip(truth_row_02, truth_col_02)])
assert(truth_02==edges_02)
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_radius_graph_ndim(dtype, device): def test_radius_graph_ndim(dtype, device):
x = tensor([[-0.9750, -0.7160, 0.7150, -0.1510, -0.3660, 0.6140, -1.0340, 2.4950], x = tensor([[-0.9750, -0.7160, 0.7150, -0.1510, -0.3660, 0.6140, -1.0340,
[ 0.8540, 0.1110, 1.0520, -1.3900, 0.7570, -0.6300, -0.9550, -0.9350], 2.4950],
[ 0.3710, 0.4610, 0.1620, 1.1370, -1.5830, 0.4100, -0.5710, -0.7760], [0.8540, 0.1110, 1.0520, -1.3900, 0.7570, -0.6300, -0.9550,
[ 0.4200, 0.1240, -1.2870, -0.2300, -1.7480, 0.5890, 0.5710, 0.1670], -0.9350],
[-0.6060, 0.8080, -2.2560, 0.4480, -0.8910, 0.2360, -0.0060, -0.6510], [0.3710, 0.4610, 0.1620, 1.1370, -1.5830, 0.4100, -0.5710,
[-0.6960, 0.7190, -0.7330, 0.4660, 0.4400, -0.0490, -1.1350, -0.5990], -0.7760],
[-0.0080, -0.4770, 0.0980, 1.2000, -0.6110, -0.7410, 0.7410, -0.2800], [0.4200, 0.1240, -1.2870, -0.2300, -1.7480, 0.5890, 0.5710,
[-2.5230, -0.8470, -0.8670, 0.4820, -0.9510, -0.9460, 0.3390, -1.6740], 0.1670],
[ 1.0770, -1.4480, 1.8110, 0.0900, 0.7980, 0.4070, 1.9570, -0.2010], [-0.6060, 0.8080, -2.2560, 0.4480, -0.8910, 0.2360, -0.0060,
[ 1.0890, -0.2150, -0.4440, 0.4370, 1.1180, -0.4280, -2.3860, 0.5860], -0.6510],
[ 0.1000, -0.2590, -2.1420, 0.9260, 0.7290, -0.1170, 0.9370, -0.0470], [-0.6960, 0.7190, -0.7330, 0.4660, 0.4400, -0.0490, -1.1350,
[-0.3870, -1.7310, -0.6020, -0.1070, 1.7890, 0.5200, 1.2620, 0.6130], -0.5990],
[-0.0740, 0.5270, 0.4090, -0.9120, -0.1690, 1.4970, -2.4540, -1.0430], [-0.0080, -0.4770, 0.0980, 1.2000, -0.6110, -0.7410, 0.7410,
[-0.9750, -1.3510, 0.0730, 0.1450, -0.9910, -1.8840, 0.1010, 0.4620], -0.2800],
[ 0.6950, 0.3560, 0.2850, -0.1050, -1.8770, 1.4910, 2.0260, -0.8170], [-2.5230, -0.8470, -0.8670, 0.4820, -0.9510, -0.9460, 0.3390,
[-1.3480, 0.1100, 0.8460, -0.1050, -1.9670, -0.0930, 0.2820, 1.7150], -1.6740],
[-0.0340, -0.7420, 0.5450, 1.8170, -0.6030, -0.0990, 0.1650, -0.0450], [1.0770, -1.4480, 1.8110, 0.0900, 0.7980, 0.4070, 1.9570,
[ 0.4490, 1.6170, -1.6880, -0.6180, -0.8350, 1.0560, -0.3860, 0.8380], -0.2010],
[ 0.9530, -0.1970, -0.7030, 1.7750, -1.6860, -1.4290, 0.6280, 0.2730], [1.0890, -0.2150, -0.4440, 0.4370, 1.1180, -0.4280, -2.3860,
[ 0.6630, 1.0780, 1.5650, -0.5490, -0.5530, -0.8070, 0.4100, -2.4380], 0.5860],
[ 0.6350, 0.0490, 0.1990, -1.2340, 0.7630, 0.2670, 1.5810, -0.4250], [0.1000, -0.2590, -2.1420, 0.9260, 0.7290, -0.1170, 0.9370,
[ 1.6700, 0.4440, -2.5800, 0.5020, 0.3520, -0.9110, -1.9960, -0.0000], -0.0470],
[ 0.1970, 0.2390, 2.2290, -0.0910, 1.2710, 0.0280, -0.5530, -1.4650], [-0.3870, -1.7310, -0.6020, -0.1070, 1.7890, 0.5200, 1.2620,
[ 0.1270, 2.5150, -0.3450, -0.8340, 1.0130, -1.3680, -0.1990, -0.5480], 0.6130],
[-1.0470, 0.0200, 2.2200, 1.7030, 0.5460, 0.4350, -1.8560, -0.9750], [-0.0740, 0.5270, 0.4090, -0.9120, -0.1690, 1.4970, -2.4540,
[ 0.7010, -0.7260, -0.2380, 0.6120, 1.1150, -1.2530, -0.2140, 1.0100], -1.0430],
[-0.2590, -0.2690, 0.1200, 1.0380, -0.8370, -0.0070, -0.0800, 0.2130], [-0.9750, -1.3510, 0.0730, 0.1450, -0.9910, -1.8840, 0.1010,
[-0.5460, 0.4000, 0.2040, -0.8370, 1.7400, 1.0940, 0.0930, -0.3370], 0.4620],
[-1.0230, 1.5400, 0.9760, -1.5210, 1.0170, -1.3290, 0.7690, 0.6260], [0.6950, 0.3560, 0.2850, -0.1050, -1.8770, 1.4910, 2.0260,
[-0.7560, 0.1360, -0.2640, -0.6130, -0.2830, 0.6830, -0.8700, -0.5610], -0.8170],
[ 0.4060, 0.3830, 2.4530, -0.4910, -1.3110, -0.0980, -0.0630, 0.3200], [-1.3480, 0.1100, 0.8460, -0.1050, -1.9670, -0.0930, 0.2820,
[ 0.1450, 0.5810, -0.7310, 0.8190, -1.3600, -0.6780, -0.3360, -0.2570]], 1.7150],
dtype, device) [-0.0340, -0.7420, 0.5450, 1.8170, -0.6030, -0.0990, 0.1650,
-0.0450],
batch = tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, [0.4490, 1.6170, -1.6880, -0.6180, -0.8350, 1.0560, -0.3860,
6, 6, 6, 6, 7, 7, 8, 9], torch.long, device) 0.8380],
[0.9530, -0.1970, -0.7030, 1.7750, -1.6860, -1.4290, 0.6280,
row_02, col_02 = radius_graph(x, r=4.4, flow='source_to_target', batch=batch) 0.2730],
[0.6630, 1.0780, 1.5650, -0.5490, -0.5530, -0.8070, 0.4100,
edges_02 = set([(i,j) for (i,j) in zip(row_02.cpu().numpy(),col_02.cpu().numpy())]) -2.4380],
[0.6350, 0.0490, 0.1990, -1.2340, 0.7630, 0.2670, 1.5810,
truth_row_02 = [ 2, 3, 2, 3, 0, 1, 3, 4, 0, 1, 2, 4, 2, 3, 6, 7, 9, 10, -0.4250],
5, 7, 8, 9, 10, 5, 6, 10, 6, 5, 6, 10, 5, 6, 7, 9, 13, 11, [1.6700, 0.4440, -2.5800, 0.5020, 0.3520, -0.9110, -1.9960,
16, 17, 18, 15, 17, 18, 15, 16, 18, 15, 16, 17, 20, 22, 19, 22, 19, 20, -0.0000],
25, 26, 27, 26, 27, 23, 26, 27, 23, 24, 25, 27, 23, 24, 25, 26, 29, 28] [0.1970, 0.2390, 2.2290, -0.0910, 1.2710, 0.0280, -0.5530,
truth_col_02 = [ 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 5, 5, 5, 5, -1.4650],
6, 6, 6, 6, 6, 7, 7, 7, 8, 9, 9, 9, 10, 10, 10, 10, 11, 13, [0.1270, 2.5150, -0.3450, -0.8340, 1.0130, -1.3680, -0.1990,
15, 15, 15, 16, 16, 16, 17, 17, 17, 18, 18, 18, 19, 19, 20, 20, 22, 22, -0.5480],
23, 23, 23, 24, 24, 25, 25, 25, 26, 26, 26, 26, 27, 27, 27, 27, 28, 29] [-1.0470, 0.0200, 2.2200, 1.7030, 0.5460, 0.4350, -1.8560,
-0.9750],
truth_02 = set([(i,j) for (i,j) in zip(truth_row_02, truth_col_02)]) [0.7010, -0.7260, -0.2380, 0.6120, 1.1150, -1.2530, -0.2140,
1.0100],
#print(edges_02.symmetric_difference(truth_02)) [-0.2590, -0.2690, 0.1200, 1.0380, -0.8370, -0.0070, -0.0800,
#print('===========') 0.2130],
#print(edges_02) [-0.5460, 0.4000, 0.2040, -0.8370, 1.7400, 1.0940, 0.0930,
#print(truth_02) -0.3370],
assert(truth_02==edges_02) [-1.0230, 1.5400, 0.9760, -1.5210, 1.0170, -1.3290, 0.7690,
\ No newline at end of file 0.6260],
[-0.7560, 0.1360, -0.2640, -0.6130, -0.2830, 0.6830, -0.8700,
-0.5610],
[0.4060, 0.3830, 2.4530, -0.4910, -1.3110, -0.0980, -0.0630,
0.3200],
[0.1450, 0.5810, -0.7310, 0.8190, -1.3600, -0.6780, -0.3360,
-0.2570]], dtype, device)
batch = tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 3, 4, 4, 4, 4, 5,
5, 5, 5, 6, 6, 6, 6, 6, 7, 7, 8, 9], torch.long, device)
row, col = radius_graph(x, r=4.4, flow='source_to_target', batch=batch)
edges = set([(i, j) for (i, j) in zip(row.cpu().numpy(),
col.cpu().numpy())])
truth_row = [2, 3, 2, 3, 0, 1, 3, 4, 0, 1, 2, 4, 2, 3, 6, 7, 9, 10, 5, 7,
8, 9, 10, 5, 6, 10, 6, 5, 6, 10, 5, 6, 7, 9, 13, 11, 16, 17,
18, 15, 17, 18, 15, 16, 18, 15, 16, 17, 20, 22, 19, 22, 19,
20, 25, 26, 27, 26, 27, 23, 26, 27, 23, 24, 25, 27, 23, 24,
25, 26, 29, 28]
truth_col = [0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 5, 5, 5, 5, 6, 6,
6, 6, 6, 7, 7, 7, 8, 9, 9, 9, 10, 10, 10, 10, 11, 13, 15, 15,
15, 16, 16, 16, 17, 17, 17, 18, 18, 18, 19, 19, 20, 20, 22,
22, 23, 23, 23, 24, 24, 25, 25, 25, 26, 26, 26, 26, 27, 27,
27, 27, 28, 29]
truth = set([(i, j) for (i, j) in zip(truth_row, truth_col)])
assert(truth == edges)
from typing import Optional from typing import Optional
import torch import torch
import scipy
def sample(col, count):
if col.size(0) > count:
col = col[torch.randperm(col.size(0))][:count]
return col
def radius(x: torch.Tensor, y: torch.Tensor, r: float, def radius(x: torch.Tensor, y: torch.Tensor, r: float,
batch_x: Optional[torch.Tensor] = None, batch_x: Optional[torch.Tensor] = None,
...@@ -55,7 +50,7 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float, ...@@ -55,7 +50,7 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
ptr_x = deg.new_zeros(batch_size + 1) ptr_x = deg.new_zeros(batch_size + 1)
torch.cumsum(deg, 0, out=ptr_x[1:]) torch.cumsum(deg, 0, out=ptr_x[1:])
else: else:
ptr_x = None#torch.tensor([0, x.size(0)], device=x.device) ptr_x = None
if batch_y is not None: if batch_y is not None:
assert y.size(0) == batch_y.numel() assert y.size(0) == batch_y.numel()
...@@ -66,19 +61,11 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float, ...@@ -66,19 +61,11 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
ptr_y = deg.new_zeros(batch_size + 1) ptr_y = deg.new_zeros(batch_size + 1)
torch.cumsum(deg, 0, out=ptr_y[1:]) torch.cumsum(deg, 0, out=ptr_y[1:])
else: else:
ptr_y = None#torch.tensor([0, y.size(0)], device=y.device) ptr_y = None
result = torch.ops.torch_cluster.radius(x, y, ptr_x, ptr_y, r, result = torch.ops.torch_cluster.radius(x, y, ptr_x, ptr_y, r,
max_num_neighbors) max_num_neighbors)
else: else:
#if batch_x is None:
# batch_x = x.new_zeros(x.size(0), dtype=torch.long)
#if batch_y is None:
# batch_y = y.new_zeros(y.size(0), dtype=torch.long)
#batch_x = batch_x.to(x.dtype)
#batch_y = batch_y.to(y.dtype)
assert x.dim() == 2 assert x.dim() == 2
if batch_x is not None: if batch_x is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment