Commit c4cb5022 authored by Paul's avatar Paul
Browse files

Merge branch 'jit-improve' into jit-layernorm

parents 6ac586a9 a0edd061
...@@ -1141,6 +1141,38 @@ TEST_CASE(transpose_contiguous_reshape_binary_broadcast) ...@@ -1141,6 +1141,38 @@ TEST_CASE(transpose_contiguous_reshape_binary_broadcast)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(transpose_unsqueeze_concat)
{
migraphx::module m1;
{
auto l0 = m1.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt0 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
auto l1 = m1.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt1 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l1);
auto l2 = m1.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt2 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l2);
std::vector<migraphx::instruction_ref> args{lt0, lt1, lt2};
std::vector<migraphx::instruction_ref> unsqueezed_args;
int64_t axis = 3;
std::transform(
args.begin(),
args.end(),
std::back_inserter(unsqueezed_args),
[&](migraphx::instruction_ref arg) {
return m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {axis}}}), arg);
});
m1.add_instruction(migraphx::make_op("concat", {{"axis", axis}}), unsqueezed_args);
}
// TODO: This could be simplified to a single transpose after concat
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(transpose_slice) TEST_CASE(transpose_slice)
{ {
migraphx::module m1; migraphx::module m1;
......
...@@ -37,6 +37,8 @@ ...@@ -37,6 +37,8 @@
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/support_metric.hpp>
#include <migraphx/instruction_ref.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -61,6 +63,13 @@ struct target ...@@ -61,6 +63,13 @@ struct target
* @return The context to be used during compilation and execution. * @return The context to be used during compilation and execution.
*/ */
context get_context() const; context get_context() const;
/**
* @brief Check how well an instruction is supported on a target with the given metric
* @param ins Instruction to check if it's supported
* @param metric Used to define how the return value should be interpreted
* @return The value based on the chosen metric. Negative numbers mean unsupported
*/
float is_supported(T&, instruction_ref ins, support_metric m) const;
/** /**
* @brief copy an argument to the current target. * @brief copy an argument to the current target.
* *
...@@ -105,11 +114,18 @@ argument copy_from_target(T&, const argument& arg) ...@@ -105,11 +114,18 @@ argument copy_from_target(T&, const argument& arg)
return arg; return arg;
} }
template <class T>
float target_is_supported(T&, instruction_ref, support_metric)
{
return 0;
}
<% <%
interface('target', interface('target',
virtual('name', returns='std::string', const=True), virtual('name', returns='std::string', const=True),
virtual('get_passes', ctx='context&', options='const compile_options&', returns='std::vector<pass>', const=True), virtual('get_passes', ctx='context&', options='const compile_options&', returns='std::vector<pass>', const=True),
virtual('get_context', returns='context', const=True), virtual('get_context', returns='context', const=True),
virtual('is_supported', returns='float', ins='instruction_ref', m='support_metric', const=True, default='target_is_supported'),
virtual('copy_to', virtual('copy_to',
returns = 'argument', returns = 'argument',
input = 'const argument&', input = 'const argument&',
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
##################################################################################### #####################################################################################
import string, sys, re import string, sys, re
trivial = ['std::size_t', 'instruction_ref'] trivial = ['std::size_t', 'instruction_ref', 'support_metric']
headers = ''' headers = '''
#include <algorithm> #include <algorithm>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment