Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0c4d99b6
Commit
0c4d99b6
authored
Nov 20, 2023
by
charlie
Browse files
More progress
parent
9280150b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
209 additions
and
61 deletions
+209
-61
src/common.cpp
src/common.cpp
+8
-14
src/eliminate_data_type.cpp
src/eliminate_data_type.cpp
+2
-1
src/include/migraphx/op/dot.hpp
src/include/migraphx/op/dot.hpp
+14
-5
src/simplify_dyn_ops.cpp
src/simplify_dyn_ops.cpp
+37
-1
src/split_single_dyn_dim.cpp
src/split_single_dyn_dim.cpp
+81
-40
test/split_single_dyn_dim_test.cpp
test/split_single_dyn_dim_test.cpp
+67
-0
No files found.
src/common.cpp
View file @
0c4d99b6
...
@@ -169,24 +169,18 @@ insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref>
...
@@ -169,24 +169,18 @@ insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref>
auto
c_dyn_dims
=
compute_common_dyn_dims
(
input_shapes
);
auto
c_dyn_dims
=
compute_common_dyn_dims
(
input_shapes
);
auto
s0
=
inputs
[
0
]
->
get_shape
();
auto
s0
=
inputs
[
0
]
->
get_shape
();
if
(
not
s0
.
dynamic
()
or
s0
.
dyn_dims
()
!=
c_dyn_dims
)
// changed to always add the multibroadcast to handle the cases from split_single_dyn_dim
{
inputs
[
0
]
=
m
.
insert_instruction
(
inputs
[
0
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
,
{{
"out_dyn_dims"
,
to_value
(
c_dyn_dims
)}}),
inputs
);
ins
,
make_op
(
"multibroadcast"
,
{{
"out_dyn_dims"
,
to_value
(
c_dyn_dims
)}}),
inputs
);
}
std
::
transform
(
inputs
.
begin
()
+
1
,
inputs
.
end
(),
inputs
.
begin
()
+
1
,
[
&
](
auto
input
)
{
std
::
transform
(
inputs
.
begin
()
+
1
,
inputs
.
end
(),
inputs
.
begin
()
+
1
,
[
&
](
auto
input
)
{
// uses previous input to avoid recalculating the common shape from the
// uses previous input to avoid recalculating the common shape from the
// full set of input shapes at runtime
// full set of input shapes at runtime
auto
s
=
input
->
get_shape
();
auto
s
=
input
->
get_shape
();
if
(
not
s
.
dynamic
()
or
s
.
dyn_dims
()
!=
c_dyn_dims
)
return
m
.
insert_instruction
(
{
ins
,
return
m
.
insert_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_dyn_dims"
,
to_value
(
c_dyn_dims
)}}),
ins
,
input
,
make_op
(
"multibroadcast"
,
{{
"out_dyn_dims"
,
to_value
(
c_dyn_dims
)}}),
inputs
[
0
]);
input
,
inputs
[
0
]);
}
return
input
;
});
});
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
if
(
input
->
get_shape
().
type
()
!=
c_type
)
if
(
input
->
get_shape
().
type
()
!=
c_type
)
...
...
src/eliminate_data_type.cpp
View file @
0c4d99b6
...
@@ -41,7 +41,8 @@ void eliminate_data_type::apply(module& m) const
...
@@ -41,7 +41,8 @@ void eliminate_data_type::apply(module& m) const
"nonmaxsuppression"
,
"nonmaxsuppression"
,
"scatternd_add"
,
"scatternd_add"
,
"scatternd_mul"
,
"scatternd_mul"
,
"scatternd_none"
};
"scatternd_none"
,
"select_module"
};
for
(
auto
ins
:
iterator_for
(
m
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
if
(
ins
->
name
()[
0
]
==
'@'
)
if
(
ins
->
name
()[
0
]
==
'@'
)
...
...
src/include/migraphx/op/dot.hpp
View file @
0c4d99b6
...
@@ -34,6 +34,9 @@ namespace migraphx {
...
@@ -34,6 +34,9 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
namespace
op
{
/**
* Matrix multiplication of two tensors.
*/
struct
dot
struct
dot
{
{
std
::
string
name
()
const
{
return
"dot"
;
}
std
::
string
name
()
const
{
return
"dot"
;
}
...
@@ -50,22 +53,26 @@ struct dot
...
@@ -50,22 +53,26 @@ struct dot
}
}
if
(
a
.
dynamic
()
or
b
.
dynamic
())
if
(
a
.
dynamic
()
or
b
.
dynamic
())
{
{
auto
dd_within_range
=
[
&
](
shape
::
dynamic_dimension
x
,
shape
::
dynamic_dimension
y
)
{
return
(
x
.
min
>=
y
.
min
and
x
.
max
<=
y
.
max
);
};
auto
s0
=
a
.
to_dynamic
();
auto
s0
=
a
.
to_dynamic
();
auto
s1
=
b
.
to_dynamic
();
auto
s1
=
b
.
to_dynamic
();
if
(
not
std
::
equal
(
s0
.
dyn_dims
().
rbegin
()
+
2
,
if
(
not
std
::
equal
(
s0
.
dyn_dims
().
rbegin
()
+
2
,
s0
.
dyn_dims
().
rend
(),
s0
.
dyn_dims
().
rend
(),
s1
.
dyn_dims
().
rbegin
()
+
2
,
s1
.
dyn_dims
().
rbegin
()
+
2
,
s1
.
dyn_dims
().
rend
()))
s1
.
dyn_dims
().
rend
(),
[
&
](
auto
x
,
auto
y
)
{
return
(
dd_within_range
(
x
,
y
)
or
dd_within_range
(
y
,
x
));
}))
{
{
MIGRAPHX_THROW
(
"DOT: dynamic outer dimensions of A and B mismatch: {"
+
MIGRAPHX_THROW
(
"DOT: dynamic outer dimensions of A and B mismatch or not within "
"dynamic_dimension range: {"
+
to_string_range
(
s0
.
dyn_dims
())
+
"} x {"
+
to_string_range
(
s0
.
dyn_dims
())
+
"} x {"
+
to_string_range
(
s1
.
dyn_dims
())
+
"}"
);
to_string_range
(
s1
.
dyn_dims
())
+
"}"
);
}
}
std
::
size_t
dim_0
=
s0
.
ndim
()
-
2
;
std
::
size_t
dim_0
=
s0
.
ndim
()
-
2
;
std
::
size_t
dim_1
=
s0
.
ndim
()
-
1
;
std
::
size_t
dim_1
=
s0
.
ndim
()
-
1
;
auto
dd_within_range
=
[
&
](
shape
::
dynamic_dimension
x
,
shape
::
dynamic_dimension
y
)
{
return
(
x
.
min
>=
y
.
min
and
x
.
max
<=
y
.
max
);
};
auto
x
=
s0
.
dyn_dims
()[
dim_1
];
auto
x
=
s0
.
dyn_dims
()[
dim_1
];
auto
y
=
s1
.
dyn_dims
()[
dim_0
];
auto
y
=
s1
.
dyn_dims
()[
dim_0
];
if
(
not
dd_within_range
(
x
,
y
)
and
not
dd_within_range
(
y
,
x
))
if
(
not
dd_within_range
(
x
,
y
)
and
not
dd_within_range
(
y
,
x
))
...
@@ -74,6 +81,8 @@ struct dot
...
@@ -74,6 +81,8 @@ struct dot
to_string_range
(
s0
.
dyn_dims
())
+
"} x {"
+
to_string_range
(
s0
.
dyn_dims
())
+
"} x {"
+
to_string_range
(
s1
.
dyn_dims
())
+
"}"
);
to_string_range
(
s1
.
dyn_dims
())
+
"}"
);
}
}
// NOTE could make this compute_shape more precise by using outer dimensions of the
// shape that's dd_within_range. currently this just uses the outer dimensions of s0.
auto
out_dyn_dims
=
s0
.
dyn_dims
();
auto
out_dyn_dims
=
s0
.
dyn_dims
();
out_dyn_dims
[
dim_1
]
=
s1
.
dyn_dims
()[
dim_1
];
out_dyn_dims
[
dim_1
]
=
s1
.
dyn_dims
()[
dim_1
];
return
{
t
,
out_dyn_dims
};
return
{
t
,
out_dyn_dims
};
...
...
src/simplify_dyn_ops.cpp
View file @
0c4d99b6
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include <migraphx/matcher.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/common.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -318,6 +319,40 @@ struct find_const_alloc_fill
...
@@ -318,6 +319,40 @@ struct find_const_alloc_fill
}
}
};
};
/**
* Simplify dot_broadcast instructions with two static shaped arguments
* From:
* dot_broadcast(static_shape_arg, static_shape_arg)
* To:
* multibroadcast(static_shape_arg); output_lens = static_dot_broadcasted_shape
*/
struct
find_static_dot_broadcast
{
auto
matcher
()
const
{
return
match
::
name
(
"dot_broadcast"
)(
match
::
arg
(
0
)(
match
::
static_shape
()),
match
::
arg
(
1
)(
match
::
static_shape
()));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
dot_broadcast_ins
=
mr
.
result
;
auto
inputs
=
dot_broadcast_ins
->
inputs
();
auto
s0
=
inputs
.
at
(
0
)
->
get_shape
();
auto
s1
=
inputs
.
at
(
1
)
->
get_shape
();
auto
l0_it
=
s0
.
lens
().
begin
()
+
s0
.
ndim
()
-
2
;
std
::
vector
<
std
::
size_t
>
l0_broadcasted_lens
(
s0
.
lens
().
begin
(),
l0_it
);
auto
l1_it
=
s1
.
lens
().
begin
()
+
s1
.
ndim
()
-
2
;
std
::
vector
<
std
::
size_t
>
l1_broadcasted_lens
(
s1
.
lens
().
begin
(),
l1_it
);
auto
output_lens
=
compute_broadcasted_lens
(
l0_broadcasted_lens
,
l1_broadcasted_lens
);
output_lens
.
insert
(
output_lens
.
end
(),
l0_it
,
s0
.
lens
().
end
());
m
.
replace_instruction
(
dot_broadcast_ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
output_lens
}}),
inputs
.
at
(
0
));
}
};
void
simplify_dyn_ops
::
apply
(
module
&
m
)
const
void
simplify_dyn_ops
::
apply
(
module
&
m
)
const
{
{
match
::
find_matches
(
m
,
match
::
find_matches
(
m
,
...
@@ -327,7 +362,8 @@ void simplify_dyn_ops::apply(module& m) const
...
@@ -327,7 +362,8 @@ void simplify_dyn_ops::apply(module& m) const
find_const_2in_slice
{},
find_const_2in_slice
{},
find_const_3in_slice
{},
find_const_3in_slice
{},
find_const_4in_slice
{},
find_const_4in_slice
{},
find_const_alloc_fill
{});
find_const_alloc_fill
{},
find_static_dot_broadcast
{});
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/split_single_dyn_dim.cpp
View file @
0c4d99b6
...
@@ -36,36 +36,79 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -36,36 +36,79 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
dynamic_dimensions_check
struct
dynamic_dimensions_check
{
{
std
::
string
dyn_param_str
;
std
::
string
dyn_param_str
;
size_t
dyn_index
;
shape
::
dynamic_dimension
dd
;
size_t
min_dim
;
size_t
max_dim
;
};
};
optional
<
dynamic_dimensions_check
>
/**
has_one_dyn_dim
(
const
std
::
unordered_map
<
std
::
string
,
shape
>&
param_shapes
)
* Returns value if the parameters contain one non-fixed dynamic_dimension that is the same between
* all of the dynamic shape parameters.
* Returns the parameters and the dynamic dimension in a vector of dynamic_dimensions_check objects.
*/
optional
<
std
::
vector
<
dynamic_dimensions_check
>>
has_one_unique_dyn_dim
(
const
std
::
unordered_map
<
std
::
string
,
shape
>&
param_shapes
)
{
{
// True if parameters contain exactly one dynamic shape with exactly one non-fixed
if
(
param_shapes
.
empty
())
// dynamic_dimension.
{
return
std
::
nullopt
;
}
auto
is_dynamic
=
[](
const
auto
&
p
)
{
return
p
.
second
.
dynamic
();
};
auto
is_dynamic
=
[](
const
auto
&
p
)
{
return
p
.
second
.
dynamic
();
};
auto
ps_it
=
std
::
find_if
(
param_shapes
.
begin
(),
param_shapes
.
end
(),
is_dynamic
);
std
::
vector
<
std
::
decay_t
<
decltype
(
param_shapes
)
>::
value_type
>
dyn_params
{};
if
(
ps_it
==
param_shapes
.
end
())
std
::
copy_if
(
param_shapes
.
begin
(),
param_shapes
.
end
(),
std
::
back_inserter
(
dyn_params
),
is_dynamic
);
if
(
dyn_params
.
empty
())
return
std
::
nullopt
;
return
std
::
nullopt
;
// Check if there is a second dynamic parameter
std
::
vector
<
dynamic_dimensions_check
>
ret
{};
if
(
std
::
any_of
(
std
::
next
(
ps_it
),
param_shapes
.
end
(),
is_dynamic
))
// get non-fixed dynamic_dimension from all parameters
for
(
const
auto
&
param
:
dyn_params
)
{
const
auto
&
dds
=
param
.
second
.
dyn_dims
();
int
num_non_fixed
=
0
;
for
(
auto
dds_it
=
dds
.
begin
();
dds_it
!=
dds
.
end
();
++
dds_it
)
{
if
(
not
dds_it
->
is_fixed
())
{
num_non_fixed
+=
1
;
// catch more than one non-fixed dynamic_dimension
if
(
num_non_fixed
>
1
)
{
return
std
::
nullopt
;
}
ret
.
push_back
(
dynamic_dimensions_check
{
param
.
first
,
*
dds_it
});
}
}
}
if
(
ret
.
empty
())
{
return
std
::
nullopt
;
return
std
::
nullopt
;
const
auto
&
dds
=
ps_it
->
second
.
dyn_dims
();
}
// check all the same dynamic_dimension
bool
same_dd
=
std
::
all_of
(
ret
.
begin
()
+
1
,
ret
.
end
(),
[
&
](
auto
ddc
)
{
return
ddc
.
dd
==
ret
.
at
(
0
).
dd
;
});
if
(
same_dd
)
{
return
ret
;
}
return
std
::
nullopt
;
}
auto
is_non_fixed
=
[](
const
auto
&
dd
)
{
return
not
dd
.
is_fixed
();
};
/**
auto
dds_it
=
std
::
find_if
(
dds
.
begin
(),
dds
.
end
(),
is_non_fixed
);
* Check the parameters in std::vector<dynamic_dimensions_check> object to see if any of the
if
(
dds_it
==
dds
.
end
())
* parameters outputs to a select_module operator.
return
std
::
nullopt
;
*/
// Check if there is a second non-fixed dynamic_dimension
bool
any_sm_next
(
module_ref
mm
,
const
std
::
vector
<
dynamic_dimensions_check
>&
ddcs
)
if
(
std
::
any_of
(
std
::
next
(
dds_it
),
dds
.
end
(),
is_non_fixed
))
{
return
std
::
nullopt
;
for
(
const
auto
&
ddc
:
ddcs
)
return
dynamic_dimensions_check
{
ps_it
->
first
,
{
static_cast
<
std
::
size_t
>
(
std
::
distance
(
dds
.
begin
(),
dds_it
)),
auto
p_outputs
=
mm
->
get_parameter
(
ddc
.
dyn_param_str
)
->
outputs
();
dds_it
->
min
,
bool
is_sm_next
=
std
::
any_of
(
p_outputs
.
cbegin
(),
p_outputs
.
cend
(),
[](
auto
ins
)
{
dds_it
->
max
};
return
ins
->
name
()
==
"select_module"
;
});
if
(
is_sm_next
)
{
return
true
;
};
}
return
false
;
}
}
/**
/**
...
@@ -79,29 +122,27 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
...
@@ -79,29 +122,27 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
module_ref
mm
=
&
mpm
.
get_module
();
module_ref
mm
=
&
mpm
.
get_module
();
auto
param_names
=
mm
->
get_parameter_names
();
auto
param_names
=
mm
->
get_parameter_names
();
auto
param_shapes
=
mm
->
get_parameter_shapes
();
auto
param_shapes
=
mm
->
get_parameter_shapes
();
optional
<
dynamic_dimensions_check
>
dd_check
=
has_one_dyn_dim
(
param_shapes
);
optional
<
std
::
vector
<
dynamic_dimensions_check
>>
dd_check_vec
=
auto
any_sm_next
=
[
&
](
auto
ddc
)
{
has_one_unique_dyn_dim
(
param_shapes
);
auto
p_outputs
=
mm
->
get_parameter
(
ddc
->
dyn_param_str
)
->
outputs
();
if
(
dd_check_vec
.
has_value
()
and
not
any_sm_next
(
mm
,
dd_check_vec
.
value
()))
return
std
::
any_of
(
p_outputs
.
cbegin
(),
p_outputs
.
cend
(),
[](
auto
ins
)
{
return
ins
->
name
()
==
"select_module"
;
});
};
if
(
dd_check
.
has_value
()
and
not
any_sm_next
(
dd_check
))
{
{
const
auto
&
dyn_param
=
mm
->
get_parameter
(
dd_check
->
dyn_param_str
);
// all dynamic dimension objects should be the same for all parameters in dd_check_vec
auto
dyn_param_shape
=
mm
->
get_parameter_shape
(
dd_check
->
dyn_param_str
);
auto
dyn_dim
=
dd_check_vec
->
at
(
0
).
dd
;
std
::
vector
<
module_ref
>
submodules
;
// create submodules for each dimension size
// create submodules for each dimension size
for
(
size_t
dim_size
:
migraphx
::
range
(
dd_check
->
min_dim
,
dd_check
->
max_dim
+
1
))
std
::
vector
<
module_ref
>
submodules
;
for
(
size_t
dim_size
:
migraphx
::
range
(
dyn_dim
.
min
,
dyn_dim
.
max
+
1
))
{
{
auto
*
submod
=
mpm
.
create_module
(
"dim_"
+
std
::
to_string
(
dim_size
));
auto
*
submod
=
mpm
.
create_module
(
"dim_"
+
std
::
to_string
(
dim_size
));
// instruction map for new static shaped submodule parameters
// instruction map for new static shaped submodule parameters
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
;
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
;
// create static shape using dim_size
for
(
const
auto
&
dd_check
:
dd_check_vec
.
value
())
auto
static_lens
=
dyn_param_shape
.
max_lens
();
{
static_lens
.
at
(
dd_check
->
dyn_index
)
=
dim_size
;
// create static shape using dim_size
map_ins
[
dyn_param
]
=
submod
->
add_parameter
(
const
auto
&
dyn_param
=
mm
->
get_parameter
(
dd_check
.
dyn_param_str
);
dd_check
->
dyn_param_str
,
migraphx
::
shape
{
dyn_param_shape
.
type
(),
static_lens
});
auto
dyn_param_shape
=
mm
->
get_parameter_shape
(
dd_check
.
dyn_param_str
);
auto
static_shape
=
dyn_param_shape
.
to_static
(
dim_size
);
map_ins
[
dyn_param
]
=
submod
->
add_parameter
(
dd_check
.
dyn_param_str
,
static_shape
);
}
auto
outputs
=
submod
->
add_instructions
(
mm
,
map_ins
);
auto
outputs
=
submod
->
add_instructions
(
mm
,
map_ins
);
submod
->
add_return
({
outputs
});
submod
->
add_return
({
outputs
});
submodules
.
push_back
(
submod
);
submodules
.
push_back
(
submod
);
...
...
test/split_single_dyn_dim_test.cpp
View file @
0c4d99b6
...
@@ -94,6 +94,73 @@ TEST_CASE(dynamic_batch)
...
@@ -94,6 +94,73 @@ TEST_CASE(dynamic_batch)
EXPECT
(
p0
==
p1
);
EXPECT
(
p0
==
p1
);
}
}
TEST_CASE
(
dynamic_batch_multiple_input
)
{
migraphx
::
program
p0
;
{
auto
*
mm0
=
p0
.
get_main_module
();
// create batch submodules
auto
create_submodule
=
[
&
](
std
::
size_t
batch_size
,
const
std
::
string
&
module_name
)
{
auto
*
submod
=
p0
.
create_module
(
module_name
);
migraphx
::
shape
sm_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
4
}};
auto
sm_input0
=
submod
->
add_parameter
(
"data0"
,
sm_shape
);
auto
sm_input1
=
submod
->
add_parameter
(
"data1"
,
sm_shape
);
auto
sm_input2
=
submod
->
add_parameter
(
"data2"
,
sm_shape
);
migraphx
::
shape
lit_s
{
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
}}};
auto
literal_ins
=
submod
->
add_literal
(
migraphx
::
literal
{
lit_s
,
{
6
}});
auto
broadcast_lit
=
submod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
),
literal_ins
,
sm_input0
);
auto
add_ins0
=
submod
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
sm_input0
,
broadcast_lit
);
auto
add_ins1
=
submod
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
add_ins0
,
sm_input1
);
auto
add_ins2
=
submod
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
add_ins1
,
sm_input2
);
submod
->
add_return
({
add_ins2
});
return
submod
;
};
auto
*
dim1
=
create_submodule
(
1
,
"dim_1"
);
auto
*
dim2
=
create_submodule
(
2
,
"dim_2"
);
auto
*
dim3
=
create_submodule
(
3
,
"dim_3"
);
auto
*
dim4
=
create_submodule
(
4
,
"dim_4"
);
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
4
,
4
}}};
auto
input0
=
mm0
->
add_parameter
(
"data0"
,
s
);
auto
input1
=
mm0
->
add_parameter
(
"data1"
,
s
);
auto
input2
=
mm0
->
add_parameter
(
"data2"
,
s
);
std
::
vector
<
migraphx
::
shape
>
sub_shapes
=
{};
sub_shapes
.
push_back
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
4
,
4
}}});
migraphx
::
shape
out_attr
=
migraphx
::
shape
{
sub_shapes
};
auto
sm_ins
=
mm0
->
add_instruction
(
migraphx
::
make_op
(
"select_module"
,
{{
"output_dyn_shapes"
,
migraphx
::
to_value
(
out_attr
)}}),
{
input0
,
input1
,
input2
},
{
dim1
,
dim2
,
dim3
,
dim4
});
auto
ret
=
mm0
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
0
}}),
sm_ins
);
mm0
->
add_return
({
ret
});
}
migraphx
::
program
p1
;
{
auto
*
mm1
=
p1
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
4
,
4
}}};
auto
input0
=
mm1
->
add_parameter
(
"data0"
,
s
);
auto
input1
=
mm1
->
add_parameter
(
"data1"
,
s
);
auto
input2
=
mm1
->
add_parameter
(
"data2"
,
s
);
migraphx
::
shape
lit_s
{
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
}}};
auto
literal_ins
=
mm1
->
add_literal
(
migraphx
::
literal
{
lit_s
,
{
6
}});
auto
broadcast_lit
=
mm1
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
),
literal_ins
,
input0
);
auto
add_ins0
=
mm1
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
input0
,
broadcast_lit
);
auto
add_ins1
=
mm1
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
add_ins0
,
input1
);
auto
add_ins2
=
mm1
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
add_ins1
,
input2
);
mm1
->
add_return
({
add_ins2
});
}
run_pass
(
p1
);
EXPECT
(
p0
==
p1
);
}
TEST_CASE
(
multiple_outputs
)
TEST_CASE
(
multiple_outputs
)
{
{
migraphx
::
program
p0
;
migraphx
::
program
p0
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment