Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
1d8840b8
Unverified
Commit
1d8840b8
authored
Sep 07, 2023
by
Paul Fultz II
Committed by
GitHub
Sep 07, 2023
Browse files
Fuse pointwise modules across reshapes (#1940)
parent
a0894c2a
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
480 additions
and
2 deletions
+480
-2
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/common_dims.cpp
src/common_dims.cpp
+156
-0
src/fuse_pointwise.cpp
src/fuse_pointwise.cpp
+53
-0
src/include/migraphx/common_dims.hpp
src/include/migraphx/common_dims.hpp
+49
-0
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+2
-0
src/include/migraphx/simplify_reshapes.hpp
src/include/migraphx/simplify_reshapes.hpp
+1
-0
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+1
-1
test/common_dims.cpp
test/common_dims.cpp
+65
-0
test/fuse_pointwise.cpp
test/fuse_pointwise.cpp
+152
-1
No files found.
src/CMakeLists.txt
View file @
1d8840b8
...
@@ -36,6 +36,7 @@ add_library(migraphx
...
@@ -36,6 +36,7 @@ add_library(migraphx
argument.cpp
argument.cpp
auto_contiguous.cpp
auto_contiguous.cpp
common.cpp
common.cpp
common_dims.cpp
compile_src.cpp
compile_src.cpp
convert_to_json.cpp
convert_to_json.cpp
cpp_generator.cpp
cpp_generator.cpp
...
...
src/common_dims.cpp
0 → 100644
View file @
1d8840b8
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/common_dims.hpp>
#include <migraphx/ranges.hpp>
#include <algorithm>
#include <cassert>
#include <numeric>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Iterator
>
static
auto
compute_end_dim
(
Iterator
start
,
Iterator
last
,
std
::
size_t
dim
)
{
std
::
size_t
x
=
1
;
auto
it
=
std
::
find_if
(
start
,
last
,
[
&
](
auto
i
)
{
x
*=
i
;
return
x
>
dim
;
});
if
(
x
<
dim
)
return
start
;
return
it
;
}
template
<
class
Range
>
static
auto
elements
(
const
Range
&
r
)
{
return
std
::
accumulate
(
r
.
begin
(),
r
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
}
struct
common_dim_state
{
common_dim_state
(
const
std
::
vector
<
std
::
size_t
>&
pdims
,
std
::
vector
<
std
::
vector
<
std
::
size_t
>>&
paxes_map
)
:
dims
(
&
pdims
),
axes_map
(
&
paxes_map
),
it
(
dims
->
begin
())
{
}
const
std
::
vector
<
std
::
size_t
>*
dims
=
nullptr
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>*
axes_map
=
nullptr
;
std
::
vector
<
std
::
size_t
>::
const_iterator
it
{};
std
::
size_t
rem
=
1
;
std
::
size_t
get
()
const
{
return
*
it
/
rem
;
}
bool
is_end
()
const
{
return
it
==
dims
->
end
();
}
void
next
(
std
::
size_t
i
=
1
)
{
it
+=
i
;
}
auto
dims_for
(
std
::
size_t
d
)
const
{
auto
dim_end
=
compute_end_dim
(
it
,
dims
->
end
(),
d
);
return
range
(
it
,
dim_end
);
}
void
add_axes
(
std
::
size_t
naxes
,
std
::
size_t
start
)
MIGRAPHX_TIDY_CONST
{
auto
axes
=
compute_axes
(
naxes
,
start
);
axes_map
->
push_back
(
std
::
move
(
axes
));
}
void
add_multi_axes
(
std
::
size_t
naxes
,
std
::
size_t
start
)
MIGRAPHX_TIDY_CONST
{
auto
axes
=
compute_axes
(
naxes
,
start
);
std
::
transform
(
axes
.
begin
(),
axes
.
end
(),
std
::
back_inserter
(
*
axes_map
),
[
&
](
auto
axis
)
->
std
::
vector
<
std
::
size_t
>
{
return
{
axis
};
});
}
std
::
vector
<
std
::
size_t
>
compute_axes
(
std
::
size_t
naxes
,
std
::
size_t
start
)
const
{
if
(
rem
!=
1
)
{
assert
(
start
>
0
);
naxes
++
;
start
--
;
}
std
::
vector
<
std
::
size_t
>
axes
(
naxes
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
start
);
return
axes
;
}
};
static
bool
compute_common_dim
(
std
::
vector
<
std
::
size_t
>&
cd_dims
,
common_dim_state
&
state1
,
common_dim_state
&
state2
)
{
assert
(
state1
.
get
()
<=
state2
.
get
());
auto
d2
=
state2
.
get
();
auto
dims
=
state1
.
dims_for
(
d2
);
auto
n
=
elements
(
dims
);
auto
naxes
=
distance
(
dims
);
if
(
naxes
==
0
)
return
false
;
// If not divisible then we can't compute a common dim
if
((
d2
%
n
)
!=
0
)
return
false
;
auto
rem
=
d2
/
n
;
state1
.
add_multi_axes
(
naxes
,
cd_dims
.
size
());
state2
.
add_axes
(
rem
==
1
?
naxes
:
naxes
+
1
,
cd_dims
.
size
());
state1
.
rem
=
rem
;
state2
.
rem
=
1
;
cd_dims
.
insert
(
cd_dims
.
end
(),
dims
.
begin
(),
dims
.
end
());
if
(
state1
.
rem
!=
1
)
cd_dims
.
push_back
(
state1
.
rem
);
state1
.
next
(
distance
(
dims
));
state2
.
next
();
return
true
;
}
common_dims
common_dims
::
compute
(
const
std
::
vector
<
std
::
size_t
>&
dims1
,
const
std
::
vector
<
std
::
size_t
>&
dims2
)
{
assert
(
elements
(
dims1
)
>
0
);
assert
(
elements
(
dims1
)
==
elements
(
dims2
));
common_dims
cd
;
common_dim_state
state1
{
dims1
,
cd
.
axes_map1
};
common_dim_state
state2
{
dims2
,
cd
.
axes_map2
};
while
(
not
state1
.
is_end
()
and
not
state2
.
is_end
())
{
auto
d1
=
state1
.
get
();
auto
d2
=
state2
.
get
();
if
(
d1
<=
d2
)
{
if
(
not
compute_common_dim
(
cd
.
dims
,
state1
,
state2
))
return
{};
}
else
// if(d1 > d2)
{
if
(
not
compute_common_dim
(
cd
.
dims
,
state2
,
state1
))
return
{};
}
}
assert
(
elements
(
dims1
)
==
elements
(
cd
.
dims
));
return
cd
;
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/fuse_pointwise.cpp
View file @
1d8840b8
...
@@ -24,11 +24,14 @@
...
@@ -24,11 +24,14 @@
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/common_dims.hpp>
#include <iterator>
#include <iterator>
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
)
...
@@ -189,6 +192,54 @@ static bool find_pointwise_modules(module& m)
...
@@ -189,6 +192,54 @@ static bool find_pointwise_modules(module& m)
}
}
return
changed
;
return
changed
;
}
}
namespace
{
struct
find_pointwise_reshape_pointwise
{
auto
matcher
()
const
{
auto
reshape
=
match
::
name
(
"reshape"
,
"squeeze"
,
"unsqueeze"
,
"flatten"
)(
match
::
used_once
());
auto
skip_contiguous
=
[](
auto
...
ms
)
{
return
match
::
arg
(
0
)(
match
::
skip
(
match
::
name
(
"contiguous"
)(
match
::
used_once
()))(
ms
...));
};
auto
pointwise
=
match
::
name
(
"pointwise"
)(
match
::
used_once
());
auto
reshape_pointwise
=
reshape
(
skip_contiguous
(
pointwise
.
bind
(
"x"
))).
bind
(
"reshape"
);
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
reshape_pointwise
));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
reshape_ins
=
r
.
instructions
[
"reshape"
];
auto
cd
=
common_dims
::
compute
(
ins
->
get_shape
().
lens
(),
x_ins
->
get_shape
().
lens
());
if
(
cd
.
dims
.
empty
())
return
;
auto
reshape_input
=
[
&
](
const
auto
&
ins_to_insert
)
{
return
[
&
](
auto
input
)
{
auto
c
=
m
.
insert_instruction
(
ins_to_insert
,
make_op
(
"contiguous"
),
input
);
return
m
.
insert_instruction
(
ins_to_insert
,
make_op
(
"reshape"
,
{{
"dims"
,
cd
.
dims
}}),
c
);
};
};
auto
x_inputs
=
x_ins
->
inputs
();
std
::
transform
(
x_inputs
.
begin
(),
x_inputs
.
end
(),
x_inputs
.
begin
(),
reshape_input
(
x_ins
));
auto
new_x_ins
=
m
.
insert_instruction
(
x_ins
,
x_ins
->
get_operator
(),
x_inputs
,
x_ins
->
module_inputs
());
auto
inputs
=
ins
->
inputs
();
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
if
(
input
==
reshape_ins
)
return
new_x_ins
;
return
reshape_input
(
ins
)(
input
);
});
auto
pw
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
inputs
,
ins
->
module_inputs
());
m
.
replace_instruction
(
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
ins
->
get_shape
().
lens
()}}),
pw
);
}
};
}
// namespace
void
fuse_pointwise
::
apply
(
module_pass_manager
&
mpm
)
const
void
fuse_pointwise
::
apply
(
module_pass_manager
&
mpm
)
const
{
{
...
@@ -200,6 +251,8 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const
...
@@ -200,6 +251,8 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const
}
}
for
(
int
i
=
0
;
i
<
8
;
i
++
)
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
{
match
::
find_matches
(
mpm
.
get_module
(),
find_pointwise_reshape_pointwise
{});
mpm
.
run_pass
(
simplify_reshapes
{
1
});
if
(
not
find_pointwise_modules
(
mpm
.
get_module
()))
if
(
not
find_pointwise_modules
(
mpm
.
get_module
()))
break
;
break
;
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
dead_code_elimination
{});
...
...
src/include/migraphx/common_dims.hpp
0 → 100644
View file @
1d8840b8
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
#include <migraphx/config.hpp>
#include <cstdint>
#include <vector>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
/// This will compute a higher dimensional space that will preserve the axes
/// for both sets of dimensions. Two axes_maps are provided for each of the
/// dims that will map the axis to the axes that are used by the result of
/// common_dims.
struct
MIGRAPHX_EXPORT
common_dims
{
static
common_dims
compute
(
const
std
::
vector
<
std
::
size_t
>&
dims1
,
const
std
::
vector
<
std
::
size_t
>&
dims2
);
std
::
vector
<
std
::
size_t
>
dims
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
axes_map1
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
axes_map2
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
src/include/migraphx/matcher.hpp
View file @
1d8840b8
...
@@ -623,6 +623,8 @@ MIGRAPHX_PRED_MATCHER(broadcast, instruction_ref ins)
...
@@ -623,6 +623,8 @@ MIGRAPHX_PRED_MATCHER(broadcast, instruction_ref ins)
template
<
class
...
Ms
>
template
<
class
...
Ms
>
auto
skip
(
Ms
...
ms
)
auto
skip
(
Ms
...
ms
)
{
{
static_assert
(((
not
std
::
is_convertible
<
Ms
,
std
::
string
>
{})
and
...),
"Use a matcher not a string for skip."
);
auto
m
=
any_of
(
ms
...);
auto
m
=
any_of
(
ms
...);
return
make_basic_fun_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
start
)
{
return
make_basic_fun_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
start
)
{
return
fix
<
optional
<
instruction_ref
>>
(
return
fix
<
optional
<
instruction_ref
>>
(
...
...
src/include/migraphx/simplify_reshapes.hpp
View file @
1d8840b8
...
@@ -38,6 +38,7 @@ struct module;
...
@@ -38,6 +38,7 @@ struct module;
*/
*/
struct
MIGRAPHX_EXPORT
simplify_reshapes
struct
MIGRAPHX_EXPORT
simplify_reshapes
{
{
size_t
depth
=
4
;
std
::
string
name
()
const
{
return
"simplify_reshapes"
;
}
std
::
string
name
()
const
{
return
"simplify_reshapes"
;
}
void
apply
(
module
&
m
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
...
...
src/simplify_reshapes.cpp
View file @
1d8840b8
...
@@ -784,7 +784,7 @@ struct find_transpose_slice
...
@@ -784,7 +784,7 @@ struct find_transpose_slice
void
simplify_reshapes
::
apply
(
module
&
m
)
const
void
simplify_reshapes
::
apply
(
module
&
m
)
const
{
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
for
(
int
i
=
0
;
i
<
depth
;
i
++
)
{
{
match
::
find_matches
(
m
,
match
::
find_matches
(
m
,
find_where_op
{},
find_where_op
{},
...
...
test/common_dims.cpp
0 → 100644
View file @
1d8840b8
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/common_dims.hpp>
#include <test.hpp>
using
axes_map
=
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
;
TEST_CASE
(
common_d1_less
)
{
auto
cd
=
migraphx
::
common_dims
::
compute
({
2
,
32
,
40
,
8
},
{
2
,
1280
,
8
});
EXPECT
(
cd
.
dims
==
std
::
vector
<
std
::
size_t
>
{
2
,
32
,
40
,
8
});
EXPECT
(
cd
.
axes_map1
==
axes_map
{{
0
},
{
1
},
{
2
},
{
3
}});
EXPECT
(
cd
.
axes_map2
==
axes_map
{{
0
},
{
1
,
2
},
{
3
}});
}
TEST_CASE
(
common1
)
{
auto
cd
=
migraphx
::
common_dims
::
compute
({
2
,
32
,
2560
},
{
2
,
1280
,
8
,
8
});
EXPECT
(
cd
.
dims
==
std
::
vector
<
std
::
size_t
>
{
2
,
32
,
40
,
8
,
8
});
EXPECT
(
cd
.
axes_map1
==
axes_map
{{
0
},
{
1
},
{
2
,
3
,
4
}});
EXPECT
(
cd
.
axes_map2
==
axes_map
{{
0
},
{
1
,
2
},
{
3
},
{
4
}});
}
TEST_CASE
(
common2
)
{
auto
cd
=
migraphx
::
common_dims
::
compute
({
2
,
1280
,
8
,
8
},
{
2
,
32
,
2560
});
EXPECT
(
cd
.
dims
==
std
::
vector
<
std
::
size_t
>
{
2
,
32
,
40
,
8
,
8
});
EXPECT
(
cd
.
axes_map1
==
axes_map
{{
0
},
{
1
,
2
},
{
3
},
{
4
}});
EXPECT
(
cd
.
axes_map2
==
axes_map
{{
0
},
{
1
},
{
2
,
3
,
4
}});
}
TEST_CASE
(
common_error1
)
{
auto
cd
=
migraphx
::
common_dims
::
compute
({
6
,
35
},
{
3
,
7
,
2
,
5
});
EXPECT
(
cd
.
dims
.
empty
());
}
TEST_CASE
(
common_error2
)
{
auto
cd
=
migraphx
::
common_dims
::
compute
({
3
,
7
,
2
,
5
},
{
6
,
35
});
EXPECT
(
cd
.
dims
.
empty
());
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/fuse_pointwise.cpp
View file @
1d8840b8
...
@@ -21,8 +21,9 @@
...
@@ -21,8 +21,9 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
...
@@ -361,4 +362,154 @@ TEST_CASE(no_input)
...
@@ -361,4 +362,154 @@ TEST_CASE(no_input)
EXPECT
(
p
==
p2
);
EXPECT
(
p
==
p2
);
}
}
TEST_CASE
(
add_reshape_add
)
{
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
3
,
10
,
16
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
float_type
,
{
3
,
40
,
2
,
2
}};
migraphx
::
shape
s3
{
migraphx
::
shape
::
float_type
,
{
3
,
10
,
4
,
2
,
2
}};
migraphx
::
program
p1
;
{
auto
*
mm
=
p1
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"x"
,
s1
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s1
);
auto
z
=
mm
->
add_parameter
(
"z"
,
s2
);
auto
add1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x
,
y
);
auto
reshape
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
s2
.
lens
()}}),
add1
);
auto
add2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
reshape
,
z
);
mm
->
add_return
({
add2
});
}
run_pass
(
p1
);
migraphx
::
program
p2
;
{
auto
*
mm
=
p2
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"x"
,
s1
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s1
);
auto
z
=
mm
->
add_parameter
(
"z"
,
s2
);
auto
x2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
s3
.
lens
()}}),
x
);
auto
y2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
s3
.
lens
()}}),
y
);
auto
z2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
s3
.
lens
()}}),
z
);
auto
fadd
=
add_pointwise
(
p2
,
"main:pointwise0"
,
{
x2
,
y2
,
z2
},
[
=
](
auto
*
pm
,
const
auto
&
inputs
)
{
auto
add1
=
pm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
inputs
[
0
],
inputs
[
1
]);
return
pm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
add1
,
inputs
[
2
]);
});
auto
reshape
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
s2
.
lens
()}}),
fadd
);
mm
->
add_return
({
reshape
});
}
EXPECT
(
p1
.
sort
()
==
p2
.
sort
());
}
TEST_CASE
(
add_reshape_add_nonstandard
)
{
migraphx
::
shape
s1
=
migraphx
::
shape
::
from_permutation
(
migraphx
::
shape
::
float_type
,
{
3
,
10
,
16
},
{
2
,
0
,
1
});
migraphx
::
shape
s2
{
migraphx
::
shape
::
float_type
,
{
3
,
40
,
2
,
2
}};
migraphx
::
shape
s3
{
migraphx
::
shape
::
float_type
,
{
3
,
10
,
4
,
2
,
2
}};
migraphx
::
program
p1
;
{
auto
*
mm
=
p1
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"x"
,
s1
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s1
);
auto
z
=
mm
->
add_parameter
(
"z"
,
s2
);
auto
add1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x
,
y
);
auto
c
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
add1
);
auto
reshape
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
s2
.
lens
()}}),
c
);
auto
add2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
reshape
,
z
);
mm
->
add_return
({
add2
});
}
run_pass
(
p1
);
migraphx
::
program
p2
;
{
auto
*
mm
=
p2
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"x"
,
s1
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s1
);
auto
z
=
mm
->
add_parameter
(
"z"
,
s2
);
auto
cx
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
x
);
auto
cy
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
y
);
auto
x2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
s3
.
lens
()}}),
cx
);
auto
y2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
s3
.
lens
()}}),
cy
);
auto
z2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
s3
.
lens
()}}),
z
);
auto
fadd
=
add_pointwise
(
p2
,
"main:pointwise0"
,
{
x2
,
y2
,
z2
},
[
=
](
auto
*
pm
,
const
auto
&
inputs
)
{
auto
add1
=
pm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
inputs
[
0
],
inputs
[
1
]);
return
pm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
add1
,
inputs
[
2
]);
});
auto
reshape
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
s2
.
lens
()}}),
fadd
);
mm
->
add_return
({
reshape
});
}
EXPECT
(
p1
.
sort
()
==
p2
.
sort
());
}
TEST_CASE
(
add_unsqueeze_add_nonstandard
)
{
migraphx
::
shape
s1
=
migraphx
::
shape
::
from_permutation
(
migraphx
::
shape
::
float_type
,
{
3
,
10
,
16
},
{
2
,
0
,
1
});
migraphx
::
shape
s2
{
migraphx
::
shape
::
float_type
,
{
3
,
10
,
1
,
16
}};
migraphx
::
program
p1
;
{
auto
*
mm
=
p1
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"x"
,
s1
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s1
);
auto
z
=
mm
->
add_parameter
(
"z"
,
s2
);
auto
add1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x
,
y
);
auto
unsqueeze
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}}}),
add1
);
auto
add2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
unsqueeze
,
z
);
mm
->
add_return
({
add2
});
}
run_pass
(
p1
);
migraphx
::
program
p2
;
{
auto
*
mm
=
p2
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"x"
,
s1
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s1
);
auto
z
=
mm
->
add_parameter
(
"z"
,
s2
);
auto
cx
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
x
);
auto
cy
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
y
);
auto
x2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
s2
.
lens
()}}),
cx
);
auto
y2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
s2
.
lens
()}}),
cy
);
auto
fadd
=
add_pointwise
(
p2
,
"main:pointwise0"
,
{
x2
,
y2
,
z
},
[
=
](
auto
*
pm
,
const
auto
&
inputs
)
{
auto
add1
=
pm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
inputs
[
0
],
inputs
[
1
]);
return
pm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
add1
,
inputs
[
2
]);
});
mm
->
add_return
({
fadd
});
}
EXPECT
(
p1
.
sort
()
==
p2
.
sort
());
}
TEST_CASE
(
add_reshape_add_error
)
{
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
6
,
35
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
float_type
,
{
3
,
7
,
2
,
5
}};
migraphx
::
program
p1
;
{
auto
*
mm
=
p1
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"x"
,
s1
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s1
);
auto
z
=
mm
->
add_parameter
(
"z"
,
s2
);
auto
add1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x
,
y
);
auto
reshape
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
s2
.
lens
()}}),
add1
);
auto
add2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
reshape
,
z
);
mm
->
add_return
({
add2
});
}
run_pass
(
p1
);
migraphx
::
program
p2
;
{
auto
*
mm
=
p2
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"x"
,
s1
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s1
);
auto
z
=
mm
->
add_parameter
(
"z"
,
s2
);
auto
fadd1
=
add_pointwise
(
p2
,
"main:pointwise0"
,
{
x
,
y
},
single_pointwise
(
"add"
));
auto
reshape
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
s2
.
lens
()}}),
fadd1
);
auto
fadd2
=
add_pointwise
(
p2
,
"main:pointwise1"
,
{
reshape
,
z
},
single_pointwise
(
"add"
));
mm
->
add_return
({
fadd2
});
}
EXPECT
(
p1
.
sort
()
==
p2
.
sort
());
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment