Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
442581b9
Commit
442581b9
authored
Aug 25, 2018
by
Paul
Browse files
Try to optimize broadcast add
parent
6ce07611
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
72 additions
and
6 deletions
+72
-6
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
+57
-6
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+15
-0
No files found.
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
View file @
442581b9
...
@@ -16,8 +16,9 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
...
@@ -16,8 +16,9 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
output_shape
=
result
.
get_shape
();
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
visit_tensor_size
(
output_shape
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
visit_tensor_size
(
output_shape
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
auto
data
=
pack
(
auto
data
=
std
::
make_pair
(
hip_tensor_descriptor
<
ndim
>
{
inputs
.
get_shape
()},
inputs
.
data
())...);
pack
(
std
::
make_pair
(
hip_tensor_descriptor
<
ndim
>
{
inputs
.
get_shape
()},
inputs
.
data
())...);
hip_tensor_descriptor
<
ndim
>
out_desc
(
output_shape
);
hip_tensor_descriptor
<
ndim
>
out_desc
(
output_shape
);
auto
*
outp
=
output
.
data
();
auto
*
outp
=
output
.
data
();
gs_launch
(
output_shape
.
elements
())([
=
](
auto
i
)
{
gs_launch
(
output_shape
.
elements
())([
=
](
auto
i
)
{
...
@@ -36,6 +37,34 @@ auto nary_nonstandard(argument result, Arguments... args)
...
@@ -36,6 +37,34 @@ auto nary_nonstandard(argument result, Arguments... args)
return
[
=
](
auto
f
)
{
return
nary_nonstandard_impl
(
f
,
result
,
args
...);
};
return
[
=
](
auto
f
)
{
return
nary_nonstandard_impl
(
f
,
result
,
args
...);
};
}
}
inline
auto
binary_broadcast
(
argument
result
,
argument
arg1
,
argument
arg2
)
{
return
[
=
](
auto
f
)
{
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
arg2
.
get_shape
();
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
std
::
find_if
(
b_shape
.
strides
().
begin
(),
b_shape
.
strides
().
end
(),
[](
auto
x
)
{
return
x
!=
0
;
}));
auto
bdim_len
=
b_shape
.
lens
()[
bdim
];
auto
outer_size
=
std
::
accumulate
(
output_shape
.
lens
().
begin
(),
output_shape
.
lens
().
begin
()
+
bdim
+
1
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
auto
inner_size
=
std
::
accumulate
(
output_shape
.
lens
().
begin
()
+
bdim
+
1
,
output_shape
.
lens
().
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
visit_all
(
result
,
arg1
,
arg2
)([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
auto
*
xp
=
input1
.
data
();
auto
*
yp
=
input2
.
data
();
auto
*
outp
=
output
.
data
();
gs_launch
(
outer_size
)(
[
=
](
auto
i
)
{
auto
*
outp2
=
outp
+
i
;
auto
*
xp2
=
xp
+
i
;
auto
b
=
yp
[
i
%
bdim_len
];
for
(
std
::
size_t
j
=
0
;
j
<
inner_size
;
j
++
)
{
outp2
[
j
]
=
f
(
xp2
[
j
],
b
);
}
});
});
};
}
template
<
class
...
Arguments
>
template
<
class
...
Arguments
>
auto
nary_standard
(
argument
result
,
Arguments
...
args
)
auto
nary_standard
(
argument
result
,
Arguments
...
args
)
{
{
...
@@ -52,13 +81,12 @@ auto nary_standard(argument result, Arguments... args)
...
@@ -52,13 +81,12 @@ auto nary_standard(argument result, Arguments... args)
}
}
template
<
class
...
Arguments
>
template
<
class
...
Arguments
>
auto
nary
(
argument
result
,
Arguments
...
args
)
auto
nary
_impl
(
argument
result
,
Arguments
...
args
)
{
{
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
bool
standard
=
all_of
({
args
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
});
bool
standard
=
all_of
({
args
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
});
bool
packed
=
all_of
({
args
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
packed
();
});
bool
packed
=
all_of
({
args
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
packed
();
});
bool
same_shapes
=
bool
same_shapes
=
all_of
({
args
.
get_shape
()...},
[
&
](
const
shape
&
s
)
{
return
s
==
result
.
get_shape
();
});
all_of
({
args
.
get_shape
()...},
[
&
](
const
shape
&
s
)
{
return
s
==
result
.
get_shape
();
});
if
(
standard
or
(
packed
and
same_shapes
))
if
(
standard
or
(
packed
and
same_shapes
))
nary_standard
(
result
,
args
...)(
f
);
nary_standard
(
result
,
args
...)(
f
);
else
else
...
@@ -67,6 +95,29 @@ auto nary(argument result, Arguments... args)
...
@@ -67,6 +95,29 @@ auto nary(argument result, Arguments... args)
};
};
}
}
template
<
class
...
Arguments
>
auto
nary
(
argument
result
,
Arguments
...
args
)
{
return
nary_impl
(
result
,
args
...);
}
#if 0
inline auto nary(argument result, argument arg1, argument arg2)
{
return [=](auto f) {
// TODO: Check for one broadcast stride
// TODO: Check result and arg1 shape is the same
if(arg1.get_shape().standard() and arg2.get_shape().broadcasted() and std::count_if(arg2.get_shape().strides().begin(), arg2.get_shape().strides().end(), [](auto x) { return x != 0; }) == 1)
{
binary_broadcast(result, arg1, arg2)(f);
}
else
{
nary_impl(result, arg1, arg2)(f);
}
};
}
#endif
}
// namespace device
}
// namespace device
}
// namespace gpu
}
// namespace gpu
}
// namespace migraph
}
// namespace migraph
...
...
test/gpu/miopen.cpp
View file @
442581b9
...
@@ -210,6 +210,20 @@ struct test_add_broadcast
...
@@ -210,6 +210,20 @@ struct test_add_broadcast
}
}
};
};
struct
test_add_broadcast2
{
migraph
::
program
create_program
()
const
{
migraph
::
program
p
;
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
3
}};
auto
x
=
p
.
add_parameter
(
"x"
,
{
migraph
::
shape
::
float_type
,
{
2
,
3
,
4
}});
auto
y
=
p
.
add_parameter
(
"y"
,
{
migraph
::
shape
::
float_type
,
{
3
}});
auto
by
=
p
.
add_instruction
(
migraph
::
broadcast
{
1
},
x
,
y
);
p
.
add_instruction
(
migraph
::
add
{},
x
,
by
);
return
p
;
}
};
struct
test_conv_relu
struct
test_conv_relu
{
{
migraph
::
program
create_program
()
const
migraph
::
program
create_program
()
const
...
@@ -418,6 +432,7 @@ int main()
...
@@ -418,6 +432,7 @@ int main()
{
{
verify_program
<
test_add
>
();
verify_program
<
test_add
>
();
verify_program
<
test_add_broadcast
>
();
verify_program
<
test_add_broadcast
>
();
verify_program
<
test_add_broadcast2
>
();
verify_program
<
test_conv_relu
>
();
verify_program
<
test_conv_relu
>
();
verify_program
<
test_add_relu
>
();
verify_program
<
test_add_relu
>
();
verify_program
<
test_conv_pooling
>
();
verify_program
<
test_conv_pooling
>
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment