Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c4cb5022
Commit
c4cb5022
authored
Jul 11, 2022
by
Paul
Browse files
Merge branch 'jit-improve' into jit-layernorm
parents
6ac586a9
a0edd061
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
49 additions
and
1 deletion
+49
-1
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+32
-0
tools/include/target.hpp
tools/include/target.hpp
+16
-0
tools/te.py
tools/te.py
+1
-1
No files found.
test/simplify_reshapes_test.cpp
View file @
c4cb5022
...
@@ -1141,6 +1141,38 @@ TEST_CASE(transpose_contiguous_reshape_binary_broadcast)
...
@@ -1141,6 +1141,38 @@ TEST_CASE(transpose_contiguous_reshape_binary_broadcast)
EXPECT
(
m1
==
m2
);
EXPECT
(
m1
==
m2
);
}
}
TEST_CASE
(
transpose_unsqueeze_concat
)
{
migraphx
::
module
m1
;
{
auto
l0
=
m1
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
1
,
1
}});
auto
lt0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
l0
);
auto
l1
=
m1
.
add_parameter
(
"1"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
1
,
1
}});
auto
lt1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
l1
);
auto
l2
=
m1
.
add_parameter
(
"2"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
1
,
1
}});
auto
lt2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
l2
);
std
::
vector
<
migraphx
::
instruction_ref
>
args
{
lt0
,
lt1
,
lt2
};
std
::
vector
<
migraphx
::
instruction_ref
>
unsqueezed_args
;
int64_t
axis
=
3
;
std
::
transform
(
args
.
begin
(),
args
.
end
(),
std
::
back_inserter
(
unsqueezed_args
),
[
&
](
migraphx
::
instruction_ref
arg
)
{
return
m1
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
axis
}}}),
arg
);
});
m1
.
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
axis
}}),
unsqueezed_args
);
}
// TODO: This could be simplified to a single transpose after concat
migraphx
::
module
m2
=
m1
;
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
transpose_slice
)
TEST_CASE
(
transpose_slice
)
{
{
migraphx
::
module
m1
;
migraphx
::
module
m1
;
...
...
tools/include/target.hpp
View file @
c4cb5022
...
@@ -37,6 +37,8 @@
...
@@ -37,6 +37,8 @@
#include <migraphx/compile_options.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/support_metric.hpp>
#include <migraphx/instruction_ref.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -61,6 +63,13 @@ struct target
...
@@ -61,6 +63,13 @@ struct target
* @return The context to be used during compilation and execution.
* @return The context to be used during compilation and execution.
*/
*/
context
get_context
()
const
;
context
get_context
()
const
;
/**
* @brief Check how well an instruction is supported on a target with the given metric
* @param ins Instruction to check if it's supported
* @param metric Used to define how the return value should be interpreted
* @return The value based on the chosen metric. Negative numbers mean unsupported
*/
float
is_supported
(
T
&
,
instruction_ref
ins
,
support_metric
m
)
const
;
/**
/**
* @brief copy an argument to the current target.
* @brief copy an argument to the current target.
*
*
...
@@ -105,11 +114,18 @@ argument copy_from_target(T&, const argument& arg)
...
@@ -105,11 +114,18 @@ argument copy_from_target(T&, const argument& arg)
return
arg
;
return
arg
;
}
}
template
<
class
T
>
float
target_is_supported
(
T
&
,
instruction_ref
,
support_metric
)
{
return
0
;
}
<%
<%
interface
(
'
target
'
,
interface
(
'
target
'
,
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
get_passes
'
,
ctx
=
'
context
&
'
,
options
=
'
const
compile_options
&
'
,
returns
=
'
std
::
vector
<
pass
>
'
,
const
=
True
),
virtual
(
'
get_passes
'
,
ctx
=
'
context
&
'
,
options
=
'
const
compile_options
&
'
,
returns
=
'
std
::
vector
<
pass
>
'
,
const
=
True
),
virtual
(
'
get_context
'
,
returns
=
'
context
'
,
const
=
True
),
virtual
(
'
get_context
'
,
returns
=
'
context
'
,
const
=
True
),
virtual
(
'
is_supported
'
,
returns
=
'
float
'
,
ins
=
'
instruction_ref
'
,
m
=
'
support_metric
'
,
const
=
True
,
default
=
'
target_is_supported
'
),
virtual
(
'
copy_to
'
,
virtual
(
'
copy_to
'
,
returns
=
'
argument
'
,
returns
=
'
argument
'
,
input
=
'
const
argument
&
'
,
input
=
'
const
argument
&
'
,
...
...
tools/te.py
View file @
c4cb5022
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
#####################################################################################
#####################################################################################
import
string
,
sys
,
re
import
string
,
sys
,
re
trivial
=
[
'std::size_t'
,
'instruction_ref'
]
trivial
=
[
'std::size_t'
,
'instruction_ref'
,
'support_metric'
]
headers
=
'''
headers
=
'''
#include <algorithm>
#include <algorithm>
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment