Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
db1a954c
Commit
db1a954c
authored
Sep 15, 2022
by
Paul
Browse files
Merge branch 'develop' into fuse-dot-weights
parents
f92195d0
333860ce
Changes
153
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
514 additions
and
161 deletions
+514
-161
src/tf/parse_conv.cpp
src/tf/parse_conv.cpp
+1
-1
src/tf/parse_depthwiseconv.cpp
src/tf/parse_depthwiseconv.cpp
+1
-1
src/tf/parse_pooling.cpp
src/tf/parse_pooling.cpp
+1
-1
src/tf/parse_relu6.cpp
src/tf/parse_relu6.cpp
+3
-2
src/tf/tf_parser.cpp
src/tf/tf_parser.cpp
+2
-2
src/tmp_dir.cpp
src/tmp_dir.cpp
+1
-1
src/value.cpp
src/value.cpp
+4
-4
test/api/test_custom_op_gpu.cpp
test/api/test_custom_op_gpu.cpp
+1
-1
test/check_shapes_test.cpp
test/check_shapes_test.cpp
+1
-1
test/eval_test.cpp
test/eval_test.cpp
+1
-1
test/fpga/get_target_assignments.cpp
test/fpga/get_target_assignments.cpp
+1
-1
test/include/basic_ops.hpp
test/include/basic_ops.hpp
+5
-4
test/include/test.hpp
test/include/test.hpp
+1
-1
test/literal_test.cpp
test/literal_test.cpp
+2
-2
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+0
-1
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+2
-1
test/rewrite_gelu_test.cpp
test/rewrite_gelu_test.cpp
+125
-0
test/shape_test.cpp
test/shape_test.cpp
+2
-2
test/simplify_algebra_test.cpp
test/simplify_algebra_test.cpp
+259
-134
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+101
-0
No files found.
src/tf/parse_conv.cpp
View file @
db1a954c
...
...
@@ -100,7 +100,7 @@ struct parse_conv : op_parser<parse_conv>
{
MIGRAPHX_THROW
(
"padding should have 4 values"
);
}
if
(
padding
[
0
]
!=
padding
[
2
]
||
padding
[
1
]
!=
padding
[
3
])
if
(
padding
[
0
]
!=
padding
[
2
]
or
padding
[
1
]
!=
padding
[
3
])
{
MIGRAPHX_THROW
(
"migraphx does not support asymetric padding"
);
}
...
...
src/tf/parse_depthwiseconv.cpp
View file @
db1a954c
...
...
@@ -90,7 +90,7 @@ struct parse_depthwiseconv : op_parser<parse_depthwiseconv>
calculate_padding
(
0
,
pads
,
input_dims
[
2
],
op
.
stride
[
0
],
op
.
dilation
[
0
],
weight_h
);
calculate_padding
(
1
,
pads
,
input_dims
[
3
],
op
.
stride
[
1
],
op
.
dilation
[
1
],
weight_w
);
if
(
pads
[
0
]
!=
pads
[
2
]
||
pads
[
1
]
!=
pads
[
3
])
if
(
pads
[
0
]
!=
pads
[
2
]
or
pads
[
1
]
!=
pads
[
3
])
{
std
::
vector
<
int64_t
>
padding
=
{
0
,
0
,
pads
[
0
],
pads
[
1
],
0
,
0
,
pads
[
2
],
pads
[
3
]};
l0
=
info
.
add_instruction
(
migraphx
::
make_op
(
"pad"
,
{{
"pads"
,
padding
}}),
l0
);
...
...
src/tf/parse_pooling.cpp
View file @
db1a954c
...
...
@@ -42,7 +42,7 @@ struct parse_pooling : op_parser<parse_pooling>
tf_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
if
(
!
starts_with
(
opd
.
tf_name
,
"Max"
)
&&
!
starts_with
(
opd
.
tf_name
,
"Av"
))
if
(
not
starts_with
(
opd
.
tf_name
,
"Max"
)
and
not
starts_with
(
opd
.
tf_name
,
"Av"
))
{
MIGRAPHX_THROW
(
"tf pooling mode must be Max or Average"
);
}
...
...
src/tf/parse_relu6.cpp
View file @
db1a954c
...
...
@@ -41,8 +41,9 @@ struct parse_relu6 : op_parser<parse_relu6>
const
tf_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
auto
min_val
=
info
.
add_literal
(
0.0
f
);
auto
max_val
=
info
.
add_literal
(
6.0
f
);
shape
::
type_t
output_type
=
args
[
0
]
->
get_shape
().
type
();
auto
min_val
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
output_type
},
{
0.0
f
}});
auto
max_val
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
output_type
},
{
6.0
f
}});
return
info
.
add_common_op
(
"clip"
,
args
[
0
],
min_val
,
max_val
);
}
...
...
src/tf/tf_parser.cpp
View file @
db1a954c
...
...
@@ -371,7 +371,7 @@ void tf_parser::parse_node(const std::string& name)
{
result
=
ops
[
node
.
op
()](
*
this
,
{
get_attributes
(
node
),
node
.
op
(),
mm
},
args
);
}
assert
(
!
result
.
empty
());
assert
(
not
result
.
empty
());
// First output has no ":" delimiter
instructions
[
name
]
=
result
.
front
();
for
(
size_t
i
=
1
;
i
<
result
.
size
();
i
++
)
...
...
@@ -458,7 +458,7 @@ literal tf_parser::parse_tensor(const tensorflow::TensorProto& t) const
{
std
::
vector
<
size_t
>
dims
=
parse_dims
(
t
.
tensor_shape
());
size_t
shape_size
=
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
size_t
>
());
if
(
!
t
.
tensor_content
().
empty
())
// has raw data
if
(
not
t
.
tensor_content
().
empty
())
// has raw data
{
const
std
::
string
&
s
=
t
.
tensor_content
();
switch
(
t
.
dtype
())
...
...
src/tmp_dir.cpp
View file @
db1a954c
...
...
@@ -78,7 +78,7 @@ void tmp_dir::execute(const std::string& exe, const std::string& args) const
tmp_dir
::~
tmp_dir
()
{
if
(
!
enabled
(
MIGRAPHX_DEBUG_SAVE_TEMP_DIR
{}))
if
(
not
enabled
(
MIGRAPHX_DEBUG_SAVE_TEMP_DIR
{}))
{
fs
::
remove_all
(
this
->
path
);
}
...
...
src/value.cpp
View file @
db1a954c
...
...
@@ -400,7 +400,7 @@ std::pair<value*, bool> value::insert(const value& v)
{
if
(
v
.
key
.
empty
())
{
if
(
!
x
)
if
(
not
x
)
x
=
std
::
make_shared
<
array_value_holder
>
();
get_array_impl
(
x
).
push_back
(
v
);
assert
(
this
->
if_array
());
...
...
@@ -408,7 +408,7 @@ std::pair<value*, bool> value::insert(const value& v)
}
else
{
if
(
!
x
)
if
(
not
x
)
x
=
std
::
make_shared
<
object_value_holder
>
();
auto
p
=
x
->
if_object
()
->
emplace
(
v
.
key
,
get_array_impl
(
x
).
size
());
if
(
p
.
second
)
...
...
@@ -420,7 +420,7 @@ std::pair<value*, bool> value::insert(const value& v)
value
*
value
::
insert
(
const
value
*
pos
,
const
value
&
v
)
{
assert
(
v
.
key
.
empty
());
if
(
!
x
)
if
(
not
x
)
x
=
std
::
make_shared
<
array_value_holder
>
();
auto
&&
a
=
get_array_impl
(
x
);
auto
it
=
a
.
insert
(
a
.
begin
()
+
(
pos
-
begin
()),
v
);
...
...
@@ -466,7 +466,7 @@ bool compare(const value& x, const value& y, F f)
value
::
type_t
value
::
get_type
()
const
{
if
(
!
x
)
if
(
not
x
)
return
null_type
;
return
x
->
get_type
();
}
...
...
test/api/test_custom_op_gpu.cpp
View file @
db1a954c
...
...
@@ -55,7 +55,7 @@ struct simple_custom_op final : migraphx::experimental_custom_op_base
virtual
migraphx
::
shape
compute_shape
(
migraphx
::
shapes
inputs
)
const
override
{
if
(
!
inputs
[
0
].
standard
())
if
(
not
inputs
[
0
].
standard
())
{
throw
std
::
runtime_error
(
"first arg must be standard shaped"
);
}
...
...
test/check_shapes_test.cpp
View file @
db1a954c
...
...
@@ -49,6 +49,6 @@ bool create_shapes(bool dynamic_allowed)
TEST_CASE
(
allow_dynamic_shape
)
{
EXPECT
(
create_shapes
(
true
));
}
TEST_CASE
(
fail_dynamic_shape
)
{
EXPECT
(
!
create_shapes
(
false
));
}
TEST_CASE
(
fail_dynamic_shape
)
{
EXPECT
(
not
create_shapes
(
false
));
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/eval_test.cpp
View file @
db1a954c
...
...
@@ -187,7 +187,7 @@ TEST_CASE(print_test)
std
::
stringstream
ss
;
ss
<<
p
;
std
::
string
s
=
ss
.
str
();
EXPECT
(
!
s
.
empty
());
EXPECT
(
not
s
.
empty
());
}
TEST_CASE
(
param_test
)
...
...
test/fpga/get_target_assignments.cpp
View file @
db1a954c
...
...
@@ -47,7 +47,7 @@ TEST_CASE(is_supported)
{
auto
p
=
create_program
();
auto
targets
=
migraphx
::
get_targets
();
EXPECT
(
!
targets
.
empty
());
EXPECT
(
not
targets
.
empty
());
auto
t
=
migraphx
::
make_target
(
"fpga"
);
const
auto
assignments
=
p
.
get_target_assignments
({
t
});
...
...
test/include/basic_ops.hpp
View file @
db1a954c
...
...
@@ -112,12 +112,12 @@ struct mod_pass_op
migraphx
::
shape
compute_shape
(
std
::
vector
<
migraphx
::
shape
>
inputs
,
std
::
vector
<
migraphx
::
module_ref
>
mods
)
const
{
if
(
!
mods
.
empty
())
if
(
not
mods
.
empty
())
{
auto
out_shapes
=
mods
[
0
]
->
get_output_shapes
();
return
out_shapes
[
0
];
}
if
(
!
inputs
.
empty
())
if
(
not
inputs
.
empty
())
{
return
inputs
.
front
();
}
...
...
@@ -186,9 +186,10 @@ struct nop
migraphx
::
shape
compute_shape
(
const
std
::
vector
<
migraphx
::
shape
>&
)
const
{
return
{};
}
};
inline
migraphx
::
literal
get_2x2
()
inline
migraphx
::
literal
get_2x2
(
int
base
=
0
)
{
return
migraphx
::
literal
{{
migraphx
::
shape
::
float_type
,
{
2
,
2
}},
{
1
,
2
,
3
,
4
}};
return
migraphx
::
literal
{{
migraphx
::
shape
::
float_type
,
{
2
,
2
}},
{
base
+
1
,
base
+
2
,
base
+
3
,
base
+
4
}};
}
inline
migraphx
::
literal
get_2x2_transposed
()
...
...
test/include/test.hpp
View file @
db1a954c
...
...
@@ -345,7 +345,7 @@ inline std::ostream& operator<<(std::ostream& os, const color& c)
template
<
class
T
,
class
F
>
void
failed
(
T
x
,
const
char
*
msg
,
const
char
*
func
,
const
char
*
file
,
int
line
,
F
f
)
{
if
(
!
bool
(
x
.
value
()))
if
(
not
bool
(
x
.
value
()))
{
std
::
cout
<<
func
<<
std
::
endl
;
std
::
cout
<<
file
<<
":"
<<
line
<<
":"
<<
std
::
endl
;
...
...
test/literal_test.cpp
View file @
db1a954c
...
...
@@ -39,8 +39,8 @@ TEST_CASE(literal_test)
migraphx
::
literal
l2
=
l1
;
// NOLINT
EXPECT
(
l1
==
l2
);
EXPECT
(
l1
.
at
<
int
>
(
0
)
==
1
);
EXPECT
(
!
l1
.
empty
());
EXPECT
(
!
l2
.
empty
());
EXPECT
(
not
l1
.
empty
());
EXPECT
(
not
l2
.
empty
());
migraphx
::
literal
l3
{};
migraphx
::
literal
l4
{};
...
...
test/onnx/onnx_test.cpp
View file @
db1a954c
...
...
@@ -38,7 +38,6 @@
#include <migraphx/onnx.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/reshape.hpp>
...
...
test/ref_ops_test.cpp
View file @
db1a954c
...
...
@@ -3988,7 +3988,8 @@ TEST_CASE(not_test)
std
::
vector
<
char
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
bool
>
gold
(
data
.
size
());
std
::
transform
(
data
.
begin
(),
data
.
end
(),
gold
.
begin
(),
[](
bool
n
)
->
bool
{
return
!
n
;
});
std
::
transform
(
data
.
begin
(),
data
.
end
(),
gold
.
begin
(),
[](
bool
n
)
->
bool
{
return
not
n
;
});
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
...
...
test/rewrite_gelu_test.cpp
0 → 100644
View file @
db1a954c
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/rewrite_gelu.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <test.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/common.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/verify.hpp>
TEST_CASE
(
bias_gelu
)
{
migraphx
::
shape
s1
{
migraphx
::
shape
::
half_type
,
{
2
,
4
,
8
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
half_type
};
migraphx
::
module
m1
;
{
auto
a
=
m1
.
add_parameter
(
"a"
,
s1
);
auto
b
=
m1
.
add_parameter
(
"b"
,
s1
);
auto
add1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
a
,
b
);
auto
l1
=
m1
.
add_literal
(
migraphx
::
literal
{
s2
,
{
1.4140625
f
}});
auto
div
=
add_common_op
(
m1
,
migraphx
::
make_op
(
"div"
),
{
add1
,
l1
});
auto
erf
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"erf"
),
div
);
auto
l2
=
m1
.
add_literal
(
migraphx
::
literal
{
s2
,
{
1.0
f
}});
auto
add2
=
add_common_op
(
m1
,
migraphx
::
make_op
(
"add"
),
{
erf
,
l2
});
auto
mul
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
add1
,
add2
);
auto
l3
=
m1
.
add_literal
(
migraphx
::
literal
{
s2
,
{
0.5
f
}});
mul
=
add_common_op
(
m1
,
migraphx
::
make_op
(
"mul"
),
{
mul
,
l3
});
m1
.
add_return
({
mul
});
}
migraphx
::
rewrite_gelu
pass
;
pass
.
apply
(
m1
);
migraphx
::
dead_code_elimination
dce
;
dce
.
apply
(
m1
);
migraphx
::
module
m2
;
{
auto
a
=
m2
.
add_parameter
(
"a"
,
s1
);
auto
b
=
m2
.
add_parameter
(
"b"
,
s1
);
auto
add
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
a
,
b
);
auto
l1
=
m2
.
add_literal
(
migraphx
::
literal
{
s2
,
{
1.702
f
}});
auto
mul
=
add_common_op
(
m2
,
migraphx
::
make_op
(
"mul"
),
{
add
,
l1
});
auto
sig
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"neg"
),
mul
);
sig
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"exp"
),
sig
);
auto
l2
=
m2
.
add_literal
(
migraphx
::
literal
{
s2
,
{
1.0
f
}});
sig
=
add_common_op
(
m2
,
migraphx
::
make_op
(
"add"
),
{
sig
,
l2
});
sig
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"div"
),
add
,
sig
);
m2
.
add_return
({
sig
});
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
non_bias_gelu
)
{
migraphx
::
shape
s1
{
migraphx
::
shape
::
half_type
,
{
2
,
4
,
8
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
half_type
};
migraphx
::
module
m1
;
{
auto
a
=
m1
.
add_parameter
(
"a"
,
s1
);
auto
b
=
m1
.
add_parameter
(
"b"
,
s1
);
auto
sub
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"sub"
),
a
,
b
);
auto
l1
=
m1
.
add_literal
(
migraphx
::
literal
{
s2
,
{
1.4140625
f
}});
auto
div
=
add_common_op
(
m1
,
migraphx
::
make_op
(
"div"
),
{
sub
,
l1
});
auto
erf
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"erf"
),
div
);
auto
l2
=
m1
.
add_literal
(
migraphx
::
literal
{
s2
,
{
1.0
f
}});
auto
add2
=
add_common_op
(
m1
,
migraphx
::
make_op
(
"add"
),
{
erf
,
l2
});
auto
mul
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
sub
,
add2
);
auto
l3
=
m1
.
add_literal
(
migraphx
::
literal
{
s2
,
{
0.5
f
}});
mul
=
add_common_op
(
m1
,
migraphx
::
make_op
(
"mul"
),
{
mul
,
l3
});
m1
.
add_return
({
mul
});
}
migraphx
::
rewrite_gelu
pass
;
pass
.
apply
(
m1
);
migraphx
::
dead_code_elimination
dce
;
dce
.
apply
(
m1
);
migraphx
::
module
m2
;
{
auto
a
=
m2
.
add_parameter
(
"a"
,
s1
);
auto
b
=
m2
.
add_parameter
(
"b"
,
s1
);
auto
sub
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"sub"
),
a
,
b
);
auto
l1
=
m2
.
add_literal
(
migraphx
::
literal
{
s2
,
{
1.702
f
}});
auto
mul
=
add_common_op
(
m2
,
migraphx
::
make_op
(
"mul"
),
{
sub
,
l1
});
auto
sig
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"neg"
),
mul
);
sig
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"exp"
),
sig
);
auto
l2
=
m2
.
add_literal
(
migraphx
::
literal
{
s2
,
{
1.0
f
}});
sig
=
add_common_op
(
m2
,
migraphx
::
make_op
(
"add"
),
{
sig
,
l2
});
sig
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"div"
),
sub
,
sig
);
m2
.
add_return
({
sig
});
}
EXPECT
(
m1
==
m2
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/shape_test.cpp
View file @
db1a954c
...
...
@@ -43,7 +43,7 @@ TEST_CASE(test_shape_assign)
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
100
,
32
,
8
,
8
}};
migraphx
::
shape
s2
=
s1
;
// NOLINT
EXPECT
(
s1
==
s2
);
EXPECT
(
!
(
s1
!=
s2
));
EXPECT
(
not
(
s1
!=
s2
));
}
TEST_CASE
(
test_shape_packed_default
)
...
...
@@ -325,7 +325,7 @@ TEST_CASE(test_shape_default_copy)
migraphx
::
shape
s1
{};
migraphx
::
shape
s2
{};
EXPECT
(
s1
==
s2
);
EXPECT
(
!
(
s1
!=
s2
));
EXPECT
(
not
(
s1
!=
s2
));
}
TEST_CASE
(
test_shape_normalize_standard1
)
...
...
test/simplify_algebra_test.cpp
View file @
db1a954c
...
...
@@ -30,7 +30,6 @@
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
void
run_pass
(
migraphx
::
module
&
m
)
...
...
@@ -358,7 +357,33 @@ TEST_CASE(simplify_mul_add)
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
simplify_inner_broadcast
)
TEST_CASE
(
simplify_dot_add
)
{
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
2
}});
auto
one
=
m1
.
add_literal
(
get_2x2
());
auto
two
=
m1
.
add_literal
(
get_2x2
(
1
));
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
one
,
x
);
auto
dot
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
sum
,
two
);
m1
.
add_instruction
(
pass_op
{},
dot
);
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
x
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
2
}});
auto
one
=
m2
.
add_literal
(
get_2x2
());
auto
two
=
m2
.
add_literal
(
get_2x2
(
1
));
auto
dot1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
x
,
two
);
auto
dot2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
one
,
two
);
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
dot1
,
dot2
);
m2
.
add_instruction
(
pass_op
{},
sum
);
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
simplify_inner_broadcast1
)
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
2
,
1
,
4
,
5
}};
migraphx
::
module
m1
;
...
...
@@ -383,6 +408,31 @@ TEST_CASE(simplify_inner_broadcast)
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
simplify_inner_broadcast2
)
{
auto
b
=
migraphx
::
op
::
multibroadcast
{{
2
,
1
,
4
,
5
}};
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
1
,
1
,
1
}});
auto
y
=
m1
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
1
,
1
,
1
}});
auto
xb
=
m1
.
add_instruction
(
b
,
x
);
auto
yb
=
m1
.
add_instruction
(
b
,
y
);
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
xb
,
yb
);
m1
.
add_instruction
(
pass_op
{},
sum
);
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
x
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
1
,
1
,
1
}});
auto
y
=
m2
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
1
,
1
,
1
}});
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
x
,
y
);
auto
sumb
=
m2
.
add_instruction
(
b
,
sum
);
m2
.
add_instruction
(
pass_op
{},
sumb
);
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
simplify_add_conv1
)
{
migraphx
::
module
m
;
...
...
@@ -1477,6 +1527,48 @@ TEST_CASE(simplify_dot_horiz_flipped)
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
// test if contiguous is added as necessary for reshapes
TEST_CASE
(
simplify_dot_horiz_reshape
)
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
4
,
4
}};
migraphx
::
module
m1
;
{
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
a
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
s
,
0
));
auto
b
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
s
,
1
));
auto
x
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
input
,
a
);
auto
y
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
input
,
b
);
auto
x_rsp
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
3
,
4
,
2
,
2
}}}),
x
);
auto
y_rsp
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}},
{
"steps"
,
{
2
}}}),
y
);
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
{
x_rsp
,
y_rsp
});
m1
.
add_instruction
(
pass_op
{},
sum
);
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
a
=
m2
.
add_literal
(
migraphx
::
generate_literal
(
s
,
0
));
auto
b
=
m2
.
add_literal
(
migraphx
::
generate_literal
(
s
,
1
));
auto
concat
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
2
}}),
a
,
b
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
input
,
concat
);
auto
x
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
4
}}}),
dot
);
auto
y
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
4
}},
{
"ends"
,
{
8
}}}),
dot
);
auto
x_cont
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
x
);
auto
x_rsp
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
3
,
4
,
2
,
2
}}}),
x_cont
);
auto
y_rsp
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}},
{
"steps"
,
{
2
}}}),
y
);
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
{
x_rsp
,
y_rsp
});
m2
.
add_instruction
(
pass_op
{},
sum
);
}
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE
(
simplify_conv_horiz
)
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
8
,
3
,
64
,
64
}};
...
...
@@ -1782,13 +1874,19 @@ TEST_CASE(simplify_mul_slice_conv_horiz_fusion)
}
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE
(
reorder_reshape_slice
)
template
<
std
::
size_t
BS
,
bool
TransposeInput
>
void
reorder_reshape_slice
()
{
std
::
vector
<
int64_t
>
perm0
=
{
0
,
2
,
1
,
3
};
std
::
vector
<
int64_t
>
perm1
=
{
0
,
2
,
3
,
1
};
auto
create_m1
=
[
&
](
std
::
size_t
batch_size
)
{
migraphx
::
module
m1
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
128
,
1920
}};
migraphx
::
module
m1
;
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
BS
,
128
,
1920
}};
if
(
TransposeInput
)
{
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
BS
,
128
,
1920
},
{
165120
,
1
,
128
}};
}
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
slc0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
640
}}}),
input
);
...
...
@@ -1803,7 +1901,7 @@ TEST_CASE(reorder_reshape_slice)
auto
c1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc1
);
auto
c2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc2
);
std
::
vector
<
int64_t
>
lens
=
{
static_cast
<
int64_t
>
(
batch_size
),
128
,
10
,
64
};
std
::
vector
<
int64_t
>
lens
=
{
static_cast
<
int64_t
>
(
BS
),
128
,
10
,
64
};
auto
r0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c0
);
auto
r1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c1
);
auto
r2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c2
);
...
...
@@ -1815,16 +1913,23 @@ TEST_CASE(reorder_reshape_slice)
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
t0
,
t1
);
auto
ret
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
sum
,
t2
);
m1
.
add_return
({
ret
});
return
m1
;
};
auto
create_m2
=
[
&
](
std
::
size_t
batch_size
)
{
migraphx
::
module
m2
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
128
,
1920
}};
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
std
::
vector
<
int64_t
>
lens
=
{
static_cast
<
int64_t
>
(
batch_size
),
128
,
30
,
64
};
auto
r
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
input
);
migraphx
::
module
m2
;
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
BS
,
128
,
1920
}};
if
(
TransposeInput
)
{
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
BS
,
128
,
1920
},
{
165120
,
1
,
128
}};
}
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
rsp_input
=
input
;
if
(
TransposeInput
)
{
rsp_input
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
{
input
});
}
std
::
vector
<
int64_t
>
lens
=
{
static_cast
<
int64_t
>
(
BS
),
128
,
30
,
64
};
auto
r
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
rsp_input
);
auto
slc0
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
10
}}}),
r
);
...
...
@@ -1843,27 +1948,25 @@ TEST_CASE(reorder_reshape_slice)
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
t0
,
t1
);
auto
ret
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
sum
,
t2
);
m2
.
add_return
({
ret
});
return
m2
;
};
auto
test
=
[
&
](
std
::
size_t
batch_size
)
{
auto
m1
=
create_m1
(
batch_size
);
run_pass
(
m1
);
auto
m2
=
create_m2
(
batch_size
);
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
};
test
(
1
);
test
(
4
);
test
(
8
);
run_pass
(
m1
);
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE
(
reorder_reshape_slice_move_axis1
)
TEST_CASE_REGISTER
(
reorder_reshape_slice
<
1
,
true
>
);
// test if contiguous is added as necessary if
// input is transposed
TEST_CASE_REGISTER
(
reorder_reshape_slice
<
4
,
true
>
);
TEST_CASE_REGISTER
(
reorder_reshape_slice
<
8
,
true
>
);
TEST_CASE_REGISTER
(
reorder_reshape_slice
<
1
,
false
>
);
TEST_CASE_REGISTER
(
reorder_reshape_slice
<
4
,
false
>
);
TEST_CASE_REGISTER
(
reorder_reshape_slice
<
8
,
false
>
);
template
<
std
::
size_t
BS
>
void
reorder_reshape_slice_move_axis1
()
{
auto
create_m1
=
[](
std
::
size_t
batch_size
)
{
migraphx
::
module
m1
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
256
,
96
}};
migraphx
::
module
m1
;
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
BS
,
256
,
96
}};
std
::
vector
<
int64_t
>
perm0
=
{
0
,
2
,
1
,
3
};
std
::
vector
<
int64_t
>
perm1
=
{
0
,
2
,
3
,
1
};
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
...
...
@@ -1878,7 +1981,7 @@ TEST_CASE(reorder_reshape_slice_move_axis1)
auto
c1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc1
);
auto
c2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc2
);
std
::
vector
<
int64_t
>
lens
=
{
static_cast
<
int64_t
>
(
batch_size
),
64
,
4
,
32
};
std
::
vector
<
int64_t
>
lens
=
{
static_cast
<
int64_t
>
(
BS
),
64
,
4
,
32
};
auto
r0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c0
);
auto
r1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c1
);
auto
r2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c2
);
...
...
@@ -1890,50 +1993,45 @@ TEST_CASE(reorder_reshape_slice_move_axis1)
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
t0
,
t1
);
auto
ret
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
sum
,
t2
);
m1
.
add_return
({
ret
});
return
m1
;
};
auto
create_m2
=
[](
std
::
size_t
batch_size
)
{
migraphx
::
module
m
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
256
,
96
}};
migraphx
::
module
m2
;
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
BS
,
256
,
96
}};
std
::
vector
<
int64_t
>
perm0
=
{
0
,
2
,
1
,
3
};
std
::
vector
<
int64_t
>
perm1
=
{
0
,
2
,
3
,
1
};
auto
input
=
m
.
add_parameter
(
"input"
,
s
);
std
::
vector
<
int64_t
>
lens
=
{
static_cast
<
int64_t
>
(
batch_size
),
64
,
4
,
96
};
auto
rsp
=
m
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
input
);
auto
slc0
=
m
.
add_instruction
(
auto
input
=
m
2
.
add_parameter
(
"input"
,
s
);
std
::
vector
<
int64_t
>
lens
=
{
static_cast
<
int64_t
>
(
BS
),
64
,
4
,
96
};
auto
rsp
=
m
2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
input
);
auto
slc0
=
m
2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
3
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
32
}}}),
rsp
);
auto
t0
=
m
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm0
}}),
slc0
);
auto
slc1
=
m
.
add_instruction
(
auto
t0
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm0
}}),
slc0
);
auto
slc1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
3
}},
{
"starts"
,
{
32
}},
{
"ends"
,
{
64
}}}),
rsp
);
auto
t1
=
m
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm0
}}),
slc1
);
auto
slc2
=
m
.
add_instruction
(
auto
t1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm0
}}),
slc1
);
auto
slc2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
3
}},
{
"starts"
,
{
64
}},
{
"ends"
,
{
96
}}}),
rsp
);
auto
t2
=
m
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm1
}}),
slc2
);
auto
sum
=
m
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
t0
,
t1
);
auto
ret
=
m
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
sum
,
t2
);
m
.
add_return
({
ret
});
return
m
;
};
auto
t2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm1
}}),
slc2
);
auto
test
=
[
&
](
std
::
size_t
batch_size
)
{
auto
m1
=
create_m1
(
batch_size
);
auto
m2
=
create_m2
(
batch_size
);
run_pass
(
m1
);
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
t0
,
t1
);
auto
ret
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
sum
,
t2
);
m2
.
add_return
({
ret
});
};
test
(
4
);
tes
t
(
8
);
run_pass
(
m1
);
EXPECT
(
m1
.
sort
()
==
m2
.
sor
t
(
)
);
}
TEST_CASE_REGISTER
(
reorder_reshape_slice_move_axis1
<
4
>
);
TEST_CASE_REGISTER
(
reorder_reshape_slice_move_axis1
<
8
>
);
TEST_CASE
(
reorder_reshape_slice_move_axis2
)
{
auto
create_m1
=
[]
{
migraphx
::
module
m1
;
migraphx
::
module
m1
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
128
,
96
}};
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
slc0
=
m1
.
add_instruction
(
...
...
@@ -1955,32 +2053,75 @@ TEST_CASE(reorder_reshape_slice_move_axis2)
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
r0
,
r1
);
auto
ret
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
sum
,
r2
);
m1
.
add_return
({
ret
});
return
m1
;
};
auto
create_m2
=
[]
{
migraphx
::
module
m
;
migraphx
::
module
m2
;
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
128
,
96
}};
auto
input
=
m
.
add_parameter
(
"input"
,
s
);
auto
input
=
m
2
.
add_parameter
(
"input"
,
s
);
std
::
vector
<
int64_t
>
lens
=
{
1
,
16
,
8
,
96
};
auto
rsp
=
m
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
input
);
auto
slc0
=
m
.
add_instruction
(
auto
rsp
=
m
2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
input
);
auto
slc0
=
m
2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
3
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
32
}}}),
rsp
);
auto
slc1
=
m
.
add_instruction
(
auto
slc1
=
m
2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
3
}},
{
"starts"
,
{
32
}},
{
"ends"
,
{
64
}}}),
rsp
);
auto
slc2
=
m
.
add_instruction
(
auto
slc2
=
m
2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
3
}},
{
"starts"
,
{
64
}},
{
"ends"
,
{
96
}}}),
rsp
);
auto
sum
=
m
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
slc0
,
slc1
);
auto
ret
=
m
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
sum
,
slc2
);
m
.
add_return
({
ret
});
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
slc0
,
slc1
);
auto
ret
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
sum
,
slc2
);
m2
.
add_return
({
ret
});
};
return
m
;
run_pass
(
m1
);
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE
(
reorder_reshape_slice_len_1
)
{
migraphx
::
module
m1
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
1
,
128
,
3
}};
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
slc0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
input
);
auto
slc1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
2
}}}),
input
);
auto
slc2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
2
}},
{
"ends"
,
{
3
}}}),
input
);
auto
c0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc0
);
auto
c1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc1
);
auto
c2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc2
);
std
::
vector
<
int64_t
>
lens
=
{
1
,
128
};
auto
r0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c0
);
auto
r1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c1
);
auto
r2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c2
);
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
r0
,
r1
);
auto
ret
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
sum
,
r2
);
m1
.
add_return
({
ret
});
};
migraphx
::
module
m2
;
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
128
,
3
}};
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
std
::
vector
<
int64_t
>
lens
=
{
1
,
384
};
auto
rsp
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
input
);
auto
slc0
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
128
}}}),
rsp
);
auto
slc1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
128
}},
{
"ends"
,
{
256
}}}),
rsp
);
auto
slc2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
256
}},
{
"ends"
,
{
384
}}}),
rsp
);
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
slc0
,
slc1
);
auto
ret
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
sum
,
slc2
);
m2
.
add_return
({
ret
});
};
auto
m1
=
create_m1
();
auto
m2
=
create_m2
();
run_pass
(
m1
);
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
...
...
@@ -2020,15 +2161,14 @@ TEST_CASE(reorder_reshape_slice_not_apply)
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE
(
reorder_reshape_slice_diff_dims
)
template
<
std
::
size_t
BS
>
void
reorder_reshape_slice_diff_dims
()
{
auto
create_m1
=
[](
std
::
size_t
batch_size
)
{
migraphx
::
module
m1
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
96
,
96
}};
std
::
vector
<
int64_t
>
perm0
=
{
0
,
2
,
1
,
3
};
std
::
vector
<
int64_t
>
perm1
=
{
0
,
2
,
3
,
1
};
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
slc0
=
m1
.
add_instruction
(
migraphx
::
module
m1
;
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
BS
,
96
,
96
}};
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
slc0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
32
}}}),
input
);
auto
slc1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
32
}},
{
"ends"
,
{
64
}}}),
input
);
...
...
@@ -2039,34 +2179,31 @@ TEST_CASE(reorder_reshape_slice_diff_dims)
auto
c1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc1
);
auto
c2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc2
);
std
::
vector
<
int64_t
>
lens
=
{
static_cast
<
int64_t
>
(
batch_size
),
32
,
3
,
32
};
std
::
vector
<
int64_t
>
lens1
=
{
static_cast
<
int64_t
>
(
batch_size
),
48
,
2
,
32
};
std
::
vector
<
int64_t
>
lens
=
{
static_cast
<
int64_t
>
(
BS
),
32
,
3
,
32
};
std
::
vector
<
int64_t
>
lens1
=
{
static_cast
<
int64_t
>
(
BS
),
48
,
2
,
32
};
auto
r0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c0
);
auto
r1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c1
);
auto
r2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens1
}}),
c2
);
m1
.
add_return
({
r0
,
r1
,
r2
});
return
m1
;
};
auto
test
=
[
&
](
std
::
size_t
batch_size
)
{
auto
m1
=
create_m1
(
batch_size
);
auto
m2
=
m1
;
run_pass
(
m1
);
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
};
test
(
4
);
test
(
8
);
auto
m2
=
m1
;
run_pass
(
m1
);
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE
(
reorder_slice_trans
)
TEST_CASE_REGISTER
(
reorder_reshape_slice_diff_dims
<
4
>
);
TEST_CASE_REGISTER
(
reorder_reshape_slice_diff_dims
<
8
>
);
template
<
std
::
size_t
BS
>
void
reorder_slice_trans
()
{
std
::
vector
<
int64_t
>
perm
=
{
0
,
2
,
1
};
auto
create_m1
=
[
&
](
std
::
size_t
batch_size
)
{
migraphx
::
module
m1
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
128
,
1920
}};
migraphx
::
module
m1
;
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
BS
,
128
,
1920
}};
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
slc0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
640
}}}),
input
);
...
...
@@ -2084,13 +2221,11 @@ TEST_CASE(reorder_slice_trans)
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
t0
,
t1
);
auto
ret
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
sum
,
t2
);
m1
.
add_return
({
ret
});
return
m1
;
};
auto
create_m2
=
[
&
](
std
::
size_t
batch_size
)
{
migraphx
::
module
m2
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
128
,
1920
}};
migraphx
::
module
m2
;
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
BS
,
128
,
1920
}};
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
r
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
input
);
...
...
@@ -2104,26 +2239,21 @@ TEST_CASE(reorder_slice_trans)
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
slc0
,
slc1
);
auto
ret
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
sum
,
slc2
);
m2
.
add_return
({
ret
});
return
m2
;
};
auto
test
=
[
&
](
std
::
size_t
batch_size
)
{
auto
m1
=
create_m1
(
batch_size
);
run_pass
(
m1
);
auto
m2
=
create_m2
(
batch_size
);
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
};
test
(
1
);
test
(
8
);
run_pass
(
m1
);
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE
(
reorder_slice_trans_diff_perm
)
TEST_CASE_REGISTER
(
reorder_slice_trans
<
1
>
);
TEST_CASE_REGISTER
(
reorder_slice_trans
<
8
>
);
template
<
std
::
size_t
BS
>
void
reorder_slice_trans_diff_perm
()
{
auto
create_m1
=
[](
std
::
size_t
batch_size
)
{
migraphx
::
module
m1
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
128
,
1920
}};
migraphx
::
module
m1
;
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
BS
,
128
,
1920
}};
std
::
vector
<
int64_t
>
perm0
=
{
0
,
2
,
1
};
std
::
vector
<
int64_t
>
perm1
=
{
0
,
1
,
2
};
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
...
...
@@ -2146,21 +2276,16 @@ TEST_CASE(reorder_slice_trans_diff_perm)
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
t0
,
t1
);
auto
ret
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
sum
,
t2
);
m1
.
add_return
({
ret
});
return
m1
;
};
auto
test
=
[
&
](
std
::
size_t
batch_size
)
{
auto
m1
=
create_m1
(
batch_size
);
run_pass
(
m1
);
auto
m2
=
m1
;
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
};
test
(
1
);
test
(
4
);
run_pass
(
m1
);
auto
m2
=
m1
;
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE_REGISTER
(
reorder_slice_trans_diff_perm
<
1
>
);
TEST_CASE_REGISTER
(
reorder_slice_trans_diff_perm
<
4
>
);
TEST_CASE
(
reorder_slice_ins_deps
)
{
auto
create_module
=
[]
{
...
...
test/simplify_reshapes_test.cpp
View file @
db1a954c
...
...
@@ -48,6 +48,26 @@ inline std::vector<std::vector<std::size_t>> to_lens(const std::vector<migraphx:
return
result
;
}
migraphx
::
module
make_concat_multibroadcast
(
const
std
::
vector
<
size_t
>&
in_lens
,
const
std
::
vector
<
size_t
>&
mbcast_lens
,
const
int
axis
)
{
migraphx
::
module
m
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
in_lens
};
auto
x
=
m
.
add_parameter
(
"x"
,
s
);
auto
y
=
m
.
add_parameter
(
"y"
,
s
);
auto
z
=
m
.
add_parameter
(
"z"
,
s
);
auto
xm
=
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
mbcast_lens
}}),
x
);
auto
ym
=
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
mbcast_lens
}}),
y
);
auto
zm
=
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
mbcast_lens
}}),
z
);
auto
concat
=
m
.
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
axis
}}),
xm
,
ym
,
zm
);
m
.
add_return
({
concat
});
return
m
;
}
TEST_CASE
(
double_contig
)
{
migraphx
::
program
p
;
...
...
@@ -337,6 +357,87 @@ TEST_CASE(nop_convert)
EXPECT
(
std
::
distance
(
m
.
begin
(),
m
.
end
())
==
n
-
1
);
}
TEST_CASE
(
concat_multibroadcasts1
)
{
// Broadcasted batch dim, new axis < old axis
std
::
vector
<
std
::
size_t
>
in_lens
=
{
3
,
4
};
std
::
vector
<
std
::
size_t
>
mbcast_lens
=
{
2
,
3
,
4
};
const
int
axis
=
2
;
auto
m
=
make_concat_multibroadcast
(
in_lens
,
mbcast_lens
,
axis
);
auto
out_shape
=
m
.
get_output_shapes
().
back
();
auto
n
=
std
::
distance
(
m
.
begin
(),
m
.
end
());
run_pass
(
m
);
EXPECT
(
m
.
get_output_shapes
().
back
().
lens
()
==
out_shape
.
lens
());
EXPECT
(
std
::
distance
(
m
.
begin
(),
m
.
end
())
==
n
-
2
);
auto
new_concat
=
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"concat"
;
});
EXPECT
(
bool
{
new_concat
!=
m
.
end
()});
auto
cd
=
std
::
distance
(
m
.
begin
(),
new_concat
);
auto
new_mb
=
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"multibroadcast"
;
});
auto
md
=
std
::
distance
(
m
.
begin
(),
new_mb
);
EXPECT
(
cd
==
md
-
1
);
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
1
);
}
TEST_CASE
(
concat_multibroadcasts2
)
{
// Broadcasted middle dim, new axis == old axis
std
::
vector
<
std
::
size_t
>
in_lens
=
{
3
,
1
,
4
};
std
::
vector
<
std
::
size_t
>
mbcast_lens
=
{
3
,
2
,
4
};
const
int
axis
=
0
;
auto
m
=
make_concat_multibroadcast
(
in_lens
,
mbcast_lens
,
axis
);
auto
out_shape
=
m
.
get_output_shapes
().
back
();
auto
n
=
std
::
distance
(
m
.
begin
(),
m
.
end
());
run_pass
(
m
);
EXPECT
(
m
.
get_output_shapes
().
back
().
lens
()
==
out_shape
.
lens
());
EXPECT
(
std
::
distance
(
m
.
begin
(),
m
.
end
())
==
n
-
2
);
auto
new_concat
=
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"concat"
;
});
EXPECT
(
bool
{
new_concat
!=
m
.
end
()});
auto
cd
=
std
::
distance
(
m
.
begin
(),
new_concat
);
auto
new_mb
=
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"multibroadcast"
;
});
auto
md
=
std
::
distance
(
m
.
begin
(),
new_mb
);
EXPECT
(
cd
==
md
-
1
);
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
0
);
}
TEST_CASE
(
concat_multibroadcasts3
)
{
// Broadcasted middle dim, new axis == old axis
std
::
vector
<
std
::
size_t
>
in_lens
=
{
3
,
1
,
4
};
std
::
vector
<
std
::
size_t
>
mbcast_lens
=
{
3
,
2
,
4
};
const
int
axis
=
2
;
auto
m
=
make_concat_multibroadcast
(
in_lens
,
mbcast_lens
,
axis
);
auto
out_shape
=
m
.
get_output_shapes
().
back
();
auto
n
=
std
::
distance
(
m
.
begin
(),
m
.
end
());
run_pass
(
m
);
EXPECT
(
m
.
get_output_shapes
().
back
().
lens
()
==
out_shape
.
lens
());
EXPECT
(
std
::
distance
(
m
.
begin
(),
m
.
end
())
==
n
-
2
);
auto
new_concat
=
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"concat"
;
});
EXPECT
(
bool
{
new_concat
!=
m
.
end
()});
auto
cd
=
std
::
distance
(
m
.
begin
(),
new_concat
);
auto
new_mb
=
std
::
find_if
(
m
.
begin
(),
m
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"multibroadcast"
;
});
auto
md
=
std
::
distance
(
m
.
begin
(),
new_mb
);
EXPECT
(
cd
==
md
-
1
);
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
2
);
}
TEST_CASE
(
concat_multibroadcasts4
)
{
// Broadcasted batch dim, axis is broadcasted dim
std
::
vector
<
std
::
size_t
>
in_lens
=
{
3
,
4
};
std
::
vector
<
std
::
size_t
>
mbcast_lens
=
{
2
,
3
,
4
};
const
int
axis
=
0
;
auto
m
=
make_concat_multibroadcast
(
in_lens
,
mbcast_lens
,
axis
);
auto
m1
=
m
;
run_pass
(
m
);
EXPECT
(
m1
==
m
);
}
TEST_CASE
(
concat_transpose1
)
{
migraphx
::
module
m
;
...
...
Prev
1
…
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment