Commit 5af9aac0 authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_batch_pass' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_test_runner

parents 7b2516e0 05e81ed3
...@@ -345,15 +345,21 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -345,15 +345,21 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("is_compiled", &migraphx::program::is_compiled) .def("is_compiled", &migraphx::program::is_compiled)
.def( .def(
"compile", "compile",
[](migraphx::program& p, const migraphx::target& t, bool offload_copy, bool fast_math) { [](migraphx::program& p,
const migraphx::target& t,
bool offload_copy,
bool fast_math,
bool exhaustive_tune) {
migraphx::compile_options options; migraphx::compile_options options;
options.offload_copy = offload_copy; options.offload_copy = offload_copy;
options.fast_math = fast_math; options.fast_math = fast_math;
options.exhaustive_tune = exhaustive_tune;
p.compile(t, options); p.compile(t, options);
}, },
py::arg("t"), py::arg("t"),
py::arg("offload_copy") = true, py::arg("offload_copy") = true,
py::arg("fast_math") = true) py::arg("fast_math") = true,
py::arg("exhaustive_tune") = false)
.def("get_main_module", [](const migraphx::program& p) { return p.get_main_module(); }) .def("get_main_module", [](const migraphx::program& p) { return p.get_main_module(); })
.def( .def(
"create_module", "create_module",
......
...@@ -104,19 +104,17 @@ void replace_allocate::apply(module& m) const ...@@ -104,19 +104,17 @@ void replace_allocate::apply(module& m) const
continue; continue;
auto s = ins->get_shape(); auto s = ins->get_shape();
if(not main_offload_copy and not(m.use_local_alloc) and model.needs_out_params() and
if(not main_offload_copy and model.needs_out_params() and contains(mod_output_names, ins)) contains(mod_output_names, ins))
{ {
auto out_param = m.add_parameter(mod_output_names[ins], s); auto out_param = m.add_parameter(mod_output_names[ins], s);
m.replace_instruction(ins, out_param); m.replace_instruction(ins, out_param);
continue;
} }
else
m.replace_instruction( {
ins, m.replace_instruction(ins,
m.insert_instruction(ins, make_op(model.name(), migraphx::value{{"shape", to_value(s)}}));
make_op(model.name(), migraphx::value{{"shape", to_value(s)}}))); }
} }
} }
......
...@@ -483,6 +483,17 @@ std::string shape::type_string() const { return name(this->type()); } ...@@ -483,6 +483,17 @@ std::string shape::type_string() const { return name(this->type()); }
bool shape::dynamic() const { return not impl->m_dyn_dims.empty(); } bool shape::dynamic() const { return not impl->m_dyn_dims.empty(); }
bool shape::any_of_dynamic() const
{
if(this->dynamic())
{
return true;
}
return std::any_of(this->sub_shapes().cbegin(), this->sub_shapes().cend(), [](auto s) {
return s.any_of_dynamic();
});
}
const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return impl->m_dyn_dims; } const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return impl->m_dyn_dims; }
std::vector<std::size_t> shape::min_lens() const std::vector<std::size_t> shape::min_lens() const
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <migraphx/op/reshape.hpp> #include <migraphx/op/reshape.hpp>
#include <migraphx/op/transpose.hpp> #include <migraphx/op/transpose.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/common.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
...@@ -340,12 +341,18 @@ struct find_inner_broadcast ...@@ -340,12 +341,18 @@ struct find_inner_broadcast
std::back_inserter(inputs), std::back_inserter(inputs),
[](auto i) { return i->inputs().front(); }); [](auto i) { return i->inputs().front(); });
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) { if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) {
return i->get_shape() != inputs.front()->get_shape(); return i->get_shape() != inputs.front()->get_shape() and
i->get_shape().elements() != 1;
})) }))
return; return;
auto op = m.insert_instruction(ins, ins->get_operator(), inputs); auto b_it = std::find_if(broadcasts.begin(), broadcasts.end(), [&](auto i) {
m.replace_instruction(ins, broadcasts.front()->get_operator(), op); return not i->get_shape().scalar();
});
if(b_it == broadcasts.end())
b_it = broadcasts.begin();
auto op = insert_common_op(m, ins, ins->get_operator(), inputs);
m.replace_instruction(ins, (*b_it)->get_operator(), op);
} }
}; };
...@@ -975,7 +982,7 @@ struct find_neg_unit_ops ...@@ -975,7 +982,7 @@ struct find_neg_unit_ops
auto ins = r.result; auto ins = r.result;
auto c_in = r.instructions["x"]; auto c_in = r.instructions["x"];
auto neg = m.add_instruction(make_op("neg"), c_in); auto neg = m.insert_instruction(ins, make_op("neg"), c_in);
m.replace_instruction(ins, neg); m.replace_instruction(ins, neg);
} }
}; };
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/split_single_dyn_dim.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
bool has_one_dyn_dim(std::unordered_map<std::string, shape> param_shapes,
std::string& dyn_param_str,
int& dyn_index,
int& min_dim,
int& max_dim)
{
// true if parameters contain exactly one dynamic shape with exactly one non-fixed
// dynamic_dimension
if(std::none_of(
param_shapes.cbegin(), param_shapes.cend(), [](auto ps) { return ps.second.dynamic(); }))
return false;
int num_dynamic = 0;
std::string out_str;
int tmp_min = -1;
int tmp_max = -1;
int tmp_ind = -1;
for(auto ps : param_shapes)
{
if(ps.second.dynamic())
{
num_dynamic += 1;
if(num_dynamic > 1)
{
return false;
}
int num_nf = 0;
auto dds = ps.second.dyn_dims();
for(int i = 0; i < dds.size(); ++i)
{
const auto& dd = dds.at(i);
if(not dd.is_fixed())
{
num_nf += 1;
tmp_min = dd.min;
tmp_max = dd.max;
tmp_ind = i;
}
}
if(num_nf == 1)
{
out_str = ps.first;
}
else
{
return false;
}
}
}
min_dim = tmp_min;
max_dim = tmp_max;
dyn_index = tmp_ind;
dyn_param_str = out_str;
return true;
}
/**
* Make all the batch sizes in the range for now.
* Probably won't work for `if` and `loop` instructions, depending on how the submodules for those
* work create additional submodules for optimal values if not already done insert select_module
* instruction to the top, replace return bypassing other instructions. Unused instructions should
* be removed by dead_code_elimination
*/
void split_single_dyn_dim::apply(module_pass_manager& mpm) const
{
module_ref mm = &mpm.get_module();
auto param_names = mm->get_parameter_names();
auto param_shapes = mm->get_parameter_shapes();
std::string dyn_param_name;
int dyn_index;
int min_dim;
int max_dim;
if(has_one_dyn_dim(param_shapes, dyn_param_name, dyn_index, min_dim, max_dim))
{
const auto& dyn_param = mm->get_parameter(dyn_param_name);
auto dyn_param_shape = mm->get_parameter_shape(dyn_param_name);
std::vector<module_ref> submodules;
// create submodules for each dimension size
for(int dim_size = min_dim; dim_size <= max_dim; ++dim_size)
{
auto submod = mpm.create_module("batch_" + std::to_string(dim_size));
// instruction map for new static submodule parameters
std::unordered_map<instruction_ref, instruction_ref> map_ins;
// create static shape using dim_size
auto static_lens = dyn_param_shape.max_lens();
static_lens.at(dyn_index) = dim_size;
auto static_param = submod->add_parameter(
dyn_param_name, migraphx::shape{dyn_param_shape.type(), static_lens});
map_ins[dyn_param] = static_param;
auto outputs = submod->add_instructions(mm, map_ins);
submod->add_return({outputs});
submodules.push_back(submod);
}
// redirect to select_module operator and return
std::vector<instruction_ref> sm_inputs;
std::transform(param_names.cbegin(),
param_names.cend(),
std::back_inserter(sm_inputs),
[&](auto pn) { return mm->get_parameter(pn); });
migraphx::shape out_attr = migraphx::shape{mm->get_output_shapes()};
auto sm_ins = mm->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
sm_inputs,
submodules);
mm->replace_return({sm_ins});
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
##################################################################################### # ####################################################################################
# The MIT License (MIT) # The MIT License (MIT)
# #
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
...@@ -20,9 +20,9 @@ ...@@ -20,9 +20,9 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE. # THE SOFTWARE.
##################################################################################### # ####################################################################################
list(APPEND CMAKE_PREFIX_PATH /opt/rocm /opt/rocm/hip /opt/rocm/hcc) list(APPEND CMAKE_PREFIX_PATH /opt/rocm /opt/rocm/hip)
find_package(miopen) find_package(miopen)
# rocblas # rocblas
...@@ -33,6 +33,8 @@ if(NOT TARGET MIOpen) ...@@ -33,6 +33,8 @@ if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen") message(SEND_ERROR "Cant find miopen")
endif() endif()
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipRTC APIs")
include(Embed) include(Embed)
file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS} file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS}
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp) ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp)
...@@ -46,6 +48,7 @@ add_library(compile_for_gpu INTERFACE) ...@@ -46,6 +48,7 @@ add_library(compile_for_gpu INTERFACE)
target_compile_options(compile_for_gpu INTERFACE -std=c++17 -fno-gpu-rdc -Wno-cuda-compat -Wno-unused-command-line-argument -Xclang -fallow-half-arguments-and-returns) target_compile_options(compile_for_gpu INTERFACE -std=c++17 -fno-gpu-rdc -Wno-cuda-compat -Wno-unused-command-line-argument -Xclang -fallow-half-arguments-and-returns)
target_link_libraries(compile_for_gpu INTERFACE hip::device -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument -Wno-option-ignored) target_link_libraries(compile_for_gpu INTERFACE hip::device -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument -Wno-option-ignored)
check_cxx_compiler_flag("--cuda-host-only -fhip-lambda-host-device -x hip" HAS_HIP_LAMBDA_HOST_DEVICE) check_cxx_compiler_flag("--cuda-host-only -fhip-lambda-host-device -x hip" HAS_HIP_LAMBDA_HOST_DEVICE)
if(HAS_HIP_LAMBDA_HOST_DEVICE) if(HAS_HIP_LAMBDA_HOST_DEVICE)
message(STATUS "Enable -fhip-lambda-host-device") message(STATUS "Enable -fhip-lambda-host-device")
target_compile_options(compile_for_gpu INTERFACE -fhip-lambda-host-device) target_compile_options(compile_for_gpu INTERFACE -fhip-lambda-host-device)
...@@ -60,11 +63,13 @@ target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURR ...@@ -60,11 +63,13 @@ target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURR
target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>) target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>)
add_library(kernel_file_check EXCLUDE_FROM_ALL) add_library(kernel_file_check EXCLUDE_FROM_ALL)
foreach(KERNEL_FILE ${KERNEL_FILES}) foreach(KERNEL_FILE ${KERNEL_FILES})
get_filename_component(KERNEL_BASE_FILE ${KERNEL_FILE} NAME_WE) get_filename_component(KERNEL_BASE_FILE ${KERNEL_FILE} NAME_WE)
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/kernels/include/migraphx/kernels/${KERNEL_BASE_FILE}.cpp "#include <migraphx/kernels/${KERNEL_BASE_FILE}.hpp>\n") file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/kernels/include/migraphx/kernels/${KERNEL_BASE_FILE}.cpp "#include <migraphx/kernels/${KERNEL_BASE_FILE}.hpp>\n")
target_sources(kernel_file_check PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/kernels/include/migraphx/kernels/${KERNEL_BASE_FILE}.cpp) target_sources(kernel_file_check PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/kernels/include/migraphx/kernels/${KERNEL_BASE_FILE}.cpp)
endforeach() endforeach()
target_compile_definitions(kernel_file_check PRIVATE -DMIGRAPHX_NLOCAL=256) target_compile_definitions(kernel_file_check PRIVATE -DMIGRAPHX_NLOCAL=256)
target_include_directories(kernel_file_check PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/>) target_include_directories(kernel_file_check PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/>)
target_link_libraries(kernel_file_check compile_for_gpu) target_link_libraries(kernel_file_check compile_for_gpu)
...@@ -125,6 +130,7 @@ function(register_migraphx_gpu_ops PREFIX) ...@@ -125,6 +130,7 @@ function(register_migraphx_gpu_ops PREFIX)
register_op(migraphx_gpu HEADER migraphx/gpu/${OP}.hpp OPERATORS gpu::${PREFIX}${OP} INCLUDES migraphx/gpu/context.hpp) register_op(migraphx_gpu HEADER migraphx/gpu/${OP}.hpp OPERATORS gpu::${PREFIX}${OP} INCLUDES migraphx/gpu/context.hpp)
endforeach() endforeach()
endfunction() endfunction()
register_migraphx_gpu_ops(hip_ register_migraphx_gpu_ops(hip_
argmax argmax
argmin argmin
...@@ -164,29 +170,8 @@ register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp ...@@ -164,29 +170,8 @@ register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp
rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_gpu) rocm_clang_tidy_check(migraphx_gpu)
# look for offload bundler
get_filename_component(CMAKE_CXX_COMPILER_PATH "${CMAKE_CXX_COMPILER}" PATH)
if(CMAKE_CXX_COMPILER MATCHES ".*clang\\+\\+$")
find_program(MIGRAPHX_OFFLOADBUNDLER_BIN clang-offload-bundler
HINTS ${CMAKE_CXX_COMPILER_PATH}
PATH_SUFFIXES bin
PATHS /opt/rocm/llvm
)
else()
find_program(MIGRAPHX_EXTRACT_KERNEL extractkernel
PATH_SUFFIXES bin
HINTS ${CMAKE_CXX_COMPILER_PATH}
PATHS
/opt/rocm/hip
/opt/rocm/hcc
/opt/rocm
)
endif()
message(STATUS "clang-offload-bundler: ${MIGRAPHX_OFFLOADBUNDLER_BIN}")
message(STATUS "extractkernel: ${MIGRAPHX_EXTRACT_KERNEL}")
set(MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL "") set(MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL "")
if(MIGRAPHX_ENABLE_MLIR) if(MIGRAPHX_ENABLE_MLIR)
# Find package rocMLIR # Find package rocMLIR
find_package(rocMLIR 1.0.0 CONFIG REQUIRED) find_package(rocMLIR 1.0.0 CONFIG REQUIRED)
...@@ -195,36 +180,38 @@ if(MIGRAPHX_ENABLE_MLIR) ...@@ -195,36 +180,38 @@ if(MIGRAPHX_ENABLE_MLIR)
target_link_libraries(migraphx_gpu PUBLIC rocMLIR::rockCompiler) target_link_libraries(migraphx_gpu PUBLIC rocMLIR::rockCompiler)
endif() endif()
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "")
if(MIGRAPHX_USE_HIPRTC) if(MIGRAPHX_USE_HIPRTC)
target_compile_definitions(migraphx_gpu PRIVATE -DMIGRAPHX_USE_HIPRTC=1) message(STATUS "MIGraphX is using hipRTC")
target_compile_definitions(migraphx_gpu PRIVATE -DMIGRAPHX_USE_HIPRTC=1)
else() else()
# Get flags needed to compile hip message(STATUS "MIGraphX is using HIP Clang")
include(TargetFlags)
target_flags(HIP_COMPILER_FLAGS hip::device) # Get flags needed to compile hip
# Remove cuda arch flags include(TargetFlags)
string(REGEX REPLACE --cuda-gpu-arch=[a-z0-9]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}") target_flags(HIP_COMPILER_FLAGS hip::device)
string(REGEX REPLACE --offload-arch=[a-z0-9:+-]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
# Skip library paths since hip will incorrectly treat it as a source file # Remove cuda arch flags
string(APPEND HIP_COMPILER_FLAGS " ") string(REGEX REPLACE --cuda-gpu-arch=[a-z0-9]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
foreach(_unused RANGE 2) string(REGEX REPLACE --offload-arch=[a-z0-9:+-]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
# Skip library paths since hip will incorrectly treat it as a source file
string(APPEND HIP_COMPILER_FLAGS " ")
foreach(_unused RANGE 2)
string(REGEX REPLACE " /[^ ]+\\.(a|so) " " " HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}") string(REGEX REPLACE " /[^ ]+\\.(a|so) " " " HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
endforeach() endforeach()
message(STATUS "Hip compiler flags: ${HIP_COMPILER_FLAGS}") message(STATUS "Hip compiler flags: ${HIP_COMPILER_FLAGS}")
target_compile_definitions(migraphx_gpu PRIVATE target_compile_definitions(migraphx_gpu PRIVATE
"-DMIGRAPHX_HIP_COMPILER=${CMAKE_CXX_COMPILER}" "-DMIGRAPHX_HIP_COMPILER=${CMAKE_CXX_COMPILER}"
"-DMIGRAPHX_HIP_COMPILER_FLAGS=${HIP_COMPILER_FLAGS}" "-DMIGRAPHX_HIP_COMPILER_FLAGS=${HIP_COMPILER_FLAGS}"
"-DMIGRAPHX_OFFLOADBUNDLER_BIN=${MIGRAPHX_OFFLOADBUNDLER_BIN}" )
"-DMIGRAPHX_EXTRACT_KERNEL=${MIGRAPHX_EXTRACT_KERNEL}"
"-DMIGRAPHX_USE_HIPRTC=0"
)
if(DEFINED CMAKE_CXX_COMPILER_LAUNCHER)
execute_process(COMMAND which ${CMAKE_CXX_COMPILER_LAUNCHER} OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER)
string(STRIP "${MIGRAPHX_HIP_COMPILER_LAUNCHER}" MIGRAPHX_HIP_COMPILER_LAUNCHER)
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_HIP_COMPILER_LAUNCHER=${MIGRAPHX_HIP_COMPILER_LAUNCHER}")
endif()
if(DEFINED CMAKE_CXX_COMPILER_LAUNCHER)
execute_process(COMMAND which ${CMAKE_CXX_COMPILER_LAUNCHER} OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER)
string(STRIP "${MIGRAPHX_HIP_COMPILER_LAUNCHER}" MIGRAPHX_HIP_COMPILER_LAUNCHER)
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_HIP_COMPILER_LAUNCHER=${MIGRAPHX_HIP_COMPILER_LAUNCHER}")
endif()
endif() endif()
# Check miopen find mode api # Check miopen find mode api
...@@ -233,8 +220,7 @@ get_target_property(MIOPEN_LOCATION MIOpen LOCATION) ...@@ -233,8 +220,7 @@ get_target_property(MIOPEN_LOCATION MIOpen LOCATION)
check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API) check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API)
check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API) check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API)
# TODO: Set default to HAS_FIND_2_API set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "")
set(MIGRAPHX_USE_FIND_2_API OFF CACHE BOOL "")
if(MIGRAPHX_USE_FIND_2_API) if(MIGRAPHX_USE_FIND_2_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API) target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API)
...@@ -250,8 +236,6 @@ else() ...@@ -250,8 +236,6 @@ else()
message(STATUS "MIOpen does not have find mode api") message(STATUS "MIOpen does not have find mode api")
endif() endif()
# Workaround broken rocblas headers
target_compile_definitions(migraphx_gpu PUBLIC -D__HIP_PLATFORM_HCC__=1)
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas) target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels) target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
...@@ -262,4 +246,3 @@ rocm_install_targets( ...@@ -262,4 +246,3 @@ rocm_install_targets(
INCLUDE INCLUDE
${CMAKE_CURRENT_SOURCE_DIR}/include ${CMAKE_CURRENT_SOURCE_DIR}/include
) )
...@@ -29,10 +29,9 @@ ...@@ -29,10 +29,9 @@
#include <cassert> #include <cassert>
#include <iostream> #include <iostream>
#if MIGRAPHX_USE_HIPRTC #ifdef MIGRAPHX_USE_HIPRTC
#include <hip/hiprtc.h> #include <hip/hiprtc.h>
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/env.hpp>
#else #else
#include <migraphx/compile_src.hpp> #include <migraphx/compile_src.hpp>
#include <migraphx/process.hpp> #include <migraphx/process.hpp>
...@@ -48,9 +47,10 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_OPTIMIZE); ...@@ -48,9 +47,10 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_OPTIMIZE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_ASM); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_ASM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC);
#if MIGRAPHX_USE_HIPRTC #ifdef MIGRAPHX_USE_HIPRTC
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_HIPRTC) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_HIPRTC);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS);
std::string hiprtc_error(hiprtcResult err, const std::string& msg) std::string hiprtc_error(hiprtcResult err, const std::string& msg)
{ {
...@@ -144,24 +144,28 @@ struct hiprtc_program ...@@ -144,24 +144,28 @@ struct hiprtc_program
std::back_inserter(c_options), std::back_inserter(c_options),
[](const std::string& s) { return s.c_str(); }); [](const std::string& s) { return s.c_str(); });
auto result = hiprtcCompileProgram(prog.get(), c_options.size(), c_options.data()); auto result = hiprtcCompileProgram(prog.get(), c_options.size(), c_options.data());
std::cerr << log() << std::endl; auto prog_log = log();
if(not prog_log.empty())
{
std::cerr << prog_log << std::endl;
}
if(result != HIPRTC_SUCCESS) if(result != HIPRTC_SUCCESS)
MIGRAPHX_HIPRTC_THROW(result, "Compilation failed."); MIGRAPHX_HIPRTC_THROW(result, "Compilation failed.");
} }
std::string log() std::string log() const
{ {
std::size_t n = 0; std::size_t n = 0;
MIGRAPHX_HIPRTC(hiprtcGetProgramLogSize(prog.get(), &n)); MIGRAPHX_HIPRTC(hiprtcGetProgramLogSize(prog.get(), &n));
if(n < 2) if(n == 0)
return {}; return {};
std::vector<char> buffer(n); std::string buffer(n, '\0');
MIGRAPHX_HIPRTC(hiprtcGetProgramLog(prog.get(), buffer.data())); MIGRAPHX_HIPRTC(hiprtcGetProgramLog(prog.get(), buffer.data()));
assert(buffer.back() == 0); assert(buffer.back() != 0);
return {buffer.begin(), buffer.end() - 1}; return buffer;
} }
std::vector<char> get_code_obj() std::vector<char> get_code_obj() const
{ {
std::size_t n = 0; std::size_t n = 0;
MIGRAPHX_HIPRTC(hiprtcGetCodeSize(prog.get(), &n)); MIGRAPHX_HIPRTC(hiprtcGetCodeSize(prog.get(), &n));
...@@ -176,6 +180,17 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -176,6 +180,17 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
{ {
hiprtc_program prog(srcs); hiprtc_program prog(srcs);
auto options = split_string(params, ' '); auto options = split_string(params, ' ');
options.push_back("-DMIGRAPHX_USE_HIPRTC=1");
// remove following three compilation flags for HIPRTC once fixes from hipRTC are available in
if(enabled(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS{}))
{
options.push_back("-DMIGRAPHX_HAS_DPP=0");
options.push_back("-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1");
options.push_back("-Wno-reserved-identifier");
options.push_back("-Wno-gnu-line-marker");
options.push_back("-Wno-old-style-cast");
}
if(enabled(MIGRAPHX_GPU_DEBUG{})) if(enabled(MIGRAPHX_GPU_DEBUG{}))
options.push_back("-DMIGRAPHX_DEBUG"); options.push_back("-DMIGRAPHX_DEBUG");
if(std::none_of(options.begin(), options.end(), [](const std::string& s) { if(std::none_of(options.begin(), options.end(), [](const std::string& s) {
...@@ -183,7 +198,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -183,7 +198,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
})) }))
options.push_back("-std=c++17"); options.push_back("-std=c++17");
options.push_back("-fno-gpu-rdc"); options.push_back("-fno-gpu-rdc");
options.push_back(" -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3")); options.push_back("-O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3"));
options.push_back("-Wno-cuda-compat"); options.push_back("-Wno-cuda-compat");
options.push_back("--offload-arch=" + arch); options.push_back("--offload-arch=" + arch);
prog.compile(options); prog.compile(options);
...@@ -192,12 +207,6 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -192,12 +207,6 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
#else // MIGRAPHX_USE_HIPRTC #else // MIGRAPHX_USE_HIPRTC
bool is_hcc_compiler()
{
static const auto result = ends_with(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER), "hcc");
return result;
}
bool is_hip_clang_compiler() bool is_hip_clang_compiler()
{ {
static const auto result = ends_with(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER), "clang++"); static const auto result = ends_with(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER), "clang++");
...@@ -221,7 +230,7 @@ std::vector<std::vector<char>> ...@@ -221,7 +230,7 @@ std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch) compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch)
{ {
assert(not srcs.empty()); assert(not srcs.empty());
if(not is_hcc_compiler() and not is_hip_clang_compiler()) if(not is_hip_clang_compiler())
MIGRAPHX_THROW("Unknown hip compiler: " + MIGRAPHX_THROW("Unknown hip compiler: " +
std::string(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER))); std::string(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER)));
...@@ -231,16 +240,9 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -231,16 +240,9 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
if(enabled(MIGRAPHX_GPU_DEBUG_SYM{})) if(enabled(MIGRAPHX_GPU_DEBUG_SYM{}))
params += " -g"; params += " -g";
params += " -c"; params += " -c";
if(is_hcc_compiler())
{
params += " -amdgpu-target=" + arch;
}
else if(is_hip_clang_compiler())
{
params += " --offload-arch=" + arch; params += " --offload-arch=" + arch;
params += " --cuda-device-only"; params += " --cuda-device-only";
params += " -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3") + " "; params += " -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3") + " ";
}
if(enabled(MIGRAPHX_GPU_DEBUG{})) if(enabled(MIGRAPHX_GPU_DEBUG{}))
params += " -DMIGRAPHX_DEBUG"; params += " -DMIGRAPHX_DEBUG";
...@@ -255,24 +257,6 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -255,24 +257,6 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
if(has_compiler_launcher()) if(has_compiler_launcher())
compiler.launcher = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_LAUNCHER); compiler.launcher = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_LAUNCHER);
#endif #endif
if(is_hcc_compiler())
compiler.process = [&](const fs::path& obj_path) -> fs::path {
process{MIGRAPHX_STRINGIZE(MIGRAPHX_EXTRACT_KERNEL) + std::string{" -i "} +
obj_path.string()}
.cwd(obj_path.parent_path());
for(const auto& entry : fs::directory_iterator{obj_path.parent_path()})
{
const auto& hsaco_path = entry.path();
if(not fs::is_regular_file(hsaco_path))
continue;
if(hsaco_path.extension() != ".hsaco")
continue;
return hsaco_path;
}
MIGRAPHX_THROW("Missing hsaco");
};
if(enabled(MIGRAPHX_GPU_DUMP_SRC{})) if(enabled(MIGRAPHX_GPU_DUMP_SRC{}))
{ {
for(const auto& src : srcs) for(const auto& src : srcs)
...@@ -292,6 +276,8 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -292,6 +276,8 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
return {compiler.compile(srcs)}; return {compiler.compile(srcs)};
} }
#endif // MIGRAPHX_USE_HIPRTC
std::string enum_params(std::size_t count, std::string param) std::string enum_params(std::size_t count, std::string param)
{ {
std::vector<std::string> items(count); std::vector<std::string> items(count);
...@@ -299,8 +285,6 @@ std::string enum_params(std::size_t count, std::string param) ...@@ -299,8 +285,6 @@ std::string enum_params(std::size_t count, std::string param)
return join_strings(items, ","); return join_strings(items, ",");
} }
#endif // MIGRAPHX_USE_HIPRTC
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -29,7 +29,6 @@ ...@@ -29,7 +29,6 @@
#include <migraphx/context.hpp> #include <migraphx/context.hpp>
#include <migraphx_kernels.hpp> #include <migraphx_kernels.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -80,6 +79,7 @@ std::string generate_args_hpp(const std::vector<shape>& inputs) ...@@ -80,6 +79,7 @@ std::string generate_args_hpp(const std::vector<shape>& inputs)
#include <migraphx/kernels/args.hpp> #include <migraphx/kernels/args.hpp>
#include <migraphx/kernels/tensor_view.hpp> #include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/types.hpp>
namespace migraphx { namespace migraphx {
......
...@@ -112,14 +112,8 @@ inline auto gs_launch(hipStream_t stream, index_int n, index_int local = 1024) ...@@ -112,14 +112,8 @@ inline auto gs_launch(hipStream_t stream, index_int n, index_int local = 1024)
#ifdef MIGRAPHX_USE_CLANG_TIDY #ifdef MIGRAPHX_USE_CLANG_TIDY
#define MIGRAPHX_DEVICE_SHARED #define MIGRAPHX_DEVICE_SHARED
#else #else
// Workaround hcc's broken tile_static macro
#ifdef tile_static
#undef tile_static
#define MIGRAPHX_DEVICE_SHARED __attribute__((tile_static))
#else
#define MIGRAPHX_DEVICE_SHARED __shared__ #define MIGRAPHX_DEVICE_SHARED __shared__
#endif #endif
#endif
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -36,6 +36,7 @@ namespace gpu { ...@@ -36,6 +36,7 @@ namespace gpu {
namespace device { namespace device {
#ifdef MIGRAPHX_NO_DPP #ifdef MIGRAPHX_NO_DPP
template <index_int N, template <index_int N,
class Op, class Op,
class T, class T,
...@@ -62,6 +63,7 @@ __device__ auto block_reduce(index idx, Op op, T init, ForStride fs, F f) ...@@ -62,6 +63,7 @@ __device__ auto block_reduce(index idx, Op op, T init, ForStride fs, F f)
} }
return buffer[0]; return buffer[0];
} }
#else #else
constexpr unsigned int dpp_row_shr(unsigned int x) { return 0x110u | x; } constexpr unsigned int dpp_row_shr(unsigned int x) { return 0x110u | x; }
...@@ -96,11 +98,7 @@ __device__ T dpp_mov(T& x) ...@@ -96,11 +98,7 @@ __device__ T dpp_mov(T& x)
input.data = x; input.data = x;
for(index_int i = 0; i < n; i++) for(index_int i = 0; i < n; i++)
{ {
#if defined(__HCC__)
output.reg[i] = __llvm_amdgcn_move_dpp(input.reg[i], DppCtrl, RowMask, BankMask, BoundCtrl);
#else
output.reg[i] = __hip_move_dpp(input.reg[i], DppCtrl, RowMask, BankMask, BoundCtrl); output.reg[i] = __hip_move_dpp(input.reg[i], DppCtrl, RowMask, BankMask, BoundCtrl);
#endif
} }
return output.data; return output.data;
} }
...@@ -310,4 +308,4 @@ void reduce(hipStream_t stream, ...@@ -310,4 +308,4 @@ void reduce(hipStream_t stream,
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif // MIGRAPHX_NO_DPP
...@@ -553,11 +553,13 @@ struct find_gemm_pointwise ...@@ -553,11 +553,13 @@ struct find_gemm_pointwise
{ {
auto matcher() const auto matcher() const
{ {
return precompile_name("pointwise")( auto gemm_op = match::name("gpu::gemm")(match::nargs(3), match::used_once()).bind("gemm");
auto binary_op = match::all_of(
match::nargs(3), match::nargs(3),
match::either_arg(0, 1)( match::either_arg(0, 1)(
match::any_of(match::standard_shape(), match::is_constant()).bind("c"), match::any_of(match::standard_shape(), match::is_constant()).bind("c"), gemm_op));
match::name("gpu::gemm")(match::nargs(3), match::used_once()).bind("gemm"))); auto unary_op = match::all_of(match::nargs(2), match::arg(0)(gemm_op));
return precompile_name("pointwise")(match::any_of(binary_op, unary_op));
} }
// TODO: Move to matcher.hpp // TODO: Move to matcher.hpp
...@@ -589,15 +591,29 @@ struct find_gemm_pointwise ...@@ -589,15 +591,29 @@ struct find_gemm_pointwise
return match::name("@return")(match::args(match::any_of(add, mul_add, add_mul))); return match::name("@return")(match::args(match::any_of(add, mul_add, add_mul)));
} }
static auto match_mul(const std::string& input)
{
auto mul = match_mul_const(match_param(input), "alpha");
return match::name("@return")(match::args(mul));
}
static float get_float(instruction_ref ins) { return ins->get_literal().at<float>(); } static float get_float(instruction_ref ins) { return ins->get_literal().at<float>(); }
template <class Gemm> template <class Gemm>
static bool update_gemm(Gemm& gemm, module_ref pm, unsigned input) static bool update_gemm(Gemm& gemm, module_ref pm, unsigned input)
{ {
auto names = pm->get_parameter_names(); auto names = pm->get_parameter_names();
if(names.size() != 2)
return false;
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
if(names.size() == 1)
{
auto mr = match::match_instruction(*pm, std::prev(pm->end()), match_mul(names[input]));
if(mr.result == pm->end())
return false;
gemm.alpha *= get_float(mr.instructions["alpha"]);
return true;
}
else if(names.size() == 2)
{
unsigned output = input == 0 ? 1 : 0; unsigned output = input == 0 ? 1 : 0;
auto mr = match::match_instruction( auto mr = match::match_instruction(
*pm, std::prev(pm->end()), match_add(names[input], names[output])); *pm, std::prev(pm->end()), match_add(names[input], names[output]));
...@@ -614,24 +630,35 @@ struct find_gemm_pointwise ...@@ -614,24 +630,35 @@ struct find_gemm_pointwise
} }
return true; return true;
} }
else
{
return false;
}
}
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto gemm_ins = r.instructions["gemm"]; auto gemm_ins = r.instructions["gemm"];
auto c_ins = r.instructions["c"];
auto gemm = any_cast<rocblas_gemm<op::dot>>(gemm_ins->get_operator()); auto gemm = any_cast<rocblas_gemm<op::dot>>(gemm_ins->get_operator());
// Already fused gemm // Already fused gemm
if(not float_equal(gemm.beta, 0)) if(not float_equal(gemm.beta, 0))
return; return;
if(ins->inputs().size() == 3)
gemm.beta = 1; gemm.beta = 1;
if(not update_gemm( if(not update_gemm(
gemm, ins->module_inputs().front(), ins->inputs().front() == gemm_ins ? 0 : 1)) gemm, ins->module_inputs().front(), ins->inputs().front() == gemm_ins ? 0 : 1))
return; return;
auto inputs = gemm_ins->inputs();
inputs.pop_back();
if(ins->inputs().size() == 3)
{
auto c_ins = r.instructions["c"];
// const-fold input if not standard shape since rocblas can't handle it // const-fold input if not standard shape since rocblas can't handle it
if(not c_ins->get_shape().standard()) if(not c_ins->get_shape().standard())
{ {
...@@ -639,11 +666,9 @@ struct find_gemm_pointwise ...@@ -639,11 +666,9 @@ struct find_gemm_pointwise
auto l = c.compute(c.compute_shape({c_ins->get_shape()}), {c_ins->eval()}); auto l = c.compute(c.compute_shape({c_ins->get_shape()}), {c_ins->eval()});
c_ins = m.add_literal(l.get_shape(), l.data()); c_ins = m.add_literal(l.get_shape(), l.data());
} }
auto inputs = gemm_ins->inputs();
inputs.pop_back();
inputs.push_back(c_ins); inputs.push_back(c_ins);
}
inputs.push_back(ins->inputs().back()); inputs.push_back(ins->inputs().back());
m.replace_instruction(ins, gemm, inputs); m.replace_instruction(ins, gemm, inputs);
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <unordered_map> #include <unordered_map>
#include <memory> #include <memory>
...@@ -215,6 +216,10 @@ struct context ...@@ -215,6 +216,10 @@ struct context
return *current_device; return *current_device;
} }
bool get_exhaustive_tune_flag() const { return exhaustive_tune; }
void set_exhaustive_tune_flag(bool t) { exhaustive_tune = t; }
hip_device::stream& get_stream() { return get_current_device().get_stream(); } hip_device::stream& get_stream() { return get_current_device().get_stream(); }
hip_device::stream& get_stream(std::size_t n) { return get_current_device().get_stream(n); } hip_device::stream& get_stream(std::size_t n) { return get_current_device().get_stream(n); }
...@@ -273,7 +278,8 @@ struct context ...@@ -273,7 +278,8 @@ struct context
auto v_streams = v.at("streams"); auto v_streams = v.at("streams");
std::size_t n_streams = v_streams.without_key().to<std::size_t>(); std::size_t n_streams = v_streams.without_key().to<std::size_t>();
this->current_device = std::make_shared<hip_device>(0, n_streams); auto device = get_device_id();
this->current_device = std::make_shared<hip_device>(device, n_streams);
} }
void wait_for(any_ptr queue) void wait_for(any_ptr queue)
...@@ -336,6 +342,7 @@ struct context ...@@ -336,6 +342,7 @@ struct context
// TODO: Make this a vector to support multiple devices // TODO: Make this a vector to support multiple devices
std::shared_ptr<hip_device> current_device; std::shared_ptr<hip_device> current_device;
std::vector<shared<hip_event_ptr>> events; std::vector<shared<hip_event_ptr>> events;
bool exhaustive_tune = false;
bool measure_perf = false; bool measure_perf = false;
// for event perf timing // for event perf timing
shared<hip_event_ptr> start_event = nullptr; shared<hip_event_ptr> start_event = nullptr;
......
...@@ -175,7 +175,8 @@ struct miopen_convolution ...@@ -175,7 +175,8 @@ struct miopen_convolution
auto* miopen_stream_handle = ctx.get_stream().get_miopen(); auto* miopen_stream_handle = ctx.get_stream().get_miopen();
solution_ptr = find_solution(miopen_stream_handle, conv_problem.get()); solution_ptr = find_solution(
miopen_stream_handle, conv_problem.get(), ctx.get_exhaustive_tune_flag());
auto status = miopenGetSolutionWorkspaceSize(solution_ptr.get(), &workspace_size); auto status = miopenGetSolutionWorkspaceSize(solution_ptr.get(), &workspace_size);
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen" + op.name() + " : failed to get solution's workspace size"); MIGRAPHX_THROW("MIOpen" + op.name() + " : failed to get solution's workspace size");
...@@ -233,7 +234,7 @@ struct miopen_convolution ...@@ -233,7 +234,7 @@ struct miopen_convolution
&perf, &perf,
workspace.implicit(), workspace.implicit(),
workspace_size, workspace_size,
false); ctx.get_exhaustive_tune_flag());
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen " + op.name() + " : find convolution failed"); MIGRAPHX_THROW("MIOpen " + op.name() + " : find convolution failed");
algo = perf.fwd_algo; algo = perf.fwd_algo;
......
...@@ -75,11 +75,18 @@ using miopen_find_options = MIGRAPHX_MANAGE_PTR(miopenFindOptions_t, miopenDestr ...@@ -75,11 +75,18 @@ using miopen_find_options = MIGRAPHX_MANAGE_PTR(miopenFindOptions_t, miopenDestr
using miopen_problem = MIGRAPHX_MANAGE_PTR(miopenProblem_t, miopenDestroyProblem); using miopen_problem = MIGRAPHX_MANAGE_PTR(miopenProblem_t, miopenDestroyProblem);
using miopen_solution = MIGRAPHX_MANAGE_PTR(miopenSolution_t, miopenDestroySolution); using miopen_solution = MIGRAPHX_MANAGE_PTR(miopenSolution_t, miopenDestroySolution);
inline miopen_solution find_solution(miopenHandle_t handle, miopenProblem_t problem) inline miopen_solution
find_solution(miopenHandle_t handle, miopenProblem_t problem, bool tune = false)
{ {
miopenSolution_t solution; miopenSolution_t solution;
size_t found = 0; size_t found = 0;
auto status = miopenFindSolutions(handle, problem, nullptr, &solution, &found, 1); miopen_find_options fo = nullptr;
if(tune)
{
fo = make_obj<miopen_find_options>(&miopenCreateFindOptions);
miopenSetFindOptionTuning(fo.get(), 1);
}
auto status = miopenFindSolutions(handle, problem, fo.get(), &solution, &found, 1);
auto result = miopen_solution{solution}; auto result = miopen_solution{solution};
if(status != miopenStatusSuccess or found == 0) if(status != miopenStatusSuccess or found == 0)
MIGRAPHX_THROW("MIOpen miopenFindSolutions failed"); MIGRAPHX_THROW("MIOpen miopenFindSolutions failed");
......
...@@ -30,14 +30,14 @@ ...@@ -30,14 +30,14 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct module; struct module_pass_manager;
namespace gpu { namespace gpu {
struct prefuse_ops struct prefuse_ops
{ {
std::string name() const { return "gpu::prefuse_ops"; } std::string name() const { return "gpu::prefuse_ops"; }
void apply(module& m) const; void apply(module_pass_manager& mpm) const;
}; };
} // namespace gpu } // namespace gpu
......
...@@ -118,17 +118,17 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -118,17 +118,17 @@ struct reduce_compiler : compiler<reduce_compiler>
options.virtual_inputs = reduce_dims(inputs); options.virtual_inputs = reduce_dims(inputs);
auto faxis = find_fast_axis({options.virtual_inputs.front()}); auto faxis = find_fast_axis({options.virtual_inputs.front()});
vectorize vec{}; vectorize vec{};
// Vectorize if the axis is a reduction axis
if(options.virtual_inputs.back().lens()[faxis] == 1)
{
vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
}
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto nelements = options.virtual_inputs.back().elements(); auto nelements = options.virtual_inputs.back().elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs)); auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs));
if(algo == "block") if(algo == "block")
{ {
// Vectorize if the axis is a reduction axis
if(options.virtual_inputs.back().lens()[faxis] == 1)
vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto block_size = compute_block_size(relements, 256); auto block_size = compute_block_size(relements, 256);
if(relements > block_size * 256)
algo = "block_large";
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size); v, compute_global_for(ctx, nelements * block_size, 256), block_size);
} }
...@@ -157,15 +157,24 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -157,15 +157,24 @@ struct reduce_compiler : compiler<reduce_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{ {
value v = value::object{}; value v = value::object{};
auto reduce_elements = get_reduce_elements(ins->inputs());
if(op.name() == "reduce_sum") if(op.name() == "reduce_sum")
{ {
v["reduction"] = "op::sum{}"; v["reduction"] = "op::sum{}";
} }
else if(op.name() == "reduce_mean") else if(op.name() == "reduce_mean")
{ {
auto reduce_elements = get_reduce_elements(ins->inputs());
auto reduce_type = ins->inputs().front()->get_shape().type();
v["reduction"] = "op::sum{}"; v["reduction"] = "op::sum{}";
v["write"] = "op::mean{" + std::to_string(reduce_elements) + "}"; std::string mean = "op::mean<" + std::to_string(reduce_elements) + ">{}";
// Use float accumulator when reduction size is too large for half
if(reduce_type == shape::half_type and reduce_elements > 16384)
v["read"] = "compose(" + mean + ", op::convert_to<float>{})";
else if(contains({shape::float_type, shape::half_type, shape::double_type},
reduce_type))
v["read"] = mean;
else
v["write"] = mean;
} }
else if(op.name() == "reduce_max") else if(op.name() == "reduce_max")
{ {
......
...@@ -105,7 +105,7 @@ constexpr auto array_for_each(T& x, Ts&... xs) ...@@ -105,7 +105,7 @@ constexpr auto array_for_each(T& x, Ts&... xs)
} }
else else
{ {
using vec_type = std::remove_reference_t<decltype(array2vec(x))>; using vec_type = remove_reference_t<decltype(array2vec(x))>;
f(array2vec(x), __builtin_convertvector(array2vec(xs), vec_type)...); f(array2vec(x), __builtin_convertvector(array2vec(xs), vec_type)...);
} }
} }
......
...@@ -178,5 +178,9 @@ MIGRAPHX_HIP_NORETURN inline __host__ __device__ void assert_fail(const source_l ...@@ -178,5 +178,9 @@ MIGRAPHX_HIP_NORETURN inline __host__ __device__ void assert_fail(const source_l
#define MIGRAPHX_WARN(...) #define MIGRAPHX_WARN(...)
#endif #endif
#define MIGRAPHX_STATIC_ASSERT_FOR(...) \
static_assert(__VA_ARGS__); \
if constexpr(__VA_ARGS__)
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_DEBUG_HPP #endif // MIGRAPHX_GUARD_KERNELS_DEBUG_HPP
...@@ -72,7 +72,7 @@ __device__ T dpp_mov(T& x) ...@@ -72,7 +72,7 @@ __device__ T dpp_mov(T& x)
} }
return output.data; return output.data;
} }
#endif #endif // MIGRAPHX_HAS_DPP
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_DPP_HPP #endif // MIGRAPHX_GUARD_KERNELS_DPP_HPP
...@@ -187,6 +187,14 @@ constexpr auto fold(F f) ...@@ -187,6 +187,14 @@ constexpr auto fold(F f)
return [=](auto&&... xs) { return fold_impl(f, static_cast<decltype(xs)&&>(xs)...); }; return [=](auto&&... xs) { return fold_impl(f, static_cast<decltype(xs)&&>(xs)...); };
} }
template <class... Fs>
constexpr auto compose(Fs... fs)
{
return fold([](auto f, auto g) {
return [=](auto&&... xs) { return f(g(static_cast<decltype(xs)>(xs)...)); };
})(fs...);
}
template <class... Ts> template <class... Ts>
constexpr auto pack(Ts... xs) constexpr auto pack(Ts... xs)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment