Commit 5fbda037 authored by Artur Wojcik's avatar Artur Wojcik
Browse files

Merge branch 'develop' into uif2-initial

parents bc8eb7c9 409fd18c
...@@ -59,6 +59,12 @@ else() ...@@ -59,6 +59,12 @@ else()
option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON) option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON)
endif() endif()
if(WIN32) # CK is not yet ported to Windows
option(MIGRAPHX_USE_COMPOSABLEKERNEL "Enable MIGraphX to use composable kernel JIT library" OFF)
else()
option(MIGRAPHX_USE_COMPOSABLEKERNEL "Enable MIGraphX to use composable kernel JIT library" ON)
endif()
# By default build shared libraries # By default build shared libraries
option(BUILD_SHARED_LIBS "Create shared libraries" ON) option(BUILD_SHARED_LIBS "Create shared libraries" ON)
...@@ -84,7 +90,7 @@ include(ROCMSetupVersion) ...@@ -84,7 +90,7 @@ include(ROCMSetupVersion)
option(BUILD_DEV "Build for development purpose only" OFF) option(BUILD_DEV "Build for development purpose only" OFF)
rocm_setup_version(VERSION 2.8.0) rocm_setup_version(VERSION 2.9.0)
math(EXPR MIGRAPHX_SO_MAJOR_VERSION "(${PROJECT_VERSION_MAJOR} * 1000 * 1000) + (${PROJECT_VERSION_MINOR} * 1000) + ${PROJECT_VERSION_PATCH}") math(EXPR MIGRAPHX_SO_MAJOR_VERSION "(${PROJECT_VERSION_MAJOR} * 1000 * 1000) + (${PROJECT_VERSION_MINOR} * 1000) + ${PROJECT_VERSION_PATCH}")
set(MIGRAPHX_SO_VERSION ${MIGRAPHX_SO_MAJOR_VERSION}.0) set(MIGRAPHX_SO_VERSION ${MIGRAPHX_SO_MAJOR_VERSION}.0)
......
...@@ -32,7 +32,7 @@ def rocmtestnode(Map conf) { ...@@ -32,7 +32,7 @@ def rocmtestnode(Map conf) {
rm -rf build rm -rf build
mkdir build mkdir build
cd build cd build
cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DBUILD_DEV=On -DCMAKE_EXECUTE_PROCESS_COMMAND_ECHO=STDOUT ${flags} .. cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DBUILD_DEV=On -DCMAKE_EXECUTE_PROCESS_COMMAND_ECHO=STDOUT -DMIGRAPHX_DISABLE_VIRTUAL_ENV=ON ${flags} ..
git diff git diff
git diff-index --quiet HEAD || (echo "Git repo is not clean after running cmake." && exit 1) git diff-index --quiet HEAD || (echo "Git repo is not clean after running cmake." && exit 1)
make -j\$(nproc) generate VERBOSE=1 make -j\$(nproc) generate VERBOSE=1
......
...@@ -21,12 +21,12 @@ ...@@ -21,12 +21,12 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE. # THE SOFTWARE.
##################################################################################### #####################################################################################
google/protobuf@v3.11.0 -DCMAKE_POSITION_INDEPENDENT_CODE=On -X subdir -Dprotobuf_BUILD_TESTS=Off google/protobuf@v3.19.0 -DCMAKE_POSITION_INDEPENDENT_CODE=On -X subdir -Dprotobuf_BUILD_TESTS=Off
nlohmann/json@v3.8.0 nlohmann/json@v3.8.0
live-clones/blaze@v3.8 -X header -DHEADER_DIR=blaze -H sha256:d0ff011f47538285178908ea5f2cab46bb6a8f55b1edb6e03224a82dbc1a3212 live-clones/blaze@v3.8 -X header -DHEADER_DIR=blaze -H sha256:d0ff011f47538285178908ea5f2cab46bb6a8f55b1edb6e03224a82dbc1a3212
ROCmSoftwarePlatform/half@rocm-5.6.0 ROCmSoftwarePlatform/half@rocm-5.6.0
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/composable_kernel@70eefcf4f263aa5c25f3c9ff0db8f6f199ef0fb9 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCmSoftwarePlatform/composable_kernel@70eefcf4f263aa5c25f3c9ff0db8f6f199ef0fb9 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/rocMLIR@507bb94ce7873786486d296ec81d2eadaab49003 -DBUILD_FAT_LIBROCKCOMPILER=On ROCmSoftwarePlatform/rocMLIR@507bb94ce7873786486d296ec81d2eadaab49003 -DBUILD_FAT_LIBROCKCOMPILER=On
\ No newline at end of file
...@@ -267,10 +267,9 @@ target_link_libraries(migraphx PRIVATE nlohmann_json::nlohmann_json) ...@@ -267,10 +267,9 @@ target_link_libraries(migraphx PRIVATE nlohmann_json::nlohmann_json)
migraphx_generate_export_header(migraphx) migraphx_generate_export_header(migraphx)
if(NOT WIN32) if(NOT WIN32)
find_package(PkgConfig) find_package(SQLite3 REQUIRED)
pkg_check_modules(SQLITE3 REQUIRED IMPORTED_TARGET sqlite3)
endif() endif()
target_link_libraries(migraphx PRIVATE PkgConfig::SQLITE3) target_link_libraries(migraphx PRIVATE SQLite::SQLite3)
if(NOT WIN32) if(NOT WIN32)
find_package(msgpackc-cxx QUIET) find_package(msgpackc-cxx QUIET)
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -41,16 +41,16 @@ namespace op { ...@@ -41,16 +41,16 @@ namespace op {
* Dynamic allocate: * Dynamic allocate:
* One input: `allocate(output_dims)` * One input: `allocate(output_dims)`
* `output_dims` are the output buffer dimensions and has a static shape. * `output_dims` are the output buffer dimensions and has a static shape.
* Either `this.s` or `this.buf_type` must be set to calculate the dynamic output shape at compute * Either `this.s` or `this.buf_type` (but not both) must be set to calculate the dynamic output
* time. If `this.buf_type` is set, the compute_shape() of allocate at compile time will have * shape at compute time. If `this.buf_type` is set, the compute_shape() of allocate at compile time
* dynamic_dimensions from {0, max_int} with rank = output_dims.ndim(). If `this.s` is set then the * will have dynamic_dimensions from {0, max_int} with rank = output_dims.ndim(). If `this.s` is set
* compute_shape() will output `this.s`; `this.s` should be a dynamic shape. * then the compute_shape() will output `this.s`; `this.s` should be a dynamic shape.
*/ */
struct allocate struct allocate
{ {
shape s{}; optional<shape> s;
// for dynamic allocate to set the buffer type // for dynamic allocate to set the buffer type
shape::type_t buf_type = shape::half_type; optional<shape::type_t> buf_type;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -62,26 +62,38 @@ struct allocate ...@@ -62,26 +62,38 @@ struct allocate
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
if(s != shape()) if(s.has_value())
{ {
if(buf_type.has_value())
{
MIGRAPHX_THROW("ALLOCATE: shape and buf_type attributes both set");
}
if(inputs.size() == 1) if(inputs.size() == 1)
{ {
migraphx::check_shapes{inputs, *this, false}.only_dims(1); migraphx::check_shapes{inputs, *this, false}.only_dims(1);
} }
else else
{ {
if(s->dynamic())
{
MIGRAPHX_THROW("ALLOCATE: dynamic shape attribute and no input");
}
migraphx::check_shapes{inputs, *this, false}.has(0); migraphx::check_shapes{inputs, *this, false}.has(0);
} }
return s; return s.value();
} }
else else
{ {
if(not buf_type.has_value())
{
MIGRAPHX_THROW("ALLOCATE: shape and buf_type attributes both not set");
}
migraphx::check_shapes{inputs, *this, false}.has(1).only_dims(1); migraphx::check_shapes{inputs, *this, false}.has(1).only_dims(1);
const auto& out_dims = inputs.at(0); const auto& out_dims = inputs.at(0);
std::size_t max_val = std::numeric_limits<std::size_t>::max(); std::size_t max_val = std::numeric_limits<std::size_t>::max();
std::vector<shape::dynamic_dimension> dyn_dims(out_dims.lens().at(0), std::vector<shape::dynamic_dimension> dyn_dims(out_dims.lens().at(0),
shape::dynamic_dimension{0, max_val}); shape::dynamic_dimension{0, max_val});
return {buf_type, dyn_dims}; return {buf_type.value(), dyn_dims};
} }
} }
argument compute(const shape& output_shape, const std::vector<argument>& args) const argument compute(const shape& output_shape, const std::vector<argument>& args) const
...@@ -94,7 +106,11 @@ struct allocate ...@@ -94,7 +106,11 @@ struct allocate
{ {
std::vector<std::size_t> output_dims(output_shape.ndim()); std::vector<std::size_t> output_dims(output_shape.ndim());
args.at(0).visit([&](auto a) { output_dims.assign(a.begin(), a.end()); }); args.at(0).visit([&](auto a) { output_dims.assign(a.begin(), a.end()); });
return argument{shape{buf_type, output_dims}}; if(s)
{
return argument{shape{s->type(), output_dims}};
}
return argument{shape{buf_type.value(), output_dims}};
} }
} }
}; };
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/requires.hpp> #include <migraphx/requires.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/optional.hpp>
#include <vector> #include <vector>
namespace migraphx { namespace migraphx {
...@@ -68,6 +69,19 @@ auto stream_write_value_impl(rank<1>, std::ostream& os, const T& x) -> decltype( ...@@ -68,6 +69,19 @@ auto stream_write_value_impl(rank<1>, std::ostream& os, const T& x) -> decltype(
os << x; os << x;
} }
template <class T>
auto stream_write_value_impl(rank<1>, std::ostream& os, const optional<T>& x)
{
if(x.has_value())
{
os << *x;
}
else
{
os << "nullopt";
}
}
template <class T> template <class T>
void stream_write_value_impl(rank<1>, std::ostream& os, const std::vector<T>& r) void stream_write_value_impl(rank<1>, std::ostream& os, const std::vector<T>& r)
{ {
......
...@@ -936,7 +936,7 @@ void program::perf_report(std::ostream& os, ...@@ -936,7 +936,7 @@ void program::perf_report(std::ostream& os,
os << std::endl; os << std::endl;
os << "Batch size: " << batch << std::endl; os << "Batch size: " << batch << std::endl;
os << "Rate: " << rate * batch << "inferences/sec" << std::endl; os << "Rate: " << rate * batch << " inferences/sec" << std::endl;
os << "Total time: " << total_time << "ms" << std::endl; os << "Total time: " << total_time << "ms" << std::endl;
os << "Total instructions time: " << total_instruction_time << "ms" << std::endl; os << "Total instructions time: " << total_instruction_time << "ms" << std::endl;
os << "Overhead time: " << overhead_time << "ms" os << "Overhead time: " << overhead_time << "ms"
......
...@@ -34,8 +34,7 @@ message(STATUS "MIGraphX is using MIOpen") ...@@ -34,8 +34,7 @@ message(STATUS "MIGraphX is using MIOpen")
find_package(rocblas REQUIRED) find_package(rocblas REQUIRED)
message(STATUS "MIGraphX build with rocBLAS") message(STATUS "MIGraphX build with rocBLAS")
if(NOT WIN32) if(MIGRAPHX_USE_COMPOSABLEKERNEL)
# TODO: re-enable when CK is ported to Windows
find_package(composable_kernel 1.0.0 REQUIRED COMPONENTS jit_library) find_package(composable_kernel 1.0.0 REQUIRED COMPONENTS jit_library)
endif() endif()
...@@ -48,10 +47,10 @@ endif() ...@@ -48,10 +47,10 @@ endif()
file(GLOB KERNEL_FILES CONFIGURE_DEPENDS file(GLOB KERNEL_FILES CONFIGURE_DEPENDS
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp) ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp)
if(WIN32) if(NOT MIGRAPHX_USE_COMPOSABLEKERNEL)
# TODO: re-enable when CK is ported to Windows
list(REMOVE_ITEM KERNEL_FILES list(REMOVE_ITEM KERNEL_FILES
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck_gemm.hpp ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck_gemm.hpp
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck_gemm_softmax_gemm.hpp
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck.hpp) ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck.hpp)
endif() endif()
...@@ -103,8 +102,7 @@ rocm_clang_tidy_check(kernel_file_check) ...@@ -103,8 +102,7 @@ rocm_clang_tidy_check(kernel_file_check)
file(GLOB JIT_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/jit/*.cpp) file(GLOB JIT_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/jit/*.cpp)
if(WIN32) if(NOT MIGRAPHX_USE_COMPOSABLEKERNEL)
# TODO: re-enable when CK is ported to Windows
list(REMOVE_ITEM JIT_GPU_SRCS list(REMOVE_ITEM JIT_GPU_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm_softmax_gemm.cpp) ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm_softmax_gemm.cpp)
...@@ -306,8 +304,7 @@ endif() ...@@ -306,8 +304,7 @@ endif()
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas) target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device) target_link_libraries(migraphx_gpu PRIVATE migraphx_device)
if(NOT WIN32) if(MIGRAPHX_USE_COMPOSABLEKERNEL)
# TODO: re-enable when CK is ported to Windows
target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library) target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library)
endif() endif()
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
cmake_policy(SET CMP0057 NEW) cmake_policy(SET CMP0057 NEW)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
rocm_test_link_libraries(Threads::Threads migraphx migraphx_ref migraphx_onnx migraphx_tf) rocm_test_link_libraries(Threads::Threads migraphx migraphx_onnx migraphx_tf)
rocm_test_include_directories(include) rocm_test_include_directories(include)
set(MIGRAPHX_DISABLE_LARGE_BUFFER_TESTS Off CACHE BOOL "") set(MIGRAPHX_DISABLE_LARGE_BUFFER_TESTS Off CACHE BOOL "")
......
...@@ -152,6 +152,9 @@ TEST_CASE(int_quant_dot_tanh_fails) ...@@ -152,6 +152,9 @@ TEST_CASE(int_quant_dot_tanh_fails)
int main(int argc, const char* argv[]) int main(int argc, const char* argv[])
{ {
if(migraphx::gpu::mlir_enabled())
{
test::run(argc, argv); test::run(argc, argv);
}
return 0; return 0;
} }
...@@ -88,7 +88,7 @@ TEST_CASE(allocate_static) ...@@ -88,7 +88,7 @@ TEST_CASE(allocate_static)
expect_shape(out_shape, migraphx::make_op("allocate", {{"shape", to_value(out_shape)}})); expect_shape(out_shape, migraphx::make_op("allocate", {{"shape", to_value(out_shape)}}));
} }
TEST_CASE(allocate_static_input_error) TEST_CASE(allocate_static_input)
{ {
migraphx::shape input{migraphx::shape::int64_type, {3}}; migraphx::shape input{migraphx::shape::int64_type, {3}};
migraphx::shape out_shape{migraphx::shape::float_type, {2, 3, 4}}; migraphx::shape out_shape{migraphx::shape::float_type, {2, 3, 4}};
...@@ -120,8 +120,22 @@ TEST_CASE(allocate_dyn_no_input_error) ...@@ -120,8 +120,22 @@ TEST_CASE(allocate_dyn_no_input_error)
{ {
migraphx::shape shape_attr{migraphx::shape::float_type, migraphx::shape shape_attr{migraphx::shape::float_type,
{{1, 4}, {3, 3}, {4, 8, {4, 6}}, {4, 8}, {4, 6}}}; {{1, 4}, {3, 3}, {4, 8, {4, 6}}, {4, 8}, {4, 6}}};
expect_shape(shape_attr, throws_shape(migraphx::make_op("allocate", {{"shape", migraphx::to_value(shape_attr)}}));
migraphx::make_op("allocate", {{"shape", migraphx::to_value(shape_attr)}})); }
TEST_CASE(allocate_shape_and_buf_type_error)
{
migraphx::shape shape_attr{migraphx::shape::float_type,
{{1, 4}, {3, 3}, {4, 8, {4, 6}}, {4, 8}, {4, 6}}};
throws_shape(migraphx::make_op(
"allocate",
{{"shape", migraphx::to_value(shape_attr)}, {"buf_type", migraphx::shape::half_type}}));
}
TEST_CASE(allocate_no_attr_error)
{
migraphx::shape input{migraphx::shape::int64_type, {4}};
throws_shape(migraphx::make_op("allocate"), input);
} }
TEST_CASE(argmax_axis0) TEST_CASE(argmax_axis0)
......
...@@ -28,6 +28,7 @@ set(VENV_ONNX ${CMAKE_BINARY_DIR}/test/py/venv-onnx) ...@@ -28,6 +28,7 @@ set(VENV_ONNX ${CMAKE_BINARY_DIR}/test/py/venv-onnx)
set(REQUIREMENTS ${CMAKE_CURRENT_SOURCE_DIR}/requirements.txt) set(REQUIREMENTS ${CMAKE_CURRENT_SOURCE_DIR}/requirements.txt)
set(REQUIREMENTS_ONNX ${CMAKE_CURRENT_SOURCE_DIR}/requirements-onnx.txt) set(REQUIREMENTS_ONNX ${CMAKE_CURRENT_SOURCE_DIR}/requirements-onnx.txt)
set(PYTHON_VERSION_TO_DISABLE_ONNX 3.6) set(PYTHON_VERSION_TO_DISABLE_ONNX 3.6)
option(MIGRAPHX_DISABLE_VIRTUAL_ENV "Disable python virtual environments" OFF)
function(add_py_venv_fixture FIXTURE_NAME VIRTUAL_ENV_DIR REQUIREMENTS_FILE) function(add_py_venv_fixture FIXTURE_NAME VIRTUAL_ENV_DIR REQUIREMENTS_FILE)
...@@ -61,23 +62,31 @@ function(add_py_test NAME SCRIPT FIXTURE_NAME VENV_DIR) ...@@ -61,23 +62,31 @@ function(add_py_test NAME SCRIPT FIXTURE_NAME VENV_DIR)
"PYTHONMALLOC=debug" "PYTHONMALLOC=debug"
"MALLOC_CHECK_=3" "MALLOC_CHECK_=3"
) )
if(MIGRAPHX_DISABLE_VIRTUAL_ENV)
set(PYTHON_EXECUTABLE ${PYTHON_${PYTHON_VERSION}_EXECUTABLE})
else()
set(PYTHON_EXECUTABLE ${VENV_DIR}/${PYTHON_VERSION}/bin/python) set(PYTHON_EXECUTABLE ${VENV_DIR}/${PYTHON_VERSION}/bin/python)
endif()
if(NOT (${FIXTURE_NAME} STREQUAL "onnx" AND ${PYTHON_VERSION} STREQUAL ${PYTHON_VERSION_TO_DISABLE_ONNX})) if(NOT (${FIXTURE_NAME} STREQUAL "onnx" AND ${PYTHON_VERSION} STREQUAL ${PYTHON_VERSION_TO_DISABLE_ONNX}))
add_test( add_test(
NAME test_py_${PYTHON_VERSION}_${NAME} NAME test_py_${PYTHON_VERSION}_${NAME}
COMMAND ${ENV_COMMAND} ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/${SCRIPT} ${ARGN}) COMMAND ${ENV_COMMAND} ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/${SCRIPT} ${ARGN})
set_tests_properties(test_py_${PYTHON_VERSION}_${NAME} PROPERTIES FIXTURES_REQUIRED ${FIXTURE_NAME}_${PYTHON_VERSION}_VENV)
add_custom_target(test_py_${PYTHON_VERSION}_${NAME} add_custom_target(test_py_${PYTHON_VERSION}_${NAME}
COMMAND ${ENV_COMMAND} ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/${SCRIPT} ${ARGN} COMMAND ${ENV_COMMAND} ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/${SCRIPT} ${ARGN}
COMMENT "${PYTHON_EXECUTABLE} ${SCRIPT}") COMMENT "${PYTHON_EXECUTABLE} ${SCRIPT}")
if(NOT MIGRAPHX_DISABLE_VIRTUAL_ENV)
set_tests_properties(test_py_${PYTHON_VERSION}_${NAME} PROPERTIES FIXTURES_REQUIRED ${FIXTURE_NAME}_${PYTHON_VERSION}_VENV)
endif()
endif() endif()
endforeach() endforeach()
endfunction() endfunction()
add_dependencies(tests migraphx_py) add_dependencies(tests migraphx_py)
add_dependencies(check migraphx_py) add_dependencies(check migraphx_py)
add_py_venv_fixture(common ${VENV} ${REQUIREMENTS}) if(NOT MIGRAPHX_DISABLE_VIRTUAL_ENV)
add_py_venv_fixture(onnx ${VENV_ONNX} ${REQUIREMENTS_ONNX}) add_py_venv_fixture(common ${VENV} ${REQUIREMENTS})
add_py_venv_fixture(onnx ${VENV_ONNX} ${REQUIREMENTS_ONNX})
endif()
add_py_test(ref test_cpu.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(ref test_cpu.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(save_load test_save_load.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(save_load test_save_load.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR})
......
...@@ -22,4 +22,4 @@ ...@@ -22,4 +22,4 @@
# THE SOFTWARE. # THE SOFTWARE.
##################################################################################### #####################################################################################
numpy==1.21.6 numpy==1.19.5
\ No newline at end of file \ No newline at end of file
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
#include <test.hpp> #include <test.hpp>
TEST_CASE(allocate_dyn) TEST_CASE(allocate_dyn0)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
...@@ -47,3 +47,21 @@ TEST_CASE(allocate_dyn) ...@@ -47,3 +47,21 @@ TEST_CASE(allocate_dyn)
migraphx::shape sresult{migraphx::shape::float_type, {2, 3, 4, 4}}; migraphx::shape sresult{migraphx::shape::float_type, {2, 3, 4, 4}};
result.visit([&](auto output) { EXPECT(output.get_shape() == sresult); }); result.visit([&](auto output) { EXPECT(output.get_shape() == sresult); });
} }
TEST_CASE(allocate_dyn1)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int64_type, {4}};
migraphx::shape out_shape{migraphx::shape::float_type, {2, 3, 4, 4}};
auto out_dims = mm->add_parameter("out_dims", s);
mm->add_instruction(migraphx::make_op("allocate", {{"shape", migraphx::to_value(out_shape)}}),
out_dims);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
std::vector<int64_t> data = {2, 3, 4, 4};
params["out_dims"] = migraphx::argument(s, data.data());
auto result = p.eval(params).back();
result.visit([&](auto output) { EXPECT(output.get_shape() == out_shape); });
}
...@@ -41,11 +41,7 @@ TEST_CASE(make_invalid_target) ...@@ -41,11 +41,7 @@ TEST_CASE(make_invalid_target)
TEST_CASE(targets) TEST_CASE(targets)
{ {
// GCC doesn't load libmigraphx_ref unless necesssary even though it is linked to the test.
// Force it to load by making ref target
#if defined(__GNUC__) && !defined(__clang__)
auto ref_target = migraphx::make_target("ref"); auto ref_target = migraphx::make_target("ref");
#endif
auto ts = migraphx::get_targets(); auto ts = migraphx::get_targets();
EXPECT(ts.size() >= 1); EXPECT(ts.size() >= 1);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment