Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
2693efc9
Commit
2693efc9
authored
Nov 23, 2020
by
rusty1s
Browse files
fix segment coo indexing
parent
4c4a2e6c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
1 deletion
+7
-1
csrc/cuda/segment_coo_cuda.cu
csrc/cuda/segment_coo_cuda.cu
+7
-1
No files found.
csrc/cuda/segment_coo_cuda.cu
View file @
2693efc9
...
@@ -85,7 +85,7 @@ __global__ void segment_coo_broadcast_kernel(
...
@@ -85,7 +85,7 @@ __global__ void segment_coo_broadcast_kernel(
int
D
=
index_info
.
sizes
[
index_info
.
dims
-
1
];
int
D
=
index_info
.
sizes
[
index_info
.
dims
-
1
];
int
E_1
=
E
/
D
;
int
E_1
=
E
/
D
;
int
E_2
=
D
+
TB
-
(
D
%
TB
);
int
E_2
=
(
D
-
1
)
+
TB
-
(
(
D
-
1
)
%
TB
);
int
row_idx
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
int
row_idx
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
int
col_idx
=
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
int
col_idx
=
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
...
@@ -215,6 +215,12 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
...
@@ -215,6 +215,12 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
auto
N
=
out
.
size
(
dim
);
auto
N
=
out
.
size
(
dim
);
auto
avg_len
=
(
float
)
E_2
/
(
float
)
N
;
auto
avg_len
=
(
float
)
E_2
/
(
float
)
N
;
std
::
cout
<<
"E "
<<
E
<<
std
::
endl
;
std
::
cout
<<
"E2 "
<<
E_2
<<
std
::
endl
;
std
::
cout
<<
"E1 "
<<
E_1
<<
std
::
endl
;
std
::
cout
<<
"K "
<<
K
<<
std
::
endl
;
std
::
cout
<<
"N "
<<
N
<<
std
::
endl
;
auto
index_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int
>
(
index
);
auto
index_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int
>
(
index
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"segment_coo_kernel"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"segment_coo_kernel"
,
[
&
]
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment