Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
4e2e69be
Commit
4e2e69be
authored
Jun 22, 2020
by
rusty1s
Browse files
major clean up
parent
1bbf8bdc
Changes
15
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
241 additions
and
1024 deletions
+241
-1024
csrc/cpu/knn_cpu.cpp
csrc/cpu/knn_cpu.cpp
+52
-128
csrc/cpu/knn_cpu.h
csrc/cpu/knn_cpu.h
+4
-10
csrc/cpu/radius_cpu.cpp
csrc/cpu/radius_cpu.cpp
+52
-127
csrc/cpu/radius_cpu.h
csrc/cpu/radius_cpu.h
+4
-10
csrc/cuda/knn_cuda.cu
csrc/cuda/knn_cuda.cu
+21
-7
csrc/cuda/knn_cuda.h
csrc/cuda/knn_cuda.h
+4
-2
csrc/cuda/radius_cuda.cu
csrc/cuda/radius_cuda.cu
+23
-9
csrc/cuda/radius_cuda.h
csrc/cuda/radius_cuda.h
+3
-2
csrc/knn.cpp
csrc/knn.cpp
+9
-39
csrc/radius.cpp
csrc/radius.cpp
+8
-38
setup.py
setup.py
+2
-2
test/test_knn.py
test/test_knn.py
+17
-44
test/test_radius.py
test/test_radius.py
+18
-593
torch_cluster/knn.py
torch_cluster/knn.py
+12
-5
torch_cluster/radius.py
torch_cluster/radius.py
+12
-8
No files found.
csrc/cpu/knn_cpu.cpp
View file @
4e2e69be
#include "radius_cpu.h"
#include <algorithm>
#include "utils.h"
#include <cstdint>
torch
::
Tensor
knn_cpu
(
torch
::
Tensor
support
,
torch
::
Tensor
query
,
int64_t
k
,
int64_t
n_threads
){
CHECK_CPU
(
query
);
CHECK_CPU
(
support
);
torch
::
Tensor
out
;
std
::
vector
<
size_t
>*
neighbors_indices
=
new
std
::
vector
<
size_t
>
();
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kLong
).
device
(
torch
::
kCPU
);
int
max_count
=
0
;
AT_DISPATCH_ALL_TYPES
(
query
.
scalar_type
(),
"radius_cpu"
,
[
&
]
{
auto
data_q
=
query
.
data_ptr
<
scalar_t
>
();
auto
data_s
=
support
.
data_ptr
<
scalar_t
>
();
std
::
vector
<
scalar_t
>
queries_stl
=
std
::
vector
<
scalar_t
>
(
data_q
,
data_q
+
query
.
size
(
0
)
*
query
.
size
(
1
));
std
::
vector
<
scalar_t
>
supports_stl
=
std
::
vector
<
scalar_t
>
(
data_s
,
data_s
+
support
.
size
(
0
)
*
support
.
size
(
1
));
int
dim
=
torch
::
size
(
query
,
1
);
max_count
=
nanoflann_neighbors
<
scalar_t
>
(
queries_stl
,
supports_stl
,
neighbors_indices
,
0
,
dim
,
0
,
n_threads
,
k
,
0
);
});
size_t
*
neighbors_indices_ptr
=
neighbors_indices
->
data
();
const
long
long
tsize
=
static_cast
<
long
long
>
(
neighbors_indices
->
size
()
/
2
);
out
=
torch
::
from_blob
(
neighbors_indices_ptr
,
{
tsize
,
2
},
options
=
options
);
out
=
out
.
t
();
auto
result
=
torch
::
zeros_like
(
out
);
auto
index
=
torch
::
tensor
({
1
,
0
});
result
.
index_copy_
(
0
,
index
,
out
);
return
result
;
}
#include "knn_cpu.h"
void
get_size_batch
(
const
std
::
vector
<
long
>&
batch
,
std
::
vector
<
long
>&
res
){
res
.
resize
(
batch
[
batch
.
size
()
-
1
]
-
batch
[
0
]
+
1
,
0
);
long
ind
=
batch
[
0
];
long
incr
=
1
;
for
(
unsigned
long
i
=
1
;
i
<
batch
.
size
();
i
++
){
if
(
batch
[
i
]
==
ind
)
incr
++
;
else
{
res
[
ind
-
batch
[
0
]]
=
incr
;
incr
=
1
;
ind
=
batch
[
i
];
}
}
res
[
ind
-
batch
[
0
]]
=
incr
;
#include "utils.h"
#include "utils/neighbors.cpp"
torch
::
Tensor
knn_cpu
(
torch
::
Tensor
x
,
torch
::
Tensor
y
,
torch
::
optional
<
torch
::
Tensor
>
ptr_x
,
torch
::
optional
<
torch
::
Tensor
>
ptr_y
,
int64_t
k
,
int64_t
num_workers
)
{
CHECK_CPU
(
x
);
CHECK_INPUT
(
x
.
dim
()
==
2
);
CHECK_CPU
(
y
);
CHECK_INPUT
(
y
.
dim
()
==
2
);
if
(
ptr_x
.
has_value
())
{
CHECK_CPU
(
ptr_x
.
value
());
CHECK_INPUT
(
ptr_x
.
value
().
dim
()
==
1
);
}
if
(
ptr_y
.
has_value
())
{
CHECK_CPU
(
ptr_y
.
value
());
CHECK_INPUT
(
ptr_y
.
value
().
dim
()
==
1
);
}
std
::
vector
<
size_t
>
*
out_vec
=
new
std
::
vector
<
size_t
>
();
AT_DISPATCH_ALL_TYPES
(
x
.
scalar_type
(),
"radius_cpu"
,
[
&
]
{
auto
x_data
=
x
.
data_ptr
<
scalar_t
>
();
auto
y_data
=
y
.
data_ptr
<
scalar_t
>
();
auto
x_vec
=
std
::
vector
<
scalar_t
>
(
x_data
,
x_data
+
x
.
numel
());
auto
y_vec
=
std
::
vector
<
scalar_t
>
(
y_data
,
y_data
+
y
.
numel
());
if
(
!
ptr_x
.
has_value
())
{
nanoflann_neighbors
<
scalar_t
>
(
y_vec
,
x_vec
,
out_vec
,
0
,
x
.
size
(
-
1
),
0
,
num_workers
,
k
,
0
);
}
else
{
auto
sx
=
(
ptr_x
.
value
().
narrow
(
0
,
1
,
ptr_x
.
value
().
numel
()
-
1
)
-
ptr_x
.
value
().
narrow
(
0
,
0
,
ptr_x
.
value
().
numel
()
-
1
));
auto
sy
=
(
ptr_y
.
value
().
narrow
(
0
,
1
,
ptr_y
.
value
().
numel
()
-
1
)
-
ptr_y
.
value
().
narrow
(
0
,
0
,
ptr_y
.
value
().
numel
()
-
1
));
auto
sx_data
=
sx
.
data_ptr
<
int64_t
>
();
auto
sy_data
=
sy
.
data_ptr
<
int64_t
>
();
auto
sx_vec
=
std
::
vector
<
long
>
(
sx_data
,
sx_data
+
sx
.
numel
());
auto
sy_vec
=
std
::
vector
<
long
>
(
sy_data
,
sy_data
+
sy
.
numel
());
batch_nanoflann_neighbors
<
scalar_t
>
(
y_vec
,
x_vec
,
sy_vec
,
sx_vec
,
out_vec
,
k
,
x
.
size
(
-
1
),
0
,
k
,
0
);
}
});
const
int64_t
size
=
out_vec
->
size
()
/
2
;
auto
out
=
torch
::
from_blob
(
out_vec
->
data
(),
{
size
,
2
},
x
.
options
().
dtype
(
torch
::
kLong
));
return
out
.
t
().
index_select
(
0
,
torch
::
tensor
({
1
,
0
}));
}
torch
::
Tensor
batch_knn_cpu
(
torch
::
Tensor
support
,
torch
::
Tensor
query
,
torch
::
Tensor
support_batch
,
torch
::
Tensor
query_batch
,
int64_t
k
)
{
CHECK_CPU
(
query
);
CHECK_CPU
(
support
);
CHECK_CPU
(
query_batch
);
CHECK_CPU
(
support_batch
);
torch
::
Tensor
out
;
auto
data_qb
=
query_batch
.
data_ptr
<
int64_t
>
();
auto
data_sb
=
support_batch
.
data_ptr
<
int64_t
>
();
std
::
vector
<
long
>
query_batch_stl
=
std
::
vector
<
long
>
(
data_qb
,
data_qb
+
query_batch
.
size
(
0
));
std
::
vector
<
long
>
size_query_batch_stl
;
CHECK_INPUT
(
std
::
is_sorted
(
query_batch_stl
.
begin
(),
query_batch_stl
.
end
()));
get_size_batch
(
query_batch_stl
,
size_query_batch_stl
);
std
::
vector
<
long
>
support_batch_stl
=
std
::
vector
<
long
>
(
data_sb
,
data_sb
+
support_batch
.
size
(
0
));
std
::
vector
<
long
>
size_support_batch_stl
;
CHECK_INPUT
(
std
::
is_sorted
(
support_batch_stl
.
begin
(),
support_batch_stl
.
end
()));
get_size_batch
(
support_batch_stl
,
size_support_batch_stl
);
std
::
vector
<
size_t
>*
neighbors_indices
=
new
std
::
vector
<
size_t
>
();
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kLong
).
device
(
torch
::
kCPU
);
int
max_count
=
0
;
AT_DISPATCH_ALL_TYPES
(
query
.
scalar_type
(),
"batch_radius_cpu"
,
[
&
]
{
auto
data_q
=
query
.
data_ptr
<
scalar_t
>
();
auto
data_s
=
support
.
data_ptr
<
scalar_t
>
();
std
::
vector
<
scalar_t
>
queries_stl
=
std
::
vector
<
scalar_t
>
(
data_q
,
data_q
+
query
.
size
(
0
)
*
query
.
size
(
1
));
std
::
vector
<
scalar_t
>
supports_stl
=
std
::
vector
<
scalar_t
>
(
data_s
,
data_s
+
support
.
size
(
0
)
*
support
.
size
(
1
));
int
dim
=
torch
::
size
(
query
,
1
);
max_count
=
batch_nanoflann_neighbors
<
scalar_t
>
(
queries_stl
,
supports_stl
,
size_query_batch_stl
,
size_support_batch_stl
,
neighbors_indices
,
0
,
dim
,
0
,
k
,
0
);
});
size_t
*
neighbors_indices_ptr
=
neighbors_indices
->
data
();
const
long
long
tsize
=
static_cast
<
long
long
>
(
neighbors_indices
->
size
()
/
2
);
out
=
torch
::
from_blob
(
neighbors_indices_ptr
,
{
tsize
,
2
},
options
=
options
);
out
=
out
.
t
();
auto
result
=
torch
::
zeros_like
(
out
);
auto
index
=
torch
::
tensor
({
1
,
0
});
result
.
index_copy_
(
0
,
index
,
out
);
return
result
;
}
\ No newline at end of file
csrc/cpu/knn_cpu.h
View file @
4e2e69be
#pragma once
#include <torch/extension.h>
#include "utils/neighbors.cpp"
#include <iostream>
torch
::
Tensor
knn_cpu
(
torch
::
Tensor
support
,
torch
::
Tensor
query
,
int64_t
k
,
int64_t
n_threads
);
torch
::
Tensor
batch_knn_cpu
(
torch
::
Tensor
support
,
torch
::
Tensor
query
,
torch
::
Tensor
support_batch
,
torch
::
Tensor
query_batch
,
int64_t
k
);
\ No newline at end of file
torch
::
Tensor
knn_cpu
(
torch
::
Tensor
x
,
torch
::
Tensor
y
,
torch
::
optional
<
torch
::
Tensor
>
ptr_x
,
torch
::
optional
<
torch
::
Tensor
>
ptr_y
,
int64_t
k
,
int64_t
num_workers
);
csrc/cpu/radius_cpu.cpp
View file @
4e2e69be
#include "radius_cpu.h"
#include <algorithm>
#include "utils.h"
#include <cstdint>
torch
::
Tensor
radius_cpu
(
torch
::
Tensor
support
,
torch
::
Tensor
query
,
double
radius
,
int64_t
max_num
,
int64_t
n_threads
){
CHECK_CPU
(
query
);
CHECK_CPU
(
support
);
torch
::
Tensor
out
;
std
::
vector
<
size_t
>*
neighbors_indices
=
new
std
::
vector
<
size_t
>
();
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kLong
).
device
(
torch
::
kCPU
);
int
max_count
=
0
;
AT_DISPATCH_ALL_TYPES
(
query
.
scalar_type
(),
"radius_cpu"
,
[
&
]
{
auto
data_q
=
query
.
data_ptr
<
scalar_t
>
();
auto
data_s
=
support
.
data_ptr
<
scalar_t
>
();
std
::
vector
<
scalar_t
>
queries_stl
=
std
::
vector
<
scalar_t
>
(
data_q
,
data_q
+
query
.
size
(
0
)
*
query
.
size
(
1
));
std
::
vector
<
scalar_t
>
supports_stl
=
std
::
vector
<
scalar_t
>
(
data_s
,
data_s
+
support
.
size
(
0
)
*
support
.
size
(
1
));
int
dim
=
torch
::
size
(
query
,
1
);
max_count
=
nanoflann_neighbors
<
scalar_t
>
(
queries_stl
,
supports_stl
,
neighbors_indices
,
radius
,
dim
,
max_num
,
n_threads
,
0
,
1
);
});
size_t
*
neighbors_indices_ptr
=
neighbors_indices
->
data
();
const
long
long
tsize
=
static_cast
<
long
long
>
(
neighbors_indices
->
size
()
/
2
);
out
=
torch
::
from_blob
(
neighbors_indices_ptr
,
{
tsize
,
2
},
options
=
options
);
out
=
out
.
t
();
auto
result
=
torch
::
zeros_like
(
out
);
auto
index
=
torch
::
tensor
({
1
,
0
});
result
.
index_copy_
(
0
,
index
,
out
);
return
result
;
}
void
get_size_batch
(
const
std
::
vector
<
long
>&
batch
,
std
::
vector
<
long
>&
res
){
res
.
resize
(
batch
[
batch
.
size
()
-
1
]
-
batch
[
0
]
+
1
,
0
);
long
ind
=
batch
[
0
];
long
incr
=
1
;
for
(
unsigned
long
i
=
1
;
i
<
batch
.
size
();
i
++
){
if
(
batch
[
i
]
==
ind
)
incr
++
;
else
{
res
[
ind
-
batch
[
0
]]
=
incr
;
incr
=
1
;
ind
=
batch
[
i
];
}
}
res
[
ind
-
batch
[
0
]]
=
incr
;
#include "utils.h"
#include "utils/neighbors.cpp"
torch
::
Tensor
radius_cpu
(
torch
::
Tensor
x
,
torch
::
Tensor
y
,
torch
::
optional
<
torch
::
Tensor
>
ptr_x
,
torch
::
optional
<
torch
::
Tensor
>
ptr_y
,
double
r
,
int64_t
max_num_neighbors
,
int64_t
num_workers
)
{
CHECK_CPU
(
x
);
CHECK_INPUT
(
x
.
dim
()
==
2
);
CHECK_CPU
(
y
);
CHECK_INPUT
(
y
.
dim
()
==
2
);
if
(
ptr_x
.
has_value
())
{
CHECK_CPU
(
ptr_x
.
value
());
CHECK_INPUT
(
ptr_x
.
value
().
dim
()
==
1
);
}
if
(
ptr_y
.
has_value
())
{
CHECK_CPU
(
ptr_y
.
value
());
CHECK_INPUT
(
ptr_y
.
value
().
dim
()
==
1
);
}
std
::
vector
<
size_t
>
*
out_vec
=
new
std
::
vector
<
size_t
>
();
AT_DISPATCH_ALL_TYPES
(
x
.
scalar_type
(),
"radius_cpu"
,
[
&
]
{
auto
x_data
=
x
.
data_ptr
<
scalar_t
>
();
auto
y_data
=
y
.
data_ptr
<
scalar_t
>
();
auto
x_vec
=
std
::
vector
<
scalar_t
>
(
x_data
,
x_data
+
x
.
numel
());
auto
y_vec
=
std
::
vector
<
scalar_t
>
(
y_data
,
y_data
+
y
.
numel
());
if
(
!
ptr_x
.
has_value
())
{
nanoflann_neighbors
<
scalar_t
>
(
y_vec
,
x_vec
,
out_vec
,
r
,
x
.
size
(
-
1
),
max_num_neighbors
,
num_workers
,
0
,
1
);
}
else
{
auto
sx
=
(
ptr_x
.
value
().
narrow
(
0
,
1
,
ptr_x
.
value
().
numel
()
-
1
)
-
ptr_x
.
value
().
narrow
(
0
,
0
,
ptr_x
.
value
().
numel
()
-
1
));
auto
sy
=
(
ptr_y
.
value
().
narrow
(
0
,
1
,
ptr_y
.
value
().
numel
()
-
1
)
-
ptr_y
.
value
().
narrow
(
0
,
0
,
ptr_y
.
value
().
numel
()
-
1
));
auto
sx_data
=
sx
.
data_ptr
<
int64_t
>
();
auto
sy_data
=
sy
.
data_ptr
<
int64_t
>
();
auto
sx_vec
=
std
::
vector
<
long
>
(
sx_data
,
sx_data
+
sx
.
numel
());
auto
sy_vec
=
std
::
vector
<
long
>
(
sy_data
,
sy_data
+
sy
.
numel
());
batch_nanoflann_neighbors
<
scalar_t
>
(
y_vec
,
x_vec
,
sy_vec
,
sx_vec
,
out_vec
,
r
,
x
.
size
(
-
1
),
max_num_neighbors
,
0
,
1
);
}
});
const
int64_t
size
=
out_vec
->
size
()
/
2
;
auto
out
=
torch
::
from_blob
(
out_vec
->
data
(),
{
size
,
2
},
x
.
options
().
dtype
(
torch
::
kLong
));
return
out
.
t
().
index_select
(
0
,
torch
::
tensor
({
1
,
0
}));
}
torch
::
Tensor
batch_radius_cpu
(
torch
::
Tensor
support
,
torch
::
Tensor
query
,
torch
::
Tensor
support_batch
,
torch
::
Tensor
query_batch
,
double
radius
,
int64_t
max_num
)
{
CHECK_CPU
(
query
);
CHECK_CPU
(
support
);
CHECK_CPU
(
query_batch
);
CHECK_CPU
(
support_batch
);
torch
::
Tensor
out
;
auto
data_qb
=
query_batch
.
data_ptr
<
int64_t
>
();
auto
data_sb
=
support_batch
.
data_ptr
<
int64_t
>
();
std
::
vector
<
long
>
query_batch_stl
=
std
::
vector
<
long
>
(
data_qb
,
data_qb
+
query_batch
.
size
(
0
));
std
::
vector
<
long
>
size_query_batch_stl
;
CHECK_INPUT
(
std
::
is_sorted
(
query_batch_stl
.
begin
(),
query_batch_stl
.
end
()));
get_size_batch
(
query_batch_stl
,
size_query_batch_stl
);
std
::
vector
<
long
>
support_batch_stl
=
std
::
vector
<
long
>
(
data_sb
,
data_sb
+
support_batch
.
size
(
0
));
std
::
vector
<
long
>
size_support_batch_stl
;
CHECK_INPUT
(
std
::
is_sorted
(
support_batch_stl
.
begin
(),
support_batch_stl
.
end
()));
get_size_batch
(
support_batch_stl
,
size_support_batch_stl
);
std
::
vector
<
size_t
>*
neighbors_indices
=
new
std
::
vector
<
size_t
>
();
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kLong
).
device
(
torch
::
kCPU
);
int
max_count
=
0
;
AT_DISPATCH_ALL_TYPES
(
query
.
scalar_type
(),
"batch_radius_cpu"
,
[
&
]
{
auto
data_q
=
query
.
data_ptr
<
scalar_t
>
();
auto
data_s
=
support
.
data_ptr
<
scalar_t
>
();
std
::
vector
<
scalar_t
>
queries_stl
=
std
::
vector
<
scalar_t
>
(
data_q
,
data_q
+
query
.
size
(
0
)
*
query
.
size
(
1
));
std
::
vector
<
scalar_t
>
supports_stl
=
std
::
vector
<
scalar_t
>
(
data_s
,
data_s
+
support
.
size
(
0
)
*
support
.
size
(
1
));
int
dim
=
torch
::
size
(
query
,
1
);
max_count
=
batch_nanoflann_neighbors
<
scalar_t
>
(
queries_stl
,
supports_stl
,
size_query_batch_stl
,
size_support_batch_stl
,
neighbors_indices
,
radius
,
dim
,
max_num
,
0
,
1
);
});
size_t
*
neighbors_indices_ptr
=
neighbors_indices
->
data
();
const
long
long
tsize
=
static_cast
<
long
long
>
(
neighbors_indices
->
size
()
/
2
);
out
=
torch
::
from_blob
(
neighbors_indices_ptr
,
{
tsize
,
2
},
options
=
options
);
out
=
out
.
t
();
auto
result
=
torch
::
zeros_like
(
out
);
auto
index
=
torch
::
tensor
({
1
,
0
});
result
.
index_copy_
(
0
,
index
,
out
);
return
result
;
}
\ No newline at end of file
csrc/cpu/radius_cpu.h
View file @
4e2e69be
#pragma once
#include <torch/extension.h>
#include "utils/neighbors.cpp"
#include <iostream>
torch
::
Tensor
radius_cpu
(
torch
::
Tensor
query
,
torch
::
Tensor
support
,
double
radius
,
int64_t
max_num
,
int64_t
n_threads
);
torch
::
Tensor
batch_radius_cpu
(
torch
::
Tensor
query
,
torch
::
Tensor
support
,
torch
::
Tensor
query_batch
,
torch
::
Tensor
support_batch
,
double
radius
,
int64_t
max_num
);
\ No newline at end of file
torch
::
Tensor
radius_cpu
(
torch
::
Tensor
x
,
torch
::
Tensor
y
,
torch
::
optional
<
torch
::
Tensor
>
ptr_x
,
torch
::
optional
<
torch
::
Tensor
>
ptr_y
,
double
r
,
int64_t
max_num_neighbors
,
int64_t
num_workers
);
csrc/cuda/knn_cuda.cu
View file @
4e2e69be
...
...
@@ -75,16 +75,30 @@ __global__ void knn_kernel(const scalar_t *x, const scalar_t *y,
}
}
torch
::
Tensor
knn_cuda
(
torch
::
Tensor
x
,
torch
::
Tensor
y
,
torch
::
Tensor
ptr_x
,
torch
::
Tensor
ptr_y
,
int64_t
k
,
bool
cosine
)
{
torch
::
Tensor
knn_cuda
(
torch
::
Tensor
x
,
torch
::
Tensor
y
,
torch
::
optional
<
torch
::
Tensor
>
ptr_x
,
torch
::
optional
<
torch
::
Tensor
>
ptr_y
,
int64_t
k
,
bool
cosine
)
{
CHECK_CUDA
(
x
);
CHECK_INPUT
(
x
.
dim
()
==
2
);
CHECK_CUDA
(
y
);
CHECK_CUDA
(
ptr_x
);
CHECK_CUDA
(
ptr_y
);
CHECK_INPUT
(
y
.
dim
()
==
2
);
cudaSetDevice
(
x
.
get_device
());
x
=
x
.
view
({
x
.
size
(
0
),
-
1
}).
contiguous
();
y
=
y
.
view
({
y
.
size
(
0
),
-
1
}).
contiguous
();
if
(
ptr_x
.
has_value
())
{
CHECK_CUDA
(
ptr_x
.
value
());
CHECK_INPUT
(
ptr_x
.
value
().
dim
()
==
1
);
}
else
{
ptr_x
=
torch
::
tensor
({
0
,
x
.
size
(
0
)},
x
.
options
().
dtype
(
torch
::
kLong
));
}
if
(
ptr_y
.
has_value
())
{
CHECK_CUDA
(
ptr_y
.
value
());
CHECK_INPUT
(
ptr_y
.
value
().
dim
()
==
1
);
}
else
{
ptr_y
=
torch
::
tensor
({
0
,
y
.
size
(
0
)},
y
.
options
().
dtype
(
torch
::
kLong
));
}
CHECK_INPUT
(
ptr_x
.
value
().
numel
()
==
ptr_y
.
value
().
numel
());
auto
dist
=
torch
::
full
(
y
.
size
(
0
)
*
k
,
1e38
,
y
.
options
());
auto
row
=
torch
::
empty
(
y
.
size
(
0
)
*
k
,
ptr_y
.
options
());
...
...
@@ -94,7 +108,7 @@ torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
AT_DISPATCH_FLOATING_TYPES
(
x
.
scalar_type
(),
"knn_kernel"
,
[
&
]
{
knn_kernel
<
scalar_t
><<<
ptr_x
.
size
(
0
)
-
1
,
THREADS
,
0
,
stream
>>>
(
x
.
data_ptr
<
scalar_t
>
(),
y
.
data_ptr
<
scalar_t
>
(),
ptr_x
.
data_ptr
<
int64_t
>
(),
ptr_y
.
data_ptr
<
int64_t
>
(),
ptr_x
.
value
().
data_ptr
<
int64_t
>
(),
ptr_y
.
value
().
data_ptr
<
int64_t
>
(),
dist
.
data_ptr
<
scalar_t
>
(),
row
.
data_ptr
<
int64_t
>
(),
col
.
data_ptr
<
int64_t
>
(),
k
,
x
.
size
(
1
),
cosine
);
});
...
...
csrc/cuda/knn_cuda.h
View file @
4e2e69be
...
...
@@ -2,5 +2,7 @@
#include <torch/extension.h>
torch
::
Tensor
knn_cuda
(
torch
::
Tensor
x
,
torch
::
Tensor
y
,
torch
::
Tensor
ptr_x
,
torch
::
Tensor
ptr_y
,
int64_t
k
,
bool
cosine
);
torch
::
Tensor
knn_cuda
(
torch
::
Tensor
x
,
torch
::
Tensor
y
,
torch
::
optional
<
torch
::
Tensor
>
ptr_x
,
torch
::
optional
<
torch
::
Tensor
>
ptr_y
,
int64_t
k
,
bool
cosine
);
csrc/cuda/radius_cuda.cu
View file @
4e2e69be
...
...
@@ -44,26 +44,40 @@ __global__ void radius_kernel(const scalar_t *x, const scalar_t *y,
}
}
torch
::
Tensor
radius_cuda
(
torch
::
Tensor
x
,
torch
::
Tensor
y
,
torch
::
Tensor
ptr_x
,
torch
::
Tensor
ptr_y
,
double
r
,
torch
::
Tensor
radius_cuda
(
torch
::
Tensor
x
,
torch
::
Tensor
y
,
torch
::
optional
<
torch
::
Tensor
>
ptr_x
,
torch
::
optional
<
torch
::
Tensor
>
ptr_y
,
double
r
,
int64_t
max_num_neighbors
)
{
CHECK_CUDA
(
x
);
CHECK_INPUT
(
x
.
dim
()
==
2
);
CHECK_CUDA
(
y
);
CHECK_CUDA
(
ptr_x
);
CHECK_CUDA
(
ptr_y
);
CHECK_INPUT
(
y
.
dim
()
==
2
);
cudaSetDevice
(
x
.
get_device
());
x
=
x
.
view
({
x
.
size
(
0
),
-
1
}).
contiguous
();
y
=
y
.
view
({
y
.
size
(
0
),
-
1
}).
contiguous
();
if
(
ptr_x
.
has_value
())
{
CHECK_CUDA
(
ptr_x
.
value
());
CHECK_INPUT
(
ptr_x
.
value
().
dim
()
==
1
);
}
else
{
ptr_x
=
torch
::
tensor
({
0
,
x
.
size
(
0
)},
x
.
options
().
dtype
(
torch
::
kLong
));
}
if
(
ptr_y
.
has_value
())
{
CHECK_CUDA
(
ptr_y
.
value
());
CHECK_INPUT
(
ptr_y
.
value
().
dim
()
==
1
);
}
else
{
ptr_y
=
torch
::
tensor
({
0
,
y
.
size
(
0
)},
y
.
options
().
dtype
(
torch
::
kLong
));
}
CHECK_INPUT
(
ptr_x
.
value
().
numel
()
==
ptr_y
.
value
().
numel
());
auto
row
=
torch
::
full
(
y
.
size
(
0
)
*
max_num_neighbors
,
-
1
,
ptr_y
.
options
());
auto
col
=
torch
::
full
(
y
.
size
(
0
)
*
max_num_neighbors
,
-
1
,
ptr_y
.
options
());
auto
row
=
torch
::
full
(
y
.
size
(
0
)
*
max_num_neighbors
,
-
1
,
ptr_y
.
value
().
options
());
auto
col
=
torch
::
full
(
y
.
size
(
0
)
*
max_num_neighbors
,
-
1
,
ptr_y
.
value
().
options
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES
(
x
.
scalar_type
(),
"radius_kernel"
,
[
&
]
{
radius_kernel
<
scalar_t
><<<
ptr_x
.
size
(
0
)
-
1
,
THREADS
,
0
,
stream
>>>
(
x
.
data_ptr
<
scalar_t
>
(),
y
.
data_ptr
<
scalar_t
>
(),
ptr_x
.
data_ptr
<
int64_t
>
(),
ptr_y
.
data_ptr
<
int64_t
>
(),
ptr_x
.
value
().
data_ptr
<
int64_t
>
(),
ptr_y
.
value
().
data_ptr
<
int64_t
>
(),
row
.
data_ptr
<
int64_t
>
(),
col
.
data_ptr
<
int64_t
>
(),
r
,
max_num_neighbors
,
x
.
size
(
1
));
});
...
...
csrc/cuda/radius_cuda.h
View file @
4e2e69be
...
...
@@ -2,6 +2,7 @@
#include <torch/extension.h>
torch
::
Tensor
radius_cuda
(
torch
::
Tensor
x
,
torch
::
Tensor
y
,
torch
::
Tensor
ptr_x
,
torch
::
Tensor
ptr_y
,
double
r
,
torch
::
Tensor
radius_cuda
(
torch
::
Tensor
x
,
torch
::
Tensor
y
,
torch
::
optiona
<
torch
::
Tensor
>
ptr_x
,
torch
::
optional
<
torch
::
Tensor
>
ptr_y
,
double
r
,
int64_t
max_num_neighbors
);
csrc/knn.cpp
View file @
4e2e69be
#include <Python.h>
#include <torch/script.h>
#include "cpu/knn_cpu.h"
#ifdef WITH_CUDA
#include "cuda/knn_cuda.h"
#endif
#include "cpu/knn_cpu.h"
#ifdef _WIN32
PyMODINIT_FUNC
PyInit__knn
(
void
)
{
return
NULL
;
}
#endif
torch
::
Tensor
knn
(
torch
::
Tensor
x
,
torch
::
Tensor
y
,
torch
::
optional
<
torch
::
Tensor
>
ptr_x
,
torch
::
optional
<
torch
::
Tensor
>
ptr_y
,
int64_t
k
,
bool
cosine
,
int64_t
n_threads
)
{
torch
::
Tensor
knn
(
torch
::
Tensor
x
,
torch
::
Tensor
y
,
torch
::
optional
<
torch
::
Tensor
>
ptr_x
,
torch
::
optional
<
torch
::
Tensor
>
ptr_y
,
int64_t
k
,
bool
cosine
,
int64_t
num_workers
)
{
if
(
x
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
if
(
!
(
ptr_x
.
has_value
())
&&
!
(
ptr_y
.
has_value
()))
{
auto
batch_x
=
torch
::
tensor
({
0
,
torch
::
size
(
x
,
0
)}).
to
(
torch
::
kLong
).
to
(
torch
::
kCUDA
);
auto
batch_y
=
torch
::
tensor
({
0
,
torch
::
size
(
y
,
0
)}).
to
(
torch
::
kLong
).
to
(
torch
::
kCUDA
);
return
knn_cuda
(
x
,
y
,
batch_x
,
batch_y
,
k
,
cosine
);
}
else
if
(
!
(
ptr_x
.
has_value
()))
{
auto
batch_x
=
torch
::
tensor
({
0
,
torch
::
size
(
x
,
0
)}).
to
(
torch
::
kLong
).
to
(
torch
::
kCUDA
);
auto
batch_y
=
ptr_y
.
value
();
return
knn_cuda
(
x
,
y
,
batch_x
,
batch_y
,
k
,
cosine
);
}
else
if
(
!
(
ptr_y
.
has_value
()))
{
auto
batch_x
=
ptr_x
.
value
();
auto
batch_y
=
torch
::
tensor
({
0
,
torch
::
size
(
y
,
0
)}).
to
(
torch
::
kLong
).
to
(
torch
::
kCUDA
);
return
knn_cuda
(
x
,
y
,
batch_x
,
batch_y
,
k
,
cosine
);
}
auto
batch_x
=
ptr_x
.
value
();
auto
batch_y
=
ptr_y
.
value
();
return
knn_cuda
(
x
,
y
,
batch_x
,
batch_y
,
k
,
cosine
);
return
knn_cuda
(
x
,
y
,
ptr_x
,
ptr_x
,
k
,
cosine
);
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
if
(
cosine
)
{
if
(
cosine
)
AT_ERROR
(
"`cosine` argument not supported on CPU"
);
}
if
(
!
(
ptr_x
.
has_value
())
&&
!
(
ptr_y
.
has_value
()))
{
return
knn_cpu
(
x
,
y
,
k
,
n_threads
);
}
if
(
!
(
ptr_x
.
has_value
()))
{
auto
batch_x
=
torch
::
zeros
({
torch
::
size
(
x
,
0
)}).
to
(
torch
::
kLong
);
auto
batch_y
=
ptr_y
.
value
();
return
batch_knn_cpu
(
x
,
y
,
batch_x
,
batch_y
,
k
);
}
else
if
(
!
(
ptr_y
.
has_value
()))
{
auto
batch_x
=
ptr_x
.
value
();
auto
batch_y
=
torch
::
zeros
({
torch
::
size
(
y
,
0
)}).
to
(
torch
::
kLong
);
return
batch_knn_cpu
(
x
,
y
,
batch_x
,
batch_y
,
k
);
}
auto
batch_x
=
ptr_x
.
value
();
auto
batch_y
=
ptr_y
.
value
();
return
batch_knn_cpu
(
x
,
y
,
batch_x
,
batch_y
,
k
);
return
knn_cpu
(
x
,
y
,
ptr_x
,
ptr_y
,
k
,
num_workers
);
}
}
...
...
csrc/radius.cpp
View file @
4e2e69be
#include <Python.h>
#include <torch/script.h>
#include <iostream>
#include "cpu/radius_cpu.h"
#ifdef WITH_CUDA
#include "cuda/radius_cuda.h"
#endif
#include "cpu/radius_cpu.h"
#ifdef _WIN32
PyMODINIT_FUNC
PyInit__radius
(
void
)
{
return
NULL
;
}
#endif
torch
::
Tensor
radius
(
torch
::
Tensor
x
,
torch
::
Tensor
y
,
torch
::
optional
<
torch
::
Tensor
>
ptr_x
,
torch
::
optional
<
torch
::
Tensor
>
ptr_y
,
double
r
,
int64_t
max_num_neighbors
,
int64_t
n_threads
)
{
torch
::
Tensor
radius
(
torch
::
Tensor
x
,
torch
::
Tensor
y
,
torch
::
optional
<
torch
::
Tensor
>
ptr_x
,
torch
::
optional
<
torch
::
Tensor
>
ptr_y
,
double
r
,
int64_t
max_num_neighbors
,
int64_t
num_workers
)
{
if
(
x
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
if
(
!
(
ptr_x
.
has_value
())
&&
!
(
ptr_y
.
has_value
()))
{
auto
batch_x
=
torch
::
tensor
({
0
,
torch
::
size
(
x
,
0
)}).
to
(
torch
::
kLong
).
to
(
torch
::
kCUDA
);
auto
batch_y
=
torch
::
tensor
({
0
,
torch
::
size
(
y
,
0
)}).
to
(
torch
::
kLong
).
to
(
torch
::
kCUDA
);
return
radius_cuda
(
x
,
y
,
batch_x
,
batch_y
,
r
,
max_num_neighbors
);
}
else
if
(
!
(
ptr_x
.
has_value
()))
{
auto
batch_x
=
torch
::
tensor
({
0
,
torch
::
size
(
x
,
0
)}).
to
(
torch
::
kLong
).
to
(
torch
::
kCUDA
);
auto
batch_y
=
ptr_y
.
value
();
return
radius_cuda
(
x
,
y
,
batch_x
,
batch_y
,
r
,
max_num_neighbors
);
}
else
if
(
!
(
ptr_y
.
has_value
()))
{
auto
batch_x
=
ptr_x
.
value
();
auto
batch_y
=
torch
::
tensor
({
0
,
torch
::
size
(
y
,
0
)}).
to
(
torch
::
kLong
).
to
(
torch
::
kCUDA
);
return
radius_cuda
(
x
,
y
,
batch_x
,
batch_y
,
r
,
max_num_neighbors
);
}
auto
batch_x
=
ptr_x
.
value
();
auto
batch_y
=
ptr_y
.
value
();
return
radius_cuda
(
x
,
y
,
batch_x
,
batch_y
,
r
,
max_num_neighbors
);
return
radius_cuda
(
x
,
y
,
ptr_x
,
ptr_y
,
r
,
max_num_neighbors
);
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
if
(
!
(
ptr_x
.
has_value
())
&&
!
(
ptr_y
.
has_value
()))
{
return
radius_cpu
(
x
,
y
,
r
,
max_num_neighbors
,
n_threads
);
}
if
(
!
(
ptr_x
.
has_value
()))
{
auto
batch_x
=
torch
::
zeros
({
torch
::
size
(
x
,
0
)}).
to
(
torch
::
kLong
);
auto
batch_y
=
ptr_y
.
value
();
return
batch_radius_cpu
(
x
,
y
,
batch_x
,
batch_y
,
r
,
max_num_neighbors
);
}
else
if
(
!
(
ptr_y
.
has_value
()))
{
auto
batch_x
=
ptr_x
.
value
();
auto
batch_y
=
torch
::
zeros
({
torch
::
size
(
y
,
0
)}).
to
(
torch
::
kLong
);
return
batch_radius_cpu
(
x
,
y
,
batch_x
,
batch_y
,
r
,
max_num_neighbors
);
}
auto
batch_x
=
ptr_x
.
value
();
auto
batch_y
=
ptr_y
.
value
();
return
batch_radius_cpu
(
x
,
y
,
batch_x
,
batch_y
,
r
,
max_num_neighbors
);
return
radius_cpu
(
x
,
y
,
ptr_x
,
ptr_y
,
r
,
max_num_neighbors
,
num_workers
);
}
}
...
...
setup.py
View file @
4e2e69be
...
...
@@ -57,9 +57,9 @@ def get_extensions():
return
extensions
install_requires
=
[
'scipy'
]
install_requires
=
[]
setup_requires
=
[
'pytest-runner'
]
tests_require
=
[
'pytest'
,
'pytest-cov'
]
tests_require
=
[
'pytest'
,
'pytest-cov'
,
'scipy'
]
setup
(
name
=
'torch_cluster'
,
...
...
test/test_knn.py
View file @
4e2e69be
...
...
@@ -2,7 +2,9 @@ from itertools import product
import
pytest
import
torch
import
scipy.spatial
from
torch_cluster
import
knn
,
knn_graph
from
.utils
import
grad_dtypes
,
devices
,
tensor
...
...
@@ -26,9 +28,11 @@ def test_knn(dtype, device):
batch_x
=
tensor
([
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
],
torch
.
long
,
device
)
batch_y
=
tensor
([
0
,
1
],
torch
.
long
,
device
)
row
,
col
=
knn
(
x
,
y
,
2
,
batch_x
,
batch_y
)
col
=
col
.
view
(
-
1
,
2
).
sort
(
dim
=-
1
)[
0
].
view
(
-
1
)
row
,
col
=
knn
(
x
,
y
,
2
)
assert
row
.
tolist
()
==
[
0
,
0
,
1
,
1
]
assert
col
.
tolist
()
==
[
2
,
3
,
0
,
1
]
row
,
col
=
knn
(
x
,
y
,
2
,
batch_x
,
batch_y
)
assert
row
.
tolist
()
==
[
0
,
0
,
1
,
1
]
assert
col
.
tolist
()
==
[
2
,
3
,
4
,
5
]
...
...
@@ -48,55 +52,24 @@ def test_knn_graph(dtype, device):
],
dtype
,
device
)
row
,
col
=
knn_graph
(
x
,
k
=
2
,
flow
=
'target_to_source'
)
col
=
col
.
view
(
-
1
,
2
).
sort
(
dim
=-
1
)[
0
].
view
(
-
1
)
assert
row
.
tolist
()
==
[
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
]
assert
col
.
tolist
()
==
[
1
,
3
,
0
,
2
,
1
,
3
,
0
,
2
]
row
,
col
=
knn_graph
(
x
,
k
=
2
,
flow
=
'source_to_target'
)
row
=
row
.
view
(
-
1
,
2
).
sort
(
dim
=-
1
)[
0
].
view
(
-
1
)
assert
row
.
tolist
()
==
[
1
,
3
,
0
,
2
,
1
,
3
,
0
,
2
]
assert
col
.
tolist
()
==
[
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
]
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_knn_graph_large
(
dtype
,
device
):
x
=
torch
.
tensor
([[
-
1.0320
,
0.2380
,
0.2380
],
[
-
1.3050
,
-
0.0930
,
0.6420
],
[
-
0.3190
,
-
0.0410
,
1.2150
],
[
1.1400
,
-
0.5390
,
-
0.3140
],
[
0.8410
,
0.8290
,
0.6090
],
[
-
1.4380
,
-
0.2420
,
-
0.3260
],
[
-
2.2980
,
0.7160
,
0.9320
],
[
-
1.3680
,
-
0.4390
,
0.1380
],
[
-
0.6710
,
0.6060
,
1.1800
],
[
0.3950
,
-
0.0790
,
1.4920
]],).
to
(
device
)
k
=
3
truth
=
set
({(
4
,
8
),
(
2
,
8
),
(
9
,
8
),
(
8
,
0
),
(
0
,
7
),
(
2
,
1
),
(
9
,
4
),
(
5
,
1
),
(
4
,
9
),
(
2
,
9
),
(
8
,
1
),
(
1
,
5
),
(
5
,
0
),
(
3
,
2
),
(
8
,
2
),
(
7
,
1
),
(
6
,
0
),
(
3
,
9
),
(
0
,
5
),
(
7
,
5
),
(
4
,
2
),
(
1
,
0
),
(
0
,
1
),
(
7
,
0
),
(
6
,
8
),
(
9
,
2
),
(
6
,
1
),
(
5
,
7
),
(
1
,
7
),
(
3
,
4
)})
row
,
col
=
knn_graph
(
x
,
k
=
k
,
flow
=
'target_to_source'
,
batch
=
None
,
n_threads
=
24
,
loop
=
False
)
edges
=
set
([(
i
,
j
)
for
(
i
,
j
)
in
zip
(
list
(
row
.
cpu
().
numpy
()),
list
(
col
.
cpu
().
numpy
()))])
assert
(
truth
==
edges
)
row
,
col
=
knn_graph
(
x
,
k
=
k
,
flow
=
'target_to_source'
,
batch
=
None
,
n_threads
=
12
,
loop
=
False
)
edges
=
set
([(
i
,
j
)
for
(
i
,
j
)
in
zip
(
list
(
row
.
cpu
().
numpy
()),
list
(
col
.
cpu
().
numpy
()))])
assert
(
truth
==
edges
)
row
,
col
=
knn_graph
(
x
,
k
=
k
,
flow
=
'target_to_source'
,
batch
=
None
,
n_threads
=
1
,
loop
=
False
)
edges
=
set
([(
i
,
j
)
for
(
i
,
j
)
in
zip
(
list
(
row
.
cpu
().
numpy
()),
list
(
col
.
cpu
().
numpy
()))])
assert
(
truth
==
edges
)
x
=
torch
.
randn
(
1000
,
3
)
row
,
col
=
knn_graph
(
x
,
k
=
5
,
flow
=
'target_to_source'
,
loop
=
True
,
num_workers
=
6
)
pred
=
set
([(
i
,
j
)
for
i
,
j
in
zip
(
row
.
tolist
(),
col
.
tolist
())])
tree
=
scipy
.
spatial
.
cKDTree
(
x
.
numpy
())
_
,
col
=
tree
.
query
(
x
.
cpu
(),
k
=
5
)
truth
=
set
([(
i
,
j
)
for
i
,
ns
in
enumerate
(
col
)
for
j
in
ns
])
assert
pred
==
truth
test/test_radius.py
View file @
4e2e69be
This diff is collapsed.
Click to expand it.
torch_cluster/knn.py
View file @
4e2e69be
...
...
@@ -48,7 +48,9 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
x
=
x
.
view
(
-
1
,
1
)
if
x
.
dim
()
==
1
else
x
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
x
,
y
=
x
.
contiguous
(),
y
.
contiguous
()
ptr_x
:
Optional
[
torch
.
Tensor
]
=
None
if
batch_x
is
not
None
:
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
batch_size
=
int
(
batch_x
.
max
())
+
1
...
...
@@ -59,6 +61,7 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
ptr_x
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_x
[
1
:])
ptr_y
:
Optional
[
torch
.
Tensor
]
=
None
if
batch_y
is
not
None
:
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
batch_size
=
int
(
batch_y
.
max
())
+
1
...
...
@@ -68,8 +71,6 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
ptr_y
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_y
[
1
:])
else
:
ptr_y
=
torch
.
tensor
([
0
,
y
.
size
(
0
)],
device
=
y
.
device
)
return
torch
.
ops
.
torch_cluster
.
knn
(
x
,
y
,
ptr_x
,
ptr_y
,
k
,
cosine
,
num_workers
)
...
...
@@ -114,10 +115,16 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
"""
assert
flow
in
[
'source_to_target'
,
'target_to_source'
]
row
,
col
=
knn
(
x
,
x
,
k
if
loop
else
k
+
1
,
batch
,
batch
,
cosine
,
num_workers
)
row
,
col
=
(
col
,
row
)
if
flow
==
'source_to_target'
else
(
row
,
col
)
edge_index
=
knn
(
x
,
x
,
k
if
loop
else
k
+
1
,
batch
,
batch
,
cosine
,
num_workers
)
if
flow
==
'source_to_target'
:
row
,
col
=
edge_index
[
1
],
edge_index
[
0
]
else
:
row
,
col
=
edge_index
[
0
],
edge_index
[
1
]
if
not
loop
:
mask
=
row
!=
col
row
,
col
=
row
[
mask
],
col
[
mask
]
return
torch
.
stack
([
row
,
col
],
dim
=
0
)
torch_cluster/radius.py
View file @
4e2e69be
...
...
@@ -45,7 +45,9 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
x
=
x
.
view
(
-
1
,
1
)
if
x
.
dim
()
==
1
else
x
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
x
,
y
=
x
.
contiguous
(),
y
.
contiguous
()
ptr_x
:
Optional
[
torch
.
Tensor
]
=
None
if
batch_x
is
not
None
:
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
batch_size
=
int
(
batch_x
.
max
())
+
1
...
...
@@ -55,9 +57,8 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
ptr_x
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_x
[
1
:])
else
:
ptr_x
=
None
ptr_y
:
Optional
[
torch
.
Tensor
]
=
None
if
batch_y
is
not
None
:
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
batch_size
=
int
(
batch_y
.
max
())
+
1
...
...
@@ -67,8 +68,6 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
ptr_y
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_y
[
1
:])
else
:
ptr_y
=
None
return
torch
.
ops
.
torch_cluster
.
radius
(
x
,
y
,
ptr_x
,
ptr_y
,
r
,
max_num_neighbors
,
num_workers
)
...
...
@@ -113,11 +112,16 @@ def radius_graph(x: torch.Tensor, r: float,
"""
assert
flow
in
[
'source_to_target'
,
'target_to_source'
]
row
,
col
=
radius
(
x
,
x
,
r
,
batch
,
batch
,
max_num_neighbors
if
loop
else
max_num_neighbors
+
1
,
num_workers
)
row
,
col
=
(
col
,
row
)
if
flow
==
'source_to_target'
else
(
row
,
col
)
edge_index
=
radius
(
x
,
x
,
r
,
batch
,
batch
,
max_num_neighbors
if
loop
else
max_num_neighbors
+
1
,
num_workers
)
if
flow
==
'source_to_target'
:
row
,
col
=
edge_index
[
1
],
edge_index
[
0
]
else
:
row
,
col
=
edge_index
[
0
],
edge_index
[
1
]
if
not
loop
:
mask
=
row
!=
col
row
,
col
=
row
[
mask
],
col
[
mask
]
return
torch
.
stack
([
row
,
col
],
dim
=
0
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment