Commit 264a7647 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

Merge branch 'develop' into multinomial_parse_merge

parents d99729f8 8e18544f
...@@ -45,7 +45,7 @@ namespace migraphx { ...@@ -45,7 +45,7 @@ namespace migraphx {
${preamble} ${preamble}
extern "C" { extern "C" {
__global__ void reduce_kernel(void* input_p, void* output_p) MIGRAPHX_GLOBAL void reduce_kernel(void* input_p, void* output_p)
{ {
transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) { transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) {
......
...@@ -41,7 +41,7 @@ namespace migraphx { ...@@ -41,7 +41,7 @@ namespace migraphx {
extern "C" { extern "C" {
__global__ void roialign_kernel(void* in_x, void* in_rois, void* in_ind, void* y) MIGRAPHX_GLOBAL void roialign_kernel(void* in_x, void* in_rois, void* in_ind, void* y)
{ {
make_tensors()(in_x, in_rois, in_ind, y)([](auto&&... xs) { make_tensors()(in_x, in_rois, in_ind, y)([](auto&&... xs) {
auto settings = make_roalign_settings(MIGRAPHX_MAKE_CONSTANT(float{ROIS_OFFSET}), auto settings = make_roalign_settings(MIGRAPHX_MAKE_CONSTANT(float{ROIS_OFFSET}),
......
...@@ -42,7 +42,7 @@ namespace migraphx { ...@@ -42,7 +42,7 @@ namespace migraphx {
extern "C" { extern "C" {
__global__ void scatternd_kernel(void* in_indices, void* in_updates, void* output) MIGRAPHX_GLOBAL void scatternd_kernel(void* in_indices, void* in_updates, void* output)
{ {
make_tensors()(in_indices, in_updates, output)([](auto&&... xs) { make_tensors()(in_indices, in_updates, output)([](auto&&... xs) {
scatternd(xs..., ${reduction}{}); scatternd(xs..., ${reduction}{});
......
...@@ -45,7 +45,7 @@ static const char* const softmax_kernel = R"__migraphx__( ...@@ -45,7 +45,7 @@ static const char* const softmax_kernel = R"__migraphx__(
namespace migraphx { namespace migraphx {
extern "C" { extern "C" {
__global__ void softmax_kernel(void* input_p, void* output_p) MIGRAPHX_GLOBAL void softmax_kernel(void* input_p, void* output_p)
{ {
transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) { transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) {
softmax<${axis}>(input, output); softmax<${axis}>(input, output);
......
...@@ -122,12 +122,14 @@ struct source_location_capture ...@@ -122,12 +122,14 @@ struct source_location_capture
{ {
T x; T x;
source_location loc; source_location loc;
template <class U, class = decltype(T(U{}))> // declval is a workaround since default constructor for "U" is not working with rocm-5.6
template <class U>
static U&& declval();
template <class U, class = decltype(T(declval<U>()))>
constexpr source_location_capture(U px, source_location ploc = source_location{}) constexpr source_location_capture(U px, source_location ploc = source_location{})
: x(px), loc(ploc) : x(px), loc(ploc)
{ {
} }
constexpr operator source_location() const { return loc; } constexpr operator source_location() const { return loc; }
constexpr operator T() const { return x; } constexpr operator T() const { return x; }
......
...@@ -106,7 +106,7 @@ struct miopen_apply ...@@ -106,7 +106,7 @@ struct miopen_apply
add_extend_op("topk"); add_extend_op("topk");
add_convolution_op("convolution"); add_convolution_op("convolution");
add_convolution_op("deconvolution"); add_convolution_op("convolution_backwards");
add_convolution_op("quant_convolution"); add_convolution_op("quant_convolution");
add_gemm_op<op::dot>("dot"); add_gemm_op<op::dot>("dot");
add_gemm_op<op::quant_dot>("quant_dot"); add_gemm_op<op::quant_dot>("quant_dot");
......
...@@ -389,14 +389,20 @@ struct mlir_program ...@@ -389,14 +389,20 @@ struct mlir_program
mlir_operation_state& add_attributes(const std::vector<named_attribute_t>& named_attrs) mlir_operation_state& add_attributes(const std::vector<named_attribute_t>& named_attrs)
{ {
auto attributes = prog->name_attributes(named_attrs); auto attributes = prog->name_attributes(named_attrs);
mlirOperationStateAddAttributes(&op_state, attributes.size(), attributes.data()); if(not attributes.empty())
{
mlirOperationStateAddAttributes(&op_state, attributes.size(), attributes.data());
}
return *this; return *this;
} }
mlir_operation_state& add_attribute_value(const value& v) mlir_operation_state& add_attribute_value(const value& v)
{ {
auto attributes = prog->name_attributes(v); auto attributes = prog->name_attributes(v);
mlirOperationStateAddAttributes(&op_state, attributes.size(), attributes.data()); if(not attributes.empty())
{
mlirOperationStateAddAttributes(&op_state, attributes.size(), attributes.data());
}
return *this; return *this;
} }
...@@ -419,13 +425,19 @@ struct mlir_program ...@@ -419,13 +425,19 @@ struct mlir_program
return shape{r.type(), r.lens()}; return shape{r.type(), r.lens()};
}); });
auto x = prog->make_tensors(reshaped); auto x = prog->make_tensors(reshaped);
mlirOperationStateAddResults(&op_state, x.size(), x.data()); if(not x.empty())
{
mlirOperationStateAddResults(&op_state, x.size(), x.data());
}
return *this; return *this;
} }
mlir_operation_state& add_operands(const std::vector<MlirValue>& inputs) mlir_operation_state& add_operands(const std::vector<MlirValue>& inputs)
{ {
mlirOperationStateAddOperands(&op_state, inputs.size(), inputs.data()); if(not inputs.empty())
{
mlirOperationStateAddOperands(&op_state, inputs.size(), inputs.data());
}
return *this; return *this;
} }
...@@ -435,7 +447,10 @@ struct mlir_program ...@@ -435,7 +447,10 @@ struct mlir_program
std::transform(regions.begin(), regions.end(), mregions.begin(), [](const auto& r) { std::transform(regions.begin(), regions.end(), mregions.begin(), [](const auto& r) {
return r.get(); return r.get();
}); });
mlirOperationStateAddOwnedRegions(&op_state, mregions.size(), mregions.data()); if(not mregions.empty())
{
mlirOperationStateAddOwnedRegions(&op_state, mregions.size(), mregions.data());
}
mlir_operation op(mlirOperationCreate(&op_state)); mlir_operation op(mlirOperationCreate(&op_state));
// Release memory since mlir_operation owns it // Release memory since mlir_operation owns it
for(auto& r : regions) for(auto& r : regions)
...@@ -607,12 +622,12 @@ struct mlir_program ...@@ -607,12 +622,12 @@ struct mlir_program
mlir_pass_manager pm_back{mlirPassManagerCreate(ctx.get())}; mlir_pass_manager pm_back{mlirPassManagerCreate(ctx.get())};
// 1st pipeline to call // 1st pipeline to call
mlirMIGraphXAddHighLevelPipeline(pm_front.get()); mlirMIGraphXAddHighLevelPipeline(pm_front.get());
mlirPassManagerRun(pm_front.get(), mmodule.get()); mlirPassManagerRunOnOp(pm_front.get(), mlirModuleGetOperation(mmodule.get()));
// 2nd pipeline to call // 2nd pipeline to call
get_module_tuned(); get_module_tuned();
mlirMIGraphXAddBackendPipeline(pm_back.get(), target_arch.c_str()); mlirMIGraphXAddBackendPipeline(pm_back.get(), target_arch.c_str());
mlirPassManagerRun(pm_back.get(), mmodule.get()); mlirPassManagerRunOnOp(pm_back.get(), mlirModuleGetOperation(mmodule.get()));
code_object_op op{}; code_object_op op{};
op.symbol_name = sym_name; op.symbol_name = sym_name;
...@@ -701,6 +716,11 @@ struct mlir_program ...@@ -701,6 +716,11 @@ struct mlir_program
bool get_module_tuned() const bool get_module_tuned() const
{ {
static mlir_tuning_table tuning_table = create_tuning_table(); static mlir_tuning_table tuning_table = create_tuning_table();
// The tuning table as currently implemented is currently not
// thread safe. This will be fixed in the future. For now,
// stick a mutex around all tuning table interaction.
static std::mutex lock;
std::lock_guard<std::mutex> guard(lock);
if(!mlirRockTuningSetFromTable(tuning_table.get(), mmodule.get())) if(!mlirRockTuningSetFromTable(tuning_table.get(), mmodule.get()))
{ {
const char* prob_config = mlirRockTuningGetKey(tuning_table.get(), mmodule.get()); const char* prob_config = mlirRockTuningGetKey(tuning_table.get(), mmodule.get());
...@@ -778,9 +798,6 @@ code_object_op compile_mlir(const context&, module m, const std::vector<instruct ...@@ -778,9 +798,6 @@ code_object_op compile_mlir(const context&, module m, const std::vector<instruct
{ {
adjust_param_shapes(m, inputs); adjust_param_shapes(m, inputs);
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{}); const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});
// set mutex while llvm thread support is disabled.
static std::mutex g_mlirc_mutex; // NOLINT
const std::lock_guard<std::mutex> lock(g_mlirc_mutex);
if(trace) if(trace)
std::cout << m << std::endl; std::cout << m << std::endl;
......
...@@ -55,9 +55,16 @@ bool get_compute_fp32_flag() ...@@ -55,9 +55,16 @@ bool get_compute_fp32_flag()
bool get_int8_x4_format(context& ctx) bool get_int8_x4_format(context& ctx)
{ {
#if ROCBLAS_VERSION_MAJOR >= 3
(void)(ctx);
return false;
#else
// int8x4 packed format is only available starting from rocblas-v2.38 and it is deprecated in
// v3.0 and will be removed in v4.0
rocblas_gemm_flags flag; rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag); rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
return flag == rocblas_gemm_flags_pack_int8x4; return flag == rocblas_gemm_flags_pack_int8x4;
#endif
} }
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -75,7 +75,9 @@ namespace gpu { ...@@ -75,7 +75,9 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC)
#ifdef _WIN32
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK)
#endif
struct id_pass struct id_pass
{ {
...@@ -136,7 +138,9 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -136,7 +138,9 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}), enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}),
dead_code_elimination{}, dead_code_elimination{},
#ifdef _WIN32
enable_pass(enabled(MIGRAPHX_ENABLE_CK{}), fuse_ck{}), enable_pass(enabled(MIGRAPHX_ENABLE_CK{}), fuse_ck{}),
#endif
dead_code_elimination{}, dead_code_elimination{},
enable_pass(mlir_enabled(), fuse_mlir{&ctx}), enable_pass(mlir_enabled(), fuse_mlir{&ctx}),
dead_code_elimination{}, dead_code_elimination{},
......
...@@ -37,6 +37,8 @@ target_link_libraries(migraphx_ref PUBLIC migraphx) ...@@ -37,6 +37,8 @@ target_link_libraries(migraphx_ref PUBLIC migraphx)
target_include_directories(migraphx_ref PRIVATE ${BLAZE_INCLUDE}) target_include_directories(migraphx_ref PRIVATE ${BLAZE_INCLUDE})
target_compile_definitions(migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS) target_compile_definitions(migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS)
migraphx_generate_export_header(migraphx_ref)
rocm_install_targets( rocm_install_targets(
TARGETS migraphx_ref TARGETS migraphx_ref
INCLUDE INCLUDE
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_RTGLIB_CONTEXT_HPP #define MIGRAPHX_GUARD_RTGLIB_CONTEXT_HPP
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/ref/export.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -24,14 +24,14 @@ ...@@ -24,14 +24,14 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_CPU_LOWERING_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_CPU_LOWERING_HPP
#define MIGRAPHX_GUARD_RTGLIB_CPU_LOWERING_HPP #define MIGRAPHX_GUARD_RTGLIB_CPU_LOWERING_HPP
#include <migraphx/ref/context.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace ref { namespace ref {
struct lowering struct MIGRAPHX_REF_EXPORT lowering
{ {
std::string name() const { return "ref::lowering"; } std::string name() const { return "ref::lowering"; }
void apply(module& m) const; void apply(module& m) const;
......
...@@ -35,7 +35,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -35,7 +35,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct pass; struct pass;
namespace ref { namespace ref {
struct target struct MIGRAPHX_REF_EXPORT target
{ {
std::string name() const; std::string name() const;
std::vector<pass> get_passes(migraphx::context& ctx, const compile_options&) const; std::vector<pass> get_passes(migraphx::context& ctx, const compile_options&) const;
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/op/deconvolution.hpp> #include <migraphx/op/convolution_backwards.hpp>
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/dot.hpp> #include <migraphx/op/dot.hpp>
#include <migraphx/op/quant_dot.hpp> #include <migraphx/op/quant_dot.hpp>
......
...@@ -42,8 +42,9 @@ target_compile_options(tf-proto PRIVATE -w) ...@@ -42,8 +42,9 @@ target_compile_options(tf-proto PRIVATE -w)
target_link_libraries(tf-proto PRIVATE ${PROTOBUF_LIBRARY}) target_link_libraries(tf-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On) set_target_properties(tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
file(GLOB TF_SRCS ${CONFIGURE_DEPENDS} *.cpp) file(GLOB TF_SRCS CONFIGURE_DEPENDS *.cpp)
add_library(migraphx_tf ${TF_SRCS}) add_library(migraphx_tf ${TF_SRCS})
migraphx_generate_export_header(migraphx_tf)
target_include_directories(migraphx_tf PRIVATE include) target_include_directories(migraphx_tf PRIVATE include)
set_target_properties(migraphx_tf PROPERTIES EXPORT_NAME tf) set_target_properties(migraphx_tf PROPERTIES EXPORT_NAME tf)
rocm_set_soversion(migraphx_tf ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx_tf ${MIGRAPHX_SO_VERSION})
......
...@@ -52,7 +52,6 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -52,7 +52,6 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
auto x_type = args[0]->get_shape().type(); auto x_type = args[0]->get_shape().type();
// unsqueeze tensors of shape (C) to broadcast correctly // unsqueeze tensors of shape (C) to broadcast correctly
auto rt = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0.5}});
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}}); auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}});
auto scale_unsqueeze = auto scale_unsqueeze =
...@@ -64,11 +63,11 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -64,11 +63,11 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
auto var_unsqueeze = auto var_unsqueeze =
info.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), args[4]); info.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), args[4]);
auto numer = info.add_broadcastable_binary_op("sub", args[0], mean_unsqueeze); auto x_sub_mean = info.add_broadcastable_binary_op("sub", args[0], mean_unsqueeze);
auto var_eps = info.add_broadcastable_binary_op("add", var_unsqueeze, eps); auto var_eps = info.add_broadcastable_binary_op("add", var_unsqueeze, eps);
auto denom = info.add_broadcastable_binary_op("pow", var_eps, rt); auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps);
auto div0 = info.add_broadcastable_binary_op("div", numer, denom); auto mul0 = info.add_broadcastable_binary_op("mul", scale_unsqueeze, rsqrt);
auto r0 = info.add_broadcastable_binary_op("mul", div0, scale_unsqueeze); auto r0 = info.add_broadcastable_binary_op("mul", x_sub_mean, mul0);
return info.add_broadcastable_binary_op("add", r0, bias_unsqueeze); return info.add_broadcastable_binary_op("add", r0, bias_unsqueeze);
} }
}; };
......
...@@ -35,7 +35,7 @@ bool verify_args(const std::string& name, ...@@ -35,7 +35,7 @@ bool verify_args(const std::string& name,
bool passed = true; bool passed = true;
visit_all(ref_arg, target_arg)([&](auto ref, auto target) { visit_all(ref_arg, target_arg)([&](auto ref, auto target) {
double error; double error;
passed = verify_range(ref, target, tolerance, &error); passed = verify::verify_range(ref, target, tolerance, &error);
if(not passed) if(not passed)
{ {
// TODO: Check for nans // TODO: Check for nans
...@@ -45,27 +45,27 @@ bool verify_args(const std::string& name, ...@@ -45,27 +45,27 @@ bool verify_args(const std::string& name,
std::cout << "ref:" << ref << std::endl; std::cout << "ref:" << ref << std::endl;
if(target.size() < 32) if(target.size() < 32)
std::cout << "target:" << target << std::endl; std::cout << "target:" << target << std::endl;
if(range_zero(ref)) if(verify::range_zero(ref))
std::cout << "Ref data is all zeros" << std::endl; std::cout << "Ref data is all zeros" << std::endl;
if(range_zero(target)) if(verify::range_zero(target))
std::cout << "Target data is all zeros" << std::endl; std::cout << "Target data is all zeros" << std::endl;
auto mxdiff = max_diff(ref, target); auto mxdiff = verify::max_diff(ref, target);
std::cout << "Max diff: " << mxdiff << std::endl; std::cout << "Max diff: " << mxdiff << std::endl;
auto idx = mismatch_idx(ref, target, float_equal); auto idx = verify::mismatch_idx(ref, target, float_equal);
if(idx < range_distance(ref)) if(idx < verify::range_distance(ref))
{ {
std::cout << "Mismatch at " << idx << ": " << ref[idx] << " != " << target[idx] std::cout << "Mismatch at " << idx << ": " << ref[idx] << " != " << target[idx]
<< std::endl; << std::endl;
} }
auto ref_nan_idx = find_idx(ref, not_finite); auto ref_nan_idx = find_idx(ref, verify::not_finite);
if(ref_nan_idx >= 0) if(ref_nan_idx >= 0)
std::cout << "Non finite number found in ref at " << ref_nan_idx << ": " std::cout << "Non finite number found in ref at " << ref_nan_idx << ": "
<< ref[ref_nan_idx] << std::endl; << ref[ref_nan_idx] << std::endl;
auto target_nan_idx = find_idx(target, not_finite); auto target_nan_idx = find_idx(target, verify::not_finite);
if(target_nan_idx >= 0) if(target_nan_idx >= 0)
std::cout << "Non finite number found in target at " << target_nan_idx << ": " std::cout << "Non finite number found in target at " << target_nan_idx << ": "
<< target[target_nan_idx] << std::endl; << target[target_nan_idx] << std::endl;
...@@ -73,27 +73,27 @@ bool verify_args(const std::string& name, ...@@ -73,27 +73,27 @@ bool verify_args(const std::string& name,
} }
else else
{ {
if(range_zero(ref)) if(verify::range_zero(ref))
std::cout << "Ref data is all zeros" << std::endl; std::cout << "Ref data is all zeros" << std::endl;
if(range_zero(target)) if(verify::range_zero(target))
std::cout << "Target data is all zeros" << std::endl; std::cout << "Target data is all zeros" << std::endl;
// auto mxdiff = max_diff(ref, target); // auto mxdiff = max_diff(ref, target);
// std::cout << "Max diff: " << mxdiff << std::endl; // std::cout << "Max diff: " << mxdiff << std::endl;
// auto idx = mismatch_idx(ref, target, float_equal); // auto idx = mismatch_idx(ref, target, float_equal);
// if(idx < range_distance(ref)) // if(idx < verify::range_distance(ref))
// { // {
// std::cout << "Mismatch at " << idx << ": " << ref[idx] << " != " << target[idx] // std::cout << "Mismatch at " << idx << ": " << ref[idx] << " != " << target[idx]
// << std::endl; // << std::endl;
// } // }
auto ref_nan_idx = find_idx(ref, not_finite); auto ref_nan_idx = find_idx(ref, verify::not_finite);
if(ref_nan_idx >= 0) if(ref_nan_idx >= 0)
std::cout << "Non finite number found in ref at " << ref_nan_idx << ": " std::cout << "Non finite number found in ref at " << ref_nan_idx << ": "
<< ref[ref_nan_idx] << std::endl; << ref[ref_nan_idx] << std::endl;
auto target_nan_idx = find_idx(target, not_finite); auto target_nan_idx = find_idx(target, verify::not_finite);
if(target_nan_idx >= 0) if(target_nan_idx >= 0)
std::cout << "Non finite number found in target at " << target_nan_idx << ": " std::cout << "Non finite number found in target at " << target_nan_idx << ": "
<< target[target_nan_idx] << std::endl; << target[target_nan_idx] << std::endl;
......
...@@ -112,7 +112,7 @@ function(add_test_executable TEST_NAME) ...@@ -112,7 +112,7 @@ function(add_test_executable TEST_NAME)
target_include_directories(${TEST_NAME} PUBLIC include) target_include_directories(${TEST_NAME} PUBLIC include)
endfunction(add_test_executable) endfunction(add_test_executable)
file(GLOB TESTS ${CONFIGURE_DEPENDS} *.cpp) file(GLOB TESTS CONFIGURE_DEPENDS *.cpp)
foreach(TEST ${TESTS}) foreach(TEST ${TESTS})
get_filename_component(BASE_NAME ${TEST} NAME_WE) get_filename_component(BASE_NAME ${TEST} NAME_WE)
...@@ -122,7 +122,7 @@ endforeach() ...@@ -122,7 +122,7 @@ endforeach()
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
# gpu tests # gpu tests
file(GLOB GPU_TESTS ${CONFIGURE_DEPENDS} gpu/*.cpp) file(GLOB GPU_TESTS CONFIGURE_DEPENDS gpu/*.cpp)
foreach(TEST ${GPU_TESTS}) foreach(TEST ${GPU_TESTS})
get_filename_component(BASE_NAME ${TEST} NAME_WE) get_filename_component(BASE_NAME ${TEST} NAME_WE)
...@@ -141,7 +141,7 @@ endif() ...@@ -141,7 +141,7 @@ endif()
if(MIGRAPHX_ENABLE_FPGA) if(MIGRAPHX_ENABLE_FPGA)
# fpga tests # fpga tests
file(GLOB FPGA_TESTS ${CONFIGURE_DEPENDS} fpga/*.cpp) file(GLOB FPGA_TESTS CONFIGURE_DEPENDS fpga/*.cpp)
foreach(TEST ${FPGA_TESTS}) foreach(TEST ${FPGA_TESTS})
get_filename_component(BASE_NAME ${TEST} NAME_WE) get_filename_component(BASE_NAME ${TEST} NAME_WE)
...@@ -191,7 +191,7 @@ endif() ...@@ -191,7 +191,7 @@ endif()
# multitarget test # multitarget test
if(MIGRAPHX_ENABLE_GPU AND MIGRAPHX_ENABLE_CPU AND MIGRAPHX_ENABLE_FPGA) if(MIGRAPHX_ENABLE_GPU AND MIGRAPHX_ENABLE_CPU AND MIGRAPHX_ENABLE_FPGA)
set(TEST_MULTI_TARGET_DIR ${CMAKE_CURRENT_SOURCE_DIR}/multi_target) set(TEST_MULTI_TARGET_DIR ${CMAKE_CURRENT_SOURCE_DIR}/multi_target)
file(GLOB MULTI_TARGET_TESTS ${CONFIGURE_DEPENDS} ${TEST_MULTI_TARGET_DIR}/*.cpp) file(GLOB MULTI_TARGET_TESTS CONFIGURE_DEPENDS ${TEST_MULTI_TARGET_DIR}/*.cpp)
foreach(MULTI_TARGET_TEST ${MULTI_TARGET_TESTS}) foreach(MULTI_TARGET_TEST ${MULTI_TARGET_TESTS})
get_filename_component(BASE_NAME ${MULTI_TARGET_TEST} NAME_WE) get_filename_component(BASE_NAME ${MULTI_TARGET_TEST} NAME_WE)
...@@ -221,14 +221,14 @@ function(test_header NAME HEADER) ...@@ -221,14 +221,14 @@ function(test_header NAME HEADER)
endfunction() endfunction()
function(test_headers PREFIX) function(test_headers PREFIX)
file(GLOB HEADERS ${CONFIGURE_DEPENDS} ${ARGN}) file(GLOB HEADERS CONFIGURE_DEPENDS ${ARGN})
foreach(HEADER ${HEADERS}) foreach(HEADER ${HEADERS})
file(RELATIVE_PATH HEADER_REL ${CMAKE_SOURCE_DIR} ${HEADER}) file(RELATIVE_PATH HEADER_REL ${CMAKE_SOURCE_DIR} ${HEADER})
string(MAKE_C_IDENTIFIER ${HEADER_REL} TEST_NAME) string(MAKE_C_IDENTIFIER ${HEADER_REL} TEST_NAME)
get_filename_component(BASE_NAME ${HEADER} NAME_WE) get_filename_component(BASE_NAME ${HEADER} NAME_WE)
test_header(header_${TEST_NAME} ${PREFIX}/${BASE_NAME}.hpp) test_header(header_${TEST_NAME} ${PREFIX}/${BASE_NAME}.hpp)
target_link_libraries(header_${TEST_NAME} migraphx_all_targets) target_link_libraries(header_${TEST_NAME} migraphx migraphx_onnx migraphx_tf migraphx_all_targets)
endforeach() endforeach()
endfunction() endfunction()
......
...@@ -36,7 +36,7 @@ endfunction() ...@@ -36,7 +36,7 @@ endfunction()
function(add_c_api_test TEST_NAME TEST_SRC TEST_DIR) function(add_c_api_test TEST_NAME TEST_SRC TEST_DIR)
set(NAME test_api_${TEST_NAME}) set(NAME test_api_${TEST_NAME})
add_executable(${NAME} EXCLUDE_FROM_ALL ${TEST_SRC}) add_executable(${NAME} EXCLUDE_FROM_ALL ${TEST_SRC})
target_link_libraries(${NAME} migraphx_c migraphx) target_link_libraries(${NAME} migraphx_c)
target_include_directories(${NAME} PUBLIC ../include) target_include_directories(${NAME} PUBLIC ../include)
add_test(NAME ${NAME} COMMAND $<TARGET_FILE:${NAME}> WORKING_DIRECTORY ${TEST_DIR}) add_test(NAME ${NAME} COMMAND $<TARGET_FILE:${NAME}> WORKING_DIRECTORY ${TEST_DIR})
add_dependencies(tests ${NAME}) add_dependencies(tests ${NAME})
......
...@@ -30,7 +30,7 @@ void expect_equal(const char* x, const char* y) ...@@ -30,7 +30,7 @@ void expect_equal(const char* x, const char* y)
abort(); abort();
} }
int main() int main(void)
{ {
char name[1024]; char name[1024];
migraphx_operation_t op; migraphx_operation_t op;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment