Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
e8620a86
Unverified
Commit
e8620a86
authored
Dec 28, 2021
by
Matthias Fey
Committed by
GitHub
Dec 28, 2021
Browse files
Half-precision support (#119)
* half support * deprecation * typo * test half * fix test
parent
0d735d7e
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
25 additions
and
18 deletions
+25
-18
csrc/cpu/graclus_cpu.cpp
csrc/cpu/graclus_cpu.cpp
+2
-1
csrc/cpu/knn_cpu.cpp
csrc/cpu/knn_cpu.cpp
+1
-1
csrc/cpu/radius_cpu.cpp
csrc/cpu/radius_cpu.cpp
+1
-1
csrc/cuda/fps_cuda.cu
csrc/cuda/fps_cuda.cu
+5
-4
csrc/cuda/graclus_cuda.cu
csrc/cuda/graclus_cuda.cu
+4
-2
csrc/cuda/grid_cuda.cu
csrc/cuda/grid_cuda.cu
+1
-1
csrc/cuda/knn_cuda.cu
csrc/cuda/knn_cuda.cu
+3
-2
csrc/cuda/nearest_cuda.cu
csrc/cuda/nearest_cuda.cu
+2
-1
csrc/cuda/radius_cuda.cu
csrc/cuda/radius_cuda.cu
+2
-1
test/test_knn.py
test/test_knn.py
+1
-1
test/test_radius.py
test/test_radius.py
+1
-1
test/utils.py
test/utils.py
+2
-2
No files found.
csrc/cpu/graclus_cpu.cpp
View file @
e8620a86
...
...
@@ -46,7 +46,8 @@ torch::Tensor graclus_cpu(torch::Tensor rowptr, torch::Tensor col,
}
}
else
{
auto
weight
=
optional_weight
.
value
();
AT_DISPATCH_ALL_TYPES
(
weight
.
scalar_type
(),
"weighted_graclus"
,
[
&
]
{
auto
scalar_type
=
weight
.
scalar_type
();
AT_DISPATCH_ALL_TYPES_AND
(
at
::
ScalarType
::
Half
,
scalar_type
,
"_"
,
[
&
]
{
auto
weight_data
=
weight
.
data_ptr
<
scalar_t
>
();
for
(
auto
n
=
0
;
n
<
num_nodes
;
n
++
)
{
...
...
csrc/cpu/knn_cpu.cpp
View file @
e8620a86
...
...
@@ -25,7 +25,7 @@ torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y,
std
::
vector
<
size_t
>
out_vec
=
std
::
vector
<
size_t
>
();
AT_DISPATCH_ALL_TYPES
(
x
.
scalar_type
(),
"
knn_cpu
"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES
_AND
(
at
::
ScalarType
::
Half
,
x
.
scalar_type
(),
"
_
"
,
[
&
]
{
// See: nanoflann/examples/vector_of_vectors_example.cpp
auto
x_data
=
x
.
data_ptr
<
scalar_t
>
();
...
...
csrc/cpu/radius_cpu.cpp
View file @
e8620a86
...
...
@@ -25,7 +25,7 @@ torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
std
::
vector
<
size_t
>
out_vec
=
std
::
vector
<
size_t
>
();
AT_DISPATCH_ALL_TYPES
(
x
.
scalar_type
(),
"
radius_cpu
"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES
_AND
(
at
::
ScalarType
::
Half
,
x
.
scalar_type
(),
"
_
"
,
[
&
]
{
// See: nanoflann/examples/vector_of_vectors_example.cpp
auto
x_data
=
x
.
data_ptr
<
scalar_t
>
();
...
...
csrc/cuda/fps_cuda.cu
View file @
e8620a86
...
...
@@ -78,19 +78,19 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
auto
batch_size
=
ptr
.
numel
()
-
1
;
auto
deg
=
ptr
.
narrow
(
0
,
1
,
batch_size
)
-
ptr
.
narrow
(
0
,
0
,
batch_size
);
auto
out_ptr
=
deg
.
toType
(
torch
::
kFloat
)
*
ratio
;
auto
out_ptr
=
deg
.
toType
(
ratio
.
scalar_type
()
)
*
ratio
;
out_ptr
=
out_ptr
.
ceil
().
toType
(
torch
::
kLong
).
cumsum
(
0
);
out_ptr
=
torch
::
cat
({
torch
::
zeros
(
1
,
ptr
.
options
()),
out_ptr
},
0
);
torch
::
Tensor
start
;
if
(
random_start
)
{
start
=
torch
::
rand
(
batch_size
,
src
.
options
());
start
=
(
start
*
deg
.
toType
(
torch
::
kFloat
)).
toType
(
torch
::
kLong
);
start
=
(
start
*
deg
.
toType
(
ratio
.
scalar_type
()
)).
toType
(
torch
::
kLong
);
}
else
{
start
=
torch
::
zeros
(
batch_size
,
ptr
.
options
());
}
auto
dist
=
torch
::
full
(
src
.
size
(
0
),
1e38
,
src
.
options
());
auto
dist
=
torch
::
full
(
src
.
size
(
0
),
5e4
,
src
.
options
());
auto
out_size
=
(
int64_t
*
)
malloc
(
sizeof
(
int64_t
));
cudaMemcpy
(
out_size
,
out_ptr
[
-
1
].
data_ptr
<
int64_t
>
(),
sizeof
(
int64_t
),
...
...
@@ -98,7 +98,8 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
auto
out
=
torch
::
empty
(
out_size
[
0
],
out_ptr
.
options
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES
(
src
.
scalar_type
(),
"fps_kernel"
,
[
&
]
{
auto
scalar_type
=
src
.
scalar_type
();
AT_DISPATCH_FLOATING_TYPES_AND
(
at
::
ScalarType
::
Half
,
scalar_type
,
"_"
,
[
&
]
{
fps_kernel
<
scalar_t
><<<
batch_size
,
THREADS
,
0
,
stream
>>>
(
src
.
data_ptr
<
scalar_t
>
(),
ptr
.
data_ptr
<
int64_t
>
(),
out_ptr
.
data_ptr
<
int64_t
>
(),
start
.
data_ptr
<
int64_t
>
(),
...
...
csrc/cuda/graclus_cuda.cu
View file @
e8620a86
...
...
@@ -113,7 +113,8 @@ void propose(torch::Tensor out, torch::Tensor proposal, torch::Tensor rowptr,
rowptr
.
data_ptr
<
int64_t
>
(),
col
.
data_ptr
<
int64_t
>
(),
out
.
numel
());
}
else
{
auto
weight
=
optional_weight
.
value
();
AT_DISPATCH_ALL_TYPES
(
weight
.
scalar_type
(),
"propose_kernel"
,
[
&
]
{
auto
scalar_type
=
weight
.
scalar_type
();
AT_DISPATCH_ALL_TYPES_AND
(
at
::
ScalarType
::
Half
,
scalar_type
,
"_"
,
[
&
]
{
weighted_propose_kernel
<
scalar_t
>
<<<
BLOCKS
(
out
.
numel
()),
THREADS
,
0
,
stream
>>>
(
out
.
data_ptr
<
int64_t
>
(),
proposal
.
data_ptr
<
int64_t
>
(),
...
...
@@ -201,7 +202,8 @@ void respond(torch::Tensor out, torch::Tensor proposal, torch::Tensor rowptr,
rowptr
.
data_ptr
<
int64_t
>
(),
col
.
data_ptr
<
int64_t
>
(),
out
.
numel
());
}
else
{
auto
weight
=
optional_weight
.
value
();
AT_DISPATCH_ALL_TYPES
(
weight
.
scalar_type
(),
"respond_kernel"
,
[
&
]
{
auto
scalar_type
=
weight
.
scalar_type
();
AT_DISPATCH_ALL_TYPES_AND
(
at
::
ScalarType
::
Half
,
scalar_type
,
"_"
,
[
&
]
{
weighted_respond_kernel
<
scalar_t
>
<<<
BLOCKS
(
out
.
numel
()),
THREADS
,
0
,
stream
>>>
(
out
.
data_ptr
<
int64_t
>
(),
proposal
.
data_ptr
<
int64_t
>
(),
...
...
csrc/cuda/grid_cuda.cu
View file @
e8620a86
...
...
@@ -61,7 +61,7 @@ torch::Tensor grid_cuda(torch::Tensor pos, torch::Tensor size,
auto
out
=
torch
::
empty
(
pos
.
size
(
0
),
pos
.
options
().
dtype
(
torch
::
kLong
));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_ALL_TYPES
(
pos
.
scalar_type
(),
"
grid_kernel
"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES
_AND
(
at
::
ScalarType
::
Half
,
pos
.
scalar_type
(),
"
_
"
,
[
&
]
{
grid_kernel
<
scalar_t
><<<
BLOCKS
(
out
.
numel
()),
THREADS
,
0
,
stream
>>>
(
pos
.
data_ptr
<
scalar_t
>
(),
size
.
data_ptr
<
scalar_t
>
(),
start
.
data_ptr
<
scalar_t
>
(),
end
.
data_ptr
<
scalar_t
>
(),
...
...
csrc/cuda/knn_cuda.cu
View file @
e8620a86
...
...
@@ -45,7 +45,7 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
int64_t
best_idx
[
100
];
for
(
int
e
=
0
;
e
<
k
;
e
++
)
{
best_dist
[
e
]
=
1e10
;
best_dist
[
e
]
=
5e4
;
best_idx
[
e
]
=
-
1
;
}
...
...
@@ -121,7 +121,8 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
dim3
BLOCKS
((
y
.
size
(
0
)
+
THREADS
-
1
)
/
THREADS
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES
(
x
.
scalar_type
(),
"knn_kernel"
,
[
&
]
{
auto
scalar_type
=
x
.
scalar_type
();
AT_DISPATCH_FLOATING_TYPES_AND
(
at
::
ScalarType
::
Half
,
scalar_type
,
"_"
,
[
&
]
{
knn_kernel
<
scalar_t
><<<
BLOCKS
,
THREADS
,
0
,
stream
>>>
(
x
.
data_ptr
<
scalar_t
>
(),
y
.
data_ptr
<
scalar_t
>
(),
ptr_x
.
value
().
data_ptr
<
int64_t
>
(),
ptr_y
.
value
().
data_ptr
<
int64_t
>
(),
...
...
csrc/cuda/nearest_cuda.cu
View file @
e8620a86
...
...
@@ -79,7 +79,8 @@ torch::Tensor nearest_cuda(torch::Tensor x, torch::Tensor y,
auto
out
=
torch
::
empty
({
x
.
size
(
0
)},
ptr_x
.
options
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES
(
x
.
scalar_type
(),
"nearest_kernel"
,
[
&
]
{
auto
scalar_type
=
x
.
scalar_type
();
AT_DISPATCH_FLOATING_TYPES_AND
(
at
::
ScalarType
::
Half
,
scalar_type
,
"_"
,
[
&
]
{
nearest_kernel
<
scalar_t
><<<
x
.
size
(
0
),
THREADS
,
0
,
stream
>>>
(
x
.
data_ptr
<
scalar_t
>
(),
y
.
data_ptr
<
scalar_t
>
(),
ptr_x
.
data_ptr
<
int64_t
>
(),
ptr_y
.
data_ptr
<
int64_t
>
(),
...
...
csrc/cuda/radius_cuda.cu
View file @
e8620a86
...
...
@@ -80,7 +80,8 @@ torch::Tensor radius_cuda(const torch::Tensor x, const torch::Tensor y,
dim3
BLOCKS
((
y
.
size
(
0
)
+
THREADS
-
1
)
/
THREADS
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES
(
x
.
scalar_type
(),
"radius_kernel"
,
[
&
]
{
auto
scalar_type
=
x
.
scalar_type
();
AT_DISPATCH_FLOATING_TYPES_AND
(
at
::
ScalarType
::
Half
,
scalar_type
,
"_"
,
[
&
]
{
radius_kernel
<
scalar_t
><<<
BLOCKS
,
THREADS
,
0
,
stream
>>>
(
x
.
data_ptr
<
scalar_t
>
(),
y
.
data_ptr
<
scalar_t
>
(),
ptr_x
.
value
().
data_ptr
<
int64_t
>
(),
ptr_y
.
value
().
data_ptr
<
int64_t
>
(),
...
...
test/test_knn.py
View file @
e8620a86
...
...
@@ -67,7 +67,7 @@ def test_knn_graph(dtype, device):
(
3
,
2
),
(
0
,
3
),
(
2
,
3
)])
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
[
torch
.
float
]
,
devices
))
def
test_knn_graph_large
(
dtype
,
device
):
x
=
torch
.
randn
(
1000
,
3
,
dtype
=
dtype
,
device
=
device
)
...
...
test/test_radius.py
View file @
e8620a86
...
...
@@ -66,7 +66,7 @@ def test_radius_graph(dtype, device):
(
3
,
2
),
(
0
,
3
),
(
2
,
3
)])
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
[
torch
.
float
]
,
devices
))
def
test_radius_graph_large
(
dtype
,
device
):
x
=
torch
.
randn
(
1000
,
3
,
dtype
=
dtype
,
device
=
device
)
...
...
test/utils.py
View file @
e8620a86
import
torch
dtypes
=
[
torch
.
float
,
torch
.
double
,
torch
.
int
,
torch
.
long
]
grad_dtypes
=
[
torch
.
float
,
torch
.
double
]
dtypes
=
[
torch
.
half
,
torch
.
float
,
torch
.
double
,
torch
.
int
,
torch
.
long
]
grad_dtypes
=
[
torch
.
half
,
torch
.
float
,
torch
.
double
]
devices
=
[
torch
.
device
(
'cpu'
)]
if
torch
.
cuda
.
is_available
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment