Commit 30c8ff61 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

integrating auto_cont changes

parents 4ff8a292 7aee6388
...@@ -23,14 +23,24 @@ ...@@ -23,14 +23,24 @@
##################################################################################### #####################################################################################
option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON) option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON)
add_library(migraphx_py py_loader.cpp)
target_include_directories(migraphx_py PRIVATE include)
target_link_libraries(migraphx_py PUBLIC migraphx)
rocm_install_targets(TARGETS migraphx_py INCLUDE include)
if(MIGRAPHX_ENABLE_PYTHON) if(MIGRAPHX_ENABLE_PYTHON)
include(PythonModules) include(PythonModules)
add_custom_target(migraphx_py)
foreach(PYTHON_VERSION ${PYTHON_VERSIONS}) foreach(PYTHON_VERSION ${PYTHON_VERSIONS})
py_add_module(migraphx_py_${PYTHON_VERSION} migraphx_py.cpp PYTHON_VERSION ${PYTHON_VERSION} PYTHON_MODULE migraphx) py_add_module(migraphx_pybind_${PYTHON_VERSION} migraphx_py.cpp PYTHON_VERSION ${PYTHON_VERSION} PYTHON_MODULE migraphx)
target_link_libraries(migraphx_py_${PYTHON_VERSION} PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets) target_link_libraries(migraphx_pybind_${PYTHON_VERSION} PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets)
rocm_install_targets(TARGETS migraphx_pybind_${PYTHON_VERSION})
add_dependencies(migraphx_py migraphx_pybind_${PYTHON_VERSION})
add_library(migraphx_py_${PYTHON_VERSION} py.cpp)
target_include_directories(migraphx_py_${PYTHON_VERSION} PRIVATE include)
target_link_libraries(migraphx_py_${PYTHON_VERSION} PUBLIC migraphx)
target_link_libraries(migraphx_py_${PYTHON_VERSION} PRIVATE pybind11::pybind11 python${PYTHON_VERSION}::runtime)
rocm_install_targets(TARGETS migraphx_py_${PYTHON_VERSION}) rocm_install_targets(TARGETS migraphx_py_${PYTHON_VERSION})
add_dependencies(migraphx_py migraphx_py_${PYTHON_VERSION}) add_dependencies(migraphx_py migraphx_py_${PYTHON_VERSION})
endforeach() endforeach()
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_PY_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_PY_HPP
#include <migraphx/config.hpp>
#include <migraphx/program.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
program load_py(const std::string& filename);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_PY_HPP
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/config.hpp>
#include <migraphx/program.hpp>
#include <migraphx/dynamic_loader.hpp>
#include <migraphx/file_buffer.hpp>
#include <pybind11/embed.h>
namespace py = pybind11;
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreturn-type-c-linkage"
#endif
// extern "C" is used to disable name mangling, but the function will still be called from C++
extern "C" program migraphx_load_py(const std::string& filename);
#ifdef __clang__
#pragma clang diagnostic pop
#endif
const std::string& python_path()
{
static const auto path = dynamic_loader::path(&migraphx_load_py).parent_path().string();
return path;
}
static py::dict run_file(const std::string& file)
{
py::object scope = py::module_::import("__main__").attr("__dict__");
std::string buffer;
buffer.append("import sys\n");
buffer.append("sys.path.insert(0, '" + python_path() + "')\n");
buffer.append("import migraphx\n");
buffer.append(read_string(file));
py::exec(buffer, scope);
return scope.cast<py::dict>();
}
extern "C" program migraphx_load_py(const std::string& filename)
{
py::scoped_interpreter guard{};
py::dict vars = run_file(filename);
auto it = std::find_if(vars.begin(), vars.end(), [](const auto& p) {
return py::isinstance<migraphx::program>(p.second);
});
if(it == vars.end())
MIGRAPHX_THROW("No program variable found");
return it->second.cast<migraphx::program>();
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/py.hpp>
#include <migraphx/dynamic_loader.hpp>
#include <migraphx/process.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static std::vector<fs::path> find_available_python_versions()
{
std::vector<fs::path> result;
auto path = dynamic_loader::path(&load_py).parent_path();
for(const auto& entry : fs::directory_iterator{path})
{
auto p = entry.path();
if(not fs::is_regular_file(p))
continue;
if(not contains(p.stem().string(), "migraphx_py_"))
continue;
result.push_back(p);
}
std::sort(result.begin(), result.end(), std::greater<>{});
return result;
}
static dynamic_loader load_py_lib()
{
auto libs = find_available_python_versions();
for(const auto& lib : libs)
{
auto result = dynamic_loader::try_load(lib);
if(result.has_value())
return *result;
}
MIGRAPHX_THROW("Cant find a viable version of python");
}
static dynamic_loader py_lib()
{
static dynamic_loader lib = load_py_lib();
return lib;
}
program load_py(const std::string& filename)
{
static auto f = py_lib().get_function<program(const std::string&)>("migraphx_load_py");
return f(filename);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -48,7 +48,7 @@ include(Embed) ...@@ -48,7 +48,7 @@ include(Embed)
file(GLOB KERNEL_FILES CONFIGURE_DEPENDS file(GLOB KERNEL_FILES CONFIGURE_DEPENDS
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp) ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp)
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
add_embed_library(migraphx_kernels ${KERNEL_FILES}) add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/)
file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp) file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp)
add_library(migraphx_device ${DEVICE_GPU_SRCS}) add_library(migraphx_device ${DEVICE_GPU_SRCS})
...@@ -89,7 +89,7 @@ rocm_clang_tidy_check(kernel_file_check) ...@@ -89,7 +89,7 @@ rocm_clang_tidy_check(kernel_file_check)
file(GLOB JIT_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/jit/*.cpp) file(GLOB JIT_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/jit/*.cpp)
if(NOT WIN32) if(WIN32)
# TODO: re-enable when CK is ported to Windows # TODO: re-enable when CK is ported to Windows
list(REMOVE_ITEM JIT_GPU_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm.cpp) list(REMOVE_ITEM JIT_GPU_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm.cpp)
endif() endif()
......
...@@ -167,7 +167,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option ...@@ -167,7 +167,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
[](auto&& p) { [](auto&& p) {
auto&& name = p.first; auto&& name = p.first;
auto&& c = p.second; auto&& c = p.second;
auto path = fs::path{"migraphx"} / "kernels" / name; auto path = name;
return src_file{path, c}; return src_file{path, c};
}); });
srcs.push_back(src_file{fs::path{"main.cpp"}, srcs.push_back(src_file{fs::path{"main.cpp"},
......
...@@ -140,8 +140,11 @@ void gemm_impl(context& ctx, ...@@ -140,8 +140,11 @@ void gemm_impl(context& ctx,
compute_type = rocblas_datatype_f32_r; compute_type = rocblas_datatype_f32_r;
} }
rocblas_gemm_flags flag = rocblas_gemm_flags flag = rocblas_gemm_flags_none;
int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none; #if ROCBLAS_VERSION_MAJOR < 3
if(int8_x4_format)
flag = rocblas_gemm_flags_pack_int8x4;
#endif
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens(); auto b_lens = args[1].get_shape().lens();
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/optional.hpp> #include <migraphx/optional.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/gpu/tuning_config.hpp>
#include <functional> #include <functional>
namespace migraphx { namespace migraphx {
...@@ -68,12 +69,6 @@ struct compiler_replace ...@@ -68,12 +69,6 @@ struct compiler_replace
} }
}; };
struct tuning_config
{
value problem;
std::vector<value> solutions;
};
using compiler_compile = using compiler_compile =
std::function<compiler_replace(context&, instruction_ref, operation, const value&)>; std::function<compiler_replace(context&, instruction_ref, operation, const value&)>;
using compiler_compile_op = using compiler_compile_op =
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/gpu/config.hpp> #include <migraphx/gpu/config.hpp>
#include <migraphx/gpu/code_object_op.hpp> #include <migraphx/gpu/code_object_op.hpp>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/gpu/tuning_config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -36,16 +37,19 @@ struct module; ...@@ -36,16 +37,19 @@ struct module;
namespace gpu { namespace gpu {
MIGRAPHX_GPU_EXPORT std::string dump_mlir(const module& m); MIGRAPHX_GPU_EXPORT std::string dump_mlir(const module& m);
MIGRAPHX_GPU_EXPORT code_object_op compile_mlir(const context& ctx, MIGRAPHX_GPU_EXPORT code_object_op compile_mlir(const context& ctx,
module m, module m,
const std::vector<instruction_ref>& inputs); const std::vector<instruction_ref>& inputs,
const value& solution);
MIGRAPHX_GPU_EXPORT instruction_ref insert_mlir(module& m, MIGRAPHX_GPU_EXPORT instruction_ref insert_mlir(module& m,
instruction_ref ins, instruction_ref ins,
code_object_op co, code_object_op co,
const std::vector<instruction_ref>& inputs); const std::vector<instruction_ref>& inputs);
MIGRAPHX_GPU_EXPORT tuning_config get_tuning_config_mlir(module m,
const std::vector<shape>& inputs);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_GPU_TUNING_CONFIG_HPP
#define MIGRAPHX_GUARD_GPU_TUNING_CONFIG_HPP
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct tuning_config
{
value problem;
std::vector<value> solutions;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_TUNING_CONFIG_HPP
...@@ -36,11 +36,12 @@ struct mlir_compiler : compiler<mlir_compiler> ...@@ -36,11 +36,12 @@ struct mlir_compiler : compiler<mlir_compiler>
operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; } operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; }
compiler_replace compile(context& ctx, instruction_ref ins, const operation&) const compiler_replace
compile(context& ctx, instruction_ref ins, const operation&, const value& solution) const
{ {
auto* smod = ins->module_inputs().front(); auto* smod = ins->module_inputs().front();
assert(smod->get_parameter_names().size() == ins->inputs().size() - 1); assert(smod->get_parameter_names().size() == ins->inputs().size() - 1);
return insert(compile_mlir(ctx, *smod, ins->inputs())); return insert(compile_mlir(ctx, *smod, ins->inputs(), solution));
} }
compiler_replace insert(code_object_op co) const compiler_replace insert(code_object_op co) const
...@@ -50,6 +51,16 @@ struct mlir_compiler : compiler<mlir_compiler> ...@@ -50,6 +51,16 @@ struct mlir_compiler : compiler<mlir_compiler>
m.replace_instruction(ins, mlir); m.replace_instruction(ins, mlir);
}}; }};
} }
optional<tuning_config>
get_tuning_config(context&, instruction_ref ins, const operation&, bool exhaustive) const
{
if(not exhaustive)
return nullopt;
auto shapes = to_shapes(ins->inputs());
auto* smod = ins->module_inputs().front();
return get_tuning_config_mlir(*smod, shapes);
}
}; };
} // namespace gpu } // namespace gpu
......
...@@ -52,6 +52,7 @@ ...@@ -52,6 +52,7 @@
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp> #include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/perfdb.hpp> #include <migraphx/gpu/perfdb.hpp>
#include <migraphx/gpu/tuning_config.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/permutation.hpp> #include <migraphx/permutation.hpp>
#include <deque> #include <deque>
...@@ -134,6 +135,10 @@ using mlir_block = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirBlock, mlirBlockD ...@@ -134,6 +135,10 @@ using mlir_block = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirBlock, mlirBlockD
using mlir_pass_manager = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirPassManager, mlirPassManagerDestroy); using mlir_pass_manager = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirPassManager, mlirPassManagerDestroy);
using mlir_tuning_table = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRockTuningTable, using mlir_tuning_table = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRockTuningTable,
mlirRockTuningTableDestroy); mlirRockTuningTableDestroy);
using mlir_tuning_space = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRockTuningSpace,
mlirRockTuningSpaceDestroy);
using mlir_tuning_param = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRockTuningParam,
mlirRockTuningParamDestroy);
std::string_view to_string_view(MlirStringRef s) { return {s.data, s.length}; } std::string_view to_string_view(MlirStringRef s) { return {s.data, s.length}; }
...@@ -616,18 +621,30 @@ struct mlir_program ...@@ -616,18 +621,30 @@ struct mlir_program
} }
} }
code_object_op compile() MIGRAPHX_TIDY_CONST void run_high_level_pipeline() MIGRAPHX_TIDY_CONST
{ {
mlir_pass_manager pm_front{mlirPassManagerCreate(ctx.get())}; mlir_pass_manager pm_front{mlirPassManagerCreate(ctx.get())};
mlir_pass_manager pm_back{mlirPassManagerCreate(ctx.get())};
// 1st pipeline to call
mlirMIGraphXAddHighLevelPipeline(pm_front.get()); mlirMIGraphXAddHighLevelPipeline(pm_front.get());
mlirPassManagerRunOnOp(pm_front.get(), mlirModuleGetOperation(mmodule.get())); mlirPassManagerRunOnOp(pm_front.get(), mlirModuleGetOperation(mmodule.get()));
}
// 2nd pipeline to call void run_backend_pipeline() MIGRAPHX_TIDY_CONST
get_module_tuned(); {
mlir_pass_manager pm_back{mlirPassManagerCreate(ctx.get())};
mlirMIGraphXAddBackendPipeline(pm_back.get(), target_arch.c_str()); mlirMIGraphXAddBackendPipeline(pm_back.get(), target_arch.c_str());
mlirPassManagerRunOnOp(pm_back.get(), mlirModuleGetOperation(mmodule.get())); mlirPassManagerRunOnOp(pm_back.get(), mlirModuleGetOperation(mmodule.get()));
}
code_object_op compile(const value& solution) MIGRAPHX_TIDY_CONST
{
// 1st pipeline to call
run_high_level_pipeline();
if(solution.is_null())
get_module_tuned();
else
set_tuning(solution);
// 2nd pipeline to call
run_backend_pipeline();
code_object_op op{}; code_object_op op{};
op.symbol_name = sym_name; op.symbol_name = sym_name;
...@@ -658,6 +675,33 @@ struct mlir_program ...@@ -658,6 +675,33 @@ struct mlir_program
MIGRAPHX_THROW("Failed to compile mlir program"); MIGRAPHX_THROW("Failed to compile mlir program");
} }
void set_tuning(const value& v)
{
auto str = v.to<std::string>();
// We need to make a copy of the buffer since mlirRockTuningSetFromStr may modify the string
std::vector<char> buffer(str.begin(), str.end());
buffer.push_back(0);
if(not mlirRockTuningSetFromStr(mmodule.get(), buffer.data()))
MIGRAPHX_THROW("Failed setting tuning key: " + str);
}
tuning_config get_tuning_config() MIGRAPHX_TIDY_CONST
{
tuning_config tc;
run_high_level_pipeline();
mlir_tuning_space params{mlirRockTuningSpaceCreate(mmodule.get())};
for(auto i : range(mlirRockTuningGetNumParamsFull(params.get())))
{
mlir_tuning_param param{mlirRockTuningParamCreate()};
if(not mlirRockTuningParamGet(params.get(), i, param.get()))
MIGRAPHX_THROW("Incorrect mlir tuning parameter: " + std::to_string(i));
tc.solutions.push_back(std::string{mlirRockTuningGetParamStr(param.get())});
}
mlir_tuning_table tuning_table{mlirRockTuningTableCreate()};
tc.problem = std::string{mlirRockTuningGetKey(tuning_table.get(), mmodule.get())};
return tc;
}
std::string get_tune_params(bool xdlops) const { return get_mlir_perf_for_conv(pp, xdlops); } std::string get_tune_params(bool xdlops) const { return get_mlir_perf_for_conv(pp, xdlops); }
// This function appends to tuning cfg file that could be // This function appends to tuning cfg file that could be
...@@ -749,14 +793,14 @@ std::string dump_mlir(const module& m) ...@@ -749,14 +793,14 @@ std::string dump_mlir(const module& m)
return mlir_print(&mlirOperationPrint, mod_op); return mlir_print(&mlirOperationPrint, mod_op);
} }
void adjust_param_shapes(module& m, const std::vector<instruction_ref>& inputs) void adjust_param_shapes(module& m, const std::vector<shape>& inputs)
{ {
auto names = m.get_parameter_names(); auto names = m.get_parameter_names();
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
for(auto i : range(names.size())) for(auto i : range(names.size()))
{ {
const auto& name = names[i]; const auto& name = names[i];
const auto& input = inputs[i]->get_shape(); const auto& input = inputs[i];
auto param = m.get_parameter(name); auto param = m.get_parameter(name);
if(input.standard()) if(input.standard())
continue; continue;
...@@ -794,9 +838,12 @@ void adjust_param_shapes(module& m, const std::vector<instruction_ref>& inputs) ...@@ -794,9 +838,12 @@ void adjust_param_shapes(module& m, const std::vector<instruction_ref>& inputs)
} }
} }
code_object_op compile_mlir(const context&, module m, const std::vector<instruction_ref>& inputs) code_object_op compile_mlir(const context&,
module m,
const std::vector<instruction_ref>& inputs,
const value& solution)
{ {
adjust_param_shapes(m, inputs); adjust_param_shapes(m, to_shapes(inputs));
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{}); const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});
if(trace) if(trace)
...@@ -808,7 +855,8 @@ code_object_op compile_mlir(const context&, module m, const std::vector<instruct ...@@ -808,7 +855,8 @@ code_object_op compile_mlir(const context&, module m, const std::vector<instruct
auto mod_op = mlirModuleGetOperation(mp.mmodule.get()); auto mod_op = mlirModuleGetOperation(mp.mmodule.get());
if(trace) if(trace)
std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl; std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
auto co = mp.compile(); auto co = mp.compile(solution);
co.expected_inputs = to_shapes(inputs);
co.output = m.get_output_shapes().front(); co.output = m.get_output_shapes().front();
return co; return co;
} }
...@@ -829,6 +877,16 @@ instruction_ref insert_mlir(module& m, ...@@ -829,6 +877,16 @@ instruction_ref insert_mlir(module& m,
return m.insert_instruction(ins, co, refs); return m.insert_instruction(ins, co, refs);
} }
tuning_config get_tuning_config_mlir(module m, const std::vector<shape>& inputs)
{
adjust_param_shapes(m, inputs);
mlir_program mp;
mp.find_target();
mp.parse(m);
return mp.get_tuning_config();
}
#else #else
std::string dump_mlir(const module&) { return {}; } std::string dump_mlir(const module&) { return {}; }
...@@ -840,11 +898,11 @@ void use(T&) ...@@ -840,11 +898,11 @@ void use(T&)
// Disabling clang-tidy warning on non-real useage. // Disabling clang-tidy warning on non-real useage.
// NOLINTBEGIN(performance-unnecessary-value-param) // NOLINTBEGIN(performance-unnecessary-value-param)
code_object_op compile_mlir(const context&, module, const std::vector<instruction_ref>&) code_object_op
compile_mlir(const context&, module, const std::vector<instruction_ref>&, const value&)
{ {
return {}; return {};
} }
// NOLINTEND(performance-unnecessary-value-param)
instruction_ref instruction_ref
// cppcheck-suppress funcArgNamesDifferent // cppcheck-suppress funcArgNamesDifferent
...@@ -854,6 +912,9 @@ insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<ins ...@@ -854,6 +912,9 @@ insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<ins
return m.end(); return m.end();
} }
tuning_config get_tuning_config_mlir(module, const std::vector<shape>&) { return {}; }
// NOLINTEND(performance-unnecessary-value-param)
#endif #endif
} // namespace gpu } // namespace gpu
......
...@@ -76,7 +76,7 @@ namespace gpu { ...@@ -76,7 +76,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC)
#ifdef _WIN32 #ifndef _WIN32
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK)
#endif #endif
...@@ -134,23 +134,22 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -134,23 +134,22 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), eliminate_layout{}), enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), eliminate_layout{}),
prefuse_ops{}, prefuse_ops{},
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{},
dead_code_elimination{},
optimize_module{}, optimize_module{},
fuse_pointwise{}, fuse_pointwise{},
dead_code_elimination{}, dead_code_elimination{},
enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}), enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}),
dead_code_elimination{}, dead_code_elimination{},
#ifdef _WIN32 #ifndef _WIN32
enable_pass(enabled(MIGRAPHX_ENABLE_CK{}), fuse_ck{}), enable_pass(enabled(MIGRAPHX_ENABLE_CK{}), fuse_ck{}),
#endif #endif
dead_code_elimination{}, dead_code_elimination{},
enable_pass(mlir_enabled(), fuse_mlir{&ctx}), enable_pass(mlir_enabled(), fuse_mlir{&ctx}),
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{},
lowering{&ctx, options.offload_copy}, lowering{&ctx, options.offload_copy},
eliminate_contiguous{"gpu::contiguous"}, eliminate_contiguous{"gpu::contiguous"},
dead_code_elimination{}, dead_code_elimination{},
// enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), eliminate_layout{}), enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), eliminate_layout{}),
// dead_code_elimination{}, // dead_code_elimination{},
eliminate_concat{concat_gpu_optimization{}}, eliminate_concat{concat_gpu_optimization{}},
dead_code_elimination{}, dead_code_elimination{},
......
...@@ -36,7 +36,7 @@ endfunction() ...@@ -36,7 +36,7 @@ endfunction()
function(add_c_api_test TEST_NAME TEST_SRC TEST_DIR) function(add_c_api_test TEST_NAME TEST_SRC TEST_DIR)
set(NAME test_api_${TEST_NAME}) set(NAME test_api_${TEST_NAME})
add_executable(${NAME} EXCLUDE_FROM_ALL ${TEST_SRC}) add_executable(${NAME} EXCLUDE_FROM_ALL ${TEST_SRC})
target_link_libraries(${NAME} migraphx_c migraphx) target_link_libraries(${NAME} migraphx_c)
target_include_directories(${NAME} PUBLIC ../include) target_include_directories(${NAME} PUBLIC ../include)
add_test(NAME ${NAME} COMMAND $<TARGET_FILE:${NAME}> WORKING_DIRECTORY ${TEST_DIR}) add_test(NAME ${NAME} COMMAND $<TARGET_FILE:${NAME}> WORKING_DIRECTORY ${TEST_DIR})
add_dependencies(tests ${NAME}) add_dependencies(tests ${NAME})
......
...@@ -99,7 +99,7 @@ TEST_CASE(run_sigmoid_custom_op) ...@@ -99,7 +99,7 @@ TEST_CASE(run_sigmoid_custom_op)
EXPECT(bool{result == migraphx::argument(s, expected_result.data())}); EXPECT(bool{result == migraphx::argument(s, expected_result.data())});
} }
extern "C" void migraphx_test_private_disable_exception_catch(bool b); extern "C" MIGRAPHX_C_EXPORT void migraphx_test_private_disable_exception_catch(bool);
TEST_CASE(run_sigmoid_with_incorrect_shape) TEST_CASE(run_sigmoid_with_incorrect_shape)
{ {
......
...@@ -148,11 +148,13 @@ TEST_CASE(two_transpose_gather) ...@@ -148,11 +148,13 @@ TEST_CASE(two_transpose_gather)
migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), data); migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), data);
auto ctd = m2.add_instruction(migraphx::make_op("contiguous"), td); auto ctd = m2.add_instruction(migraphx::make_op("contiguous"), td);
auto sd = m2.add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), ctd); auto sd = m2.add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), ctd);
auto bd = auto csd = m2.add_instruction(migraphx::make_op("contiguous"), sd);
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), sd); auto bd = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), csd);
auto cbd = m2.add_instruction(migraphx::make_op("contiguous"), bd); auto cbd = m2.add_instruction(migraphx::make_op("contiguous"), bd);
auto r = m2.add_instruction(migraphx::make_op("gather", {{"axis", 2}}), cbd, ind); auto r = m2.add_instruction(migraphx::make_op("gather", {{"axis", 2}}), cbd, ind);
m2.add_return({r}); auto cr = m2.add_instruction(migraphx::make_op("contiguous"), r);
m2.add_return({cr});
} }
EXPECT(m1 == m2); EXPECT(m1 == m2);
...@@ -174,8 +176,11 @@ TEST_CASE(standard_reshape) ...@@ -174,8 +176,11 @@ TEST_CASE(standard_reshape)
auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}}); auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto add = m2.add_instruction(migraphx::make_op("add"), data, data); auto add = m2.add_instruction(migraphx::make_op("add"), data, data);
auto ca = m2.add_instruction(migraphx::make_op("contiguous"), add); auto ca = m2.add_instruction(migraphx::make_op("contiguous"), add);
auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 12, 5}}}), ca); // extra contiguous coming from reshape logic which has "requires_std_shape" attribute
m2.add_return({r}); auto cb = m2.add_instruction(migraphx::make_op("contiguous"), ca);
auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 12, 5}}}), cb);
auto cr = m2.add_instruction(migraphx::make_op("contiguous"), r);
m2.add_return({cr});
} }
EXPECT(m1 == m2); EXPECT(m1 == m2);
......
...@@ -84,7 +84,7 @@ migraphx::program create_program_from_mlir(const migraphx::module& mmlir) ...@@ -84,7 +84,7 @@ migraphx::program create_program_from_mlir(const migraphx::module& mmlir)
inputs.push_back(mm->add_parameter("output", mmlir.get_output_shapes().front())); inputs.push_back(mm->add_parameter("output", mmlir.get_output_shapes().front()));
migraphx::gpu::context ctx; migraphx::gpu::context ctx;
migraphx::gpu::insert_mlir(*mm, mm->end(), compile_mlir(ctx, mmlir, inputs), inputs); migraphx::gpu::insert_mlir(*mm, mm->end(), compile_mlir(ctx, mmlir, inputs, {}), inputs);
return p; return p;
} }
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include <migraphx/gpu/rocblas.hpp> #include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/device_name.hpp> #include <migraphx/gpu/device_name.hpp>
#include <migraphx/auto_contiguous.hpp> #include <migraphx/auto_contiguous.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/replace_allocate.hpp> #include <migraphx/replace_allocate.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -49,6 +50,7 @@ void run_passes(migraphx::module& m, migraphx::gpu::context& ctx) ...@@ -49,6 +50,7 @@ void run_passes(migraphx::module& m, migraphx::gpu::context& ctx)
migraphx::run_passes(m, migraphx::run_passes(m,
{migraphx::auto_contiguous{}, {migraphx::auto_contiguous{},
migraphx::gpu::lowering{&ctx, false}, migraphx::gpu::lowering{&ctx, false},
migraphx::eliminate_contiguous{"gpu::contiguous"},
migraphx::dead_code_elimination{}, migraphx::dead_code_elimination{},
migraphx::replace_allocate{migraphx::gpu::gpu_allocation_model{}}, migraphx::replace_allocate{migraphx::gpu::gpu_allocation_model{}},
migraphx::dead_code_elimination{}, migraphx::dead_code_elimination{},
...@@ -104,13 +106,9 @@ TEST_CASE(quant_dot) ...@@ -104,13 +106,9 @@ TEST_CASE(quant_dot)
auto beta_broadcast = m.add_instruction( auto beta_broadcast = m.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", m3_shape.lens()}}), beta); migraphx::make_op("multibroadcast", {{"out_lens", m3_shape.lens()}}), beta);
auto beta_alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(m3_shape)}}));
auto beta_contiguous =
m.add_instruction(migraphx::make_op("gpu::contiguous"), beta_broadcast, beta_alloc);
auto mul_alloc = m.add_instruction( auto mul_alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(m3_shape)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(m3_shape)}}));
auto m3_beta = m.add_instruction(make_precompile_op("mul"), l3, beta_contiguous, mul_alloc); auto m3_beta = m.add_instruction(make_precompile_op("mul"), l3, beta_broadcast, mul_alloc);
auto gemm_add = m.add_instruction(make_precompile_op("add"), gemm, m3_beta, output); auto gemm_add = m.add_instruction(make_precompile_op("add"), gemm, m3_beta, output);
m.add_return({gemm_add}); m.add_return({gemm_add});
...@@ -158,53 +156,43 @@ TEST_CASE(quant_dot_trans) ...@@ -158,53 +156,43 @@ TEST_CASE(quant_dot_trans)
auto tl1 = auto tl1 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1); m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
migraphx::shape ts1{migraphx::shape::int8_type, {3, 2, 5, 8}}; migraphx::shape ts1{migraphx::shape::int8_type, {3, 2, 5, 8}};
auto alloca = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts1)}}));
auto conta = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl1, alloca);
auto tl2 = auto tl2 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2); m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
migraphx::shape ts2{migraphx::shape::int8_type, {3, 2, 8, 7}}; migraphx::shape ts2{migraphx::shape::int8_type, {3, 2, 8, 7}};
auto allocb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}}));
auto contb = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl2, allocb);
auto alpha_broadcast = m.add_instruction( auto alpha_broadcast = m.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", conta->get_shape().lens()}}), alpha); migraphx::make_op("multibroadcast", {{"out_lens", tl1->get_shape().lens()}}), alpha);
auto alpha_alloc = m.add_instruction(migraphx::make_op( auto alpha_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", "hip::allocate",
{{"shape", {{"shape",
migraphx::to_value(migraphx::shape(migraphx::shape::int32_type, {3, 2, 5, 8}))}})); migraphx::to_value(migraphx::shape(migraphx::shape::int32_type, {3, 2, 5, 8}))}}));
auto alpha_contiguous =
m.add_instruction(migraphx::make_op("gpu::contiguous"), alpha_broadcast, alpha_alloc);
// alpha = int32 and tl1 = int8, convert tl1 to int32 for multiplication and then convert // alpha = int32 and tl1 = int8, convert tl1 to int32 for multiplication and then convert
// back result to int8 // back result to int8
auto tl1_convert_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", {{"shape", migraphx::to_value(alpha_contiguous->get_shape())}}));
auto tl1_convert = auto tl1_convert =
m.add_instruction(make_precompile_op(migraphx::make_op( m.add_instruction(make_precompile_op(migraphx::make_op(
"convert", {{"target_type", alpha->get_shape().type()}})), "convert", {{"target_type", alpha->get_shape().type()}})),
conta, tl1,
tl1_convert_alloc); alpha_alloc);
auto mul_alloc = m.add_instruction(migraphx::make_op( auto mul_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", {{"shape", migraphx::to_value(tl1_convert->get_shape())}})); "hip::allocate", {{"shape", migraphx::to_value(alpha_alloc->get_shape())}}));
auto tl1_alpha_int32 = auto tl1_alpha_int32 =
m.add_instruction(make_precompile_op("mul"), alpha_contiguous, tl1_convert, mul_alloc); m.add_instruction(make_precompile_op("mul"), alpha_broadcast, tl1_convert, mul_alloc);
// convert mul_res to int8 // convert mul_res to int8
auto tl1_alpha_int8_alloc = m.add_instruction(migraphx::make_op( auto tl1_alpha_int8_alloc = m.add_instruction(
"hip::allocate", {{"shape", migraphx::to_value(conta->get_shape())}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts1)}}));
auto tl1_alpha_int8 = auto tl1_alpha_int8 =
m.add_instruction(make_precompile_op(migraphx::make_op( m.add_instruction(make_precompile_op(migraphx::make_op(
"convert", {{"target_type", conta->get_shape().type()}})), "convert", {{"target_type", tl1->get_shape().type()}})),
tl1_alpha_int32, tl1_alpha_int32,
tl1_alpha_int8_alloc); tl1_alpha_int8_alloc);
auto packb = contb; auto packb = tl2;
if(int8_x4) if(int8_x4)
{ {
auto allocpb = m.add_instruction( auto allocpb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}}));
packb = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), contb, allocpb); packb = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), tl2, allocpb);
} }
auto gemm = m.add_instruction( auto gemm = m.add_instruction(
...@@ -301,13 +289,9 @@ TEST_CASE(quant_dot_pad) ...@@ -301,13 +289,9 @@ TEST_CASE(quant_dot_pad)
auto beta_broadcast = auto beta_broadcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s3.lens()}}), beta); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s3.lens()}}), beta);
auto beta_alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(s3)}}));
auto beta_contiguous =
m.add_instruction(migraphx::make_op("gpu::contiguous"), beta_broadcast, beta_alloc);
auto mul_alloc = m.add_instruction( auto mul_alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(s3)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(s3)}}));
auto m3_beta = m.add_instruction(make_precompile_op("mul"), l3, beta_contiguous, mul_alloc); auto m3_beta = m.add_instruction(make_precompile_op("mul"), l3, beta_broadcast, mul_alloc);
auto gemm_add = m.add_instruction(make_precompile_op("add"), gemm, m3_beta, output); auto gemm_add = m.add_instruction(make_precompile_op("add"), gemm, m3_beta, output);
m.add_return({gemm_add}); m.add_return({gemm_add});
return m; return m;
...@@ -357,15 +341,10 @@ TEST_CASE(quant_dot_trans_pad) ...@@ -357,15 +341,10 @@ TEST_CASE(quant_dot_trans_pad)
auto tl1 = auto tl1 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1); m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
migraphx::shape ts1{migraphx::shape::int8_type, {3, 2, 5, 9}}; migraphx::shape ts1{migraphx::shape::int8_type, {3, 2, 5, 9}};
auto ta = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts1)}}));
auto conta = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl1, ta);
auto tl2 = auto tl2 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2); m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
migraphx::shape ts2{migraphx::shape::int8_type, {3, 2, 9, 7}}; migraphx::shape ts2{migraphx::shape::int8_type, {3, 2, 9, 7}};
auto tb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}}));
migraphx::instruction_ref ptb{}; migraphx::instruction_ref ptb{};
if(int8_x4) if(int8_x4)
...@@ -373,42 +352,37 @@ TEST_CASE(quant_dot_trans_pad) ...@@ -373,42 +352,37 @@ TEST_CASE(quant_dot_trans_pad)
ptb = m.add_instruction( ptb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps2)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps2)}}));
} }
auto contb = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl2, tb); auto pb = tl2;
auto pb = contb;
if(int8_x4) if(int8_x4)
{ {
pb = m.add_instruction( pb = m.add_instruction(
migraphx::make_op("gpu::pad", {{"mode", 0}, {"pads", {0, 0, 3, 0, 0, 0, 0, 0}}}), migraphx::make_op("gpu::pad", {{"mode", 0}, {"pads", {0, 0, 3, 0, 0, 0, 0, 0}}}),
contb, tl2,
ptb); ptb);
} }
auto alpha_broadcast = m.add_instruction( auto alpha_broadcast = m.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", conta->get_shape().lens()}}), alpha); migraphx::make_op("multibroadcast", {{"out_lens", tl1->get_shape().lens()}}), alpha);
auto alpha_alloc = m.add_instruction( auto alpha_alloc = m.add_instruction(
migraphx::make_op("hip::allocate", migraphx::make_op("hip::allocate",
{{"shape", {{"shape",
migraphx::to_value(migraphx::shape(migraphx::shape::int32_type, migraphx::to_value(migraphx::shape(migraphx::shape::int32_type,
conta->get_shape().lens()))}})); tl1->get_shape().lens()))}}));
auto alpha_contiguous =
m.add_instruction(migraphx::make_op("gpu::contiguous"), alpha_broadcast, alpha_alloc);
// alpha = int32 and tl1 = int8, convert tl1 to int32 for multiplication and then convert // alpha = int32 and tl1 = int8, convert tl1 to int32 for multiplication and then convert
// back result to int8 // back result to int8
auto tl1_convert_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", {{"shape", migraphx::to_value(alpha_contiguous->get_shape())}}));
auto tl1_convert = auto tl1_convert =
m.add_instruction(make_precompile_op(migraphx::make_op( m.add_instruction(make_precompile_op(migraphx::make_op(
"convert", {{"target_type", alpha->get_shape().type()}})), "convert", {{"target_type", alpha->get_shape().type()}})),
conta, tl1,
tl1_convert_alloc); alpha_alloc);
auto mul_alloc = m.add_instruction(migraphx::make_op( auto mul_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", {{"shape", migraphx::to_value(tl1_convert->get_shape())}})); "hip::allocate", {{"shape", migraphx::to_value(alpha_alloc->get_shape())}}));
auto tl1_alpha_int32 = auto tl1_alpha_int32 =
m.add_instruction(make_precompile_op("mul"), alpha_contiguous, tl1_convert, mul_alloc); m.add_instruction(make_precompile_op("mul"), alpha_broadcast, tl1_convert, mul_alloc);
// convert mul_res to int8 // convert mul_res to int8
auto tl1_alpha_int8_alloc = m.add_instruction(migraphx::make_op( auto tl1_alpha_int8_alloc = m.add_instruction(
"hip::allocate", {{"shape", migraphx::to_value(conta->get_shape())}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts1)}}));
migraphx::instruction_ref pta{}; migraphx::instruction_ref pta{};
if(int8_x4) if(int8_x4)
...@@ -417,9 +391,8 @@ TEST_CASE(quant_dot_trans_pad) ...@@ -417,9 +391,8 @@ TEST_CASE(quant_dot_trans_pad)
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps1)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps1)}}));
} }
auto tl1_alpha_int8 = auto tl1_alpha_int8 = m.add_instruction(
m.add_instruction(make_precompile_op(migraphx::make_op( make_precompile_op(migraphx::make_op("convert", {{"target_type", ts1.type()}})),
"convert", {{"target_type", conta->get_shape().type()}})),
tl1_alpha_int32, tl1_alpha_int32,
tl1_alpha_int8_alloc); tl1_alpha_int8_alloc);
......
...@@ -384,7 +384,7 @@ bool throws(F f, const std::string& msg = "") ...@@ -384,7 +384,7 @@ bool throws(F f, const std::string& msg = "")
} }
template <class T, class U> template <class T, class U>
auto near(T px, U py, double ptol = 1e-6f) auto within_abs(T px, U py, double ptol = 1e-6f)
{ {
return make_function("near", [](auto x, auto y, auto tol) { return std::abs(x - y) < tol; })( return make_function("near", [](auto x, auto y, auto tol) { return std::abs(x - y) < tol; })(
px, py, ptol); px, py, ptol);
......
...@@ -82,9 +82,9 @@ TEST_CASE(generate_module) ...@@ -82,9 +82,9 @@ TEST_CASE(generate_module)
auto f = compile_module<float(float, float)>(m); auto f = compile_module<float(float, float)>(m);
EXPECT(test::near(f(2, 2), 2)); EXPECT(test::within_abs(f(2, 2), 2));
EXPECT(test::near(f(10, 6), 4)); EXPECT(test::within_abs(f(10, 6), 4));
EXPECT(test::near(f(1, 2), std::sqrt(3))); EXPECT(test::within_abs(f(1, 2), std::sqrt(3)));
} }
TEST_CASE(generate_module_with_literals) TEST_CASE(generate_module_with_literals)
...@@ -99,9 +99,9 @@ TEST_CASE(generate_module_with_literals) ...@@ -99,9 +99,9 @@ TEST_CASE(generate_module_with_literals)
auto f = compile_module<float(float, float)>(m); auto f = compile_module<float(float, float)>(m);
EXPECT(test::near(f(1, 2), 2)); EXPECT(test::within_abs(f(1, 2), 2));
EXPECT(test::near(f(9, 6), 4)); EXPECT(test::within_abs(f(9, 6), 4));
EXPECT(test::near(f(0, 2), std::sqrt(3))); EXPECT(test::within_abs(f(0, 2), std::sqrt(3)));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment