Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c539a7b0
Commit
c539a7b0
authored
Oct 18, 2022
by
charlie
Browse files
Refactor into precomputing dyn output shape
also adding limitations on broadcasting dynamic shapes
parent
5fc6afe6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
125 additions
and
48 deletions
+125
-48
src/common.cpp
src/common.cpp
+70
-41
src/include/migraphx/common.hpp
src/include/migraphx/common.hpp
+0
-3
src/include/migraphx/op/binary.hpp
src/include/migraphx/op/binary.hpp
+12
-3
src/include/migraphx/op/multibroadcast.hpp
src/include/migraphx/op/multibroadcast.hpp
+43
-1
No files found.
src/common.cpp
View file @
c539a7b0
...
...
@@ -66,33 +66,50 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
return
out_lens
;
}
// Handling opt dyn_dims calculation
std
::
vector
<
std
::
size_t
>
compute_broadcasted_opt_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
s1
)
std
::
vector
<
shape
::
dynamic_dimension
>
compute_broadcasted_dyn_dims
(
shape
s0
,
shape
s1
)
{
if
(
s0
==
s1
)
return
s0
;
if
(
s0
.
size
()
>
s1
.
size
())
s0
.
swap
(
s1
);
std
::
vector
<
std
::
size_t
>
out_lens
(
s1
);
auto
offset
=
s1
.
size
()
-
s0
.
size
();
std
::
transform
(
s0
.
begin
(),
s0
.
end
(),
s1
.
begin
()
+
offset
,
out_lens
.
begin
()
+
offset
,
[
&
](
auto
a
,
auto
b
)
{
if
(
a
==
b
)
{
return
a
;
}
else
if
((
a
==
1
or
b
==
1
)
and
a
!=
0
and
b
!=
0
)
{
return
std
::
max
(
a
,
b
);
}
else
{
// if not matching nor 1, set to 0
return
static_cast
<
std
::
size_t
>
(
0
);
}
});
return
out_lens
;
if
(
s0
.
dynamic
()
or
s1
.
dynamic
())
{
// change both shapes to dynamic_dimension representation
if
(
not
s0
.
dynamic
())
s0
=
s0
.
to_dynamic
();
if
(
not
s1
.
dynamic
())
s1
=
s1
.
to_dynamic
();
if
(
s0
.
rank
()
>
s1
.
rank
())
{
std
::
swap
(
s0
,
s1
);
}
auto
offset
=
s1
.
rank
()
-
s0
.
rank
();
std
::
vector
<
shape
::
dynamic_dimension
>
out_dims
(
s1
.
dyn_dims
());
std
::
vector
<
shape
::
dynamic_dimension
>
one_dyn_dims
{{
1
,
1
,
0
},
{
1
,
1
,
1
}};
std
::
transform
(
s0
.
dyn_dims
().
cbegin
(),
s0
.
dyn_dims
().
cend
(),
s1
.
dyn_dims
().
cbegin
()
+
offset
,
out_dims
.
begin
()
+
offset
,
[
&
](
auto
a
,
auto
b
)
{
if
(
a
==
b
)
{
return
a
;
}
else
if
(
contains
(
one_dyn_dims
,
a
)
or
contains
(
one_dyn_dims
,
b
))
{
return
shape
::
dynamic_dimension
{
std
::
max
(
a
.
min
,
b
.
min
),
std
::
max
(
a
.
max
,
b
.
max
),
std
::
max
(
a
.
opt
,
b
.
opt
)};
}
else
{
MIGRAPHX_THROW
(
"COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {"
+
migraphx
::
to_string_range
(
s0
.
dyn_dims
())
+
"} and {"
+
migraphx
::
to_string_range
(
s1
.
dyn_dims
())
+
"} mismatch!"
);
}
});
}
else
{
MIGRAPHX_THROW
(
"COMPUTE_BROADCASTED_DYN_DIMS: given two static shapes"
);
}
}
// Compute the common (broadcasted) dimensions of a list of fixed shapes
...
...
@@ -149,24 +166,36 @@ instruction_ref insert_common_op(module& m,
if
(
std
::
any_of
(
inputs
.
cbegin
(),
inputs
.
cend
(),
[](
auto
input
)
{
return
input
->
get_shape
().
dynamic
();
}))
{
// currently only handles the binary case
if
(
inputs
.
size
()
!=
2
)
{
MIGRAPHX_THROW
(
"INSERT_COMMON_OP: not handled; "
+
migraphx
::
to_string
(
inputs
.
size
())
+
"inputs, only handle two inputs"
);
}
auto
c_type
=
compute_common_types
(
to_shapes
(
inputs
));
// broadcast all inputs combinations
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
a_input
)
{
const
auto
&
ori_input
=
a_input
;
// multibroadcast this input between every other input
std
::
for_each
(
inputs
.
cbegin
(),
inputs
.
cend
(),
[
&
](
auto
b_input
)
{
if
(
b_input
!=
ori_input
)
{
a_input
=
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
),
a_input
,
b_input
);
}
});
if
(
a_input
->
get_shape
().
type
()
!=
c_type
)
auto
c_dyn_dims
=
compute_broadcasted_dyn_dims
(
inputs
[
0
]
->
get_shape
(),
inputs
[
1
]
->
get_shape
());
// following should work for a static or dynamic shape
// TODO: compute_broadcasted_dyn_dims() is going to be called again in the multibroadcast
// compute_shape should figure out a way to get around recomputing that. Attribute in
// multibroadcast?
if
(
inputs
[
0
]
->
get_shape
().
dyn_dims
()
!=
c_dyn_dims
)
{
inputs
[
0
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
),
inputs
[
0
],
inputs
[
1
]);
}
if
(
inputs
[
1
]
->
get_shape
().
dyn_dims
()
!=
c_dyn_dims
)
{
inputs
[
1
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
),
inputs
[
1
],
inputs
[
0
]);
}
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
if
(
input
->
get_shape
().
type
()
!=
c_type
)
{
a_
input
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
c_type
}}),
a_
input
);
input
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
c_type
}}),
input
);
}
return
a_
input
;
return
input
;
});
}
else
...
...
src/include/migraphx/common.hpp
View file @
c539a7b0
...
...
@@ -37,9 +37,6 @@ struct operation;
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
s1
);
std
::
vector
<
std
::
size_t
>
compute_broadcasted_opt_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
s1
);
shape
common_shape
(
const
std
::
vector
<
shape
>&
shapes
);
instruction_ref
insert_common_op
(
module
&
m
,
...
...
src/include/migraphx/op/binary.hpp
View file @
c539a7b0
...
...
@@ -28,6 +28,7 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -63,7 +64,15 @@ struct binary : op_name<Derived>
check_shapes
{
inputs
,
static_cast
<
const
Derived
&>
(
*
this
)}.
has
(
2
).
same_type
().
same_dims
();
auto
s0
=
inputs
.
at
(
0
);
auto
s1
=
inputs
.
at
(
1
);
if
(
s0
==
s1
and
s0
.
packed
())
if
(
s0
.
dynamic
()
and
s1
.
dynamic
()
and
s0
==
s1
)
{
return
s0
;
}
else
if
(
s0
.
dynamic
()
or
s1
.
dynamic
())
{
MIGRAPHX_THROW
(
"BINARY: "
+
point_function
()
+
": fixed-dyn shape for inputs"
);
}
else
if
(
s0
==
s1
and
s0
.
packed
())
{
return
s0
;
}
...
...
@@ -81,9 +90,9 @@ struct binary : op_name<Derived>
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
std
::
transform
(
input1
.
begin
(),
input1
.
end
(),
...
...
src/include/migraphx/op/multibroadcast.hpp
View file @
c539a7b0
...
...
@@ -104,10 +104,51 @@ struct multibroadcast
if
(
s0
.
dynamic
()
and
s1
.
dynamic
())
{
// TODO handle both dynamic case
MIGRAPHX_THROW
(
"MULTIBROADCAST_2IN: two dynamic shape inputs not handled."
);
MIGRAPHX_THROW
(
"MULTIBROADCAST_2IN: not handled; two dynamic shape inputs not handled"
);
}
else
if
(
s0
.
dynamic
()
or
s1
.
dynamic
())
{
// only handles the case when broadcasting static shape to dynamic shape
// all the dimensions in the static shape must match to a fixed dimension in the
// dynamic shape or be 1
// TODO: handling the other possibilities
if
(
s1
.
dynamic
())
{
std
::
swap
(
s0
,
s1
);
}
auto
static_rank
=
s1
.
lens
().
size
();
auto
dyn_rank
=
s0
.
max_lens
().
size
();
if
(
static_rank
>
dyn_rank
)
{
MIGRAPHX_THROW
(
"MULTIBROADCAST_2IN: not handled; static shape has a higher "
"rank than dynamic shape"
);
}
return
s0
;
auto
offset
=
dyn_rank
-
static_rank
;
std
::
vector
<
shape
::
dynamic_dimension
>
out_dims
(
s0
.
dyn_dims
());
std
::
transform
(
s0
.
dyn_dims
().
begin
(),
s0
.
dyn_dims
().
end
(),
s1
.
lens
().
begin
()
+
offset
,
out_lens
.
begin
()
+
offset
,
[
&
](
auto
a
,
auto
b
)
{
if
(
a
==
b
)
{
return
a
;
}
else
if
((
a
==
1
or
b
==
1
)
and
a
!=
0
and
b
!=
0
)
{
return
std
::
max
(
a
,
b
);
}
else
{
// if not matching nor 1, set to 0
return
static_cast
<
std
::
size_t
>
(
0
);
}
});
/*
auto bcast_min_lens = compute_broadcasted_lens(s0.min_lens(), s1.min_lens());
auto bcast_max_lens = compute_broadcasted_lens(s0.max_lens(), s1.max_lens());
auto bcast_opt_lens = compute_broadcasted_opt_lens(s0.opt_lens(), s1.opt_lens());
...
...
@@ -115,6 +156,7 @@ struct multibroadcast
std::move(bcast_min_lens),
std::move(bcast_max_lens),
std::move(bcast_opt_lens)};
*/
}
else
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment