Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
3994f3ab
Commit
3994f3ab
authored
Jan 12, 2020
by
rusty1s
Browse files
faster segment coo cpu implementation
parent
4a5379c4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
19 deletions
+27
-19
cpu/segment.cpp
cpu/segment.cpp
+27
-19
No files found.
cpu/segment.cpp
View file @
3994f3ab
...
@@ -180,6 +180,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
...
@@ -180,6 +180,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
}
}
auto
E
=
index
.
numel
();
auto
E_1
=
index
.
numel
()
/
src
.
size
(
reduce_dim
);
auto
E_1
=
index
.
numel
()
/
src
.
size
(
reduce_dim
);
auto
E_2
=
src
.
size
(
reduce_dim
);
auto
E_2
=
src
.
size
(
reduce_dim
);
auto
K
=
src
.
numel
()
/
index
.
numel
();
auto
K
=
src
.
numel
()
/
index
.
numel
();
...
@@ -191,41 +192,48 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
...
@@ -191,41 +192,48 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
auto
src_data
=
src
.
DATA_PTR
<
scalar_t
>
();
auto
src_data
=
src
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
scalar_t
val
;
scalar_t
val
s
[
K
]
;
int64_t
idx
,
next_idx
,
row_start
,
arg
;
int64_t
idx
,
next_idx
,
row_start
,
arg
s
[
K
]
;
AT_DISPATCH_REDUCTION_TYPES
(
reduce
,
[
&
]
{
AT_DISPATCH_REDUCTION_TYPES
(
reduce
,
[
&
]
{
for
(
int
e_1
=
0
;
e_1
<
E_1
;
e_1
++
)
{
for
(
int
e_1
=
0
;
e_1
<
E_1
;
e_1
++
)
{
int
offset
=
IndexToOffset
<
int64_t
>::
get
(
e_1
*
E_2
,
index_info
);
int
offset
=
IndexToOffset
<
int64_t
>::
get
(
e_1
*
E_2
,
index_info
);
idx
=
index_info
.
data
[
offset
];
row_start
=
0
;
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
idx
=
index_info
.
data
[
offset
];
vals
[
k
]
=
out_data
[
e_1
*
N
*
K
+
k
];
row_start
=
0
;
}
val
=
out_data
[
e_1
*
N
*
K
+
k
];
for
(
int
e_2
=
0
;
e_2
<
E_2
;
e_2
++
)
{
for
(
int
e_2
=
0
;
e_2
<
E_2
;
e_2
++
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
Reducer
<
scalar_t
,
REDUCE
>::
update
(
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
src_data
[
e_1
*
E_2
*
K
+
e_2
*
K
+
k
],
&
arg
,
e_2
);
&
vals
[
k
],
src_data
[
e_1
*
E_2
*
K
+
e_2
*
K
+
k
],
&
args
[
k
],
e_2
);
}
if
(
e_2
==
E_2
-
1
)
{
if
(
e_2
==
E_2
-
1
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
Reducer
<
scalar_t
,
REDUCE
>::
write
(
Reducer
<
scalar_t
,
REDUCE
>::
write
(
out_data
+
e_1
*
N
*
K
+
idx
*
K
+
k
,
val
,
out_data
+
e_1
*
N
*
K
+
idx
*
K
+
k
,
val
s
[
k
]
,
arg_out_data
+
e_1
*
N
*
K
+
idx
*
K
+
k
,
arg
,
arg_out_data
+
e_1
*
N
*
K
+
idx
*
K
+
k
,
arg
s
[
k
]
,
e_2
+
1
-
row_start
);
e_2
+
1
-
row_start
);
}
else
{
}
next_idx
=
index_info
.
data
[
offset
+
(
e_2
+
1
)
*
stride
];
}
else
{
next_idx
=
index_info
.
data
[
offset
+
(
e_2
+
1
)
*
stride
];
if
(
idx
!=
next_idx
)
{
if
(
idx
!=
next_idx
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
Reducer
<
scalar_t
,
REDUCE
>::
write
(
Reducer
<
scalar_t
,
REDUCE
>::
write
(
out_data
+
e_1
*
N
*
K
+
idx
*
K
+
k
,
val
,
out_data
+
e_1
*
N
*
K
+
idx
*
K
+
k
,
val
s
[
k
]
,
arg_out_data
+
e_1
*
N
*
K
+
idx
*
K
+
k
,
arg
,
arg_out_data
+
e_1
*
N
*
K
+
idx
*
K
+
k
,
arg
s
[
k
]
,
e_2
+
1
-
row_start
);
e_2
+
1
-
row_start
);
row_start
=
e_2
+
1
;
vals
[
k
]
=
out_data
[
e_1
*
N
*
K
+
next_idx
*
K
+
k
];
val
=
out_data
[
e_1
*
N
*
K
+
next_idx
*
K
+
k
];
}
}
row_start
=
e_2
+
1
;
idx
=
next_idx
;
}
}
idx
=
next_idx
;
}
}
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment