"profiler/vscode:/vscode.git/clone" did not exist on "9697ad4e0cee715eee8b558c218adc5062ed25b7"
Unverified Commit 7a7040aa authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Dynamically plug-in backend target libs (#1608)

Fixes #1595
parent 9ef6801e
...@@ -56,8 +56,10 @@ add_api_test(custom_op test_custom_op.cpp ${TEST_ONNX_DIR}) ...@@ -56,8 +56,10 @@ add_api_test(custom_op test_custom_op.cpp ${TEST_ONNX_DIR})
add_api_test(tf_parser test_tf_parser.cpp ${TEST_TF_DIR}) add_api_test(tf_parser test_tf_parser.cpp ${TEST_TF_DIR})
# GPU-based tests # GPU-based tests
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
find_package(hip)
add_api_test(gpu test_gpu.cpp ${TEST_ONNX_DIR}) add_api_test(gpu test_gpu.cpp ${TEST_ONNX_DIR})
target_link_libraries(test_api_gpu migraphx_gpu) target_link_libraries(test_api_gpu hip::host)
add_api_test(custom_op_gpu test_custom_op_gpu.cpp ${TEST_ONNX_DIR}) add_api_test(custom_op_gpu test_custom_op_gpu.cpp ${TEST_ONNX_DIR})
target_link_libraries(test_api_custom_op_gpu migraphx_gpu) target_link_libraries(test_api_custom_op_gpu hip::host)
endif() endif()
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/serialize.hpp>
#include <migraphx/context.hpp>
#include <migraphx/ref/context.hpp>
#include <migraphx/functional.hpp>
#include <test.hpp>
TEST_CASE(context)
{
migraphx::context ctx = migraphx::ref::context{};
migraphx::value v = ctx.to_value();
EXPECT(v.empty());
migraphx::context cpu_ctx = migraphx::ref::context{};
cpu_ctx.from_value(v);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/fpga/target.hpp>
#include <migraphx/target_assignments.hpp> #include <migraphx/target_assignments.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
......
...@@ -27,8 +27,8 @@ ...@@ -27,8 +27,8 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/gpu/kernel.hpp> #include <migraphx/gpu/kernel.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp> #include <migraphx/gpu/device_name.hpp>
...@@ -235,7 +235,7 @@ TEST_CASE(code_object_hip) ...@@ -235,7 +235,7 @@ TEST_CASE(code_object_hip)
auto y = mm->add_parameter("output", input); auto y = mm->add_parameter("output", input);
mm->add_instruction(co, x, y); mm->add_instruction(co, x, y);
migraphx::compile_options options; migraphx::compile_options options;
p.compile(migraphx::gpu::target{}, options); p.compile(migraphx::make_target("gpu"), options);
auto result = auto result =
migraphx::gpu::from_gpu(p.eval({{"output", migraphx::gpu::allocate_gpu(input)}}).front()); migraphx::gpu::from_gpu(p.eval({{"output", migraphx::gpu::allocate_gpu(input)}}).front());
...@@ -261,7 +261,7 @@ TEST_CASE(compile_code_object_hip) ...@@ -261,7 +261,7 @@ TEST_CASE(compile_code_object_hip)
auto x = mm->add_literal(input_literal); auto x = mm->add_literal(input_literal);
auto y = mm->add_parameter("output", input); auto y = mm->add_parameter("output", input);
mm->add_instruction(co, x, y); mm->add_instruction(co, x, y);
p.compile(migraphx::gpu::target{}, migraphx::compile_options{}); p.compile(migraphx::make_target("gpu"), migraphx::compile_options{});
auto result = auto result =
migraphx::gpu::from_gpu(p.eval({{"output", migraphx::gpu::allocate_gpu(input)}}).front()); migraphx::gpu::from_gpu(p.eval({{"output", migraphx::gpu::allocate_gpu(input)}}).front());
...@@ -284,7 +284,7 @@ TEST_CASE(compile_pointwise) ...@@ -284,7 +284,7 @@ TEST_CASE(compile_pointwise)
auto x = mm->add_literal(input_literal); auto x = mm->add_literal(input_literal);
auto y = mm->add_parameter("output", input); auto y = mm->add_parameter("output", input);
mm->add_instruction(co, x, y); mm->add_instruction(co, x, y);
p.compile(migraphx::gpu::target{}, migraphx::compile_options{}); p.compile(migraphx::make_target("gpu"), migraphx::compile_options{});
auto result = auto result =
migraphx::gpu::from_gpu(p.eval({{"output", migraphx::gpu::allocate_gpu(input)}}).front()); migraphx::gpu::from_gpu(p.eval({{"output", migraphx::gpu::allocate_gpu(input)}}).front());
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
...@@ -35,7 +36,7 @@ void gpu_literal_test() ...@@ -35,7 +36,7 @@ void gpu_literal_test()
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto lit = generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto lit = generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
mm->add_literal(lit); mm->add_literal(lit);
p.compile(migraphx::gpu::target{}); p.compile(migraphx::make_target("gpu"));
auto scratch = p.get_parameter("scratch"); auto scratch = p.get_parameter("scratch");
if(scratch == mm->end()) if(scratch == mm->end())
{ {
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <hip/hip_runtime_api.h> #include <hip/hip_runtime_api.h>
#include <migraphx/gpu/target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <test.hpp> #include <test.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
...@@ -57,7 +57,7 @@ TEST_CASE(host_same_buffer_copy) ...@@ -57,7 +57,7 @@ TEST_CASE(host_same_buffer_copy)
pp["a"] = migraphx::argument(ss, a_vec.data()); pp["a"] = migraphx::argument(ss, a_vec.data());
pp["b"] = migraphx::argument(ss, b_vec.data()); pp["b"] = migraphx::argument(ss, b_vec.data());
std::vector<float> gpu_result; std::vector<float> gpu_result;
migraphx::target gpu_t = migraphx::gpu::target{}; migraphx::target gpu_t = migraphx::make_target("gpu");
migraphx::compile_options options; migraphx::compile_options options;
options.offload_copy = true; options.offload_copy = true;
p.compile(gpu_t, options); p.compile(gpu_t, options);
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#include <migraphx/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/write_literals.hpp> #include <migraphx/gpu/write_literals.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
...@@ -121,7 +121,7 @@ migraphx::argument run_gpu(migraphx::program p, const migraphx::parameter_map& i ...@@ -121,7 +121,7 @@ migraphx::argument run_gpu(migraphx::program p, const migraphx::parameter_map& i
migraphx::argument run_ref(migraphx::program p, const migraphx::parameter_map& inputs) migraphx::argument run_ref(migraphx::program p, const migraphx::parameter_map& inputs)
{ {
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
return p.eval(inputs).front(); return p.eval(inputs).front();
} }
......
...@@ -27,8 +27,7 @@ ...@@ -27,8 +27,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/propagate_constant.hpp> #include <migraphx/propagate_constant.hpp>
...@@ -39,8 +38,8 @@ ...@@ -39,8 +38,8 @@
TEST_CASE(gpu_target_copy) TEST_CASE(gpu_target_copy)
{ {
migraphx::target gpu_t = migraphx::gpu::target{}; migraphx::target gpu_t = migraphx::make_target("gpu");
migraphx::target ref_t = migraphx::ref::target{}; migraphx::target ref_t = migraphx::make_target("ref");
migraphx::shape s{migraphx::shape::int8_type, {2, 3, 4, 5}}; migraphx::shape s{migraphx::shape::int8_type, {2, 3, 4, 5}};
auto ref_arg_orig = migraphx::generate_argument(s, 0x123456L); auto ref_arg_orig = migraphx::generate_argument(s, 0x123456L);
...@@ -104,11 +103,11 @@ TEST_CASE(int8_quantization) ...@@ -104,11 +103,11 @@ TEST_CASE(int8_quantization)
m["a"] = migraphx::generate_argument(sa); m["a"] = migraphx::generate_argument(sa);
m["b"] = migraphx::generate_argument(sb); m["b"] = migraphx::generate_argument(sb);
std::vector<float> ref_result; std::vector<float> ref_result;
migraphx::target ref_t = migraphx::ref::target{}; migraphx::target ref_t = migraphx::make_target("ref");
run_prog(p, ref_t, m, ref_result); run_prog(p, ref_t, m, ref_result);
std::vector<float> gpu_result; std::vector<float> gpu_result;
migraphx::target gpu_t = migraphx::gpu::target{}; migraphx::target gpu_t = migraphx::make_target("gpu");
run_prog(p, gpu_t, m, gpu_result); run_prog(p, gpu_t, m, gpu_result);
EXPECT(migraphx::verify_range(ref_result, gpu_result)); EXPECT(migraphx::verify_range(ref_result, gpu_result));
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <migraphx/register_target.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/context.hpp> #include <migraphx/context.hpp>
#include <migraphx/gpu/compile_hip.hpp> #include <migraphx/gpu/compile_hip.hpp>
...@@ -133,7 +134,7 @@ TEST_CASE(test_stream_sync) ...@@ -133,7 +134,7 @@ TEST_CASE(test_stream_sync)
auto mult_out = mm->add_instruction(migraphx::make_op("dot"), x, y); auto mult_out = mm->add_instruction(migraphx::make_op("dot"), x, y);
mm->add_instruction(migraphx::make_op("add"), mult_out, test_val); mm->add_instruction(migraphx::make_op("add"), mult_out, test_val);
p.compile(migraphx::gpu::target{}); p.compile(migraphx::make_target("gpu"));
// Run network and then verify with kernel // Run network and then verify with kernel
auto args = p.eval({{"x", ginput}, {"output", goutput}}, {pstream.get(), true}); auto args = p.eval({{"x", ginput}, {"output", goutput}}, {pstream.get(), true});
......
...@@ -22,12 +22,11 @@ ...@@ -22,12 +22,11 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/marker.hpp> #include <migraphx/marker.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/register_target.hpp>
#include "test.hpp" #include "test.hpp"
struct mock_marker struct mock_marker
...@@ -64,7 +63,7 @@ TEST_CASE(marker) ...@@ -64,7 +63,7 @@ TEST_CASE(marker)
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(migraphx::make_op("add"), one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
mock_marker temp_marker; mock_marker temp_marker;
p.mark({}, temp_marker); p.mark({}, temp_marker);
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <sstream> #include <sstream>
#include "test.hpp" #include "test.hpp"
......
This diff is collapsed.
...@@ -22,10 +22,9 @@ ...@@ -22,10 +22,9 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/register_target.hpp>
#include "test.hpp" #include "test.hpp"
TEST_CASE(perf_report) TEST_CASE(perf_report)
...@@ -37,7 +36,7 @@ TEST_CASE(perf_report) ...@@ -37,7 +36,7 @@ TEST_CASE(perf_report)
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(migraphx::make_op("add"), one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
p.perf_report(ss, 2, {}); p.perf_report(ss, 2, {});
std::string output = ss.str(); std::string output = ss.str();
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/register_target.hpp>
#include <sstream> #include <sstream>
#include <migraphx/apply_alpha_beta.hpp> #include <migraphx/apply_alpha_beta.hpp>
#include "test.hpp" #include "test.hpp"
...@@ -139,10 +139,10 @@ TEST_CASE(program_copy) ...@@ -139,10 +139,10 @@ TEST_CASE(program_copy)
migraphx::program p2{}; migraphx::program p2{};
p2 = p1; p2 = p1;
p2.compile(migraphx::ref::target{}); p2.compile(migraphx::make_target("ref"));
EXPECT(p1 != p2); EXPECT(p1 != p2);
p1.compile(migraphx::ref::target{}); p1.compile(migraphx::make_target("ref"));
EXPECT(p1 == p2); EXPECT(p1 == p2);
EXPECT(p1.get_parameter_names() == p2.get_parameter_names()); EXPECT(p1.get_parameter_names() == p2.get_parameter_names());
...@@ -153,7 +153,7 @@ TEST_CASE(program_copy) ...@@ -153,7 +153,7 @@ TEST_CASE(program_copy)
auto p2(p1); auto p2(p1);
EXPECT(p1 == p2); EXPECT(p1 == p2);
p1.compile(migraphx::ref::target{}); p1.compile(migraphx::make_target("ref"));
EXPECT(p1 != p2); EXPECT(p1 != p2);
p2 = p1; p2 = p1;
...@@ -168,8 +168,8 @@ TEST_CASE(program_copy) ...@@ -168,8 +168,8 @@ TEST_CASE(program_copy)
p2 = p1; p2 = p1;
EXPECT(p1 == p2); EXPECT(p1 == p2);
p1.compile(migraphx::ref::target{}); p1.compile(migraphx::make_target("ref"));
p2.compile(migraphx::ref::target{}); p2.compile(migraphx::make_target("ref"));
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
...@@ -190,8 +190,8 @@ TEST_CASE(program_copy) ...@@ -190,8 +190,8 @@ TEST_CASE(program_copy)
p2 = p1; p2 = p1;
EXPECT(p2 == p1); EXPECT(p2 == p1);
p1.compile(migraphx::ref::target{}); p1.compile(migraphx::make_target("ref"));
p2.compile(migraphx::ref::target{}); p2.compile(migraphx::make_target("ref"));
EXPECT(p2 == p1); EXPECT(p2 == p1);
} }
} }
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/apply_alpha_beta.hpp> #include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
...@@ -487,7 +487,7 @@ TEST_CASE(op_capture) ...@@ -487,7 +487,7 @@ TEST_CASE(op_capture)
{ {
auto p = create_program_float(); auto p = create_program_float();
auto op_capture_p = create_program_op(); auto op_capture_p = create_program_op();
migraphx::target t = migraphx::ref::target{}; migraphx::target t = migraphx::make_target("ref");
std::size_t param_index = 0; std::size_t param_index = 0;
migraphx::run_passes( migraphx::run_passes(
p, {migraphx::capture_arguments_pass{{"dot", "convolution"}, {}, &param_index}}); p, {migraphx::capture_arguments_pass{{"dot", "convolution"}, {}, &param_index}});
...@@ -562,7 +562,7 @@ TEST_CASE(op_capture_subgraph) ...@@ -562,7 +562,7 @@ TEST_CASE(op_capture_subgraph)
{ {
auto p = create_program(); auto p = create_program();
auto op_capture_p = create_program_op(); auto op_capture_p = create_program_op();
migraphx::target t = migraphx::ref::target{}; migraphx::target t = migraphx::make_target("ref");
std::size_t param_index = 0; std::size_t param_index = 0;
migraphx::run_passes( migraphx::run_passes(
p, {migraphx::capture_arguments_pass{{"dot", "convolution"}, {}, &param_index}}); p, {migraphx::capture_arguments_pass{{"dot", "convolution"}, {}, &param_index}});
...@@ -1010,7 +1010,7 @@ TEST_CASE(target_copy) ...@@ -1010,7 +1010,7 @@ TEST_CASE(target_copy)
migraphx::shape s{migraphx::shape::float_type, {3, 3}}; migraphx::shape s{migraphx::shape::float_type, {3, 3}};
m["x"] = migraphx::generate_argument(s); m["x"] = migraphx::generate_argument(s);
std::vector<float> ref_result; std::vector<float> ref_result;
migraphx::target ref_t = migraphx::ref::target{}; migraphx::target ref_t = migraphx::make_target("ref");
run_prog(p, ref_t, m, ref_result); run_prog(p, ref_t, m, ref_result);
std::vector<float> orig_result; std::vector<float> orig_result;
...@@ -1074,7 +1074,7 @@ TEST_CASE(int8_quantization_dot) ...@@ -1074,7 +1074,7 @@ TEST_CASE(int8_quantization_dot)
m["a"] = migraphx::generate_argument(sa, get_hash(std::string("a"))); m["a"] = migraphx::generate_argument(sa, get_hash(std::string("a")));
m["b"] = migraphx::generate_argument(sb, get_hash(std::string("b"))); m["b"] = migraphx::generate_argument(sb, get_hash(std::string("b")));
std::vector<float> quant_result; std::vector<float> quant_result;
migraphx::target ref_t = migraphx::ref::target{}; migraphx::target ref_t = migraphx::make_target("ref");
run_prog(p, ref_t, m, quant_result, true); run_prog(p, ref_t, m, quant_result, true);
std::vector<float> no_quant_result; std::vector<float> no_quant_result;
...@@ -1119,7 +1119,7 @@ TEST_CASE(int8_quantization_conv) ...@@ -1119,7 +1119,7 @@ TEST_CASE(int8_quantization_conv)
{ {
auto p = create_program(); auto p = create_program();
std::vector<float> quant_result; std::vector<float> quant_result;
migraphx::target ref_t = migraphx::ref::target{}; migraphx::target ref_t = migraphx::make_target("ref");
run_prog(p, ref_t, quant_result, true); run_prog(p, ref_t, quant_result, true);
std::vector<float> no_quant_result; std::vector<float> no_quant_result;
...@@ -1261,13 +1261,13 @@ TEST_CASE(test_op_capture) ...@@ -1261,13 +1261,13 @@ TEST_CASE(test_op_capture)
auto calc = [](std::size_t, const std::vector<migraphx::argument>&) {}; auto calc = [](std::size_t, const std::vector<migraphx::argument>&) {};
migraphx::program capture_p = p; migraphx::program capture_p = p;
migraphx::target t = migraphx::ref::target{}; migraphx::target t = migraphx::make_target("ref");
std::size_t param_index = 0; std::size_t param_index = 0;
migraphx::run_passes(capture_p, migraphx::run_passes(capture_p,
{migraphx::capture_arguments_pass{{"dot"}, calc, &param_index}}); {migraphx::capture_arguments_pass{{"dot"}, calc, &param_index}});
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
capture_p.compile(migraphx::ref::target{}); capture_p.compile(migraphx::make_target("ref"));
auto cap_res = capture_p.eval({}).back(); auto cap_res = capture_p.eval({}).back();
auto res = p.eval({}).back(); auto res = p.eval({}).back();
......
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include "test.hpp" #include "test.hpp"
...@@ -54,7 +54,7 @@ TEST_CASE(add_two_literals) ...@@ -54,7 +54,7 @@ TEST_CASE(add_two_literals)
mm->add_instruction(migraphx::make_op("add"), one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
// compile the program on the reference device // compile the program on the reference device
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
// evaulate the program and retreive the result // evaulate the program and retreive the result
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -78,7 +78,7 @@ TEST_CASE(add_parameters) ...@@ -78,7 +78,7 @@ TEST_CASE(add_parameters)
// add the "add" instruction between the "x" parameter and "two" to the module // add the "add" instruction between the "x" parameter and "two" to the module
mm->add_instruction(migraphx::make_op("add"), x, two); mm->add_instruction(migraphx::make_op("add"), x, two);
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
// create a parameter_map object for passing a value to the "x" parameter // create a parameter_map object for passing a value to the "x" parameter
std::vector<int> data = {4}; std::vector<int> data = {4};
...@@ -111,7 +111,7 @@ TEST_CASE(handling_tensors) ...@@ -111,7 +111,7 @@ TEST_CASE(handling_tensors)
input, input,
weights); weights);
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
// Allocated buffers by the user // Allocated buffers by the user
std::vector<float> a = { std::vector<float> a = {
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#include <vector> #include <vector>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
...@@ -76,7 +76,7 @@ void dot_2d_test() ...@@ -76,7 +76,7 @@ void dot_2d_test()
migraphx::shape b_shape{migraphx::shape::get_type<T>{}, {5, 3}}; migraphx::shape b_shape{migraphx::shape::get_type<T>{}, {5, 3}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
mm->add_instruction(migraphx::make_op("dot"), al, bl); mm->add_instruction(migraphx::make_op("dot"), al, bl);
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<T> results_vector; std::vector<T> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
...@@ -127,7 +127,7 @@ void dot_4d_test() ...@@ -127,7 +127,7 @@ void dot_4d_test()
migraphx::shape b_shape{migraphx::shape::get_type<T>{}, {1, 1, 5, 3}}; migraphx::shape b_shape{migraphx::shape::get_type<T>{}, {1, 1, 5, 3}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
mm->add_instruction(migraphx::make_op("dot"), al, bl); mm->add_instruction(migraphx::make_op("dot"), al, bl);
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<T> results_vector; std::vector<T> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
...@@ -164,7 +164,7 @@ TEST_CASE(dot_3D_test) ...@@ -164,7 +164,7 @@ TEST_CASE(dot_3D_test)
auto l2 = mm->add_literal(migraphx::literal{m2_shape, m2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, m2});
mm->add_instruction(migraphx::make_op("dot"), l1, l2); mm->add_instruction(migraphx::make_op("dot"), l1, l2);
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -240,7 +240,7 @@ TEST_CASE(dot_3D_C_test0) ...@@ -240,7 +240,7 @@ TEST_CASE(dot_3D_C_test0)
migraphx::make_op("dot"), migraphx::make_op("dot"),
alpha, alpha,
beta); beta);
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -307,7 +307,7 @@ TEST_CASE(dot_3D_C_test1) ...@@ -307,7 +307,7 @@ TEST_CASE(dot_3D_C_test1)
migraphx::make_op("dot"), migraphx::make_op("dot"),
alpha, alpha,
beta); beta);
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -349,7 +349,7 @@ TEST_CASE(dot_4D_test1) ...@@ -349,7 +349,7 @@ TEST_CASE(dot_4D_test1)
auto l2 = mm->add_literal(migraphx::literal{m2_shape, m2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, m2});
mm->add_instruction(migraphx::make_op("dot"), l1, l2); mm->add_instruction(migraphx::make_op("dot"), l1, l2);
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -403,7 +403,7 @@ TEST_CASE(dot_4D_alpha_beta_test) ...@@ -403,7 +403,7 @@ TEST_CASE(dot_4D_alpha_beta_test)
migraphx::make_op("scalar", {{"scalar_bcst_dims", m12_alpha->get_shape().lens()}}), l_beta); migraphx::make_op("scalar", {{"scalar_bcst_dims", m12_alpha->get_shape().lens()}}), l_beta);
auto m3_beta = mm->add_instruction(migraphx::make_op("mul"), b_beta, l3); auto m3_beta = mm->add_instruction(migraphx::make_op("mul"), b_beta, l3);
mm->add_instruction(migraphx::make_op("add"), m3_beta, m12_alpha); mm->add_instruction(migraphx::make_op("add"), m3_beta, m12_alpha);
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -455,7 +455,7 @@ TEST_CASE(dot_4D_alpha_beta_C_test) ...@@ -455,7 +455,7 @@ TEST_CASE(dot_4D_alpha_beta_C_test)
migraphx::make_op("dot"), migraphx::make_op("dot"),
alpha, alpha,
beta); beta);
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -525,7 +525,7 @@ TEST_CASE(dot_2D_C_test0) ...@@ -525,7 +525,7 @@ TEST_CASE(dot_2D_C_test0)
-0.835966, -0.835966,
5.74736, 5.74736,
4.22063}; 4.22063};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -563,7 +563,7 @@ TEST_CASE(dot_vv_inner_product) ...@@ -563,7 +563,7 @@ TEST_CASE(dot_vv_inner_product)
auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl); auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl);
mm->add_instruction(migraphx::make_op("dot"), ual, ubl); mm->add_instruction(migraphx::make_op("dot"), ual, ubl);
std::vector<float> gold = {-1.43461}; std::vector<float> gold = {-1.43461};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -600,7 +600,7 @@ TEST_CASE(dot_vv_inner_product) ...@@ -600,7 +600,7 @@ TEST_CASE(dot_vv_inner_product)
migraphx::add_apply_alpha_beta( migraphx::add_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{ual, ubl}, migraphx::make_op("dot"), alpha); *mm, std::vector<migraphx::instruction_ref>{ual, ubl}, migraphx::make_op("dot"), alpha);
std::vector<float> gold = {-0.4590752}; std::vector<float> gold = {-0.4590752};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -638,7 +638,7 @@ TEST_CASE(dot_vm) ...@@ -638,7 +638,7 @@ TEST_CASE(dot_vm)
mm->add_instruction(migraphx::make_op("dot"), ual, bl); mm->add_instruction(migraphx::make_op("dot"), ual, bl);
std::vector<float> gold = {-3.78111, -3.40007, -2.1972, -3.31448, -3.80326}; std::vector<float> gold = {-3.78111, -3.40007, -2.1972, -3.31448, -3.80326};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -675,7 +675,7 @@ TEST_CASE(dot_vm) ...@@ -675,7 +675,7 @@ TEST_CASE(dot_vm)
*mm, std::vector<migraphx::instruction_ref>{ual, bl}, migraphx::make_op("dot"), alpha); *mm, std::vector<migraphx::instruction_ref>{ual, bl}, migraphx::make_op("dot"), alpha);
std::vector<float> gold = {-1.89056, -1.70003, -1.0986, -1.65724, -1.90163}; std::vector<float> gold = {-1.89056, -1.70003, -1.0986, -1.65724, -1.90163};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -722,7 +722,7 @@ TEST_CASE(dot_vm) ...@@ -722,7 +722,7 @@ TEST_CASE(dot_vm)
1.38484, 1.38484,
-2.45019, -2.45019,
-1.35064}; -1.35064};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -770,7 +770,7 @@ TEST_CASE(dot_vm) ...@@ -770,7 +770,7 @@ TEST_CASE(dot_vm)
0.290817, 0.290817,
-0.514539, -0.514539,
-0.283635}; -0.283635};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -809,7 +809,7 @@ TEST_CASE(dot_mv) ...@@ -809,7 +809,7 @@ TEST_CASE(dot_mv)
auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl); auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl);
mm->add_instruction(migraphx::make_op("dot"), al, ubl); mm->add_instruction(migraphx::make_op("dot"), al, ubl);
std::vector<float> gold = {1.31982, 1.19022, -1.96062}; std::vector<float> gold = {1.31982, 1.19022, -1.96062};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -847,7 +847,7 @@ TEST_CASE(dot_mv) ...@@ -847,7 +847,7 @@ TEST_CASE(dot_mv)
migraphx::add_apply_alpha_beta( migraphx::add_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{al, ubl}, migraphx::make_op("dot"), alpha); *mm, std::vector<migraphx::instruction_ref>{al, ubl}, migraphx::make_op("dot"), alpha);
std::vector<float> gold = {0.395946, 0.357067, -0.588187}; std::vector<float> gold = {0.395946, 0.357067, -0.588187};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -891,7 +891,7 @@ TEST_CASE(dot_mv) ...@@ -891,7 +891,7 @@ TEST_CASE(dot_mv)
2.87146, 2.87146,
3.29447, 3.29447,
0.765651}; 0.765651};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -945,7 +945,7 @@ TEST_CASE(dot_mm1) ...@@ -945,7 +945,7 @@ TEST_CASE(dot_mm1)
-0.557691, 6.13527, -2.91743, 2.37836, -6.42584, 1.14979, -0.557691, 6.13527, -2.91743, 2.37836, -6.42584, 1.14979,
0.77227, 0.349659, 2.92759, 2.32384, -2.90664, 0.0527679, 0.77227, 0.349659, 2.92759, 2.32384, -2.90664, 0.0527679,
-0.547761, -0.155467, 0.964619, 2.09133, -4.44281, -1.3864}; -0.547761, -0.155467, 0.964619, 2.09133, -4.44281, -1.3864};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -998,7 +998,7 @@ TEST_CASE(dot_mm1) ...@@ -998,7 +998,7 @@ TEST_CASE(dot_mm1)
-0.710558, 0.259424, -0.342345, -1.80522, -0.580476, 0.277368, -3.95582, 0.614823, -0.710558, 0.259424, -0.342345, -1.80522, -0.580476, 0.277368, -3.95582, 0.614823,
-0.415107, 0.305138, 0.435993, -0.107089, -0.767885, -4.00837, 1.09921, -2.02129, -0.415107, 0.305138, 0.435993, -0.107089, -0.767885, -4.00837, 1.09921, -2.02129,
0.109717, 0.618422, 0.438342, 0.29602, 2.00928, 0.420871}; 0.109717, 0.618422, 0.438342, 0.29602, 2.00928, 0.420871};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -1043,7 +1043,7 @@ TEST_CASE(dot_mm2) ...@@ -1043,7 +1043,7 @@ TEST_CASE(dot_mm2)
1.83096922, 2.5512663, -1.07926588, -2.12749134, 0.33014536, -0.80393025, 1.83096922, 2.5512663, -1.07926588, -2.12749134, 0.33014536, -0.80393025,
0.60740202, 0.95217761, -1.06087445, -4.75868152, -3.6687713, -1.26539821}; 0.60740202, 0.95217761, -1.06087445, -4.75868152, -3.6687713, -1.26539821};
mm->add_instruction(migraphx::make_op("dot"), al, bbl); mm->add_instruction(migraphx::make_op("dot"), al, bbl);
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -1085,7 +1085,7 @@ TEST_CASE(dot_mm2) ...@@ -1085,7 +1085,7 @@ TEST_CASE(dot_mm2)
1.02442564e-01, -1.87659303e+00, -4.67302454e-01, 9.16189968e-01, -1.33537175e-01, 1.02442564e-01, -1.87659303e+00, -4.67302454e-01, 9.16189968e-01, -1.33537175e-01,
8.27398578e-01, 1.94406914e+00, -2.39250915e-01, -1.77062701e+00, -6.46239534e-01, 8.27398578e-01, 1.94406914e+00, -2.39250915e-01, -1.77062701e+00, -6.46239534e-01,
-7.95202750e-01}; -7.95202750e-01};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -1137,7 +1137,7 @@ TEST_CASE(dot_mm2) ...@@ -1137,7 +1137,7 @@ TEST_CASE(dot_mm2)
-0.61459168, -0.52561056, 0.3309648, -0.46185697, -1.60586695, -0.98590829, -0.61459168, -0.52561056, 0.3309648, -0.46185697, -1.60586695, -0.98590829,
0.63012062, -0.25606052, -0.69419352, -1.78299913, -0.38572706, 1.92249442, 0.63012062, -0.25606052, -0.69419352, -1.78299913, -0.38572706, 1.92249442,
0.3884186, -0.48153048, 0.84932351, 0.67234919, -1.07821322, -0.01208216}; 0.3884186, -0.48153048, 0.84932351, 0.67234919, -1.07821322, -0.01208216};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -1185,7 +1185,7 @@ TEST_CASE(dot_mm2) ...@@ -1185,7 +1185,7 @@ TEST_CASE(dot_mm2)
1.38307367, 0.42677257, 0.83759966, -0.34827442, -1.45067092, 2.09599671, 1.38307367, 0.42677257, 0.83759966, -0.34827442, -1.45067092, 2.09599671,
1.92882983, -0.30996324, 2.19736278, 2.32389426, 2.36741832, 1.62253915, 1.92882983, -0.30996324, 2.19736278, 2.32389426, 2.36741832, 1.62253915,
0.26698225, -0.00741609, -2.53680983, -0.0679954, 0.04499683, 0.85354276}; 0.26698225, -0.00741609, -2.53680983, -0.0679954, 0.04499683, 0.85354276};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -1202,7 +1202,7 @@ TEST_CASE(dot_dyn_2D_test) ...@@ -1202,7 +1202,7 @@ TEST_CASE(dot_dyn_2D_test)
migraphx::shape b_shape{migraphx::shape::float_type, {5, 3}}; migraphx::shape b_shape{migraphx::shape::float_type, {5, 3}};
auto bp = mm->add_parameter("b", b_shape); auto bp = mm->add_parameter("b", b_shape);
mm->add_instruction(migraphx::make_op("dot"), ap, bp); mm->add_instruction(migraphx::make_op("dot"), ap, bp);
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
std::vector<float> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885, std::vector<float> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027, 1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
...@@ -1256,7 +1256,7 @@ TEST_CASE(dot_dyn_4D_test) ...@@ -1256,7 +1256,7 @@ TEST_CASE(dot_dyn_4D_test)
migraphx::shape b_shape{migraphx::shape::float_type, {1, 1, 5, 3}}; migraphx::shape b_shape{migraphx::shape::float_type, {1, 1, 5, 3}};
auto bl = mm->add_parameter("b", b_shape); auto bl = mm->add_parameter("b", b_shape);
mm->add_instruction(migraphx::make_op("dot"), al, bl); mm->add_instruction(migraphx::make_op("dot"), al, bl);
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
std::vector<float> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885, std::vector<float> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027, 1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
...@@ -1321,7 +1321,7 @@ TEST_CASE(quant_dot_2args_multi4) ...@@ -1321,7 +1321,7 @@ TEST_CASE(quant_dot_2args_multi4)
370, 392, 414, 436, 458, 496, 534, 572, 610, 648, 686, 370, 392, 414, 436, 458, 496, 534, 572, 610, 648, 686,
724, 762, 688, 742, 796, 850, 904, 958, 1012, 1066}; 724, 762, 688, 742, 796, 850, 904, 958, 1012, 1066};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -1349,7 +1349,7 @@ TEST_CASE(quant_dot_2args_multi4) ...@@ -1349,7 +1349,7 @@ TEST_CASE(quant_dot_2args_multi4)
580, 608, 636, 664, 692, 544, 576, 608, 640, 672, 704, 580, 608, 636, 664, 692, 544, 576, 608, 640, 672, 704,
736, 768, 592, 628, 664, 700, 736, 772, 808, 844}; 736, 768, 592, 628, 664, 700, 736, 772, 808, 844};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -1377,7 +1377,7 @@ TEST_CASE(quant_dot_2args_multi4) ...@@ -1377,7 +1377,7 @@ TEST_CASE(quant_dot_2args_multi4)
302, 390, 478, 566, 654, 62, 214, 366, 518, 670, 822, 302, 390, 478, 566, 654, 62, 214, 366, 518, 670, 822,
974, 1126, 86, 302, 518, 734, 950, 1166, 1382, 1598}; 974, 1126, 86, 302, 518, 734, 950, 1166, 1382, 1598};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -1407,7 +1407,7 @@ TEST_CASE(quant_dot_2args_multi4) ...@@ -1407,7 +1407,7 @@ TEST_CASE(quant_dot_2args_multi4)
398, 510, 622, 734, 846, 68, 196, 324, 452, 580, 708, 398, 510, 622, 734, 846, 68, 196, 324, 452, 580, 708,
836, 964, 74, 218, 362, 506, 650, 794, 938, 1082}; 836, 964, 74, 218, 362, 506, 650, 794, 938, 1082};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -1435,7 +1435,7 @@ TEST_CASE(quant_dot_2args_general) ...@@ -1435,7 +1435,7 @@ TEST_CASE(quant_dot_2args_general)
std::vector<int> gold = { std::vector<int> gold = {
70, 76, 82, 88, 94, 190, 212, 234, 256, 278, 310, 348, 386, 424, 462}; 70, 76, 82, 88, 94, 190, 212, 234, 256, 278, 310, 348, 386, 424, 462};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -1462,7 +1462,7 @@ TEST_CASE(quant_dot_2args_general) ...@@ -1462,7 +1462,7 @@ TEST_CASE(quant_dot_2args_general)
std::vector<int> gold = { std::vector<int> gold = {
210, 228, 246, 264, 282, 240, 262, 284, 306, 328, 270, 296, 322, 348, 374}; 210, 228, 246, 264, 282, 240, 262, 284, 306, 328, 270, 296, 322, 348, 374};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -1490,7 +1490,7 @@ TEST_CASE(quant_dot_2args_general) ...@@ -1490,7 +1490,7 @@ TEST_CASE(quant_dot_2args_general)
std::vector<int> gold = { std::vector<int> gold = {
28, 76, 124, 172, 220, 76, 252, 428, 604, 780, 124, 428, 732, 1036, 1340}; 28, 76, 124, 172, 220, 76, 252, 428, 604, 780, 124, 428, 732, 1036, 1340};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -1519,7 +1519,7 @@ TEST_CASE(quant_dot_2args_general) ...@@ -1519,7 +1519,7 @@ TEST_CASE(quant_dot_2args_general)
std::vector<int> gold = { std::vector<int> gold = {
126, 342, 558, 774, 990, 144, 408, 672, 936, 1200, 162, 474, 786, 1098, 1410}; 126, 342, 558, 774, 990, 144, 408, 672, 936, 1200, 162, 474, 786, 1098, 1410};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -1551,7 +1551,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1551,7 +1551,7 @@ TEST_CASE(quant_dot_3args_general)
std::vector<int> gold = { std::vector<int> gold = {
982, 1011, 1040, 1069, 1098, 1127, 1156, 2557, 2650, 2743, 2836, 2929, 3022, 3115}; 982, 1011, 1040, 1069, 1098, 1127, 1156, 2557, 2650, 2743, 2836, 2929, 3022, 3115};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -1579,7 +1579,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1579,7 +1579,7 @@ TEST_CASE(quant_dot_3args_general)
std::vector<int> gold = { std::vector<int> gold = {
70, 76, 82, 88, 94, 190, 212, 234, 256, 278, 310, 348, 386, 424, 462}; 70, 76, 82, 88, 94, 190, 212, 234, 256, 278, 310, 348, 386, 424, 462};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -1610,7 +1610,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1610,7 +1610,7 @@ TEST_CASE(quant_dot_3args_general)
std::vector<int> gold = { std::vector<int> gold = {
1966, 2025, 2084, 2143, 2202, 2261, 2320, 2183, 2250, 2317, 2384, 2451, 2518, 2585}; 1966, 2025, 2084, 2143, 2202, 2261, 2320, 2183, 2250, 2317, 2384, 2451, 2518, 2585};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -1641,7 +1641,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1641,7 +1641,7 @@ TEST_CASE(quant_dot_3args_general)
std::vector<int> gold = { std::vector<int> gold = {
286, 737, 1188, 1639, 2090, 2541, 2992, 755, 2230, 3705, 5180, 6655, 8130, 9605}; 286, 737, 1188, 1639, 2090, 2541, 2992, 755, 2230, 3705, 5180, 6655, 8130, 9605};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -1674,7 +1674,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1674,7 +1674,7 @@ TEST_CASE(quant_dot_3args_general)
std::vector<int> gold = { std::vector<int> gold = {
844, 2190, 3536, 4882, 6228, 7574, 8920, 942, 2480, 4018, 5556, 7094, 8632, 10170}; 844, 2190, 3536, 4882, 6228, 7574, 8920, 942, 2480, 4018, 5556, 7094, 8632, 10170};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -1710,7 +1710,7 @@ TEST_CASE(quant_dot_3args_batch) ...@@ -1710,7 +1710,7 @@ TEST_CASE(quant_dot_3args_batch)
5916, 6004, 6092, 6180, 6268, 6356, 9762, 9866, 9970, 10074, 10178, 10282, 5916, 6004, 6092, 6180, 6268, 6356, 9762, 9866, 9970, 10074, 10178, 10282,
10386, 11288, 11408, 11528, 11648, 11768, 11888, 12008}; 10386, 11288, 11408, 11528, 11648, 11768, 11888, 12008};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
...@@ -1748,7 +1748,7 @@ TEST_CASE(quant_dot_3args_batch) ...@@ -1748,7 +1748,7 @@ TEST_CASE(quant_dot_3args_batch)
12258, 13237, 14216, 15195, 16174, 17153, 24012, 25311, 26610, 27909, 29208, 30507, 12258, 13237, 14216, 15195, 16174, 17153, 24012, 25311, 26610, 27909, 29208, 30507,
24618, 25949, 27280, 28611, 29942, 31273, 25224, 26587, 27950, 29313, 30676, 32039}; 24618, 25949, 27280, 28611, 29942, 31273, 25224, 26587, 27950, 29313, 30676, 32039};
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
...@@ -72,7 +72,7 @@ static auto run_prog(int64_t iter_num, bool cond, int64_t ini_val) ...@@ -72,7 +72,7 @@ static auto run_prog(int64_t iter_num, bool cond, int64_t ini_val)
}; };
auto p = create_program(); auto p = create_program();
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
migraphx::parameter_map pp; migraphx::parameter_map pp;
pp["iter_num"] = migraphx::argument(si, &iter_num); pp["iter_num"] = migraphx::argument(si, &iter_num);
pp["ccond"] = migraphx::argument(sc, &cond); pp["ccond"] = migraphx::argument(sc, &cond);
......
...@@ -25,7 +25,8 @@ ...@@ -25,7 +25,8 @@
#include <vector> #include <vector>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
...@@ -41,7 +42,7 @@ TEST_CASE(argmax_test_nonstd_shape) ...@@ -41,7 +42,7 @@ TEST_CASE(argmax_test_nonstd_shape)
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), dl); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), dl);
mm->add_instruction(migraphx::make_op("argmax", {{"axis", -3}}), dl_trans); mm->add_instruction(migraphx::make_op("argmax", {{"axis", -3}}), dl_trans);
auto p_uncompiled = p; auto p_uncompiled = p;
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
auto res_gold = p_uncompiled.eval({}).back(); auto res_gold = p_uncompiled.eval({}).back();
std::vector<int64_t> result_vec; std::vector<int64_t> result_vec;
...@@ -60,7 +61,7 @@ TEST_CASE(argmin_test_nonstd_shape) ...@@ -60,7 +61,7 @@ TEST_CASE(argmin_test_nonstd_shape)
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), dl); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), dl);
mm->add_instruction(migraphx::make_op("argmin", {{"axis", -1}}), dl_trans); mm->add_instruction(migraphx::make_op("argmin", {{"axis", -1}}), dl_trans);
auto p_uncompiled = p; auto p_uncompiled = p;
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
auto res_gold = p_uncompiled.eval({}).back(); auto res_gold = p_uncompiled.eval({}).back();
std::vector<int64_t> result_vec; std::vector<int64_t> result_vec;
...@@ -82,7 +83,7 @@ TEST_CASE(isnan_broadcast_test) ...@@ -82,7 +83,7 @@ TEST_CASE(isnan_broadcast_test)
auto l1 = mm->add_instruction( auto l1 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", s1.lens()}}), l0); migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", s1.lens()}}), l0);
mm->add_instruction(migraphx::make_op("isnan"), l1); mm->add_instruction(migraphx::make_op("isnan"), l1);
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
...@@ -104,7 +105,7 @@ TEST_CASE(squeeze_transpose_test) ...@@ -104,7 +105,7 @@ TEST_CASE(squeeze_transpose_test)
auto* mm_uncompiled = p_uncompiled.get_main_module(); auto* mm_uncompiled = p_uncompiled.get_main_module();
mm_uncompiled->add_instruction(migraphx::make_op("contiguous"), mm_uncompiled->add_instruction(migraphx::make_op("contiguous"),
std::prev(mm_uncompiled->end())); std::prev(mm_uncompiled->end()));
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back(); auto expected_result = p_uncompiled.eval({}).back();
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {3, 4, 3}}); EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {3, 4, 3}});
...@@ -124,7 +125,7 @@ TEST_CASE(squeeze_multibroadcast_test) ...@@ -124,7 +125,7 @@ TEST_CASE(squeeze_multibroadcast_test)
auto* mm_uncompiled = p_uncompiled.get_main_module(); auto* mm_uncompiled = p_uncompiled.get_main_module();
mm_uncompiled->add_instruction(migraphx::make_op("contiguous"), mm_uncompiled->add_instruction(migraphx::make_op("contiguous"),
std::prev(mm_uncompiled->end())); std::prev(mm_uncompiled->end()));
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back(); auto expected_result = p_uncompiled.eval({}).back();
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {4, 3, 4, 3}}); EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {4, 3, 4, 3}});
...@@ -144,7 +145,7 @@ TEST_CASE(squeeze_slice_test) ...@@ -144,7 +145,7 @@ TEST_CASE(squeeze_slice_test)
auto* mm_uncompiled = p_uncompiled.get_main_module(); auto* mm_uncompiled = p_uncompiled.get_main_module();
mm_uncompiled->add_instruction(migraphx::make_op("contiguous"), mm_uncompiled->add_instruction(migraphx::make_op("contiguous"),
std::prev(mm_uncompiled->end())); std::prev(mm_uncompiled->end()));
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back(); auto expected_result = p_uncompiled.eval({}).back();
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {3, 3}}); EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {3, 3}});
...@@ -164,7 +165,7 @@ TEST_CASE(unsqueeze_transpose_test) ...@@ -164,7 +165,7 @@ TEST_CASE(unsqueeze_transpose_test)
auto* mm_uncompiled = p_uncompiled.get_main_module(); auto* mm_uncompiled = p_uncompiled.get_main_module();
mm_uncompiled->add_instruction(migraphx::make_op("contiguous"), mm_uncompiled->add_instruction(migraphx::make_op("contiguous"),
std::prev(mm_uncompiled->end())); std::prev(mm_uncompiled->end()));
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back(); auto expected_result = p_uncompiled.eval({}).back();
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {3, 4, 1, 3}}); EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {3, 4, 1, 3}});
...@@ -184,7 +185,7 @@ TEST_CASE(unsqueeze_multibroadcast_test) ...@@ -184,7 +185,7 @@ TEST_CASE(unsqueeze_multibroadcast_test)
auto* mm_uncompiled = p_uncompiled.get_main_module(); auto* mm_uncompiled = p_uncompiled.get_main_module();
mm_uncompiled->add_instruction(migraphx::make_op("contiguous"), mm_uncompiled->add_instruction(migraphx::make_op("contiguous"),
std::prev(mm_uncompiled->end())); std::prev(mm_uncompiled->end()));
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back(); auto expected_result = p_uncompiled.eval({}).back();
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {4, 4, 1, 3, 3}}); EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {4, 4, 1, 3, 3}});
...@@ -204,7 +205,7 @@ TEST_CASE(unsqueeze_slice_test) ...@@ -204,7 +205,7 @@ TEST_CASE(unsqueeze_slice_test)
auto* mm_uncompiled = p_uncompiled.get_main_module(); auto* mm_uncompiled = p_uncompiled.get_main_module();
mm_uncompiled->add_instruction(migraphx::make_op("contiguous"), mm_uncompiled->add_instruction(migraphx::make_op("contiguous"),
std::prev(mm_uncompiled->end())); std::prev(mm_uncompiled->end()));
p.compile(migraphx::ref::target{}); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back(); auto expected_result = p_uncompiled.eval({}).back();
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {2, 1, 3, 4, 1}}); EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {2, 1, 3, 4, 1}});
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment