Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
01fa24ee
Unverified
Commit
01fa24ee
authored
Apr 20, 2022
by
Matthias Fey
Committed by
GitHub
Apr 20, 2022
Browse files
adjust tensor creation (#127)
parent
516d988d
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
7 additions
and
7 deletions
+7
-7
csrc/cpu/fps_cpu.cpp
csrc/cpu/fps_cpu.cpp
+1
-1
csrc/cpu/grid_cpu.cpp
csrc/cpu/grid_cpu.cpp
+1
-1
csrc/cuda/fps_cuda.cu
csrc/cuda/fps_cuda.cu
+3
-3
csrc/cuda/grid_cuda.cu
csrc/cuda/grid_cuda.cu
+1
-1
csrc/cuda/knn_cuda.cu
csrc/cuda/knn_cuda.cu
+1
-1
No files found.
csrc/cpu/fps_cpu.cpp
View file @
01fa24ee
...
...
@@ -24,7 +24,7 @@ torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
auto
out_ptr
=
deg
.
toType
(
torch
::
kFloat
)
*
ratio
;
out_ptr
=
out_ptr
.
ceil
().
toType
(
torch
::
kLong
).
cumsum
(
0
);
auto
out
=
torch
::
empty
(
out_ptr
[
-
1
].
data_ptr
<
int64_t
>
()[
0
],
ptr
.
options
());
auto
out
=
torch
::
empty
(
{
out_ptr
[
-
1
].
data_ptr
<
int64_t
>
()[
0
]
}
,
ptr
.
options
());
auto
ptr_data
=
ptr
.
data_ptr
<
int64_t
>
();
auto
out_ptr_data
=
out_ptr
.
data_ptr
<
int64_t
>
();
...
...
csrc/cpu/grid_cpu.cpp
View file @
01fa24ee
...
...
@@ -35,7 +35,7 @@ torch::Tensor grid_cpu(torch::Tensor pos, torch::Tensor size,
auto
num_voxels
=
(
end
-
start
).
true_divide
(
size
).
toType
(
torch
::
kLong
)
+
1
;
num_voxels
=
num_voxels
.
cumprod
(
0
);
num_voxels
=
torch
::
cat
({
torch
::
ones
(
1
,
num_voxels
.
options
()),
num_voxels
},
0
);
torch
::
cat
({
torch
::
ones
(
{
1
}
,
num_voxels
.
options
()),
num_voxels
},
0
);
num_voxels
=
num_voxels
.
narrow
(
0
,
0
,
size
.
size
(
0
));
auto
out
=
pos
.
true_divide
(
size
.
view
({
1
,
-
1
})).
toType
(
torch
::
kLong
);
...
...
csrc/cuda/fps_cuda.cu
View file @
01fa24ee
...
...
@@ -80,14 +80,14 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
auto
deg
=
ptr
.
narrow
(
0
,
1
,
batch_size
)
-
ptr
.
narrow
(
0
,
0
,
batch_size
);
auto
out_ptr
=
deg
.
toType
(
ratio
.
scalar_type
())
*
ratio
;
out_ptr
=
out_ptr
.
ceil
().
toType
(
torch
::
kLong
).
cumsum
(
0
);
out_ptr
=
torch
::
cat
({
torch
::
zeros
(
1
,
ptr
.
options
()),
out_ptr
},
0
);
out_ptr
=
torch
::
cat
({
torch
::
zeros
(
{
1
}
,
ptr
.
options
()),
out_ptr
},
0
);
torch
::
Tensor
start
;
if
(
random_start
)
{
start
=
torch
::
rand
(
batch_size
,
src
.
options
());
start
=
(
start
*
deg
.
toType
(
ratio
.
scalar_type
())).
toType
(
torch
::
kLong
);
}
else
{
start
=
torch
::
zeros
(
batch_size
,
ptr
.
options
());
start
=
torch
::
zeros
(
{
batch_size
}
,
ptr
.
options
());
}
auto
dist
=
torch
::
full
(
src
.
size
(
0
),
5e4
,
src
.
options
());
...
...
@@ -95,7 +95,7 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
auto
out_size
=
(
int64_t
*
)
malloc
(
sizeof
(
int64_t
));
cudaMemcpy
(
out_size
,
out_ptr
[
-
1
].
data_ptr
<
int64_t
>
(),
sizeof
(
int64_t
),
cudaMemcpyDeviceToHost
);
auto
out
=
torch
::
empty
(
out_size
[
0
],
out_ptr
.
options
());
auto
out
=
torch
::
empty
(
{
out_size
[
0
]
}
,
out_ptr
.
options
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
scalar_type
=
src
.
scalar_type
();
...
...
csrc/cuda/grid_cuda.cu
View file @
01fa24ee
...
...
@@ -58,7 +58,7 @@ torch::Tensor grid_cuda(torch::Tensor pos, torch::Tensor size,
auto
start
=
optional_start
.
value
();
auto
end
=
optional_end
.
value
();
auto
out
=
torch
::
empty
(
pos
.
size
(
0
),
pos
.
options
().
dtype
(
torch
::
kLong
));
auto
out
=
torch
::
empty
(
{
pos
.
size
(
0
)
}
,
pos
.
options
().
dtype
(
torch
::
kLong
));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_ALL_TYPES_AND
(
at
::
ScalarType
::
Half
,
pos
.
scalar_type
(),
"_"
,
[
&
]
{
...
...
csrc/cuda/knn_cuda.cu
View file @
01fa24ee
...
...
@@ -115,7 +115,7 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
cudaSetDevice
(
x
.
get_device
());
auto
row
=
torch
::
empty
(
y
.
size
(
0
)
*
k
,
ptr_y
.
value
().
options
());
auto
row
=
torch
::
empty
(
{
y
.
size
(
0
)
*
k
}
,
ptr_y
.
value
().
options
());
auto
col
=
torch
::
full
(
y
.
size
(
0
)
*
k
,
-
1
,
ptr_y
.
value
().
options
());
dim3
BLOCKS
((
y
.
size
(
0
)
+
THREADS
-
1
)
/
THREADS
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment