Commit 3a4d36cf authored by charlie's avatar charlie
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_model_test

parents 6bec381f e19f78ae
......@@ -21,43 +21,26 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_CONVERT_HPP
#define MIGRAPHX_GUARD_RTGLIB_CONVERT_HPP
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/convert.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_convert
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct test_unbatched_gemm_2 : verify_program<test_unbatched_gemm_2>
{
op::convert op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::convert"; }
shape compute_shape(std::vector<shape> inputs) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
migraphx::program create_program() const
{
return shapes.size() - 1;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {4, 32, 64}};
migraphx::shape m2_shape{migraphx::shape::float_type, {64, 64}};
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_literal(migraphx::generate_literal(m2_shape));
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 64, 64}}}),
l2);
mm->add_instruction(migraphx::make_op("dot"), l1, l2);
return p;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -560,7 +560,7 @@ lifetime get_lifetime_op(const T&)
inline bool operator!=(const operation& x, const operation& y)
{
return !(x == y);
return not(x == y);
}
inline value
......
......@@ -37,8 +37,10 @@
#include <migraphx/compile_options.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/module_ref.hpp>
#include <migraphx/support_metric.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/supported_segments.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -64,12 +66,12 @@ struct target
*/
context get_context() const;
/**
* @brief Check how well an instruction is supported on a target with the given metric
* @param ins Instruction to check if it's supported
* @param metric Used to define how the return value should be interpreted
* @return The value based on the chosen metric. Negative numbers mean unsupported
* @brief Get the ranges of instructions that are supported on a target
* @param module Module to check for supported instructions
* @param metric Used to define how the quality of the support should be measured
* @return the supported segments of the graph
*/
float is_supported(T&, instruction_ref ins, support_metric m) const;
supported_segments target_is_supported(T&, const_module_ref mod, support_metric metric) const;
/**
* @brief copy an argument to the current target.
*
......@@ -115,9 +117,9 @@ argument copy_from_target(T&, const argument& arg)
}
template <class T>
float target_is_supported(T&, instruction_ref, support_metric)
supported_segments target_find_supported(T&, const_module_ref, support_metric)
{
return 0;
return {};
}
<%
......@@ -125,7 +127,7 @@ interface('target',
virtual('name', returns='std::string', const=True),
virtual('get_passes', ctx='context&', options='const compile_options&', returns='std::vector<pass>', const=True),
virtual('get_context', returns='context', const=True),
virtual('is_supported', returns='float', ins='instruction_ref', m='support_metric', const=True, default='target_is_supported'),
virtual('find_supported', returns='supported_segments', mod='const_module_ref', m='support_metric', const=True, default='target_find_supported'),
virtual('copy_to',
returns = 'argument',
input = 'const argument&',
......
......@@ -23,7 +23,9 @@
#####################################################################################
import string, sys, re
trivial = ['std::size_t', 'instruction_ref', 'support_metric']
trivial = [
'std::size_t', 'instruction_ref', 'support_metric', 'const_module_ref'
]
headers = '''
#include <algorithm>
......@@ -134,7 +136,7 @@ private:
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type (PrivateDetailTypeErasedT value,
typename std::enable_if<
!std::is_reference<PrivateDetailTypeErasedU>::value,
not std::is_reference<PrivateDetailTypeErasedU>::value,
int
>::type * = nullptr) noexcept :
private_detail_te_value (std::move(value))
......@@ -176,7 +178,7 @@ private:
private_detail_te_handle_base_type & private_detail_te_get_handle ()
{
assert(private_detail_te_handle_mem_var != nullptr);
if (!private_detail_te_handle_mem_var.unique())
if (not private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment