Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d37f0677
Commit
d37f0677
authored
Aug 25, 2018
by
Paul
Browse files
Formatting
parent
442581b9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
30 additions
and
22 deletions
+30
-22
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
+30
-22
No files found.
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
View file @
d37f0677
...
@@ -16,9 +16,8 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
...
@@ -16,9 +16,8 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
output_shape
=
result
.
get_shape
();
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
visit_tensor_size
(
output_shape
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
visit_tensor_size
(
output_shape
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
auto
data
=
auto
data
=
pack
(
pack
(
std
::
make_pair
(
hip_tensor_descriptor
<
ndim
>
{
inputs
.
get_shape
()},
std
::
make_pair
(
hip_tensor_descriptor
<
ndim
>
{
inputs
.
get_shape
()},
inputs
.
data
())...);
inputs
.
data
())...);
hip_tensor_descriptor
<
ndim
>
out_desc
(
output_shape
);
hip_tensor_descriptor
<
ndim
>
out_desc
(
output_shape
);
auto
*
outp
=
output
.
data
();
auto
*
outp
=
output
.
data
();
gs_launch
(
output_shape
.
elements
())([
=
](
auto
i
)
{
gs_launch
(
output_shape
.
elements
())([
=
](
auto
i
)
{
...
@@ -42,21 +41,29 @@ inline auto binary_broadcast(argument result, argument arg1, argument arg2)
...
@@ -42,21 +41,29 @@ inline auto binary_broadcast(argument result, argument arg1, argument arg2)
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
arg2
.
get_shape
();
const
auto
&
b_shape
=
arg2
.
get_shape
();
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
std
::
find_if
(
b_shape
.
strides
().
begin
(),
b_shape
.
strides
().
end
(),
[](
auto
x
)
{
return
x
!=
0
;
}));
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
std
::
find_if
(
b_shape
.
strides
().
begin
(),
b_shape
.
strides
().
end
(),
[](
auto
x
)
{
return
x
!=
0
;
}));
auto
bdim_len
=
b_shape
.
lens
()[
bdim
];
auto
bdim_len
=
b_shape
.
lens
()[
bdim
];
auto
outer_size
=
std
::
accumulate
(
output_shape
.
lens
().
begin
(),
output_shape
.
lens
().
begin
()
+
bdim
+
1
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
auto
outer_size
=
std
::
accumulate
(
output_shape
.
lens
().
begin
(),
auto
inner_size
=
std
::
accumulate
(
output_shape
.
lens
().
begin
()
+
bdim
+
1
,
output_shape
.
lens
().
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
output_shape
.
lens
().
begin
()
+
bdim
+
1
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
auto
inner_size
=
std
::
accumulate
(
output_shape
.
lens
().
begin
()
+
bdim
+
1
,
output_shape
.
lens
().
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
visit_all
(
result
,
arg1
,
arg2
)([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
visit_all
(
result
,
arg1
,
arg2
)([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
auto
*
xp
=
input1
.
data
();
auto
*
xp
=
input1
.
data
();
auto
*
yp
=
input2
.
data
();
auto
*
yp
=
input2
.
data
();
auto
*
outp
=
output
.
data
();
auto
*
outp
=
output
.
data
();
gs_launch
(
outer_size
)(
gs_launch
(
outer_size
)([
=
](
auto
i
)
{
[
=
](
auto
i
)
{
auto
*
outp2
=
outp
+
i
;
auto
*
outp2
=
outp
+
i
;
auto
*
xp2
=
xp
+
i
;
auto
*
xp2
=
xp
+
i
;
auto
b
=
yp
[
i
%
bdim_len
];
auto
b
=
yp
[
i
%
bdim_len
];
for
(
std
::
size_t
j
=
0
;
j
<
inner_size
;
j
++
)
for
(
std
::
size_t
j
=
0
;
j
<
inner_size
;
j
++
)
{
{
outp2
[
j
]
=
f
(
xp2
[
j
],
b
);
outp2
[
j
]
=
f
(
xp2
[
j
],
b
);
}
}
...
@@ -86,7 +93,8 @@ auto nary_impl(argument result, Arguments... args)
...
@@ -86,7 +93,8 @@ auto nary_impl(argument result, Arguments... args)
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
bool
standard
=
all_of
({
args
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
});
bool
standard
=
all_of
({
args
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
});
bool
packed
=
all_of
({
args
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
packed
();
});
bool
packed
=
all_of
({
args
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
packed
();
});
bool
same_shapes
=
all_of
({
args
.
get_shape
()...},
[
&
](
const
shape
&
s
)
{
return
s
==
result
.
get_shape
();
});
bool
same_shapes
=
all_of
({
args
.
get_shape
()...},
[
&
](
const
shape
&
s
)
{
return
s
==
result
.
get_shape
();
});
if
(
standard
or
(
packed
and
same_shapes
))
if
(
standard
or
(
packed
and
same_shapes
))
nary_standard
(
result
,
args
...)(
f
);
nary_standard
(
result
,
args
...)(
f
);
else
else
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment