Unverified Commit 40fbef9b authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'develop' into threaded_nms

parents d164b151 aeb9f78c
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/tf/tf_parser.hpp> #include <migraphx/tf/tf_parser.hpp>
#include <migraphx/tf/op_parser.hpp>
#include <iostream> #include <iostream>
#include <fstream> #include <fstream>
#include <unordered_map> #include <unordered_map>
...@@ -62,5 +63,7 @@ program parse_tf(const std::string& name, const tf_options& options) ...@@ -62,5 +63,7 @@ program parse_tf(const std::string& name, const tf_options& options)
return std::move(parser.prog); return std::move(parser.prog);
} }
std::vector<std::string> get_tf_operators() { return tf::get_op_parsers(); }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/optional.hpp> #include <migraphx/optional.hpp>
#include <migraphx/hash.hpp>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
...@@ -519,6 +520,38 @@ std::ostream& operator<<(std::ostream& os, const value& d) ...@@ -519,6 +520,38 @@ std::ostream& operator<<(std::ostream& os, const value& d)
return os; return os;
} }
template <class T>
std::size_t value_hash(const std::string& key, const T& x)
{
std::size_t h = hash_value(key);
hash_combine(h, x);
return h;
}
std::size_t value_hash(const std::string& key, std::nullptr_t) { return hash_value(key); }
std::size_t value_hash(const std::string& key, const std::vector<value>& x)
{
std::size_t h = hash_value(key);
for(const auto& v : x)
hash_combine(h, v);
return h;
}
std::size_t value_hash(const std::string& key, const value::binary& x)
{
std::size_t h = hash_value(key);
for(const auto& v : x)
hash_combine(h, v);
return h;
}
std::size_t value::hash() const
{
std::size_t h = 0;
this->visit_value([&](const auto& a) { h = value_hash(this->get_key(), a); });
return h;
}
void value::debug_print(bool show_type) const void value::debug_print(bool show_type) const
{ {
if(show_type) if(show_type)
......
...@@ -35,7 +35,7 @@ bool verify_args(const std::string& name, ...@@ -35,7 +35,7 @@ bool verify_args(const std::string& name,
bool passed = true; bool passed = true;
visit_all(ref_arg, target_arg)([&](auto ref, auto target) { visit_all(ref_arg, target_arg)([&](auto ref, auto target) {
double error; double error;
passed = verify_range(ref, target, tolerance, &error); passed = verify::verify_range(ref, target, tolerance, &error);
if(not passed) if(not passed)
{ {
// TODO: Check for nans // TODO: Check for nans
...@@ -45,27 +45,27 @@ bool verify_args(const std::string& name, ...@@ -45,27 +45,27 @@ bool verify_args(const std::string& name,
std::cout << "ref:" << ref << std::endl; std::cout << "ref:" << ref << std::endl;
if(target.size() < 32) if(target.size() < 32)
std::cout << "target:" << target << std::endl; std::cout << "target:" << target << std::endl;
if(range_zero(ref)) if(verify::range_zero(ref))
std::cout << "Ref data is all zeros" << std::endl; std::cout << "Ref data is all zeros" << std::endl;
if(range_zero(target)) if(verify::range_zero(target))
std::cout << "Target data is all zeros" << std::endl; std::cout << "Target data is all zeros" << std::endl;
auto mxdiff = max_diff(ref, target); auto mxdiff = verify::max_diff(ref, target);
std::cout << "Max diff: " << mxdiff << std::endl; std::cout << "Max diff: " << mxdiff << std::endl;
auto idx = mismatch_idx(ref, target, float_equal); auto idx = verify::mismatch_idx(ref, target, float_equal);
if(idx < range_distance(ref)) if(idx < verify::range_distance(ref))
{ {
std::cout << "Mismatch at " << idx << ": " << ref[idx] << " != " << target[idx] std::cout << "Mismatch at " << idx << ": " << ref[idx] << " != " << target[idx]
<< std::endl; << std::endl;
} }
auto ref_nan_idx = find_idx(ref, not_finite); auto ref_nan_idx = find_idx(ref, verify::not_finite);
if(ref_nan_idx >= 0) if(ref_nan_idx >= 0)
std::cout << "Non finite number found in ref at " << ref_nan_idx << ": " std::cout << "Non finite number found in ref at " << ref_nan_idx << ": "
<< ref[ref_nan_idx] << std::endl; << ref[ref_nan_idx] << std::endl;
auto target_nan_idx = find_idx(target, not_finite); auto target_nan_idx = find_idx(target, verify::not_finite);
if(target_nan_idx >= 0) if(target_nan_idx >= 0)
std::cout << "Non finite number found in target at " << target_nan_idx << ": " std::cout << "Non finite number found in target at " << target_nan_idx << ": "
<< target[target_nan_idx] << std::endl; << target[target_nan_idx] << std::endl;
...@@ -73,27 +73,27 @@ bool verify_args(const std::string& name, ...@@ -73,27 +73,27 @@ bool verify_args(const std::string& name,
} }
else else
{ {
if(range_zero(ref)) if(verify::range_zero(ref))
std::cout << "Ref data is all zeros" << std::endl; std::cout << "Ref data is all zeros" << std::endl;
if(range_zero(target)) if(verify::range_zero(target))
std::cout << "Target data is all zeros" << std::endl; std::cout << "Target data is all zeros" << std::endl;
// auto mxdiff = max_diff(ref, target); // auto mxdiff = max_diff(ref, target);
// std::cout << "Max diff: " << mxdiff << std::endl; // std::cout << "Max diff: " << mxdiff << std::endl;
// auto idx = mismatch_idx(ref, target, float_equal); // auto idx = mismatch_idx(ref, target, float_equal);
// if(idx < range_distance(ref)) // if(idx < verify::range_distance(ref))
// { // {
// std::cout << "Mismatch at " << idx << ": " << ref[idx] << " != " << target[idx] // std::cout << "Mismatch at " << idx << ": " << ref[idx] << " != " << target[idx]
// << std::endl; // << std::endl;
// } // }
auto ref_nan_idx = find_idx(ref, not_finite); auto ref_nan_idx = find_idx(ref, verify::not_finite);
if(ref_nan_idx >= 0) if(ref_nan_idx >= 0)
std::cout << "Non finite number found in ref at " << ref_nan_idx << ": " std::cout << "Non finite number found in ref at " << ref_nan_idx << ": "
<< ref[ref_nan_idx] << std::endl; << ref[ref_nan_idx] << std::endl;
auto target_nan_idx = find_idx(target, not_finite); auto target_nan_idx = find_idx(target, verify::not_finite);
if(target_nan_idx >= 0) if(target_nan_idx >= 0)
std::cout << "Non finite number found in target at " << target_nan_idx << ": " std::cout << "Non finite number found in target at " << target_nan_idx << ": "
<< target[target_nan_idx] << std::endl; << target[target_nan_idx] << std::endl;
......
...@@ -24,8 +24,6 @@ ...@@ -24,8 +24,6 @@
cmake_policy(SET CMP0057 NEW) cmake_policy(SET CMP0057 NEW)
include(CTest)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
include(ProcessorCount) include(ProcessorCount)
ProcessorCount(N) ProcessorCount(N)
...@@ -114,7 +112,7 @@ function(add_test_executable TEST_NAME) ...@@ -114,7 +112,7 @@ function(add_test_executable TEST_NAME)
target_include_directories(${TEST_NAME} PUBLIC include) target_include_directories(${TEST_NAME} PUBLIC include)
endfunction(add_test_executable) endfunction(add_test_executable)
file(GLOB TESTS ${CONFIGURE_DEPENDS} *.cpp) file(GLOB TESTS CONFIGURE_DEPENDS *.cpp)
foreach(TEST ${TESTS}) foreach(TEST ${TESTS})
get_filename_component(BASE_NAME ${TEST} NAME_WE) get_filename_component(BASE_NAME ${TEST} NAME_WE)
...@@ -124,7 +122,7 @@ endforeach() ...@@ -124,7 +122,7 @@ endforeach()
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
# gpu tests # gpu tests
file(GLOB GPU_TESTS ${CONFIGURE_DEPENDS} gpu/*.cpp) file(GLOB GPU_TESTS CONFIGURE_DEPENDS gpu/*.cpp)
foreach(TEST ${GPU_TESTS}) foreach(TEST ${GPU_TESTS})
get_filename_component(BASE_NAME ${TEST} NAME_WE) get_filename_component(BASE_NAME ${TEST} NAME_WE)
...@@ -134,13 +132,16 @@ if(MIGRAPHX_ENABLE_GPU) ...@@ -134,13 +132,16 @@ if(MIGRAPHX_ENABLE_GPU)
COST 10 COST 10
RESOURCE_LOCK gpu RESOURCE_LOCK gpu
) )
if(MIGRAPHX_USE_HIPRTC)
target_compile_definitions(test_gpu_${BASE_NAME} PUBLIC -DMIGRAPHX_USE_HIPRTC)
endif()
target_link_libraries(test_gpu_${BASE_NAME} migraphx_gpu migraphx_kernels) target_link_libraries(test_gpu_${BASE_NAME} migraphx_gpu migraphx_kernels)
endforeach() endforeach()
endif() endif()
if(MIGRAPHX_ENABLE_FPGA) if(MIGRAPHX_ENABLE_FPGA)
# fpga tests # fpga tests
file(GLOB FPGA_TESTS ${CONFIGURE_DEPENDS} fpga/*.cpp) file(GLOB FPGA_TESTS CONFIGURE_DEPENDS fpga/*.cpp)
foreach(TEST ${FPGA_TESTS}) foreach(TEST ${FPGA_TESTS})
get_filename_component(BASE_NAME ${TEST} NAME_WE) get_filename_component(BASE_NAME ${TEST} NAME_WE)
...@@ -187,6 +188,25 @@ if(MIGRAPHX_ENABLE_PYTHON) ...@@ -187,6 +188,25 @@ if(MIGRAPHX_ENABLE_PYTHON)
add_subdirectory(py) add_subdirectory(py)
endif() endif()
# multitarget test
if(MIGRAPHX_ENABLE_GPU AND MIGRAPHX_ENABLE_CPU AND MIGRAPHX_ENABLE_FPGA)
set(TEST_MULTI_TARGET_DIR ${CMAKE_CURRENT_SOURCE_DIR}/multi_target)
file(GLOB MULTI_TARGET_TESTS CONFIGURE_DEPENDS ${TEST_MULTI_TARGET_DIR}/*.cpp)
foreach(MULTI_TARGET_TEST ${MULTI_TARGET_TESTS})
get_filename_component(BASE_NAME ${MULTI_TARGET_TEST} NAME_WE)
set(TEST_NAME test_${BASE_NAME})
add_executable(${TEST_NAME} ${MULTI_TARGET_TEST})
rocm_clang_tidy_check(${TEST_NAME})
target_link_libraries(${TEST_NAME} migraphx migraphx_onnx migraphx_tf migraphx_all_targets)
target_include_directories(${TEST_NAME} PUBLIC include)
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}> WORKING_DIRECTORY ${TEST_MULTI_TARGET_DIR})
add_dependencies(tests ${TEST_NAME})
add_dependencies(check ${TEST_NAME})
endforeach()
endif()
function(test_header NAME HEADER) function(test_header NAME HEADER)
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/header-main-include-${NAME}.cpp file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/header-main-include-${NAME}.cpp
"#include <${HEADER}>\nint main() {}\n" "#include <${HEADER}>\nint main() {}\n"
...@@ -201,14 +221,14 @@ function(test_header NAME HEADER) ...@@ -201,14 +221,14 @@ function(test_header NAME HEADER)
endfunction() endfunction()
function(test_headers PREFIX) function(test_headers PREFIX)
file(GLOB HEADERS ${CONFIGURE_DEPENDS} ${ARGN}) file(GLOB HEADERS CONFIGURE_DEPENDS ${ARGN})
foreach(HEADER ${HEADERS}) foreach(HEADER ${HEADERS})
file(RELATIVE_PATH HEADER_REL ${CMAKE_SOURCE_DIR} ${HEADER}) file(RELATIVE_PATH HEADER_REL ${CMAKE_SOURCE_DIR} ${HEADER})
string(MAKE_C_IDENTIFIER ${HEADER_REL} TEST_NAME) string(MAKE_C_IDENTIFIER ${HEADER_REL} TEST_NAME)
get_filename_component(BASE_NAME ${HEADER} NAME_WE) get_filename_component(BASE_NAME ${HEADER} NAME_WE)
test_header(header_${TEST_NAME} ${PREFIX}/${BASE_NAME}.hpp) test_header(header_${TEST_NAME} ${PREFIX}/${BASE_NAME}.hpp)
target_link_libraries(header_${TEST_NAME} migraphx_all_targets) target_link_libraries(header_${TEST_NAME} migraphx migraphx_onnx migraphx_tf migraphx_all_targets)
endforeach() endforeach()
endfunction() endfunction()
...@@ -225,3 +245,4 @@ if(MIGRAPHX_ENABLE_FPGA) ...@@ -225,3 +245,4 @@ if(MIGRAPHX_ENABLE_FPGA)
test_headers(migraphx/fpga ${CMAKE_SOURCE_DIR}/src/targets/fpga/include/migraphx/fpga/*.hpp) test_headers(migraphx/fpga ${CMAKE_SOURCE_DIR}/src/targets/fpga/include/migraphx/fpga/*.hpp)
endif() endif()
...@@ -36,7 +36,7 @@ endfunction() ...@@ -36,7 +36,7 @@ endfunction()
function(add_c_api_test TEST_NAME TEST_SRC TEST_DIR) function(add_c_api_test TEST_NAME TEST_SRC TEST_DIR)
set(NAME test_api_${TEST_NAME}) set(NAME test_api_${TEST_NAME})
add_executable(${NAME} EXCLUDE_FROM_ALL ${TEST_SRC}) add_executable(${NAME} EXCLUDE_FROM_ALL ${TEST_SRC})
target_link_libraries(${NAME} migraphx_c migraphx) target_link_libraries(${NAME} migraphx_c)
target_include_directories(${NAME} PUBLIC ../include) target_include_directories(${NAME} PUBLIC ../include)
add_test(NAME ${NAME} COMMAND $<TARGET_FILE:${NAME}> WORKING_DIRECTORY ${TEST_DIR}) add_test(NAME ${NAME} COMMAND $<TARGET_FILE:${NAME}> WORKING_DIRECTORY ${TEST_DIR})
add_dependencies(tests ${NAME}) add_dependencies(tests ${NAME})
...@@ -48,6 +48,7 @@ add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR}) ...@@ -48,6 +48,7 @@ add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR})
add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR}) add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR})
add_api_test(lookup test_lookup.cpp ${TEST_ONNX_DIR}) add_api_test(lookup test_lookup.cpp ${TEST_ONNX_DIR})
add_api_test(module_construct test_module_construct.cpp ${TEST_ONNX_DIR}) add_api_test(module_construct test_module_construct.cpp ${TEST_ONNX_DIR})
add_api_test(dynamic_shape test_dynamic_shape.cpp ${TEST_ONNX_DIR})
add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR}) add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR})
add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR}) add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR})
add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR}) add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR})
......
...@@ -30,7 +30,7 @@ void expect_equal(const char* x, const char* y) ...@@ -30,7 +30,7 @@ void expect_equal(const char* x, const char* y)
abort(); abort();
} }
int main() int main(void)
{ {
char name[1024]; char name[1024];
migraphx_operation_t op; migraphx_operation_t op;
......
...@@ -99,7 +99,7 @@ TEST_CASE(run_sigmoid_custom_op) ...@@ -99,7 +99,7 @@ TEST_CASE(run_sigmoid_custom_op)
EXPECT(bool{result == migraphx::argument(s, expected_result.data())}); EXPECT(bool{result == migraphx::argument(s, expected_result.data())});
} }
extern "C" void migraphx_test_private_disable_exception_catch(bool b); extern "C" MIGRAPHX_C_EXPORT void migraphx_test_private_disable_exception_catch(bool);
TEST_CASE(run_sigmoid_with_incorrect_shape) TEST_CASE(run_sigmoid_with_incorrect_shape)
{ {
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include "test.hpp"
TEST_CASE(create_dynamic_dimensions)
{
migraphx::dynamic_dimension dd0{1, 4};
EXPECT(not dd0.is_fixed());
migraphx::dynamic_dimension dd1{4, 4};
EXPECT(dd1.is_fixed());
migraphx::optimals opts{1, 2, 4};
migraphx::dynamic_dimension dd2{1, 4, opts};
migraphx::dynamic_dimensions dyn_dims0{dd0, dd1, dd2};
CHECK(bool{dyn_dims0[0] == dd0});
CHECK(bool{dyn_dims0[1] == dd1});
CHECK(bool{dyn_dims0[2] == dd2});
CHECK(bool{dyn_dims0[2] != dd0});
EXPECT(dyn_dims0.size() == 3);
}
TEST_CASE(create_dynamic_shape)
{
migraphx::dynamic_dimensions dyn_dims(migraphx::dynamic_dimension{1, 4},
migraphx::dynamic_dimension{78, 92},
migraphx::dynamic_dimension{1, 4, {1, 4}});
migraphx::shape dyn_shape{migraphx_shape_float_type, dyn_dims};
CHECK(bool{dyn_shape.dynamic()});
CHECK(bool{dyn_shape.dyn_dims()[0] == migraphx::dynamic_dimension{1, 4}});
migraphx::shape static_shape{migraphx_shape_float_type, {3, 8}};
EXPECT(not static_shape.dynamic());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#include <hip/hip_runtime_api.h> #include <hip/hip_runtime_api.h>
#include <migraphx/migraphx.h> #include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp> #include <migraphx/migraphx.hpp>
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include "test.hpp" #include "test.hpp"
...@@ -35,7 +34,6 @@ TEST_CASE(load_and_run) ...@@ -35,7 +34,6 @@ TEST_CASE(load_and_run)
auto shapes_before = p.get_output_shapes(); auto shapes_before = p.get_output_shapes();
migraphx::compile_options options; migraphx::compile_options options;
options.set_offload_copy(); options.set_offload_copy();
options.set_exhaustive_tune_flag();
p.compile(migraphx::target("gpu"), options); p.compile(migraphx::target("gpu"), options);
auto shapes_after = p.get_output_shapes(); auto shapes_after = p.get_output_shapes();
CHECK(shapes_before.size() == 1); CHECK(shapes_before.size() == 1);
...@@ -72,6 +70,105 @@ hip_ptr get_hip_buffer(size_t size) ...@@ -72,6 +70,105 @@ hip_ptr get_hip_buffer(size_t size)
return hip_ptr{ptr}; return hip_ptr{ptr};
} }
// TODO: placeholder until we have a way to copy tuple arguments to/from device through c++ api
// TEST_CASE(dynamic_batch_load_and_run)
//{
// migraphx::onnx_options o_options;
// migraphx::dynamic_dimensions dyn_dims = {{1, 4, {2, 4}}, {3, 3}, {4, 4}, {4, 4}};
// o_options.set_dyn_input_parameter_shape("0", dyn_dims);
// dyn_dims = {{2, 2}, {3, 3}, {3, 3}, {3, 3}};
// o_options.set_dyn_input_parameter_shape("1", dyn_dims);
// auto p = migraphx::parse_onnx("conv_dynamic_batch_test.onnx", o_options);
// migraphx::compile_options c_options;
// c_options.set_split_single_dyn_dim();
// p.compile(migraphx::target("gpu"), c_options);
// auto out_shapes = p.get_output_shapes();
// CHECK(out_shapes.size() == 1);
// EXPECT(out_shapes[0].dynamic());
//
// std::vector<float> a(0.12, 2*3*4*4);
// std::vector<float> c(0.75, 2*3*3*3);
//
// auto param_shapes = p.get_parameter_shapes();
// int batch_size = 2;
// std::unordered_map<std::string, migraphx::argument> arg_map;
//
// arg_map["0"] = migraphx::argument(param_shapes["0"].to_static(batch_size), a.data());
// arg_map["1"] = migraphx::argument(param_shapes["1"].to_static(batch_size), c.data());
//
// migraphx::program_parameters pp;
// std::vector<hip_ptr> buffs;
// std::vector<migraphx::argument> args;
//
// // copy to GPU and create parameter map
// for(auto&& name : param_shapes.names())
// {
// if(arg_map.find(name) != arg_map.end())
// {
// args.push_back(arg_map.at(name));
// }
// else
// {
// migraphx::shape static_shape = param_shapes[name].to_static(batch_size);
// auto output_arg = migraphx::argument(static_shape);
// args.push_back(output_arg);
// }
// buffs.push_back(get_hip_buffer(args.rbegin()->get_shape().bytes()));
// auto err = hipMemcpy(buffs.rbegin()->get(),
// args.rbegin()->data(),
// args.rbegin()->get_shape().bytes(),
// hipMemcpyHostToDevice);
// EXPECT(err == hipSuccess);
// pp.add(name, migraphx::argument(args.rbegin()->get_shape(), buffs.rbegin()->get()));
// }
//
// auto output = p.eval(pp)[0];
//
// // copy output back to host
// auto host_arg = migraphx::argument(output.get_shape());
// auto err = hipMemcpy(
// host_arg.data(), output.data(), output.get_shape().bytes(), hipMemcpyDeviceToHost);
// EXPECT(err == hipSuccess);
//}
TEST_CASE(dynamic_batch_load_and_run_offload)
{
migraphx::onnx_options o_options;
migraphx::dynamic_dimensions dyn_dims = {migraphx::dynamic_dimension{1, 4, {2, 4}},
migraphx::dynamic_dimension{3, 3},
migraphx::dynamic_dimension{4, 4},
migraphx::dynamic_dimension{4, 4}};
o_options.set_dyn_input_parameter_shape("0", dyn_dims);
dyn_dims = {migraphx::dynamic_dimension{2, 2},
migraphx::dynamic_dimension{3, 3},
migraphx::dynamic_dimension{3, 3},
migraphx::dynamic_dimension{3, 3}};
o_options.set_dyn_input_parameter_shape("1", dyn_dims);
auto p = migraphx::parse_onnx("conv_dynamic_batch_test.onnx", o_options);
auto shapes_before = p.get_output_shapes();
migraphx::compile_options c_options;
c_options.set_offload_copy();
p.compile(migraphx::target("gpu"), c_options);
auto out_shapes = p.get_output_shapes();
EXPECT(out_shapes.size() == 1);
EXPECT(out_shapes[0].dynamic());
// batch size = 2
std::vector<float> a(2 * 3 * 4 * 4, 0.12);
std::vector<float> c(2 * 3 * 3 * 3, 0.75);
migraphx::program_parameters pp;
auto param_shapes = p.get_parameter_shapes();
pp.add("0",
migraphx::argument(migraphx::shape(migraphx_shape_float_type, {2, 3, 4, 4}), a.data()));
pp.add("1",
migraphx::argument(migraphx::shape(migraphx_shape_float_type, {2, 3, 3, 3}), c.data()));
auto outputs = p.eval(pp);
EXPECT(shapes_before.size() == outputs.size());
EXPECT(bool{outputs.front().get_shape() ==
migraphx::shape(migraphx_shape_float_type, {2, 2, 2, 2})});
}
TEST_CASE(load_and_run_async) TEST_CASE(load_and_run_async)
{ {
auto p = migraphx::parse_onnx("conv_relu_maxpool_test.onnx"); auto p = migraphx::parse_onnx("conv_relu_maxpool_test.onnx");
......
...@@ -193,6 +193,15 @@ TEST_CASE(value_argument) ...@@ -193,6 +193,15 @@ TEST_CASE(value_argument)
EXPECT(a4 == a2); EXPECT(a4 == a2);
} }
TEST_CASE(value_empty_argument)
{
migraphx::argument a5;
EXPECT(a5.empty());
auto v3 = migraphx::to_value(a5);
auto a6 = migraphx::from_value<migraphx::argument>(v3);
EXPECT(a6 == a5);
}
TEST_CASE(value_tuple) TEST_CASE(value_tuple)
{ {
auto a1 = make_tuple(3, 3.0, make_tuple(3, 4)); auto a1 = make_tuple(3, 3.0, make_tuple(3, 4));
......
...@@ -41,7 +41,7 @@ TEST_CASE(simple_test) ...@@ -41,7 +41,7 @@ TEST_CASE(simple_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == count); EXPECT(std::distance(mm->begin(), mm->end()) == count);
...@@ -57,7 +57,7 @@ TEST_CASE(simple_test_nop) ...@@ -57,7 +57,7 @@ TEST_CASE(simple_test_nop)
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(nop{}); mm->add_instruction(nop{});
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == count); EXPECT(std::distance(mm->begin(), mm->end()) == count);
...@@ -73,7 +73,7 @@ TEST_CASE(simple_test_nop2) ...@@ -73,7 +73,7 @@ TEST_CASE(simple_test_nop2)
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(nop{}); mm->add_instruction(nop{});
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(nop{}); mm->add_instruction(nop{});
run_pass(p); run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == 2); EXPECT(std::distance(mm->begin(), mm->end()) == 2);
...@@ -88,8 +88,8 @@ TEST_CASE(duplicate_test1) ...@@ -88,8 +88,8 @@ TEST_CASE(duplicate_test1)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 1)); EXPECT(std::distance(mm->begin(), mm->end()) == (count - 1));
...@@ -104,9 +104,9 @@ TEST_CASE(duplicate_test2) ...@@ -104,9 +104,9 @@ TEST_CASE(duplicate_test2)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(minus_op{}, one, two); mm->add_instruction(migraphx::make_op("sub"), one, two);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 2)); EXPECT(std::distance(mm->begin(), mm->end()) == (count - 2));
...@@ -121,11 +121,11 @@ TEST_CASE(depth_test) ...@@ -121,11 +121,11 @@ TEST_CASE(depth_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto x1 = mm->add_instruction(sum_op{}, one, two); auto x1 = mm->add_instruction(migraphx::make_op("add"), one, two);
auto x2 = mm->add_instruction(sum_op{}, one, two); auto x2 = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(minus_op{}, x1, x2); mm->add_instruction(migraphx::make_op("sub"), x1, x2);
mm->add_instruction(minus_op{}, x1, x2); mm->add_instruction(migraphx::make_op("sub"), x1, x2);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 4)); EXPECT(std::distance(mm->begin(), mm->end()) == (count - 4));
...@@ -141,7 +141,7 @@ TEST_CASE(undefined_test) ...@@ -141,7 +141,7 @@ TEST_CASE(undefined_test)
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(migraphx::make_op("undefined")); mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
auto count = std::distance(mm->begin(), mm->end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == count - 1); EXPECT(std::distance(mm->begin(), mm->end()) == count - 1);
...@@ -232,7 +232,6 @@ TEST_CASE(reused_twice) ...@@ -232,7 +232,6 @@ TEST_CASE(reused_twice)
auto count = std::distance(mm->begin(), mm->end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
p.debug_print();
EXPECT(std::distance(mm->begin(), mm->end()) != count); EXPECT(std::distance(mm->begin(), mm->end()) != count);
EXPECT(std::distance(mm->begin(), mm->end()) == 4); EXPECT(std::distance(mm->begin(), mm->end()) == 4);
} }
...@@ -274,4 +273,17 @@ TEST_CASE(param_not_eliminated) ...@@ -274,4 +273,17 @@ TEST_CASE(param_not_eliminated)
EXPECT(p == create_program()); EXPECT(p == create_program());
} }
TEST_CASE(tuple_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(tuple_op{}, one, two);
mm->add_return({one, two});
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 1));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -45,7 +45,7 @@ TEST_CASE(simple_test) ...@@ -45,7 +45,7 @@ TEST_CASE(simple_test)
auto one_identity = mm->add_instruction(migraphx::make_op("identity"), one); auto one_identity = mm->add_instruction(migraphx::make_op("identity"), one);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto two_identity = mm->add_instruction(migraphx::make_op("identity"), two); auto two_identity = mm->add_instruction(migraphx::make_op("identity"), two);
mm->add_instruction(sum_op{}, one_identity, two_identity); mm->add_instruction(migraphx::make_op("add"), one_identity, two_identity);
run_pass(p); run_pass(p);
EXPECT(std::none_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) { EXPECT(std::none_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity"; return ins.name() == "identity";
...@@ -62,7 +62,7 @@ TEST_CASE(simple_test_end) ...@@ -62,7 +62,7 @@ TEST_CASE(simple_test_end)
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto ans = mm->add_instruction(sum_op{}, one, two); auto ans = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(migraphx::make_op("identity"), ans); mm->add_instruction(migraphx::make_op("identity"), ans);
run_pass(p); run_pass(p);
EXPECT(std::none_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) { EXPECT(std::none_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) {
...@@ -81,8 +81,8 @@ TEST_CASE(simple_test_end_dependency) ...@@ -81,8 +81,8 @@ TEST_CASE(simple_test_end_dependency)
auto one = mm->add_literal(1.0); auto one = mm->add_literal(1.0);
auto two = mm->add_literal(2.0); auto two = mm->add_literal(2.0);
auto three = mm->add_literal(3.0); auto three = mm->add_literal(3.0);
auto ans = mm->add_instruction(sum_op{}, one, two); auto ans = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(sum_op{}, ans, three); mm->add_instruction(migraphx::make_op("add"), ans, three);
mm->add_instruction(migraphx::make_op("identity"), ans); mm->add_instruction(migraphx::make_op("identity"), ans);
run_pass(p); run_pass(p);
EXPECT(std::any_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) { EXPECT(std::any_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) {
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/make_op.hpp>
#include <sstream> #include <sstream>
#include "test.hpp" #include "test.hpp"
#include <basic_ops.hpp> #include <basic_ops.hpp>
...@@ -49,7 +50,7 @@ struct id_target ...@@ -49,7 +50,7 @@ struct id_target
struct id_ctx_op struct id_ctx_op
{ {
std::string name() const { return "id_ctx_op"; } std::string name() const { return ""; }
migraphx::argument migraphx::argument
compute(id_target::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const compute(id_target::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
{ {
...@@ -156,7 +157,7 @@ TEST_CASE(literal_test1) ...@@ -156,7 +157,7 @@ TEST_CASE(literal_test1)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3}); EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4}); EXPECT(result != migraphx::literal{4});
...@@ -168,8 +169,8 @@ TEST_CASE(literal_test2) ...@@ -168,8 +169,8 @@ TEST_CASE(literal_test2)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two); auto sum1 = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(sum_op{}, sum1, two); mm->add_instruction(migraphx::make_op("add"), sum1, two);
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{5}); EXPECT(result == migraphx::literal{5});
...@@ -182,7 +183,7 @@ TEST_CASE(print_test) ...@@ -182,7 +183,7 @@ TEST_CASE(print_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::int32_type}); auto x = mm->add_parameter("x", {migraphx::shape::int32_type});
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, x, two); mm->add_instruction(migraphx::make_op("add"), x, two);
std::stringstream ss; std::stringstream ss;
ss << p; ss << p;
...@@ -197,7 +198,7 @@ TEST_CASE(param_test) ...@@ -197,7 +198,7 @@ TEST_CASE(param_test)
auto x = mm->add_parameter("x", {migraphx::shape::int32_type}); auto x = mm->add_parameter("x", {migraphx::shape::int32_type});
auto y = mm->add_parameter("y", {migraphx::shape::int32_type}); auto y = mm->add_parameter("y", {migraphx::shape::int32_type});
mm->add_instruction(sum_op{}, x, y); mm->add_instruction(migraphx::make_op("add"), x, y);
auto result = p.eval({{"x", migraphx::literal{1}.get_argument()}, auto result = p.eval({{"x", migraphx::literal{1}.get_argument()},
{"y", migraphx::literal{2}.get_argument()}}) {"y", migraphx::literal{2}.get_argument()}})
.back(); .back();
...@@ -227,7 +228,7 @@ TEST_CASE(param_error_shape_test) ...@@ -227,7 +228,7 @@ TEST_CASE(param_error_shape_test)
auto x = mm->add_parameter("x", {migraphx::shape::int32_type, {1, 1}}); auto x = mm->add_parameter("x", {migraphx::shape::int32_type, {1, 1}});
auto y = mm->add_parameter("y", {migraphx::shape::int32_type, {1, 1}}); auto y = mm->add_parameter("y", {migraphx::shape::int32_type, {1, 1}});
mm->add_instruction(sum_op{}, x, y); mm->add_instruction(migraphx::make_op("add"), x, y);
EXPECT(test::throws<migraphx::exception>( EXPECT(test::throws<migraphx::exception>(
[&] { [&] {
p.eval({ p.eval({
...@@ -245,7 +246,7 @@ TEST_CASE(get_param1) ...@@ -245,7 +246,7 @@ TEST_CASE(get_param1)
migraphx::shape s{migraphx::shape::int32_type, {1, 2}}; migraphx::shape s{migraphx::shape::int32_type, {1, 2}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
mm->add_instruction(sum_op{}, x, y); mm->add_instruction(migraphx::make_op("add"), x, y);
EXPECT(bool{p.get_parameter("x") == x}); EXPECT(bool{p.get_parameter("x") == x});
EXPECT(bool{p.get_parameter("y") == y}); EXPECT(bool{p.get_parameter("y") == y});
EXPECT(bool{p.get_parameter("nonexistent") == mm->end()}); EXPECT(bool{p.get_parameter("nonexistent") == mm->end()});
...@@ -257,7 +258,7 @@ TEST_CASE(get_param2) ...@@ -257,7 +258,7 @@ TEST_CASE(get_param2)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
EXPECT(bool{p.get_parameter("nonexistent") == mm->end()}); EXPECT(bool{p.get_parameter("nonexistent") == mm->end()});
} }
...@@ -268,7 +269,7 @@ TEST_CASE(get_param_shapes) ...@@ -268,7 +269,7 @@ TEST_CASE(get_param_shapes)
migraphx::shape s{migraphx::shape::int32_type, {1, 2}}; migraphx::shape s{migraphx::shape::int32_type, {1, 2}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
mm->add_instruction(sum_op{}, x, y); mm->add_instruction(migraphx::make_op("add"), x, y);
auto m = p.get_parameter_shapes(); auto m = p.get_parameter_shapes();
EXPECT(m.count("nonexistent") == 0); EXPECT(m.count("nonexistent") == 0);
EXPECT(m.at("x") == s); EXPECT(m.at("x") == s);
...@@ -281,8 +282,8 @@ TEST_CASE(replace_test) ...@@ -281,8 +282,8 @@ TEST_CASE(replace_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two); auto sum = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->replace_instruction(sum, minus_op{}, two, one); mm->replace_instruction(sum, migraphx::make_op("sub"), two, one);
EXPECT(bool{p.validate() == mm->end()}); EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -296,8 +297,8 @@ TEST_CASE(replace_ins_test) ...@@ -296,8 +297,8 @@ TEST_CASE(replace_ins_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two); auto sum = mm->add_instruction(migraphx::make_op("add"), one, two);
auto minus = mm->add_instruction(minus_op{}, two, one); auto minus = mm->add_instruction(migraphx::make_op("sub"), two, one);
mm->replace_instruction(sum, minus); mm->replace_instruction(sum, minus);
EXPECT(bool{p.validate() == mm->end()}); EXPECT(bool{p.validate() == mm->end()});
...@@ -312,8 +313,8 @@ TEST_CASE(replace_ins_test2) ...@@ -312,8 +313,8 @@ TEST_CASE(replace_ins_test2)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two); auto sum = mm->add_instruction(migraphx::make_op("add"), one, two);
auto minus = mm->add_instruction(minus_op{}, two, one); auto minus = mm->add_instruction(migraphx::make_op("sub"), two, one);
mm->add_instruction(pass_op{}, minus); mm->add_instruction(pass_op{}, minus);
mm->replace_instruction(two, sum); mm->replace_instruction(two, sum);
EXPECT(bool{p.validate() == mm->end()}); EXPECT(bool{p.validate() == mm->end()});
...@@ -329,8 +330,8 @@ TEST_CASE(replace_op_test) ...@@ -329,8 +330,8 @@ TEST_CASE(replace_op_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, two, one); auto sum = mm->add_instruction(migraphx::make_op("add"), two, one);
sum->replace(minus_op{}); sum->replace(migraphx::make_op("sub"));
EXPECT(bool{p.validate() == mm->end()}); EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -344,7 +345,7 @@ TEST_CASE(replace_op_recompute_shape_throw) ...@@ -344,7 +345,7 @@ TEST_CASE(replace_op_recompute_shape_throw)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two); auto sum = mm->add_instruction(migraphx::make_op("add"), one, two);
EXPECT(test::throws<migraphx::exception>([&] { sum->replace(unary_pass_op{}); })); EXPECT(test::throws<migraphx::exception>([&] { sum->replace(unary_pass_op{}); }));
} }
...@@ -354,11 +355,11 @@ TEST_CASE(insert_replace_test) ...@@ -354,11 +355,11 @@ TEST_CASE(insert_replace_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two); auto sum1 = mm->add_instruction(migraphx::make_op("add"), one, two);
mm->add_instruction(sum_op{}, sum1, two); mm->add_instruction(migraphx::make_op("add"), sum1, two);
auto sum0 = mm->insert_instruction(sum1, sum_op{}, two, two); auto sum0 = mm->insert_instruction(sum1, migraphx::make_op("add"), two, two);
mm->replace_instruction(sum1, minus_op{}, sum0, two); mm->replace_instruction(sum1, migraphx::make_op("sub"), sum0, two);
EXPECT(bool{p.validate() == mm->end()}); EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -372,8 +373,8 @@ TEST_CASE(remove_test1) ...@@ -372,8 +373,8 @@ TEST_CASE(remove_test1)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two); auto sum = mm->add_instruction(migraphx::make_op("add"), one, two);
auto removed = mm->add_instruction(minus_op{}, sum, one); auto removed = mm->add_instruction(migraphx::make_op("sub"), sum, one);
mm->remove_instruction(removed); mm->remove_instruction(removed);
EXPECT(bool{p.validate() == mm->end()}); EXPECT(bool{p.validate() == mm->end()});
...@@ -388,8 +389,8 @@ TEST_CASE(remove_test2) ...@@ -388,8 +389,8 @@ TEST_CASE(remove_test2)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto removed = mm->add_instruction(minus_op{}, two, one); auto removed = mm->add_instruction(migraphx::make_op("sub"), two, one);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
mm->remove_instruction(removed); mm->remove_instruction(removed);
EXPECT(bool{p.validate() == mm->end()}); EXPECT(bool{p.validate() == mm->end()});
...@@ -404,7 +405,7 @@ TEST_CASE(target_test) ...@@ -404,7 +405,7 @@ TEST_CASE(target_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(migraphx::make_op("add"), one, two);
p.compile(id_target{}); p.compile(id_target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3}); EXPECT(result == migraphx::literal{3});
...@@ -460,7 +461,7 @@ TEST_CASE(eval_context1) ...@@ -460,7 +461,7 @@ TEST_CASE(eval_context1)
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
p.compile(t); p.compile(t);
EXPECT(is_shared(t.ctx, p.get_context())); EXPECT(is_shared(t.ctx, p.get_context()));
p.eval({}).back(); std::ignore = p.eval({}).back();
EXPECT(is_shared(t.ctx, p.get_context())); EXPECT(is_shared(t.ctx, p.get_context()));
} }
...@@ -475,7 +476,7 @@ TEST_CASE(eval_context2) ...@@ -475,7 +476,7 @@ TEST_CASE(eval_context2)
mm->add_instruction(id_ctx_op{}, one, two); mm->add_instruction(id_ctx_op{}, one, two);
p.compile(t); p.compile(t);
EXPECT(is_shared(t.ctx, p.get_context())); EXPECT(is_shared(t.ctx, p.get_context()));
p.eval({}).back(); std::ignore = p.eval({}).back();
// id_ctx_op will modify the context // id_ctx_op will modify the context
EXPECT(not is_shared(t.ctx, p.get_context())); EXPECT(not is_shared(t.ctx, p.get_context()));
} }
...@@ -492,8 +493,8 @@ TEST_CASE(eval_context3) ...@@ -492,8 +493,8 @@ TEST_CASE(eval_context3)
p.compile(t); p.compile(t);
// Finalizer will modify the context // Finalizer will modify the context
EXPECT(not is_shared(t.ctx, p.get_context())); EXPECT(not is_shared(t.ctx, p.get_context()));
auto ctx = p.get_context(); auto ctx = p.get_context();
p.eval({}).back(); std::ignore = p.eval({}).back();
EXPECT(is_shared(ctx, p.get_context())); EXPECT(is_shared(ctx, p.get_context()));
EXPECT(not is_shared(t.ctx, p.get_context())); EXPECT(not is_shared(t.ctx, p.get_context()));
} }
......
...@@ -329,4 +329,36 @@ TEST_CASE(all_scalar_input) ...@@ -329,4 +329,36 @@ TEST_CASE(all_scalar_input)
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
TEST_CASE(no_input)
{
migraphx::program p;
{
auto* mm = p.get_main_module();
migraphx::shape g_shape{migraphx::shape::int64_type, {1}, {0}};
migraphx::shape s_indices{migraphx::shape::int32_type, {3}};
std::vector<int> indices{3, 800, 800};
auto a0 = mm->add_literal(migraphx::literal{s_indices, indices});
auto a1 = mm->add_literal(migraphx::literal{g_shape, {1}});
int axis = 0;
auto out = mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), a0, a1);
mm->add_return({out});
}
run_pass(p);
// This should NOT create a pointwise module if there are no inputs here.
migraphx::program p2;
{
auto* mm = p2.get_main_module();
migraphx::shape g_shape{migraphx::shape::int64_type, {1}, {0}};
migraphx::shape s_indices{migraphx::shape::int32_type, {3}};
std::vector<int> indices{3, 800, 800};
auto a0 = mm->add_literal(migraphx::literal{s_indices, indices});
auto a1 = mm->add_literal(migraphx::literal{g_shape, {1}});
int axis = 0;
auto out = mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), a0, a1);
mm->add_return({out});
}
EXPECT(p == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
#include <migraphx/make_op.hpp>
void run_prog(migraphx::program p,
const migraphx::target& t,
migraphx::parameter_map& m_in,
std::vector<float>& res)
{
p.compile(t);
migraphx::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
if(m_in.count(x.first) > 0)
{
m[x.first] = t.copy_to(m_in[x.first]);
}
else
{
m[x.first] = t.allocate(x.second);
}
}
auto result = t.copy_from(p.eval(m).back());
result.visit([&](auto v) { res.assign(v.begin(), v.end()); });
}
// This test ensures that the codegen path doesn't round up literals,
// otherwise there are accuracy differences compared to ref.
// The values being passed in are 0.5 * (1/0.00787402),
// and after rounding must equal 63, not 64.
TEST_CASE(mul_literal_round_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::float_type, {1}};
auto l0 = mm->add_parameter("a", s0);
auto l1 = mm->add_literal(1 / 0.00787402f);
auto mul = mm->add_instruction(migraphx::make_op("mul"), l0, l1);
auto round = mm->add_instruction(migraphx::make_op("round"), mul);
mm->add_return({round});
migraphx::parameter_map m;
std::vector<float> a = {0.5f};
m["a"] = migraphx::argument{s0, a.data()};
std::vector<float> ref_result;
migraphx::target ref_t = migraphx::make_target("ref");
run_prog(p, ref_t, m, ref_result);
std::vector<float> gpu_result;
migraphx::target gpu_t = migraphx::make_target("gpu");
run_prog(p, gpu_t, m, gpu_result);
EXPECT(migraphx::verify::verify_range(ref_result, gpu_result));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/program.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
#include <pointwise.hpp>
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::gpu::fuse_mlir{}, migraphx::dead_code_elimination{}});
}
template <class F>
migraphx::instruction_ref add_mlir(migraphx::program& p,
const std::string& name,
std::vector<migraphx::instruction_ref> inputs,
std::vector<std::string> arg_names,
F f)
{
assert(inputs.size() == arg_names.size() && "One interior parameter name given per input.");
auto* mm = p.get_main_module();
auto* pm = p.create_module(name);
pm->set_bypass();
std::vector<migraphx::instruction_ref> params;
for(size_t i = 0, e = inputs.size(); i < e; ++i)
{
params.push_back(pm->add_parameter(arg_names[i], inputs[i]->get_shape()));
}
auto values = f(pm, params);
auto root = std::get<0>(values);
auto r = std::get<1>(values);
pm->add_return({r});
return mm->add_instruction(
migraphx::make_op("gpu::mlir_op", {{"op", migraphx::to_value(root->get_operator())}}),
inputs,
{pm});
}
TEST_CASE(dot_add)
{
migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto a = mm->add_parameter("a", s);
auto b = mm->add_parameter("b", s);
auto x = mm->add_parameter("x", s);
auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto add = add_pointwise(p1, "main:pointwise0", {dot, x}, single_pointwise("add"));
mm->add_return({add});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto a = mm->add_parameter("a", s);
auto b = mm->add_parameter("b", s);
auto x = mm->add_parameter("x", s);
auto fused =
add_mlir(p2,
"mlir_main:pointwise0",
{x, a, b},
{"x1", "y0", "y1"},
[=](auto* pm, const auto& inputs) {
auto dot =
pm->add_instruction(migraphx::make_op("dot"), inputs[1], inputs[2]);
auto add = pm->add_instruction(migraphx::make_op("add"), dot, inputs[0]);
return std::make_tuple(dot, add);
});
mm->add_return({fused});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(int_quant_dot_abs)
{
migraphx::shape s_a{migraphx::shape::int8_type, {5, 4}};
migraphx::shape s_b{migraphx::shape::int8_type, {4, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto a = mm->add_parameter("a", s_a);
auto b = mm->add_parameter("b", s_b);
auto dot = mm->add_instruction(migraphx::make_op("quant_dot"), a, b);
auto abs = add_pointwise(p1, "main:pointwise0", {dot}, single_pointwise("abs"));
mm->add_return({abs});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto a = mm->add_parameter("a", s_a);
auto b = mm->add_parameter("b", s_b);
auto fused = add_mlir(
p2, "mlir_main:pointwise0", {a, b}, {"y0", "y1"}, [=](auto* pm, const auto& inputs) {
auto dot =
pm->add_instruction(migraphx::make_op("quant_dot"), inputs[0], inputs[1]);
auto abs = pm->add_instruction(migraphx::make_op("abs"), dot);
return std::make_tuple(dot, abs);
});
mm->add_return({fused});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(int_quant_dot_tanh_fails)
{
migraphx::shape s_a{migraphx::shape::int8_type, {5, 4}};
migraphx::shape s_b{migraphx::shape::int8_type, {4, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto a = mm->add_parameter("a", s_a);
auto b = mm->add_parameter("b", s_b);
auto dot = mm->add_instruction(migraphx::make_op("quant_dot"), a, b);
auto tanh = add_pointwise(p1, "main:pointwise0", {dot}, single_pointwise("tanh"));
mm->add_return({tanh});
}
migraphx::program p2(p1);
// This pass should do nothing as int32_t tanh isn't supported.
run_pass(p1);
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[])
{
if(migraphx::gpu::mlir_enabled())
test::run(argc, argv);
return 0;
}
...@@ -206,8 +206,16 @@ TEST_CASE(compile_warnings) ...@@ -206,8 +206,16 @@ TEST_CASE(compile_warnings)
EXPECT(not compile("").empty()); EXPECT(not compile("").empty());
EXPECT(not compile("-Wunused-parameter -Wno-error").empty()); EXPECT(not compile("-Wunused-parameter -Wno-error").empty());
EXPECT(not compile("-Wno-unused-parameter -Werror").empty()); EXPECT(not compile("-Wno-unused-parameter -Werror").empty());
#ifdef MIGRAPHX_USE_HIPRTC
if(not migraphx::enabled(migraphx::gpu::MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS{}))
{
EXPECT(test::throws([&] { compile("-Werror=unused-parameter"); }));
EXPECT(test::throws([&] { compile("-Wunused-parameter -Werror"); }));
}
#else
EXPECT(test::throws([&] { compile("-Werror=unused-parameter"); })); EXPECT(test::throws([&] { compile("-Werror=unused-parameter"); }));
EXPECT(test::throws([&] { compile("-Wunused-parameter -Werror"); })); EXPECT(test::throws([&] { compile("-Wunused-parameter -Werror"); }));
#endif
} }
TEST_CASE(code_object_hip) TEST_CASE(code_object_hip)
...@@ -356,4 +364,69 @@ TEST_CASE(compile_math) ...@@ -356,4 +364,69 @@ TEST_CASE(compile_math)
}); });
} }
// NOLINTNEXTLINE
const std::string assert_template = R"__migraphx__(
#include <migraphx/kernels/math.hpp>
#include <migraphx/kernels/types.hpp>
using namespace migraphx;
extern "C" {
__global__ void kernel(void*)
{
static_assert(numeric_max<${type}>() == ${max}, "");
static_assert(numeric_lowest<${type}>() == ${min}, "");
}
}
int main() {}
)__migraphx__";
TEST_CASE(assert_type_min_max)
{
std::vector<std::string> data_types;
migraphx::gpu::hip_compile_options options;
for(auto&& t : migraphx::shape::types())
{
if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t))
continue;
auto name = migraphx::shape::cpp_type(t);
if(t == migraphx::shape::half_type)
name.insert(0, "migraphx::");
migraphx::shape::visit(t, [&](auto as) {
std::string min = "";
std::string max = "";
// Note 9223372036854775808 is a constant literal that is outside the range of long
// long type For the same reason, 18446744073709551616 needs postfix ULL to be parsed
// correctly
if(t == migraphx::shape::int64_type)
{
min = "(" + std::to_string(as.min() + 1) + "LL - 1)";
max = std::to_string(as.max());
}
else if(t == migraphx::shape::uint64_type)
{
min = std::to_string(as.min());
max = std::to_string(as.max()) + "ULL";
}
else
{
min = std::to_string(as.min());
max = std::to_string(as.max());
}
auto src = migraphx::interpolate_string(assert_template,
{{"type", name}, {"max", max}, {"min", min}});
migraphx::shape input{migraphx::shape::float_type, {5, 2}};
options.global = 1024;
options.local = 1024;
options.inputs = {input};
options.output = input;
options.params = "-Wno-float-equal";
auto co = migraphx::gpu::compile_hip_code_object(src, options);
});
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -64,7 +64,7 @@ TEST_CASE(host_same_buffer_copy) ...@@ -64,7 +64,7 @@ TEST_CASE(host_same_buffer_copy)
auto result = p.eval(pp).back(); auto result = p.eval(pp).back();
std::vector<float> results_vector(ss.elements(), -1); std::vector<float> results_vector(ss.elements(), -1);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c_vec, results_vector)); EXPECT(migraphx::verify::verify_range(c_vec, results_vector));
} }
TEST_CASE(arguments_lifetime) TEST_CASE(arguments_lifetime)
......
...@@ -84,7 +84,7 @@ migraphx::program create_program_from_mlir(const migraphx::module& mmlir) ...@@ -84,7 +84,7 @@ migraphx::program create_program_from_mlir(const migraphx::module& mmlir)
inputs.push_back(mm->add_parameter("output", mmlir.get_output_shapes().front())); inputs.push_back(mm->add_parameter("output", mmlir.get_output_shapes().front()));
migraphx::gpu::context ctx; migraphx::gpu::context ctx;
migraphx::gpu::insert_mlir(*mm, mm->end(), compile_mlir(ctx, mmlir, inputs), inputs); migraphx::gpu::insert_mlir(*mm, mm->end(), compile_mlir(ctx, mmlir, inputs, {}), inputs);
return p; return p;
} }
...@@ -187,12 +187,39 @@ module { ...@@ -187,12 +187,39 @@ module {
EXPECT(verify_mlir(m)); EXPECT(verify_mlir(m));
} }
TEST_CASE(quant_dot_add)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @main(%arg0: tensor<1x5x4xi8>, %arg1: tensor<1x4x3xi8>, %arg2: tensor<1x5x3xi32>) -> tensor<1x5x3xi32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.quant_dot(%arg0, %arg1) : (tensor<1x5x4xi8>, tensor<1x4x3xi8>) -> tensor<1x5x3xi32>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xi32>, tensor<1x5x3xi32>) -> tensor<1x5x3xi32>
return %1 : tensor<1x5x3xi32>
}
}
)__migraphx__";
migraphx::module m;
auto arg0 = m.add_parameter("arg0", {migraphx::shape::int8_type, {1, 5, 4}});
auto arg1 = m.add_parameter("arg1", {migraphx::shape::int8_type, {1, 4, 3}});
auto arg2 = m.add_parameter("arg2", {migraphx::shape::int32_type, {1, 5, 3}});
auto conv = m.add_instruction(migraphx::make_op("quant_dot"), arg0, arg1);
auto add = m.add_instruction(migraphx::make_op("add"), conv, arg2);
m.add_return({add});
auto s = migraphx::gpu::dump_mlir(m);
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
EXPECT(verify_mlir(m));
}
TEST_CASE(dot_add) TEST_CASE(dot_add)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @mlir_dot(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr"} { func.func @mlir_dot(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.dot(%arg0, %arg1) : tensor<1x5x4xf32>, tensor<1x4x3xf32> -> tensor<1x5x3xf32> %0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32> %1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32>
return %1 : tensor<1x5x3xf32> return %1 : tensor<1x5x3xf32>
} }
...@@ -246,4 +273,57 @@ module { ...@@ -246,4 +273,57 @@ module {
EXPECT(verify_mlir(m)); EXPECT(verify_mlir(m));
} }
TEST_CASE(dot_convert)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_dot(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>) -> tensor<1x5x3xf16> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32>
%1 = migraphx.convert(%0) {target_type = 1 : i64} : (tensor<1x5x3xf32>) -> tensor<1x5x3xf16>
return %1 : tensor<1x5x3xf16>
}
}
)__migraphx__";
migraphx::module m;
auto arg0 = m.add_parameter("arg0", {migraphx::shape::float_type, {1, 5, 4}});
auto arg1 = m.add_parameter("arg1", {migraphx::shape::float_type, {1, 4, 3}});
auto dot = m.add_instruction(migraphx::make_op("dot"), arg0, arg1);
auto trunc = m.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), dot);
m.add_return({trunc});
auto s = migraphx::gpu::dump_mlir(m);
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
EXPECT(verify_mlir(m));
}
TEST_CASE(dot_where)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_dot(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xi8>, %arg3: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32>
%1 = migraphx.where(%arg2, %0, %arg3) : (tensor<1x5x3xi8>, tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32>
return %1 : tensor<1x5x3xf32>
}
}
)__migraphx__";
migraphx::module m;
auto arg0 = m.add_parameter("arg0", {migraphx::shape::float_type, {1, 5, 4}});
auto arg1 = m.add_parameter("arg1", {migraphx::shape::float_type, {1, 4, 3}});
auto arg2 = m.add_parameter("arg2", {migraphx::shape::bool_type, {1, 5, 3}});
auto arg3 = m.add_parameter("arg3", {migraphx::shape::float_type, {1, 5, 3}});
auto dot = m.add_instruction(migraphx::make_op("dot"), arg0, arg1);
auto where = m.add_instruction(migraphx::make_op("where"), arg2, dot, arg3);
m.add_return({where});
auto s = migraphx::gpu::dump_mlir(m);
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
EXPECT(verify_mlir(m));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
*/ */
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
...@@ -51,7 +52,7 @@ TEST_CASE(gpu_target_copy) ...@@ -51,7 +52,7 @@ TEST_CASE(gpu_target_copy)
std::vector<int8_t> val_final; std::vector<int8_t> val_final;
ref_arg_final.visit([&](auto v) { val_final.assign(v.begin(), v.end()); }); ref_arg_final.visit([&](auto v) { val_final.assign(v.begin(), v.end()); });
EXPECT(migraphx::verify_range(val_orig, val_final)); EXPECT(migraphx::verify::verify_range(val_orig, val_final));
} }
TEST_CASE(int8_quantization) TEST_CASE(int8_quantization)
...@@ -110,7 +111,16 @@ TEST_CASE(int8_quantization) ...@@ -110,7 +111,16 @@ TEST_CASE(int8_quantization)
migraphx::target gpu_t = migraphx::make_target("gpu"); migraphx::target gpu_t = migraphx::make_target("gpu");
run_prog(p, gpu_t, m, gpu_result); run_prog(p, gpu_t, m, gpu_result);
EXPECT(migraphx::verify_range(ref_result, gpu_result)); // Note: the tolerance for mlir_enabled result is temporarily bumped
// higher because the lowering pipeline between mlir fallback and
// regular non-mlir pipeline diverged. MLIR fallback uses the
// rewrite_quantization at the very end of the pipeline, whereas
// the regular pipeline uses the rewrite_quantization in the much
// earlier stage.
if(migraphx::gpu::mlir_enabled())
EXPECT(migraphx::verify::verify_range(ref_result, gpu_result, 1e5));
else
EXPECT(migraphx::verify::verify_range(ref_result, gpu_result));
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment