Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
54d23137
Commit
54d23137
authored
Jul 16, 2020
by
rusty1s
Browse files
use last index for dim_size in segment_coo
parent
524326d0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
3 deletions
+8
-3
csrc/cpu/segment_coo_cpu.cpp
csrc/cpu/segment_coo_cpu.cpp
+5
-2
csrc/cuda/segment_coo_cuda.cu
csrc/cuda/segment_coo_cuda.cu
+3
-1
No files found.
csrc/cpu/segment_coo_cpu.cpp
View file @
54d23137
...
@@ -36,8 +36,11 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
...
@@ -36,8 +36,11 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
sizes
[
dim
]
=
dim_size
.
value
();
sizes
[
dim
]
=
dim_size
.
value
();
else
if
(
index
.
numel
()
==
0
)
else
if
(
index
.
numel
()
==
0
)
sizes
[
dim
]
=
0
;
sizes
[
dim
]
=
0
;
else
else
{
sizes
[
dim
]
=
1
+
*
index
.
max
().
data_ptr
<
int64_t
>
();
auto
tmp
=
index
.
select
(
dim
,
index
.
size
(
dim
)
-
1
);
tmp
=
tmp
.
numel
()
>
1
?
tmp
.
max
()
:
tmp
;
sizes
[
dim
]
=
1
+
*
tmp
.
data_ptr
<
int64_t
>
();
}
out
=
torch
::
empty
(
sizes
,
src
.
options
());
out
=
torch
::
empty
(
sizes
,
src
.
options
());
}
}
...
...
csrc/cuda/segment_coo_cuda.cu
View file @
54d23137
...
@@ -184,7 +184,9 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
...
@@ -184,7 +184,9 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
else
if
(
index
.
numel
()
==
0
)
else
if
(
index
.
numel
()
==
0
)
sizes
[
dim
]
=
0
;
sizes
[
dim
]
=
0
;
else
{
else
{
auto
d_size
=
index
.
max
().
data_ptr
<
int64_t
>
();
auto
tmp
=
index
.
select
(
dim
,
index
.
size
(
dim
)
-
1
);
tmp
=
tmp
.
numel
()
>
1
?
tmp
.
max
()
:
tmp
;
auto
d_size
=
tmp
.
data_ptr
<
int64_t
>
();
auto
h_size
=
(
int64_t
*
)
malloc
(
sizeof
(
int64_t
));
auto
h_size
=
(
int64_t
*
)
malloc
(
sizeof
(
int64_t
));
cudaMemcpy
(
h_size
,
d_size
,
sizeof
(
int64_t
),
cudaMemcpyDeviceToHost
);
cudaMemcpy
(
h_size
,
d_size
,
sizeof
(
int64_t
),
cudaMemcpyDeviceToHost
);
sizes
[
dim
]
=
1
+
*
h_size
;
sizes
[
dim
]
=
1
+
*
h_size
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment