"vscode:/vscode.git/clone" did not exist on "b0c2e607b8bab54ee3e23f2a285e50cbe5ba7101"
Commit 3f322644 authored by Alan Turner's avatar Alan Turner
Browse files

Merge remote-tracking branch 'origin/develop' into ck-gsg

parents 53aee707 09aaa63e
...@@ -75,12 +75,19 @@ using miopen_find_options = MIGRAPHX_MANAGE_PTR(miopenFindOptions_t, miopenDestr ...@@ -75,12 +75,19 @@ using miopen_find_options = MIGRAPHX_MANAGE_PTR(miopenFindOptions_t, miopenDestr
using miopen_problem = MIGRAPHX_MANAGE_PTR(miopenProblem_t, miopenDestroyProblem); using miopen_problem = MIGRAPHX_MANAGE_PTR(miopenProblem_t, miopenDestroyProblem);
using miopen_solution = MIGRAPHX_MANAGE_PTR(miopenSolution_t, miopenDestroySolution); using miopen_solution = MIGRAPHX_MANAGE_PTR(miopenSolution_t, miopenDestroySolution);
inline miopen_solution find_solution(miopenHandle_t handle, miopenProblem_t problem) inline miopen_solution
find_solution(miopenHandle_t handle, miopenProblem_t problem, bool tune = false)
{ {
miopenSolution_t solution; miopenSolution_t solution;
size_t found = 0; size_t found = 0;
auto status = miopenFindSolutions(handle, problem, nullptr, &solution, &found, 1); miopen_find_options fo = nullptr;
auto result = miopen_solution{solution}; if(tune)
{
fo = make_obj<miopen_find_options>(&miopenCreateFindOptions);
miopenSetFindOptionTuning(fo.get(), 1);
}
auto status = miopenFindSolutions(handle, problem, fo.get(), &solution, &found, 1);
auto result = miopen_solution{solution};
if(status != miopenStatusSuccess or found == 0) if(status != miopenStatusSuccess or found == 0)
MIGRAPHX_THROW("MIOpen miopenFindSolutions failed"); MIGRAPHX_THROW("MIOpen miopenFindSolutions failed");
return result; return result;
......
...@@ -56,7 +56,6 @@ struct oper ...@@ -56,7 +56,6 @@ struct oper
return name.substr(pos_ns + 2); return name.substr(pos_ns + 2);
} }
} }
return "unknown_operator_name"; return "unknown_operator_name";
} }
}; };
......
...@@ -37,7 +37,6 @@ struct target ...@@ -37,7 +37,6 @@ struct target
std::string name() const; std::string name() const;
std::vector<pass> get_passes(migraphx::context& gctx, const compile_options& options) const; std::vector<pass> get_passes(migraphx::context& gctx, const compile_options& options) const;
migraphx::context get_context() const; migraphx::context get_context() const;
argument copy_to(const argument& arg) const; argument copy_to(const argument& arg) const;
argument copy_from(const argument& arg) const; argument copy_from(const argument& arg) const;
argument allocate(const shape& s) const; argument allocate(const shape& s) const;
......
...@@ -127,7 +127,7 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -127,7 +127,7 @@ struct reduce_compiler : compiler<reduce_compiler>
vec = vectorize::elements(ctx, faxis, options.virtual_inputs); vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size; auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto block_size = compute_block_size(relements, 256); auto block_size = compute_block_size(relements, 256);
if(relements > block_size * 256) if(relements >= block_size * 256)
algo = "block_large"; algo = "block_large";
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size); v, compute_global_for(ctx, nelements * block_size, 256), block_size);
......
...@@ -25,10 +25,6 @@ ...@@ -25,10 +25,6 @@
#define MIGRAPHX_GUARD_KERNELS_HIP_HPP #define MIGRAPHX_GUARD_KERNELS_HIP_HPP
#ifndef MIGRAPHX_USE_HIPRTC #ifndef MIGRAPHX_USE_HIPRTC
// Workaround macro redefinition issue with clang tidy
#if defined(__HIP_PLATFORM_HCC__) && defined(MIGRAPHX_USE_CLANG_TIDY)
#undef __HIP_PLATFORM_HCC__ // NOLINT
#endif
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
#include <hip/math_functions.h> #include <hip/math_functions.h>
......
...@@ -48,12 +48,20 @@ __device__ void generic_binary_layernorm( ...@@ -48,12 +48,20 @@ __device__ void generic_binary_layernorm(
{ {
using block = reduce::auto_block<reduce::reduce_elements_with_axis<Input1, Axis>()>; using block = reduce::auto_block<reduce::reduce_elements_with_axis<Input1, Axis>()>;
using reduce_output = reduce::with_axis<Input1, Axis>; using reduce_output = reduce::with_axis<Input1, Axis>;
block::template run<reduce_output>([&](auto, auto r) { block::template run<reduce_output>([&](auto, auto r) {
auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2); auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2);
using value_type = typename Input1::type; using value_type = typename Input1::type;
constexpr auto relements = r.template elements<Input1>(); constexpr auto relements = r.template elements<Input1>();
constexpr auto relements_r = vec_type<value_type>{1.0 / relements};
auto relements_rsqrt = sqrt(relements_r);
auto means = r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x) { auto means = r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x) {
return make_array(x, x * x) * vec_type<value_type>{1.0 / relements}; auto x_out = x * relements_r;
// dividing x by sqrt(relements) before squaring allows computing higher values
// before overflow in low precision
auto x2_sqrt = x * relements_rsqrt;
return make_array(x_out, x2_sqrt * x2_sqrt);
})(input); })(input);
auto mean_x = means[0]; auto mean_x = means[0];
......
...@@ -410,9 +410,9 @@ struct block_large ...@@ -410,9 +410,9 @@ struct block_large
}; };
template <class Size, class F> template <class Size, class F>
constexpr inner_storage<Size, F> make_inner_storage(Size, F f) static constexpr inner_storage<Size, F> make_inner_storage(Size, F f)
{ {
return {f}; return {{}, {f}};
} }
template <class Op, class T, class Read, class N, class... Ts> template <class Op, class T, class Read, class N, class... Ts>
...@@ -483,9 +483,9 @@ struct lane ...@@ -483,9 +483,9 @@ struct lane
}; };
template <class Size, class F> template <class Size, class F>
constexpr inner_storage<Size, F> make_inner_storage(Size, F f) static constexpr inner_storage<Size, F> make_inner_storage(Size, F f)
{ {
return {f}; return {{}, {f}};
} }
template <class Op, class T, class Read, class N, class U, class... Us> template <class Op, class T, class Read, class N, class U, class... Us>
......
...@@ -83,8 +83,7 @@ struct miopen_apply ...@@ -83,8 +83,7 @@ struct miopen_apply
auto& ctx = get_context(); auto& ctx = get_context();
int8_x4_format = get_int8_x4_format(ctx); int8_x4_format = get_int8_x4_format(ctx);
compute_fp32 = get_compute_fp32_flag(); compute_fp32 = get_compute_fp32_flag();
offload_copy = (mod->name() == "main") ? pass->offload_copy : false;
offload_copy = (mod->name() == "main") ? pass->offload_copy : false;
add_generic_op("contiguous"); add_generic_op("contiguous");
...@@ -112,6 +111,7 @@ struct miopen_apply ...@@ -112,6 +111,7 @@ struct miopen_apply
add_loop_op(); add_loop_op();
add_neg_op(); add_neg_op();
add_nms_op(); add_nms_op();
add_select_module_op();
} }
void copy_params() const void copy_params() const
...@@ -359,6 +359,20 @@ struct miopen_apply ...@@ -359,6 +359,20 @@ struct miopen_apply
return mod->replace_instruction(ins, gpu_out); return mod->replace_instruction(ins, gpu_out);
}); });
} }
/**
* Adds dynamic allocation for submodule output parameter.
*/
void add_select_module_op()
{
apply_map.emplace("select_module", [=](instruction_ref ins) {
auto s = ins->get_shape();
auto output = insert_allocation(ins, s);
std::vector<instruction_ref> inputs = ins->inputs();
inputs.push_back(output);
return mod->replace_instruction(ins, ins->get_operator(), inputs, ins->module_inputs());
});
}
}; };
void lowering::apply(module& m) const { miopen_apply{&m, this}.apply(); } void lowering::apply(module& m) const { miopen_apply{&m, this}.apply(); }
......
...@@ -76,7 +76,6 @@ namespace gpu { ...@@ -76,7 +76,6 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC)
struct id_pass struct id_pass
{ {
std::string name() const { return "id"; } std::string name() const { return "id"; }
...@@ -93,6 +92,7 @@ pass enable_pass(bool enabled, pass p) ...@@ -93,6 +92,7 @@ pass enable_pass(bool enabled, pass p)
std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_options& options) const std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_options& options) const
{ {
auto& ctx = any_cast<context>(gctx); auto& ctx = any_cast<context>(gctx);
ctx.set_exhaustive_tune_flag(options.exhaustive_tune);
std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end()); std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end());
unsupported_types.erase(shape::type_t::float_type); unsupported_types.erase(shape::type_t::float_type);
unsupported_types.erase(shape::type_t::half_type); unsupported_types.erase(shape::type_t::half_type);
......
...@@ -46,8 +46,6 @@ struct target ...@@ -46,8 +46,6 @@ struct target
argument allocate(const shape& s) const; argument allocate(const shape& s) const;
}; };
MIGRAPHX_REGISTER_TARGET(target);
} // namespace ref } // namespace ref
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -24,4 +24,6 @@ ...@@ -24,4 +24,6 @@
// clang-format off // clang-format off
#define MIGRAPHX_VERSION_MAJOR @PROJECT_VERSION_MAJOR@ #define MIGRAPHX_VERSION_MAJOR @PROJECT_VERSION_MAJOR@
#define MIGRAPHX_VERSION_MINOR @PROJECT_VERSION_MINOR@ #define MIGRAPHX_VERSION_MINOR @PROJECT_VERSION_MINOR@
#define MIGRAPHX_VERSION_PATCH @PROJECT_VERSION_PATCH@
#define MIGRAPHX_VERSION_TWEAK @PROJECT_VERSION_TWEAK@
// clang-format on // clang-format on
...@@ -106,19 +106,11 @@ function(add_test_executable TEST_NAME) ...@@ -106,19 +106,11 @@ function(add_test_executable TEST_NAME)
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU") if(CMAKE_CXX_COMPILER_ID MATCHES "GNU")
set_target_properties(${TEST_NAME} PROPERTIES COMPILE_FLAGS -pthread LINK_FLAGS -pthread) set_target_properties(${TEST_NAME} PROPERTIES COMPILE_FLAGS -pthread LINK_FLAGS -pthread)
endif() endif()
set(TEST_COMMAND ${TEST_NAME})
separate_arguments(MIOPEN_TEST_FLAGS_ARGS UNIX_COMMAND ${MIOPEN_TEST_FLAGS})
if(MIOPEN_TEST_ALL)
set(TEST_COMMAND ${TEST_NAME} ${MIOPEN_TEST_FLOAT_ARG} --all ${MIOPEN_TEST_FLAGS_ARGS})
else()
set(TEST_COMMAND ${TEST_NAME} ${MIOPEN_TEST_FLOAT_ARG} ${MIOPEN_TEST_FLAGS_ARGS})
endif()
add_test_command(${TEST_NAME} ${TEST_COMMAND}) add_test_command(${TEST_NAME} ${TEST_COMMAND})
add_dependencies(tests ${TEST_NAME}) add_dependencies(tests ${TEST_NAME})
add_dependencies(check ${TEST_NAME}) add_dependencies(check ${TEST_NAME})
target_link_libraries(${TEST_NAME} migraphx migraphx_ref migraphx_onnx) target_link_libraries(${TEST_NAME} migraphx migraphx_onnx)
target_include_directories(${TEST_NAME} PUBLIC include) target_include_directories(${TEST_NAME} PUBLIC include)
endfunction(add_test_executable) endfunction(add_test_executable)
...@@ -171,7 +163,7 @@ foreach(ONNX_TEST ${ONNX_TESTS}) ...@@ -171,7 +163,7 @@ foreach(ONNX_TEST ${ONNX_TESTS})
set(TEST_NAME test_${BASE_NAME}) set(TEST_NAME test_${BASE_NAME})
add_executable(${TEST_NAME} ${ONNX_TEST}) add_executable(${TEST_NAME} ${ONNX_TEST})
rocm_clang_tidy_check(${TEST_NAME}) rocm_clang_tidy_check(${TEST_NAME})
target_link_libraries(${TEST_NAME} migraphx_onnx migraphx_ref) target_link_libraries(${TEST_NAME} migraphx_onnx)
target_include_directories(${TEST_NAME} PUBLIC include) target_include_directories(${TEST_NAME} PUBLIC include)
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}> WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}> WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_dependencies(tests ${TEST_NAME}) add_dependencies(tests ${TEST_NAME})
...@@ -182,7 +174,7 @@ endforeach() ...@@ -182,7 +174,7 @@ endforeach()
set(TEST_TF_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tf) set(TEST_TF_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tf)
add_executable(test_tf tf/tf_test.cpp) add_executable(test_tf tf/tf_test.cpp)
rocm_clang_tidy_check(test_tf) rocm_clang_tidy_check(test_tf)
target_link_libraries(test_tf migraphx_tf migraphx_ref) target_link_libraries(test_tf migraphx_tf)
target_include_directories(test_tf PUBLIC include) target_include_directories(test_tf PUBLIC include)
add_test(NAME test_tf COMMAND $<TARGET_FILE:test_tf> WORKING_DIRECTORY ${TEST_TF_DIR}) add_test(NAME test_tf COMMAND $<TARGET_FILE:test_tf> WORKING_DIRECTORY ${TEST_TF_DIR})
add_dependencies(tests test_tf) add_dependencies(tests test_tf)
...@@ -216,10 +208,7 @@ function(test_headers PREFIX) ...@@ -216,10 +208,7 @@ function(test_headers PREFIX)
string(MAKE_C_IDENTIFIER ${HEADER_REL} TEST_NAME) string(MAKE_C_IDENTIFIER ${HEADER_REL} TEST_NAME)
get_filename_component(BASE_NAME ${HEADER} NAME_WE) get_filename_component(BASE_NAME ${HEADER} NAME_WE)
test_header(header_${TEST_NAME} ${PREFIX}/${BASE_NAME}.hpp) test_header(header_${TEST_NAME} ${PREFIX}/${BASE_NAME}.hpp)
target_link_libraries(header_${TEST_NAME} migraphx_all_targets)
if(MIGRAPHX_ENABLE_GPU)
target_link_libraries(header_${TEST_NAME} migraphx_gpu)
endif()
endforeach() endforeach()
endfunction() endfunction()
......
...@@ -56,8 +56,10 @@ add_api_test(custom_op test_custom_op.cpp ${TEST_ONNX_DIR}) ...@@ -56,8 +56,10 @@ add_api_test(custom_op test_custom_op.cpp ${TEST_ONNX_DIR})
add_api_test(tf_parser test_tf_parser.cpp ${TEST_TF_DIR}) add_api_test(tf_parser test_tf_parser.cpp ${TEST_TF_DIR})
# GPU-based tests # GPU-based tests
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
find_package(hip)
add_api_test(gpu test_gpu.cpp ${TEST_ONNX_DIR}) add_api_test(gpu test_gpu.cpp ${TEST_ONNX_DIR})
target_link_libraries(test_api_gpu migraphx_gpu) target_link_libraries(test_api_gpu hip::host)
add_api_test(custom_op_gpu test_custom_op_gpu.cpp ${TEST_ONNX_DIR}) add_api_test(custom_op_gpu test_custom_op_gpu.cpp ${TEST_ONNX_DIR})
target_link_libraries(test_api_custom_op_gpu migraphx_gpu) target_link_libraries(test_api_custom_op_gpu hip::host)
endif() endif()
...@@ -35,6 +35,7 @@ TEST_CASE(load_and_run) ...@@ -35,6 +35,7 @@ TEST_CASE(load_and_run)
auto shapes_before = p.get_output_shapes(); auto shapes_before = p.get_output_shapes();
migraphx::compile_options options; migraphx::compile_options options;
options.set_offload_copy(); options.set_offload_copy();
options.set_exhaustive_tune_flag();
p.compile(migraphx::target("gpu"), options); p.compile(migraphx::target("gpu"), options);
auto shapes_after = p.get_output_shapes(); auto shapes_after = p.get_output_shapes();
CHECK(shapes_before.size() == 1); CHECK(shapes_before.size() == 1);
......
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/fpga/target.hpp>
#include <migraphx/target_assignments.hpp> #include <migraphx/target_assignments.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
......
...@@ -27,8 +27,8 @@ ...@@ -27,8 +27,8 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/gpu/kernel.hpp> #include <migraphx/gpu/kernel.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp> #include <migraphx/gpu/device_name.hpp>
...@@ -235,7 +235,7 @@ TEST_CASE(code_object_hip) ...@@ -235,7 +235,7 @@ TEST_CASE(code_object_hip)
auto y = mm->add_parameter("output", input); auto y = mm->add_parameter("output", input);
mm->add_instruction(co, x, y); mm->add_instruction(co, x, y);
migraphx::compile_options options; migraphx::compile_options options;
p.compile(migraphx::gpu::target{}, options); p.compile(migraphx::make_target("gpu"), options);
auto result = auto result =
migraphx::gpu::from_gpu(p.eval({{"output", migraphx::gpu::allocate_gpu(input)}}).front()); migraphx::gpu::from_gpu(p.eval({{"output", migraphx::gpu::allocate_gpu(input)}}).front());
...@@ -261,7 +261,7 @@ TEST_CASE(compile_code_object_hip) ...@@ -261,7 +261,7 @@ TEST_CASE(compile_code_object_hip)
auto x = mm->add_literal(input_literal); auto x = mm->add_literal(input_literal);
auto y = mm->add_parameter("output", input); auto y = mm->add_parameter("output", input);
mm->add_instruction(co, x, y); mm->add_instruction(co, x, y);
p.compile(migraphx::gpu::target{}, migraphx::compile_options{}); p.compile(migraphx::make_target("gpu"), migraphx::compile_options{});
auto result = auto result =
migraphx::gpu::from_gpu(p.eval({{"output", migraphx::gpu::allocate_gpu(input)}}).front()); migraphx::gpu::from_gpu(p.eval({{"output", migraphx::gpu::allocate_gpu(input)}}).front());
...@@ -284,7 +284,7 @@ TEST_CASE(compile_pointwise) ...@@ -284,7 +284,7 @@ TEST_CASE(compile_pointwise)
auto x = mm->add_literal(input_literal); auto x = mm->add_literal(input_literal);
auto y = mm->add_parameter("output", input); auto y = mm->add_parameter("output", input);
mm->add_instruction(co, x, y); mm->add_instruction(co, x, y);
p.compile(migraphx::gpu::target{}, migraphx::compile_options{}); p.compile(migraphx::make_target("gpu"), migraphx::compile_options{});
auto result = auto result =
migraphx::gpu::from_gpu(p.eval({{"output", migraphx::gpu::allocate_gpu(input)}}).front()); migraphx::gpu::from_gpu(p.eval({{"output", migraphx::gpu::allocate_gpu(input)}}).front());
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
...@@ -35,7 +36,7 @@ void gpu_literal_test() ...@@ -35,7 +36,7 @@ void gpu_literal_test()
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto lit = generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto lit = generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
mm->add_literal(lit); mm->add_literal(lit);
p.compile(migraphx::gpu::target{}); p.compile(migraphx::make_target("gpu"));
auto scratch = p.get_parameter("scratch"); auto scratch = p.get_parameter("scratch");
if(scratch == mm->end()) if(scratch == mm->end())
{ {
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <hip/hip_runtime_api.h> #include <hip/hip_runtime_api.h>
#include <migraphx/gpu/target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <test.hpp> #include <test.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
...@@ -57,7 +57,7 @@ TEST_CASE(host_same_buffer_copy) ...@@ -57,7 +57,7 @@ TEST_CASE(host_same_buffer_copy)
pp["a"] = migraphx::argument(ss, a_vec.data()); pp["a"] = migraphx::argument(ss, a_vec.data());
pp["b"] = migraphx::argument(ss, b_vec.data()); pp["b"] = migraphx::argument(ss, b_vec.data());
std::vector<float> gpu_result; std::vector<float> gpu_result;
migraphx::target gpu_t = migraphx::gpu::target{}; migraphx::target gpu_t = migraphx::make_target("gpu");
migraphx::compile_options options; migraphx::compile_options options;
options.offload_copy = true; options.offload_copy = true;
p.compile(gpu_t, options); p.compile(gpu_t, options);
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#include <migraphx/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/write_literals.hpp> #include <migraphx/gpu/write_literals.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
...@@ -121,7 +121,7 @@ migraphx::argument run_gpu(migraphx::program p, const migraphx::parameter_map& i ...@@ -121,7 +121,7 @@ migraphx::argument run_gpu(migraphx::program p, const migraphx::parameter_map& i
migraphx::argument run_ref(migraphx::program p, const migraphx::parameter_map& inputs) migraphx::argument run_ref(migraphx::program p, const migraphx::parameter_map& inputs)
{ {
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
return p.eval(inputs).front(); return p.eval(inputs).front();
} }
......
...@@ -27,8 +27,7 @@ ...@@ -27,8 +27,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/propagate_constant.hpp> #include <migraphx/propagate_constant.hpp>
...@@ -39,8 +38,8 @@ ...@@ -39,8 +38,8 @@
TEST_CASE(gpu_target_copy) TEST_CASE(gpu_target_copy)
{ {
migraphx::target gpu_t = migraphx::gpu::target{}; migraphx::target gpu_t = migraphx::make_target("gpu");
migraphx::target ref_t = migraphx::ref::target{}; migraphx::target ref_t = migraphx::make_target("ref");
migraphx::shape s{migraphx::shape::int8_type, {2, 3, 4, 5}}; migraphx::shape s{migraphx::shape::int8_type, {2, 3, 4, 5}};
auto ref_arg_orig = migraphx::generate_argument(s, 0x123456L); auto ref_arg_orig = migraphx::generate_argument(s, 0x123456L);
...@@ -104,11 +103,11 @@ TEST_CASE(int8_quantization) ...@@ -104,11 +103,11 @@ TEST_CASE(int8_quantization)
m["a"] = migraphx::generate_argument(sa); m["a"] = migraphx::generate_argument(sa);
m["b"] = migraphx::generate_argument(sb); m["b"] = migraphx::generate_argument(sb);
std::vector<float> ref_result; std::vector<float> ref_result;
migraphx::target ref_t = migraphx::ref::target{}; migraphx::target ref_t = migraphx::make_target("ref");
run_prog(p, ref_t, m, ref_result); run_prog(p, ref_t, m, ref_result);
std::vector<float> gpu_result; std::vector<float> gpu_result;
migraphx::target gpu_t = migraphx::gpu::target{}; migraphx::target gpu_t = migraphx::make_target("gpu");
run_prog(p, gpu_t, m, gpu_result); run_prog(p, gpu_t, m, gpu_result);
EXPECT(migraphx::verify_range(ref_result, gpu_result)); EXPECT(migraphx::verify_range(ref_result, gpu_result));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment