Commit 673ca71c authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into conv-add

parents 4bfe1662 91cc7242
......@@ -1092,11 +1092,23 @@ struct find_split_reshape
return;
}
// Only want to apply this optimization if each split output is followed by
// a contiguous op and a reshape
if(std::any_of(split_outputs.begin(), split_outputs.end(), [](auto i) {
if(i->outputs().size() == 1)
{
auto cont = i->outputs().front();
return cont->outputs().size() != 1;
}
return false;
}))
{
return;
}
std::vector<instruction_ref> vec_rsp(split_outputs.size());
std::transform(split_outputs.begin(), split_outputs.end(), vec_rsp.begin(), [](auto i) {
assert(i->outputs().size() == 1);
auto cont = i->outputs().front();
assert(cont->outputs().size() == 1);
return cont->outputs().front();
});
......
......@@ -763,16 +763,23 @@ struct find_transpose_slice
// Compute axis before transpose to use for unsqueeze
auto perm = ins->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto preaxis = std::find(perm.begin(), perm.end(), axis) - perm.begin();
// Make unsqeeze
// Make unsqueeze
std::vector<int64_t> steps(sdistance.size());
std::transform(
slice.axes.begin(),
slice.axes.end(),
sdistance.begin(),
steps.begin(),
[&](const auto ax, const auto sdis) { return ins->get_shape().lens().at(ax) / sdis; });
auto unsqueeze = m.insert_instruction(
ins, make_op("unsqueeze", {{"axes", {preaxis}}, {"steps", sdistance}}), ins->inputs());
ins, make_op("unsqueeze", {{"axes", {preaxis}}, {"steps", steps}}), ins->inputs());
// Make transpose
std::transform(perm.begin(), perm.end(), perm.begin(), [&](auto i) {
if(i > preaxis)
if(i >= preaxis)
return i + 1;
return i;
});
perm.insert(perm.begin(), preaxis + 1);
perm.insert(perm.begin(), preaxis);
auto transpose =
m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), unsqueeze);
// Slice and squeeze
......
#####################################################################################
# ####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
......@@ -20,7 +20,7 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
# ####################################################################################
list(APPEND CMAKE_PREFIX_PATH /opt/rocm /opt/rocm/hip /opt/rocm/hcc)
find_package(miopen)
......@@ -33,6 +33,8 @@ if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen")
endif()
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipRTC APIs")
include(Embed)
file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS}
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp)
......@@ -46,9 +48,10 @@ add_library(compile_for_gpu INTERFACE)
target_compile_options(compile_for_gpu INTERFACE -std=c++17 -fno-gpu-rdc -Wno-cuda-compat -Wno-unused-command-line-argument -Xclang -fallow-half-arguments-and-returns)
target_link_libraries(compile_for_gpu INTERFACE hip::device -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument -Wno-option-ignored)
check_cxx_compiler_flag("--cuda-host-only -fhip-lambda-host-device -x hip" HAS_HIP_LAMBDA_HOST_DEVICE)
if(HAS_HIP_LAMBDA_HOST_DEVICE)
message(STATUS "Enable -fhip-lambda-host-device")
target_compile_options(compile_for_gpu INTERFACE -fhip-lambda-host-device)
message(STATUS "Enable -fhip-lambda-host-device")
target_compile_options(compile_for_gpu INTERFACE -fhip-lambda-host-device)
endif()
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
......@@ -60,11 +63,13 @@ target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURR
target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>)
add_library(kernel_file_check EXCLUDE_FROM_ALL)
foreach(KERNEL_FILE ${KERNEL_FILES})
get_filename_component(KERNEL_BASE_FILE ${KERNEL_FILE} NAME_WE)
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/kernels/include/migraphx/kernels/${KERNEL_BASE_FILE}.cpp "#include <migraphx/kernels/${KERNEL_BASE_FILE}.hpp>\n")
target_sources(kernel_file_check PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/kernels/include/migraphx/kernels/${KERNEL_BASE_FILE}.cpp)
endforeach()
target_compile_definitions(kernel_file_check PRIVATE -DMIGRAPHX_NLOCAL=256)
target_include_directories(kernel_file_check PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/>)
target_link_libraries(kernel_file_check compile_for_gpu)
......@@ -125,6 +130,7 @@ function(register_migraphx_gpu_ops PREFIX)
register_op(migraphx_gpu HEADER migraphx/gpu/${OP}.hpp OPERATORS gpu::${PREFIX}${OP} INCLUDES migraphx/gpu/context.hpp)
endforeach()
endfunction()
register_migraphx_gpu_ops(hip_
argmax
argmin
......@@ -146,47 +152,41 @@ register_migraphx_gpu_ops(miopen_
lrn
pooling
)
register_op(migraphx_gpu
HEADER migraphx/gpu/rnn_variable_seq_lens.hpp
register_op(migraphx_gpu
HEADER migraphx/gpu/rnn_variable_seq_lens.hpp
OPERATORS gpu::hip_rnn_var_sl_shift_sequence gpu::hip_rnn_var_sl_shift_output gpu::hip_rnn_var_sl_last_output
INCLUDES migraphx/gpu/context.hpp)
register_op(migraphx_gpu
HEADER migraphx/gpu/int8_gemm_pack.hpp
register_op(migraphx_gpu
HEADER migraphx/gpu/int8_gemm_pack.hpp
OPERATORS gpu::hip_int8_gemm_pack_a gpu::hip_int8_gemm_pack_b
INCLUDES migraphx/gpu/context.hpp)
register_op(migraphx_gpu
HEADER migraphx/gpu/gemm.hpp
register_op(migraphx_gpu
HEADER migraphx/gpu/gemm.hpp
OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot>
INCLUDES migraphx/gpu/context.hpp)
register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp
register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp
OPERATORS gpu::miopen_convolution<op::convolution> gpu::miopen_convolution<op::deconvolution> gpu::miopen_convolution<op::quant_convolution>
INCLUDES migraphx/gpu/context.hpp)
rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_gpu)
# look for offload bundler
get_filename_component(CMAKE_CXX_COMPILER_PATH "${CMAKE_CXX_COMPILER}" PATH)
if(CMAKE_CXX_COMPILER MATCHES ".*clang\\+\\+$")
find_program(MIGRAPHX_OFFLOADBUNDLER_BIN clang-offload-bundler
HINTS ${CMAKE_CXX_COMPILER_PATH}
PATH_SUFFIXES bin
PATHS /opt/rocm/llvm
)
else()
if(NOT CMAKE_CXX_COMPILER MATCHES ".*clang\\+\\+$")
find_program(MIGRAPHX_EXTRACT_KERNEL extractkernel
PATH_SUFFIXES bin
HINTS ${CMAKE_CXX_COMPILER_PATH}
PATHS
/opt/rocm/hip
/opt/rocm/hcc
/opt/rocm
/opt/rocm/hip
/opt/rocm/hcc
/opt/rocm
)
endif()
message(STATUS "clang-offload-bundler: ${MIGRAPHX_OFFLOADBUNDLER_BIN}")
message(STATUS "extractkernel: ${MIGRAPHX_EXTRACT_KERNEL}")
set(MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL "")
if(MIGRAPHX_ENABLE_MLIR)
# Find package rocMLIR
find_package(rocMLIR 1.0.0 CONFIG REQUIRED)
......@@ -195,36 +195,39 @@ if(MIGRAPHX_ENABLE_MLIR)
target_link_libraries(migraphx_gpu PUBLIC rocMLIR::rockCompiler)
endif()
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "")
if(MIGRAPHX_USE_HIPRTC)
target_compile_definitions(migraphx_gpu PRIVATE -DMIGRAPHX_USE_HIPRTC=1)
message(STATUS "MIGraphX is using hipRTC")
target_compile_definitions(migraphx_gpu PRIVATE -DMIGRAPHX_USE_HIPRTC=1)
else()
# Get flags needed to compile hip
include(TargetFlags)
target_flags(HIP_COMPILER_FLAGS hip::device)
# Remove cuda arch flags
string(REGEX REPLACE --cuda-gpu-arch=[a-z0-9]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
string(REGEX REPLACE --offload-arch=[a-z0-9:+-]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
# Skip library paths since hip will incorrectly treat it as a source file
string(APPEND HIP_COMPILER_FLAGS " ")
foreach(_unused RANGE 2)
string(REGEX REPLACE " /[^ ]+\\.(a|so) " " " HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
endforeach()
message(STATUS "MIGraphX is using HIP Clang")
message(STATUS "Hip compiler flags: ${HIP_COMPILER_FLAGS}")
target_compile_definitions(migraphx_gpu PRIVATE
"-DMIGRAPHX_HIP_COMPILER=${CMAKE_CXX_COMPILER}"
"-DMIGRAPHX_HIP_COMPILER_FLAGS=${HIP_COMPILER_FLAGS}"
"-DMIGRAPHX_OFFLOADBUNDLER_BIN=${MIGRAPHX_OFFLOADBUNDLER_BIN}"
"-DMIGRAPHX_EXTRACT_KERNEL=${MIGRAPHX_EXTRACT_KERNEL}"
"-DMIGRAPHX_USE_HIPRTC=0"
)
if(DEFINED CMAKE_CXX_COMPILER_LAUNCHER)
execute_process(COMMAND which ${CMAKE_CXX_COMPILER_LAUNCHER} OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER)
string(STRIP "${MIGRAPHX_HIP_COMPILER_LAUNCHER}" MIGRAPHX_HIP_COMPILER_LAUNCHER)
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_HIP_COMPILER_LAUNCHER=${MIGRAPHX_HIP_COMPILER_LAUNCHER}")
endif()
# Get flags needed to compile hip
include(TargetFlags)
target_flags(HIP_COMPILER_FLAGS hip::device)
# Remove cuda arch flags
string(REGEX REPLACE --cuda-gpu-arch=[a-z0-9]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
string(REGEX REPLACE --offload-arch=[a-z0-9:+-]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
# Skip library paths since hip will incorrectly treat it as a source file
string(APPEND HIP_COMPILER_FLAGS " ")
foreach(_unused RANGE 2)
string(REGEX REPLACE " /[^ ]+\\.(a|so) " " " HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
endforeach()
message(STATUS "Hip compiler flags: ${HIP_COMPILER_FLAGS}")
target_compile_definitions(migraphx_gpu PRIVATE
"-DMIGRAPHX_HIP_COMPILER=${CMAKE_CXX_COMPILER}"
"-DMIGRAPHX_HIP_COMPILER_FLAGS=${HIP_COMPILER_FLAGS}"
"-DMIGRAPHX_EXTRACT_KERNEL=${MIGRAPHX_EXTRACT_KERNEL}"
)
if(DEFINED CMAKE_CXX_COMPILER_LAUNCHER)
execute_process(COMMAND which ${CMAKE_CXX_COMPILER_LAUNCHER} OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER)
string(STRIP "${MIGRAPHX_HIP_COMPILER_LAUNCHER}" MIGRAPHX_HIP_COMPILER_LAUNCHER)
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_HIP_COMPILER_LAUNCHER=${MIGRAPHX_HIP_COMPILER_LAUNCHER}")
endif()
endif()
# Check miopen find mode api
......@@ -236,7 +239,7 @@ check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_
# TODO: Set default to HAS_FIND_2_API
set(MIGRAPHX_USE_FIND_2_API OFF CACHE BOOL "")
if(MIGRAPHX_USE_FIND_2_API)
if(MIGRAPHX_USE_FIND_2_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API)
message(STATUS "MIGraphx is using Find-2.0 API of MIOpen")
else()
......@@ -258,8 +261,7 @@ target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
add_subdirectory(driver)
rocm_install_targets(
TARGETS migraphx_gpu migraphx_device compile_for_gpu
INCLUDE
TARGETS migraphx_gpu migraphx_device compile_for_gpu
INCLUDE
${CMAKE_CURRENT_SOURCE_DIR}/include
)
......@@ -29,10 +29,9 @@
#include <cassert>
#include <iostream>
#if MIGRAPHX_USE_HIPRTC
#ifdef MIGRAPHX_USE_HIPRTC
#include <hip/hiprtc.h>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/env.hpp>
#else
#include <migraphx/compile_src.hpp>
#include <migraphx/process.hpp>
......@@ -48,9 +47,10 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_OPTIMIZE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_ASM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC);
#if MIGRAPHX_USE_HIPRTC
#ifdef MIGRAPHX_USE_HIPRTC
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_HIPRTC)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_HIPRTC);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS);
std::string hiprtc_error(hiprtcResult err, const std::string& msg)
{
......@@ -143,25 +143,29 @@ struct hiprtc_program
options.end(),
std::back_inserter(c_options),
[](const std::string& s) { return s.c_str(); });
auto result = hiprtcCompileProgram(prog.get(), c_options.size(), c_options.data());
std::cerr << log() << std::endl;
auto result = hiprtcCompileProgram(prog.get(), c_options.size(), c_options.data());
auto prog_log = log();
if(not prog_log.empty())
{
std::cerr << prog_log << std::endl;
}
if(result != HIPRTC_SUCCESS)
MIGRAPHX_HIPRTC_THROW(result, "Compilation failed.");
}
std::string log()
std::string log() const
{
std::size_t n = 0;
MIGRAPHX_HIPRTC(hiprtcGetProgramLogSize(prog.get(), &n));
if(n < 2)
if(n == 0)
return {};
std::vector<char> buffer(n);
std::string buffer(n, '\0');
MIGRAPHX_HIPRTC(hiprtcGetProgramLog(prog.get(), buffer.data()));
assert(buffer.back() == 0);
return {buffer.begin(), buffer.end() - 1};
assert(buffer.back() != 0);
return buffer;
}
std::vector<char> get_code_obj()
std::vector<char> get_code_obj() const
{
std::size_t n = 0;
MIGRAPHX_HIPRTC(hiprtcGetCodeSize(prog.get(), &n));
......@@ -176,6 +180,17 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
{
hiprtc_program prog(srcs);
auto options = split_string(params, ' ');
options.push_back("-DMIGRAPHX_USE_HIPRTC=1");
// remove following three compilation flags for HIPRTC once fixes from hipRTC are available in
if(enabled(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS{}))
{
options.push_back("-DMIGRAPHX_HAS_DPP=0");
options.push_back("-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1");
options.push_back("-Wno-reserved-identifier");
options.push_back("-Wno-gnu-line-marker");
options.push_back("-Wno-old-style-cast");
}
if(enabled(MIGRAPHX_GPU_DEBUG{}))
options.push_back("-DMIGRAPHX_DEBUG");
if(std::none_of(options.begin(), options.end(), [](const std::string& s) {
......@@ -183,7 +198,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
}))
options.push_back("-std=c++17");
options.push_back("-fno-gpu-rdc");
options.push_back(" -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3"));
options.push_back("-O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3"));
options.push_back("-Wno-cuda-compat");
options.push_back("--offload-arch=" + arch);
prog.compile(options);
......@@ -292,6 +307,8 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
return {compiler.compile(srcs)};
}
#endif // MIGRAPHX_USE_HIPRTC
std::string enum_params(std::size_t count, std::string param)
{
std::vector<std::string> items(count);
......@@ -299,8 +316,6 @@ std::string enum_params(std::size_t count, std::string param)
return join_strings(items, ",");
}
#endif // MIGRAPHX_USE_HIPRTC
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -29,7 +29,6 @@
#include <migraphx/context.hpp>
#include <migraphx_kernels.hpp>
#include <migraphx/stringutils.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -80,6 +79,7 @@ std::string generate_args_hpp(const std::vector<shape>& inputs)
#include <migraphx/kernels/args.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/types.hpp>
namespace migraphx {
......
......@@ -36,6 +36,7 @@ namespace gpu {
namespace device {
#ifdef MIGRAPHX_NO_DPP
template <index_int N,
class Op,
class T,
......@@ -62,6 +63,7 @@ __device__ auto block_reduce(index idx, Op op, T init, ForStride fs, F f)
}
return buffer[0];
}
#else
constexpr unsigned int dpp_row_shr(unsigned int x) { return 0x110u | x; }
......@@ -96,11 +98,7 @@ __device__ T dpp_mov(T& x)
input.data = x;
for(index_int i = 0; i < n; i++)
{
#if defined(__HCC__)
output.reg[i] = __llvm_amdgcn_move_dpp(input.reg[i], DppCtrl, RowMask, BankMask, BoundCtrl);
#else
output.reg[i] = __hip_move_dpp(input.reg[i], DppCtrl, RowMask, BankMask, BoundCtrl);
#endif
}
return output.data;
}
......@@ -310,4 +308,4 @@ void reduce(hipStream_t stream,
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#endif // MIGRAPHX_NO_DPP
......@@ -553,11 +553,13 @@ struct find_gemm_pointwise
{
auto matcher() const
{
return precompile_name("pointwise")(
auto gemm_op = match::name("gpu::gemm")(match::nargs(3), match::used_once()).bind("gemm");
auto binary_op = match::all_of(
match::nargs(3),
match::either_arg(0, 1)(
match::any_of(match::standard_shape(), match::is_constant()).bind("c"),
match::name("gpu::gemm")(match::nargs(3), match::used_once()).bind("gemm")));
match::any_of(match::standard_shape(), match::is_constant()).bind("c"), gemm_op));
auto unary_op = match::all_of(match::nargs(2), match::arg(0)(gemm_op));
return precompile_name("pointwise")(match::any_of(binary_op, unary_op));
}
// TODO: Move to matcher.hpp
......@@ -589,61 +591,84 @@ struct find_gemm_pointwise
return match::name("@return")(match::args(match::any_of(add, mul_add, add_mul)));
}
static auto match_mul(const std::string& input)
{
auto mul = match_mul_const(match_param(input), "alpha");
return match::name("@return")(match::args(mul));
}
static float get_float(instruction_ref ins) { return ins->get_literal().at<float>(); }
template <class Gemm>
static bool update_gemm(Gemm& gemm, module_ref pm, unsigned input)
{
auto names = pm->get_parameter_names();
if(names.size() != 2)
return false;
std::sort(names.begin(), names.end());
unsigned output = input == 0 ? 1 : 0;
auto mr = match::match_instruction(
*pm, std::prev(pm->end()), match_add(names[input], names[output]));
if(mr.result == pm->end())
return false;
if(contains(mr.instructions, "alpha_mul"))
if(names.size() == 1)
{
auto mr = match::match_instruction(*pm, std::prev(pm->end()), match_mul(names[input]));
if(mr.result == pm->end())
return false;
gemm.alpha *= get_float(mr.instructions["alpha"]);
else if(contains(mr.instructions, "beta_mul"))
gemm.beta *= get_float(mr.instructions["beta"]);
else if(contains(mr.instructions, "gamma_mul"))
return true;
}
else if(names.size() == 2)
{
gemm.alpha *= get_float(mr.instructions["gamma"]);
gemm.beta *= get_float(mr.instructions["gamma"]);
unsigned output = input == 0 ? 1 : 0;
auto mr = match::match_instruction(
*pm, std::prev(pm->end()), match_add(names[input], names[output]));
if(mr.result == pm->end())
return false;
if(contains(mr.instructions, "alpha_mul"))
gemm.alpha *= get_float(mr.instructions["alpha"]);
else if(contains(mr.instructions, "beta_mul"))
gemm.beta *= get_float(mr.instructions["beta"]);
else if(contains(mr.instructions, "gamma_mul"))
{
gemm.alpha *= get_float(mr.instructions["gamma"]);
gemm.beta *= get_float(mr.instructions["gamma"]);
}
return true;
}
else
{
return false;
}
return true;
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm_ins = r.instructions["gemm"];
auto c_ins = r.instructions["c"];
auto gemm = any_cast<rocblas_gemm<op::dot>>(gemm_ins->get_operator());
// Already fused gemm
if(not float_equal(gemm.beta, 0))
return;
gemm.beta = 1;
if(ins->inputs().size() == 3)
gemm.beta = 1;
if(not update_gemm(
gemm, ins->module_inputs().front(), ins->inputs().front() == gemm_ins ? 0 : 1))
return;
// const-fold input if not standard shape since rocblas can't handle it
if(not c_ins->get_shape().standard())
{
auto c = make_op("contiguous");
auto l = c.compute(c.compute_shape({c_ins->get_shape()}), {c_ins->eval()});
c_ins = m.add_literal(l.get_shape(), l.data());
}
auto inputs = gemm_ins->inputs();
inputs.pop_back();
inputs.push_back(c_ins);
if(ins->inputs().size() == 3)
{
auto c_ins = r.instructions["c"];
// const-fold input if not standard shape since rocblas can't handle it
if(not c_ins->get_shape().standard())
{
auto c = make_op("contiguous");
auto l = c.compute(c.compute_shape({c_ins->get_shape()}), {c_ins->eval()});
c_ins = m.add_literal(l.get_shape(), l.data());
}
inputs.push_back(c_ins);
}
inputs.push_back(ins->inputs().back());
m.replace_instruction(ins, gemm, inputs);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
// NOLINTNEXTLINE
static const char* const gather_kernel = R"__migraphx__(
#include <migraphx/kernels/gather.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void gather_kernel(void* in_data, void* in_indices, void* output)
{
make_tensors()(in_data, in_indices, output)([](auto&&... xs) {
gather<${axis}>(xs...);
});
}
}
} // namespace migraphx
)__migraphx__";
struct gather_compiler : compiler<gather_compiler>
{
std::vector<std::string> names() const { return {"gather"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
const auto& out_s = inputs.back();
options.set_launch_params(v, compute_global_for(ctx, out_s.elements()));
options.inputs = inputs;
options.output = out_s;
options.kernel_name = "gather_kernel";
options.virtual_inputs = inputs;
auto axis = v.at("axis").to<std::string>();
auto src = interpolate_string(gather_kernel, {{"axis", axis}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -156,16 +156,25 @@ struct reduce_compiler : compiler<reduce_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
value v = value::object{};
auto reduce_elements = get_reduce_elements(ins->inputs());
value v = value::object{};
if(op.name() == "reduce_sum")
{
v["reduction"] = "op::sum{}";
}
else if(op.name() == "reduce_mean")
{
v["reduction"] = "op::sum{}";
v["write"] = "op::mean{" + std::to_string(reduce_elements) + "}";
auto reduce_elements = get_reduce_elements(ins->inputs());
auto reduce_type = ins->inputs().front()->get_shape().type();
v["reduction"] = "op::sum{}";
std::string mean = "op::mean{" + std::to_string(reduce_elements) + "}";
// Use float accumulator when reduction size is too large for half
if(reduce_type == shape::half_type and reduce_elements > 16384)
v["read"] = "compose(" + mean + ", op::convert_to<float>{})";
else if(contains({shape::float_type, shape::half_type, shape::double_type},
reduce_type))
v["read"] = mean;
else
v["write"] = mean;
}
else if(op.name() == "reduce_max")
{
......
......@@ -105,7 +105,7 @@ constexpr auto array_for_each(T& x, Ts&... xs)
}
else
{
using vec_type = std::remove_reference_t<decltype(array2vec(x))>;
using vec_type = remove_reference_t<decltype(array2vec(x))>;
f(array2vec(x), __builtin_convertvector(array2vec(xs), vec_type)...);
}
}
......
......@@ -72,7 +72,7 @@ __device__ T dpp_mov(T& x)
}
return output.data;
}
#endif
#endif // MIGRAPHX_HAS_DPP
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_DPP_HPP
......@@ -187,6 +187,14 @@ constexpr auto fold(F f)
return [=](auto&&... xs) { return fold_impl(f, static_cast<decltype(xs)&&>(xs)...); };
}
template <class... Fs>
constexpr auto compose(Fs... fs)
{
return fold([](auto f, auto g) {
return [=](auto&&... xs) { return f(g(static_cast<decltype(xs)>(xs)...)); };
})(fs...);
}
template <class... Ts>
constexpr auto pack(Ts... xs)
{
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_GATHER_HPP
#define MIGRAPHX_GUARD_KERNELS_GATHER_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/shape.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/tensor_view.hpp>
namespace migraphx {
template <int Axis, class Input, class Indices>
constexpr auto gather_shape(Input input, Indices indices)
{
auto lengths = input.lens;
lengths[Axis] = indices.elements();
return make_shape(lengths, input.strides);
}
template <int Axis, class Input, class Indices, class Output>
__device__ void gather(Input input, Indices indices, Output output)
{
auto ind = make_index();
auto axis_dim_size = input.get_shape().lens[Axis];
constexpr auto out_comp = gather_shape<Axis>(get_shape_c<Input>{}, get_shape_c<Indices>{});
ind.global_stride(output.get_shape().elements(), [&](auto i) {
auto idx = out_comp.multi(i);
auto in_index = indices[idx[Axis]];
auto new_in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
idx[Axis] = new_in_index;
output[i] = input[idx];
});
}
} // namespace migraphx
#endif
......@@ -26,7 +26,7 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/ops.hpp>
namespace migraphx {
template <class T>
......@@ -53,23 +53,17 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t,
auto indices_shape_lens = indices_shape.lens;
auto data_shape_lens = data_shape.lens;
auto num_slice_dims = indices_shape_lens.back();
std::size_t num_slices = accumulate(indices_shape_lens.begin(),
indices_shape_lens.end() - 1,
1,
std::multiplies<std::size_t>());
std::size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims,
std::size_t num_slices =
accumulate(indices_shape_lens.begin(), indices_shape_lens.end() - 1, 1, op::product{});
std::size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
const std::size_t num_batches = accumulate(data_shape_lens.begin(),
data_shape_lens.begin() + batch_dims,
1,
std::multiplies<std::size_t>());
const std::size_t data_batch_stride = accumulate(data_shape_lens.begin() + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
const auto num_slices_per_batch = num_slices / num_batches;
op::product{});
const std::size_t num_batches =
accumulate(data_shape_lens.begin(), data_shape_lens.begin() + batch_dims, 1, op::product{});
const std::size_t data_batch_stride =
accumulate(data_shape_lens.begin() + batch_dims, data_shape_lens.end(), 1, op::product{});
const auto num_slices_per_batch = num_slices / num_batches;
ind.global_stride(output_shape.elements(), [&](auto i) {
const auto* indices_ptr = indices_t.data();
......@@ -83,15 +77,15 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t,
int64_t index = slice_indices[idx];
const std::size_t input_dim_idx = batch_dims + idx;
const auto input_dim = data_shape_lens[input_dim_idx];
assert(index >= -static_cast<int64_t>(input_dim) and
index < static_cast<int64_t>(input_dim));
MIGRAPHX_ASSERT(index >= -static_cast<int64_t>(input_dim) and
index < static_cast<int64_t>(input_dim));
if(index < 0)
index += input_dim;
std::size_t size_from_slice_dims =
accumulate(data_shape_lens.begin() + batch_dims + idx + 1,
data_shape_lens.begin() + batch_dims + num_slice_dims,
slice_size,
std::multiplies<std::size_t>());
op::product{});
relative_slice_offset += index * size_from_slice_dims;
}
......
......@@ -24,11 +24,18 @@
#ifndef MIGRAPHX_GUARD_KERNELS_HIP_HPP
#define MIGRAPHX_GUARD_KERNELS_HIP_HPP
#ifndef MIGRAPHX_USE_HIPRTC
// Workaround macro redefinition issue with clang tidy
#if defined(__HIP_PLATFORM_HCC__) && defined(MIGRAPHX_USE_CLANG_TIDY)
#undef __HIP_PLATFORM_HCC__ // NOLINT
#endif
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/math_functions.h>
#include <hip/hip_math_constants.h>
#elif defined(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS)
#include <hip/hip_common.h>
#include <hip/hip_math_constants.h>
#endif
#endif // MIGRAPHX_GUARD_KERNELS_HIP_HPP
......@@ -28,8 +28,7 @@
#include <migraphx/kernels/vec.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <hip/hip_fp16.h>
#include <hip/math_functions.h>
#include <migraphx/kernels/hip.hpp>
namespace migraphx {
......@@ -132,9 +131,14 @@ MIGRAPHX_DEVICE_MATH_FOR(float, fmod, ::fmodf)
// Builtin half functions
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, abs, ::__habs)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, ceil, ::hceil)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, cos, ::hcos)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, exp, ::hexp)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, floor, ::hfloor)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, isnan, ::__hisnan)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, log, ::hlog)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, rsqrt, ::hrsqrt)
// MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sin, ::hsin)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sqrt, ::hsqrt)
// Use float to compute half overload
......@@ -144,16 +148,11 @@ MIGRAPHX_DEVICE_MATH_HALF(asin, ::asin)
MIGRAPHX_DEVICE_MATH_HALF(asinh, ::asinh)
MIGRAPHX_DEVICE_MATH_HALF(atan, ::atan)
MIGRAPHX_DEVICE_MATH_HALF(atanh, ::atanh)
MIGRAPHX_DEVICE_MATH_HALF(ceil, ::ceil)
MIGRAPHX_DEVICE_MATH_HALF(cos, ::cos)
MIGRAPHX_DEVICE_MATH_HALF(cosh, ::cosh)
MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf)
MIGRAPHX_DEVICE_MATH_HALF(floor, ::floor)
MIGRAPHX_DEVICE_MATH_HALF(isnan, ::isnan)
MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow)
MIGRAPHX_DEVICE_MATH_HALF(remainder, ::remainder)
MIGRAPHX_DEVICE_MATH_HALF(round, ::round)
MIGRAPHX_DEVICE_MATH_HALF(sin, ::sin)
MIGRAPHX_DEVICE_MATH_HALF(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan)
MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh)
......@@ -166,19 +165,19 @@ MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod)
// at this time are: exp2, exp10, log2, log10, isinf
MIGRAPHX_DEVICE_MATH_HALF2(abs, ::__habs2)
MIGRAPHX_DEVICE_MATH_HALF2(ceil, ::h2ceil)
MIGRAPHX_DEVICE_MATH_HALF2(floor, ::h2floor)
MIGRAPHX_DEVICE_MATH_HALF2(sin, ::h2sin)
MIGRAPHX_DEVICE_MATH_HALF2(cos, ::h2cos)
MIGRAPHX_DEVICE_MATH_HALF2(exp, ::h2exp)
MIGRAPHX_DEVICE_MATH_HALF2(exp2, ::h2exp2)
MIGRAPHX_DEVICE_MATH_HALF2(exp10, ::h2exp10)
MIGRAPHX_DEVICE_MATH_HALF2(log2, ::h2log2)
MIGRAPHX_DEVICE_MATH_HALF2(exp2, ::h2exp2)
MIGRAPHX_DEVICE_MATH_HALF2(floor, ::h2floor)
MIGRAPHX_DEVICE_MATH_HALF2(isinf, ::__hisinf2)
MIGRAPHX_DEVICE_MATH_HALF2(isnan, ::__hisnan2)
MIGRAPHX_DEVICE_MATH_HALF2(log, ::h2log)
MIGRAPHX_DEVICE_MATH_HALF2(log10, ::h2log10)
MIGRAPHX_DEVICE_MATH_HALF2(log2, ::h2log2)
MIGRAPHX_DEVICE_MATH_HALF2(rsqrt, ::h2rsqrt)
// MIGRAPHX_DEVICE_MATH_HALF2(sin, ::h2sin)
MIGRAPHX_DEVICE_MATH_HALF2(sqrt, ::h2sqrt)
MIGRAPHX_DEVICE_MATH_HALF2(isinf, ::__hisinf2)
MIGRAPHX_DEVICE_MATH_HALF2(isnan, ::__hisnan2)
template <class T, class U>
constexpr auto where(bool cond, const T& a, const U& b)
......@@ -218,6 +217,14 @@ constexpr auto min(const T& a, const U& b)
return min<common_type_t<T, U>>(a, b);
}
// Sin for half is broken on hip, so use cos instead
template <class T, MIGRAPHX_REQUIRES(is_same<vec_type<T>, half>{})>
constexpr T sin(T x)
{
constexpr const T shift = HIP_PIO2_F;
return migraphx::cos(shift - x);
}
MIGRAPHX_DEVICE_MATH_VEC(abs)
MIGRAPHX_DEVICE_MATH_VEC(acos)
MIGRAPHX_DEVICE_MATH_VEC(acosh)
......
......@@ -56,6 +56,16 @@ struct id
}
};
template <class T>
struct convert_to
{
template <class U>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(U x) const
{
return convert<T>(x);
}
};
struct mean
{
index_int item_num = 1;
......
......@@ -76,14 +76,6 @@ struct shape
constexpr index_int index(index_array x) const { return x.dot(strides); }
constexpr index_int index(std::initializer_list<index_int> x) const
{
index_int idx = 0;
for(index_int i = 0; i < x.size(); i++)
idx += *(x.begin() + i) * strides[i];
return idx;
}
constexpr index_int index(index_int i) const
{
if(this->standard())
......@@ -128,6 +120,7 @@ struct shape
result[0] = tidx;
return result;
}
/// Convert multi-index into a single index
constexpr index_int single(index_array idx) const
{
......
......@@ -28,8 +28,45 @@
namespace migraphx {
using index_int = std::uint32_t;
using diff_int = std::int32_t;
#if defined(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS) and defined(MIGRAPHX_USE_HIPRTC)
using int8_t = signed char;
using uint8_t = unsigned char;
using int16_t = signed short;
using uint16_t = unsigned short;
using int32_t = signed int;
using uint32_t = unsigned int;
using int64_t = signed long long;
using uint64_t = unsigned long long;
#elif defined(MIGRAPHX_USE_HIPRTC)
using int8_t = __hip_int8_t;
using uint8_t = __hip_uint8_t;
using int16_t = __hip_int16_t;
using uint16_t = __hip_uint16_t;
using int32_t = __hip_int32_t;
using uint32_t = __hip_uint32_t;
using int64_t = __hip_int64_t;
using uint64_t = __hip_uint64_t;
#else
using int8_t = std::int8_t;
using uint8_t = std::uint8_t;
using int16_t = std::int16_t;
using uint16_t = std::uint16_t;
using int32_t = std::int32_t;
using uint32_t = std::uint32_t;
using int64_t = std::int64_t;
using uint64_t = std::uint64_t;
#endif // MIGRAPHX_USE_HIPRTC
using index_int = uint32_t;
using diff_int = int32_t;
static_assert(sizeof(int8_t) == 1, "int8_t must be 1 bytes");
static_assert(sizeof(uint8_t) == 1, "uint8_t must be 1 bytes");
static_assert(sizeof(int16_t) == 2, "int16_t must be 2 bytes");
static_assert(sizeof(uint16_t) == 2, "uint16_t must be 2 bytes");
static_assert(sizeof(int32_t) == 4, "int32_t must be 4 bytes");
static_assert(sizeof(uint32_t) == 4, "uint32_t must be 4 bytes");
static_assert(sizeof(int64_t) == 8, "int64_t must be 8 bytes");
static_assert(sizeof(uint64_t) == 8, "uint64_t must be 8 bytes");
#define MIGRAPHX_DEVICE_CONSTEXPR constexpr __device__ __host__ // NOLINT
......
......@@ -90,7 +90,6 @@ struct miopen_apply
add_extend_op("argmax");
add_extend_op("argmin");
add_extend_op("gather");
add_extend_op("logsoftmax");
add_extend_op("lrn");
add_extend_op("multinomial");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment