Unverified Commit a24ed87e authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into optimize_jenkinsfile

parents 6481cd69 a09dc502
......@@ -34,23 +34,32 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {
struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::pooling>
struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_v2_forward, op::pooling>
{
std::vector<int> arg_map(int) const { return {MIGRAPHX_DNNL_PREFIX(ARG_SRC)}; }
dnnl::pooling_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
dnnl::pooling_v2_forward::desc
get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
auto algo = op.mode == op::pooling_mode::max ? dnnl::algorithm::pooling_max
: dnnl::algorithm::pooling_avg;
auto algo = op.mode == op::pooling_mode::max ? dnnl::algorithm::pooling_max
: dnnl::algorithm::pooling_avg;
auto kdims = op.kdims();
std::vector<size_t> padding_l(op.padding.begin(), op.padding.begin() + kdims);
std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end());
// Note: It is not documented, but the default dilation seems to be 0 instead of 1.
// We need to offset dilations with -1.
std::vector<size_t> dilations;
std::transform(op.dilations.cbegin(),
op.dilations.cend(),
std::back_inserter(dilations),
[](size_t d) { return d - 1; });
return {dnnl::prop_kind::forward_inference,
algo,
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)),
to_dnnl_dims(op.stride),
to_dnnl_dims(op.lengths),
to_dnnl_dims(dilations),
to_dnnl_dims(padding_l),
to_dnnl_dims(padding_r)};
}
......
# ####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
......@@ -22,23 +22,22 @@
# THE SOFTWARE.
# ####################################################################################
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
find_package(hip)
find_package(hip REQUIRED)
if(NOT GPU_TARGETS)
message(FATAL_ERROR "HIP package is broken and has no GPU_TARGETS, please pass -DGPU_TARGETS=$(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') to cmake to build for your gpu.")
set(fatal_msg "HIP package is broken and has no GPU_TARGETS. Please pass GPU_TARGETS to cmake.")
if(NOT WIN32)
set(fatal_msg "${fatal_msg}\nUse -DGPU_TARGETS=$(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') to build for your GPU.")
endif()
message(FATAL_ERROR ${fatal_msg})
endif()
find_package(miopen)
find_package(miopen REQUIRED)
message(STATUS "MIGraphX is using MIOpen")
# rocblas
find_package(rocblas REQUIRED PATHS /opt/rocm)
message(STATUS "Build with rocblas")
if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen")
endif()
find_package(rocblas REQUIRED)
message(STATUS "MIGraphX build with rocBLAS")
if(NOT WIN32)
# TODO: re-enable when CK is ported to Windows
if(MIGRAPHX_USE_COMPOSABLEKERNEL)
find_package(composable_kernel 1.0.0 REQUIRED COMPONENTS jit_library)
endif()
......@@ -50,12 +49,11 @@ endif()
file(GLOB KERNEL_FILES CONFIGURE_DEPENDS
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp)
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
if(WIN32)
# TODO: re-enable when CK is ported to Windows
if(NOT MIGRAPHX_USE_COMPOSABLEKERNEL)
list(REMOVE_ITEM KERNEL_FILES
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck_gemm.hpp
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck_gemm_softmax_gemm.hpp
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck.hpp)
endif()
......@@ -67,8 +65,10 @@ file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/*
add_library(migraphx_device ${DEVICE_GPU_SRCS})
add_library(compile_for_gpu INTERFACE)
target_compile_options(compile_for_gpu INTERFACE -std=c++17 -fno-gpu-rdc -Wno-cuda-compat -Wno-unused-command-line-argument -Xclang -fallow-half-arguments-and-returns)
target_link_libraries(compile_for_gpu INTERFACE hip::device -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument -Wno-option-ignored)
target_compile_features(compile_for_gpu INTERFACE cxx_std_17)
target_compile_options(compile_for_gpu INTERFACE -fno-gpu-rdc -Wno-cuda-compat -Wno-unused-command-line-argument -Xclang -fallow-half-arguments-and-returns)
target_link_options(compile_for_gpu INTERFACE -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument -Wno-option-ignored)
target_link_libraries(compile_for_gpu INTERFACE hip::device)
check_cxx_compiler_flag("--cuda-host-only -fhip-lambda-host-device -x hip" HAS_HIP_LAMBDA_HOST_DEVICE)
if(HAS_HIP_LAMBDA_HOST_DEVICE)
......@@ -103,9 +103,10 @@ rocm_clang_tidy_check(kernel_file_check)
file(GLOB JIT_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/jit/*.cpp)
if(WIN32)
# TODO: re-enable when CK is ported to Windows
list(REMOVE_ITEM JIT_GPU_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm.cpp)
if(NOT MIGRAPHX_USE_COMPOSABLEKERNEL)
list(REMOVE_ITEM JIT_GPU_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm_softmax_gemm.cpp)
endif()
add_library(migraphx_gpu
......@@ -125,11 +126,8 @@ add_library(migraphx_gpu
fuse_ck.cpp
fuse_mlir.cpp
fuse_ops.cpp
gather.cpp
gemm_impl.cpp
hip.cpp
int8_conv_pack.cpp
int8_gemm_pack.cpp
kernel.cpp
lowering.cpp
logsoftmax.cpp
......@@ -140,9 +138,7 @@ add_library(migraphx_gpu
no_device.cpp
nonzero.cpp
pack_args.cpp
pack_int8_args.cpp
prefuse_ops.cpp
pad.cpp
perfdb.cpp
pooling.cpp
reverse.cpp
......@@ -170,12 +166,10 @@ endfunction()
register_migraphx_gpu_ops(hip_
argmax
argmin
gather
logsoftmax
loop
multinomial
nonzero
pad
prefix_scan_sum
reverse
scatter
......@@ -184,7 +178,6 @@ register_migraphx_gpu_ops(hip_
register_migraphx_gpu_ops(miopen_
abs
contiguous
int8_conv_pack
lrn
pooling
)
......@@ -192,10 +185,6 @@ register_op(migraphx_gpu
HEADER migraphx/gpu/rnn_variable_seq_lens.hpp
OPERATORS gpu::hip_rnn_var_sl_shift_sequence gpu::hip_rnn_var_sl_shift_output gpu::hip_rnn_var_sl_last_output
INCLUDES migraphx/gpu/context.hpp)
register_op(migraphx_gpu
HEADER migraphx/gpu/int8_gemm_pack.hpp
OPERATORS gpu::hip_int8_gemm_pack_a gpu::hip_int8_gemm_pack_b
INCLUDES migraphx/gpu/context.hpp)
register_op(migraphx_gpu
HEADER migraphx/gpu/gemm.hpp
OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot>
......@@ -219,8 +208,10 @@ if(MIGRAPHX_ENABLE_MLIR)
endif()
if(MIGRAPHX_USE_HIPRTC)
find_package(hiprtc REQUIRED)
message(STATUS "MIGraphX is using hipRTC")
target_compile_definitions(migraphx_gpu PRIVATE -DMIGRAPHX_USE_HIPRTC=1)
target_link_libraries(migraphx_gpu PUBLIC hiprtc::hiprtc)
else()
message(STATUS "MIGraphX is using HIP Clang")
......@@ -229,34 +220,47 @@ else()
target_flags(HIP_COMPILER_FLAGS hip::device)
# Remove cuda arch flags
string(REGEX REPLACE --cuda-gpu-arch=[a-z0-9]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
string(REGEX REPLACE --offload-arch=[a-z0-9:+-]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
string(REGEX REPLACE "--cuda-gpu-arch=[a-z0-9]+ ?" "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
string(REGEX REPLACE "--offload-arch=[a-z0-9:+-]+ ?" "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
# Skip library paths since hip will incorrectly treat it as a source file
string(APPEND HIP_COMPILER_FLAGS " ")
if(WIN32)
string(REPLACE "\\" "/" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
endif()
foreach(_unused RANGE 2)
string(REGEX REPLACE " /[^ ]+\\.(a|so) " " " HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
endforeach()
message(STATUS "Hip compiler flags: ${HIP_COMPILER_FLAGS}")
message(STATUS "Hip compiler flags: \"${HIP_COMPILER_FLAGS}\"")
target_compile_definitions(migraphx_gpu PRIVATE
"-DMIGRAPHX_HIP_COMPILER=${CMAKE_CXX_COMPILER}"
"-DMIGRAPHX_HIP_COMPILER_FLAGS=${HIP_COMPILER_FLAGS}"
-DMIGRAPHX_HIP_COMPILER="${CMAKE_CXX_COMPILER}"
-DMIGRAPHX_HIP_COMPILER_FLAGS="${HIP_COMPILER_FLAGS}"
)
if(DEFINED CMAKE_CXX_COMPILER_LAUNCHER)
execute_process(COMMAND which ${CMAKE_CXX_COMPILER_LAUNCHER} OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER)
if(WIN32)
execute_process(COMMAND where ${CMAKE_CXX_COMPILER_LAUNCHER} OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER)
else()
execute_process(COMMAND which ${CMAKE_CXX_COMPILER_LAUNCHER} OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER)
endif()
string(STRIP "${MIGRAPHX_HIP_COMPILER_LAUNCHER}" MIGRAPHX_HIP_COMPILER_LAUNCHER)
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_HIP_COMPILER_LAUNCHER=${MIGRAPHX_HIP_COMPILER_LAUNCHER}")
target_compile_definitions(migraphx_gpu PRIVATE -DMIGRAPHX_HIP_COMPILER_LAUNCHER="${MIGRAPHX_HIP_COMPILER_LAUNCHER}")
endif()
endif()
# Check miopen find mode api
include(CheckLibraryExists)
get_target_property(MIOPEN_LOCATION MIOpen LOCATION)
get_target_property(ROCBLAS_LOCATION roc::rocblas LOCATION)
check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API)
check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API)
# Beta API for automated GEMM tuning
check_library_exists(roc::rocblas "rocblas_gemm_ex_get_solutions" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_TUNING_BETA_FEATURE_API)
# rocblas FP8 API
check_library_exists(roc::rocblas "rocblas_gemm_strided_batched_ex3" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_FP8_BETA_API)
set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "")
......@@ -279,11 +283,25 @@ else()
message(STATUS "MIOpen does not have find mode api")
endif()
if(HAS_ROCBLAS_TUNING_BETA_FEATURE_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_TUNING_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS)
message(STATUS "MIGraphx is using Beta API of rocBLAS")
else()
message(STATUS "rocBLAS does not have User Tuning Beta API")
endif()
if(HAS_ROCBLAS_FP8_BETA_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_FP8_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS)
message(STATUS "MIGraphX is using Beta API of rocBLAS for FP8 computations")
else()
message(STATUS "rocBLAS does not have Fp8 Beta API")
endif()
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
if(NOT WIN32)
# TODO: re-enable when CK is ported to Windows
if(MIGRAPHX_USE_COMPOSABLEKERNEL)
target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library)
target_compile_definitions(migraphx_gpu PRIVATE MIGRAPHX_USE_COMPOSABLEKERNEL=1)
endif()
add_subdirectory(driver)
......
......@@ -54,6 +54,11 @@ vectorize vectorize::elements(std::size_t axis,
const std::vector<shape>& inputs,
const std::vector<std::size_t>& sizes)
{
// disable vectorization for fp8 types
if(std::any_of(inputs.begin(), inputs.end(), [&](auto ishape) {
return ishape.type() == migraphx::shape::fp8e4m3fnuz_type;
}))
return {1, axis};
if(std::all_of(
inputs.begin(), inputs.end(), [&](const auto& s) { return s.lens()[axis] == 1; }))
return {1, axis};
......@@ -86,6 +91,11 @@ vectorize vectorize::elements(std::size_t axis,
vectorize vectorize::elements(context& ctx, std::size_t axis, const std::vector<shape>& inputs)
{
// disable vectorization for fp8 types
if(std::any_of(inputs.begin(), inputs.end(), [&](auto ishape) {
return ishape.type() == migraphx::shape::fp8e4m3fnuz_type;
}))
return {1, axis};
if(inputs.empty())
return {1, axis};
std::size_t n = std::max_element(inputs.begin(),
......
......@@ -194,7 +194,7 @@ struct hiprtc_program
};
std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_src_file> srcs,
std::string params,
const std::string& params,
const std::string& arch)
{
hiprtc_program prog(std::move(srcs));
......@@ -238,8 +238,9 @@ bool hip_has_flags(const std::vector<std::string>& flags)
}
}
std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch)
std::vector<std::vector<char>> compile_hip_src(const std::vector<src_file>& srcs,
const std::string& params,
const std::string& arch)
{
std::vector<hiprtc_src_file> hsrcs{srcs.begin(), srcs.end()};
if(enabled(MIGRAPHX_GPU_DUMP_SRC{}))
......@@ -251,10 +252,21 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
std::cout << std::string(src.content) << std::endl;
}
}
auto fname = fs::path{"migraphx-hiprtc-driver"};
#ifdef _WIN32
fname.replace_extension(".exe");
#endif
auto p = dynamic_loader::path(&compile_hip_src_with_hiprtc);
auto driver = p.parent_path().parent_path() / "bin" / "migraphx-hiprtc-driver";
auto driver = p.parent_path() / fname;
if(fs::exists(driver))
bool found = fs::exists(driver);
if(not found)
{
driver = p.parent_path().parent_path() / "bin" / fname;
found = fs::exists(driver);
}
if(found)
{
value v;
v["srcs"] = to_value(hsrcs);
......@@ -270,13 +282,13 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
if(fs::exists(out))
return {read_buffer(out.string())};
}
return compile_hip_src_with_hiprtc(std::move(hsrcs), std::move(params), arch);
return compile_hip_src_with_hiprtc(std::move(hsrcs), params, arch);
}
#else // MIGRAPHX_USE_HIPRTC
std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_src_file>, // NOLINT
std::string, // NOLINT
const std::string&, // NOLINT
const std::string&)
{
MIGRAPHX_THROW("Not using hiprtc");
......@@ -284,16 +296,20 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
bool is_hip_clang_compiler()
{
static const auto result = ends_with(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER), "clang++");
static const auto result = fs::path{MIGRAPHX_HIP_COMPILER}.stem() == "clang++";
return result;
}
#ifdef MIGRAPHX_HIP_COMPILER_LAUNCHER
bool has_compiler_launcher()
{
static const auto result = fs::exists(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_LAUNCHER));
static const auto result = fs::exists(MIGRAPHX_HIP_COMPILER_LAUNCHER);
return result;
}
#endif
src_compiler assemble(src_compiler compiler)
{
compiler.out_ext = ".S";
......@@ -301,37 +317,39 @@ src_compiler assemble(src_compiler compiler)
return compiler;
}
std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch)
std::vector<std::vector<char>> compile_hip_src(const std::vector<src_file>& srcs,
const std::string& params,
const std::string& arch)
{
assert(not srcs.empty());
if(not is_hip_clang_compiler())
MIGRAPHX_THROW("Unknown hip compiler: " +
std::string(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER)));
MIGRAPHX_THROW("Unknown hip compiler: " MIGRAPHX_HIP_COMPILER);
src_compiler compiler;
compiler.flags = params;
compiler.compiler = MIGRAPHX_HIP_COMPILER;
#ifdef MIGRAPHX_HIP_COMPILER_LAUNCHER
if(has_compiler_launcher())
compiler.launcher = MIGRAPHX_HIP_COMPILER_LAUNCHER;
#endif
if(params.find("-std=") == std::string::npos)
params += " --std=c++17";
params += " -fno-gpu-rdc";
compiler.flags += " --std=c++17";
compiler.flags += " -fno-gpu-rdc";
if(enabled(MIGRAPHX_GPU_DEBUG_SYM{}))
params += " -g";
params += " -c";
params += " --offload-arch=" + arch;
params += " --cuda-device-only";
params += " -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3") + " ";
compiler.flags += " -g";
compiler.flags += " -c";
compiler.flags += " --offload-arch=" + arch;
compiler.flags += " --cuda-device-only";
compiler.flags += " -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3") + " ";
if(enabled(MIGRAPHX_GPU_DEBUG{}))
params += " -DMIGRAPHX_DEBUG";
compiler.flags += " -DMIGRAPHX_DEBUG";
params += " -Wno-unused-command-line-argument -Wno-cuda-compat ";
params += MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_FLAGS);
compiler.flags += " -Wno-unused-command-line-argument -Wno-cuda-compat ";
compiler.flags += MIGRAPHX_HIP_COMPILER_FLAGS;
src_compiler compiler;
compiler.flags = params;
compiler.compiler = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER);
#ifdef MIGRAPHX_HIP_COMPILER_LAUNCHER
if(has_compiler_launcher())
compiler.launcher = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_LAUNCHER);
#endif
if(enabled(MIGRAPHX_GPU_DUMP_SRC{}))
{
for(const auto& src : srcs)
......@@ -354,7 +372,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
bool hip_has_flags(const std::vector<std::string>& flags)
{
src_compiler compiler;
compiler.compiler = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER);
compiler.compiler = MIGRAPHX_HIP_COMPILER;
compiler.flags =
join_strings(flags, " ") + " -x hip -c --offload-arch=gfx900 --cuda-device-only";
......
......@@ -200,7 +200,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
options.params += " " + join_strings(compiler_warnings(), " ");
options.params += " -ftemplate-backtrace-limit=0";
options.params += " -Werror";
auto cos = compile_hip_src(srcs, std::move(options.params), get_device_name());
auto cos = compile_hip_src(srcs, options.params, get_device_name());
if(cos.size() != 1)
MIGRAPHX_THROW("No code object");
return code_object_op{value::binary{cos.front()},
......
......@@ -60,9 +60,8 @@ struct miopen_op
};
MIGRAPHX_REGISTER_OP(miopen_op);
std::size_t compile_miopen::compile(operation& op, instruction_ref ins, bool format) const
std::size_t compile_miopen::compile(operation& op, instruction_ref ins) const
{
op.from_value({{"int8_x4_format", format}});
auto v = op.compile(*ctx, ins->get_shape(), to_shapes(ins->inputs()));
return v.get<std::size_t>("workspace", 0);
}
......@@ -70,25 +69,15 @@ std::size_t compile_miopen::compile(operation& op, instruction_ref ins, bool for
void compile_miopen::apply(module& m) const
{
assert(ctx);
const bool int8_x4_format = get_int8_x4_format(any_cast<migraphx::gpu::context>(*ctx));
for(auto ins : iterator_for(m))
{
if(ins->name() != "gpu::miopen_op")
continue;
auto op = any_cast<miopen_op>(ins->get_operator()).op;
std::size_t ws = 0;
try
{
// for the regular convolution and convolution_backwards, this try would always succeed
ws = compile(op, ins, int8_x4_format);
}
catch(migraphx::exception&)
{
// In case no solver supports the default format, retry using the other format.
ws = compile(op, ins, not int8_x4_format);
}
auto inputs = ins->inputs();
auto alloc = m.insert_instruction(
ws = compile(op, ins);
auto inputs = ins->inputs();
auto alloc = m.insert_instruction(
ins, make_op("allocate", {{"shape", to_value(shape{shape::int8_type, {ws}})}}));
inputs.insert(std::prev(inputs.end()), alloc);
......
......@@ -37,6 +37,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_BENCHMARKING);
struct precompile_op
{
......@@ -167,6 +168,7 @@ struct compile_plan
}
const compiled_result& benchmark(problem_cache& pc) const
{
const auto trace_level = value_of(MIGRAPHX_TRACE_BENCHMARKING{});
if(results.empty())
MIGRAPHX_THROW("No configs to tune");
if(results.size() == 1)
......@@ -177,18 +179,35 @@ struct compile_plan
}
if(not config)
MIGRAPHX_THROW("Multiple kernels without config");
std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs"
<< std::endl;
if(trace_level > 0)
std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs"
<< std::endl;
if(trace_level > 1)
std::cout << "Problem: " << config->problem << std::endl;
std::vector<double> times;
times.reserve(results.size());
std::transform(
results.begin(), results.end(), std::back_inserter(times), [&](const auto& cr) {
if(not cr.has_value())
return std::numeric_limits<double>::max();
return time_op(*ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20);
});
std::transform(results.begin(),
results.end(),
config->solutions.begin(),
std::back_inserter(times),
[&](const auto& cr, const auto& solution) {
if(trace_level > 1)
std::cout << "Benchmarking solution: " << solution << std::endl;
if(not cr.has_value())
{
if(trace_level > 1)
std::cout << "No binary" << std::endl;
return std::numeric_limits<double>::max();
}
auto t = time_op(
*ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20);
if(trace_level > 1)
std::cout << t << "ms" << std::endl;
return t;
});
auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end()));
std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl;
if(trace_level > 0)
std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl;
pc.insert(preop.name(), config->problem, config->solutions.at(i));
if(not results[i].has_value())
MIGRAPHX_THROW("No valid tuned compilation.");
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/gather.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument gather(hipStream_t stream, argument result, argument arg1, argument arg2, int64_t axis)
{
const auto& input_shape = arg1.get_shape();
auto lens = input_shape.lens();
auto axis_dim_size = lens[axis];
lens[axis] = arg2.get_shape().elements();
shape out_comp_shape{result.get_shape().type(), lens};
std::size_t nelements = result.get_shape().elements();
visit_all(result, arg1)([&](auto output, auto input_v) {
hip_visit_views(input_v, out_comp_shape)([&](auto input, auto out_comp) {
arg2.visit([&](auto indices) {
const auto* indices_ptr = device_cast(indices.data());
auto* output_ptr = device_cast(output.data());
gs_launch(stream, nelements, 256)([=](auto i) __device__ {
auto idx = out_comp.multi(i);
auto in_index = indices_ptr[idx[axis]];
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
idx[axis] = in_index;
output_ptr[i] = input[idx];
});
});
});
});
return result;
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -43,24 +43,32 @@ template <index_int N,
__device__ void block_scan(index idx, Op op, T init, ForStride fs, Input input, Output output)
{
using type = decltype(input(deduce_for_stride(fs)));
MIGRAPHX_DEVICE_SHARED type buffer[N];
MIGRAPHX_DEVICE_SHARED type buffer[2][N];
type x = init;
fs([&](auto i) {
index_int iout = 0;
index_int iin = 1;
if(idx.local == 0)
buffer[idx.local] = op(input(i), x);
buffer[iout][idx.local] = op(input(i), x);
else
buffer[idx.local] = input(i);
buffer[iout][idx.local] = input(i);
__syncthreads();
for(index_int s = 1; s < idx.nlocal(); s *= 2)
{
if(idx.local + s < idx.nlocal())
iout = 1 - iout;
iin = 1 - iin;
if(idx.local >= s)
{
buffer[idx.local + s] = op(buffer[idx.local], buffer[idx.local + s]);
buffer[iout][idx.local] = op(buffer[iin][idx.local], buffer[iin][idx.local - s]);
}
else
{
buffer[iout][idx.local] = buffer[iin][idx.local];
}
__syncthreads();
}
x = buffer[idx.nlocal() - 1];
output(i, buffer[idx.local]);
x = buffer[iout][idx.nlocal() - 1];
output(i, buffer[iout][idx.local]);
});
}
......
......@@ -146,20 +146,20 @@ __device__ __host__ T to_hip_type(T x)
// Hip doens't support __fp16
inline __device__ __host__ float to_hip_type(gpu_half x) { return x; }
#define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \
struct trait : std::trait<X> \
{ \
}; \
\
template <> \
struct trait<T> : std::true_type \
{ \
#define MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \
struct trait : std::trait<X> \
{ \
}; \
\
template <> \
struct trait<T> : std::true_type \
{ \
};
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, __fp16)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, __fp16)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, __fp16)
MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, __fp16)
MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_signed, __fp16)
MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, __fp16)
} // namespace device
} // namespace gpu
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/int8_gemm_pack.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/tensor.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void int8_gemm_pack_a(hipStream_t stream, const argument& result, const argument& arg)
{
auto comp_shape = arg.get_shape();
auto out_lens = comp_shape.lens();
auto dim_0 = out_lens.size() - 2;
auto dim_1 = out_lens.size() - 1;
std::size_t lda = comp_shape.strides()[dim_0];
std::size_t m_size = out_lens[dim_0] * out_lens[dim_1];
visit_all(result, arg)([&](auto output, auto input) {
std::size_t nelements = comp_shape.elements();
auto* out_ptr = device_cast(output.data());
auto* in_ptr = device_cast(input.data());
visit_tensor_size(out_lens.size(), [&](auto out_dim) {
hip_tensor_descriptor<out_dim> desc(comp_shape);
gs_launch(stream, nelements, 256)([=](auto ii) __device__ {
const size_t nb = 4;
auto idx = desc.multi(ii);
std::size_t i_m = idx[dim_1];
std::size_t i_k = idx[dim_0];
std::size_t offset = ii / m_size * m_size;
out_ptr[i_k % nb + (i_m + (i_k / nb) * lda) * nb + offset] =
in_ptr[i_m + i_k * lda + offset];
});
});
});
}
void int8_gemm_pack_b(hipStream_t stream, const argument& result, const argument& arg)
{
auto trans_shape = arg.get_shape();
auto out_lens = trans_shape.lens();
auto dim_0 = trans_shape.lens().size() - 2;
auto dim_1 = trans_shape.lens().size() - 1;
std::size_t ldb = trans_shape.strides()[dim_1];
auto wrap_lens = out_lens;
std::swap(wrap_lens[dim_0], wrap_lens[dim_1]);
shape comp_shape{trans_shape.type(), wrap_lens};
std::size_t m_size = out_lens[dim_0] * out_lens[dim_1];
visit_all(result, arg)([&](auto output, auto input) {
std::size_t nelements = comp_shape.elements();
auto* out_ptr = device_cast(output.data());
auto* in_ptr = device_cast(input.data());
visit_tensor_size(out_lens.size(), [&](auto out_dim) {
hip_tensor_descriptor<out_dim> desc(comp_shape);
gs_launch(stream, nelements, 256)([=](auto ii) __device__ {
const size_t nb = 4;
auto idx = desc.multi(ii);
std::size_t i_n = idx[dim_1];
std::size_t i_k = idx[dim_0];
std::size_t offset = ii / m_size * m_size;
out_ptr[i_k % nb + (i_n + (i_k / nb) * ldb) * nb + offset] =
in_ptr[i_n + i_k * ldb + offset];
});
});
});
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -21,58 +21,64 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/driver/action.hpp>
#include <migraphx/gpu/time_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/compile_ops.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace driver {
shape pack_int8_shape(const shape& s)
struct precompile_op : action<precompile_op>
{
if(s.type() != shape::int8_type)
static program create_preop_program(const operation& preop, std::vector<shape> inputs)
{
MIGRAPHX_THROW("PACK_INT8_ARGS: only process int8_type");
program p;
auto* mm = p.get_main_module();
std::vector<instruction_ref> args;
inputs.pop_back();
transform(inputs, range(inputs.size()), std::back_inserter(args), [&](auto input, auto i) {
return mm->add_parameter("x" + std::to_string(i), input);
});
mm->add_instruction(preop, args);
return p;
}
auto lens = s.lens();
auto strides = s.strides();
lens[1] = (lens[1] + 3) / 4 * 4;
strides[0] = strides[1] * lens[1];
return {s.type(), lens, strides};
}
shape miopen_int8_conv_pack::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{{inputs.at(0)}, *this}.has(1).standard();
return pack_int8_shape(inputs.at(0));
}
argument
miopen_int8_conv_pack::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
auto arg_desc = make_tensor(args[0].get_shape());
auto arg_desc_vec4 = make_tensor(args[0].get_shape(), true);
float alpha = 1;
float beta = 0;
// pack input to vec4 format
auto status = miopenTransformTensor(ctx.get_stream().get_miopen(),
&alpha,
arg_desc.get(),
args[0].implicit(),
&beta,
arg_desc_vec4.get(),
args[1].implicit());
if(status != miopenStatusSuccess)
static operation get_code_object(const program& p)
{
MIGRAPHX_THROW("INT8_CONV_PACK: transform input tensor failed");
MIGRAPHX_TIDY_CONST auto* mm = p.get_main_module();
auto it = std::find_if(mm->begin(), mm->end(), [](const auto& ins) {
return (ins.name() == "gpu::code_object");
});
if(it == mm->end())
MIGRAPHX_THROW("Failed to create code object");
return it->get_operator();
}
static void apply(const parser& p, const value& v)
{
context ctx;
auto inputs = p.parse_shapes(v.at("inputs"));
auto name = v.at("name").to<std::string>();
auto preop = make_op(name);
if(v.contains("fields"))
preop.from_value(v.at("fields"));
bool exhaustive = v.get("exhaustive", false);
auto prog = create_preop_program(preop, inputs);
run_passes(prog, {lowering{}, compile_ops{&ctx, exhaustive}});
auto op = get_code_object(prog);
auto t = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << preop << ": " << t << "ms" << std::endl;
}
};
return args[1];
}
} // namespace driver
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -38,6 +38,18 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_EXTRA_MLIR);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR);
/**
* @brief Declares a new MIGraphX environment variable which forces to generate
* only specific MLIR operations.
*
* The variable, if defined, forces MIGraphX to use only specific operations
* with MLIR regardless of the underlying GPU architecture. The variable accepts
* a list of operations separated by comma. The variable recognizes the following
* operations: "fused", "convolution", "dot". If the variable is not defined MIGraphX
* will decide by itself which operations to delegate to MLIR. The variable is
* intended to be primarily used by rocMLIR developers.
*/
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS);
bool mlir_enabled()
{
......@@ -49,6 +61,26 @@ bool mlir_enabled()
#endif
}
static bool is_requested(std::string_view option, bool fallback = false)
{
auto string_value = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "");
if(string_value.empty())
return fallback;
const auto options = split_string(string_value, ',');
return contains(options, option);
}
bool mlir_attention_enabled()
{
#ifdef MIGRAPHX_MLIR
if(not mlir_enabled())
return false;
return is_requested("attention");
#else
return false;
#endif
}
#ifdef MIGRAPHX_MLIR
struct mlir_op
......@@ -62,41 +94,27 @@ struct mlir_op
return pack(f(self.op, "op"));
}
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
shape compute_shape(const std::vector<shape>& inputs, const std::vector<module_ref>& mods) const
{
module_ref mod = mods[0];
check_shapes{inputs, *this}.packed_or_broadcasted();
if(mods.size() != 1)
MIGRAPHX_THROW("should have one submodule.");
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");
module_ref mod = mods[0];
auto type = mod->get_output_shapes().front().type();
auto type = mod->get_output_shapes().front().type();
std::unordered_map<instruction_ref, shape> ins_shapes;
size_t param_cnt = 0;
std::vector<std::string> names = mod->get_parameter_names();
std::sort(names.begin(), names.end());
for(const std::string& param_name : names)
{
ins_shapes[mod->get_parameter(param_name)] = inputs[param_cnt++];
}
for(auto ins : iterator_for(*mod))
{
if(ins->name() == "@param")
{
continue;
}
if(ins->name() == "@literal")
if(ins->name() == "@literal" or ins->name() == "@param")
{
ins_shapes[ins] = ins->get_shape();
continue;
}
if(ins->name() == "@return")
{
auto s = ins_shapes[ins->inputs().at(0)].with_type(type);
if(not s.standard())
MIGRAPHX_THROW("MLIR doesnt support non-standard output");
return s;
return ins_shapes[ins->inputs().at(0)].with_type(type);
}
std::vector<shape> input_shapes;
input_shapes.resize(ins->inputs().size());
......@@ -112,38 +130,55 @@ struct mlir_op
MIGRAPHX_REGISTER_OP(mlir_op);
namespace {
std::tuple<instruction_ref, std::vector<operation>>
get_fusable_input_op_stream(instruction_ref lower_input)
{
instruction_ref upper_input = lower_input;
std::vector<operation> op_stream;
while(contains({"slice",
"transpose",
"multibroadcast",
"broadcast",
"contiguous",
"reshape",
"squeeze",
"flatten",
"unsqueeze"},
upper_input->name()))
{
operation op = upper_input->get_operator();
if(contains({"squeeze", "flatten", "unsqueeze"}, upper_input->name()))
{
op = migraphx::make_op("reshape", {{"dims", upper_input->get_shape().lens()}});
}
op_stream.push_back(op);
upper_input = upper_input->inputs().at(0);
}
return {upper_input, op_stream};
}
std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
fuse_input_ops_and_gemm_based_op(module_ref mm,
const std::vector<instruction_ref>& gemm_based_op_inputs,
const operation& gemm_based_op)
{
std::vector<instruction_ref> top_inputs;
std::vector<instruction_ref> imm_inputs;
size_t input_cnt = 0;
for(instruction_ref input : gemm_based_op->inputs())
for(instruction_ref input : gemm_based_op_inputs)
{
std::vector<operation> op_stream;
while(contains(
{"slice", "transpose", "contiguous", "reshape", "squeeze", "flatten", "unsqueeze"},
input->name()))
{
operation op = input->get_operator();
if(contains({"squeeze", "flatten", "unsqueeze"}, input->name()))
{
op = migraphx::make_op("reshape", {{"dims", input->get_shape().lens()}});
}
op_stream.push_back(op);
input = input->inputs().at(0);
}
top_inputs.push_back(input);
auto [upper_input, op_stream] = get_fusable_input_op_stream(input);
top_inputs.push_back(upper_input);
instruction_ref prev_input =
mm->add_parameter("y" + std::to_string(input_cnt++), input->get_shape());
mm->add_parameter("y" + std::to_string(input_cnt++), upper_input->get_shape());
for(const auto& op : reverse(op_stream))
{
prev_input = mm->add_instruction(op, {prev_input});
}
imm_inputs.push_back(prev_input);
}
instruction_ref new_gemm_based_op =
mm->add_instruction(gemm_based_op->get_operator(), imm_inputs);
instruction_ref new_gemm_based_op = mm->add_instruction(gemm_based_op, imm_inputs);
return {new_gemm_based_op, top_inputs};
}
......@@ -205,102 +240,135 @@ auto is_mlir_conv(mlir_mode mode)
});
}
struct find_mlir_fused_ops
std::unordered_map<instruction_ref, instruction_ref>
create_param_map_with_literals(module_ref mm, const module* pm, const shape& shape)
{
mlir_mode conv_mode = mlir_mode::none;
mlir_mode dot_mode = mlir_mode::none;
auto matcher() const
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(*pm))
{
auto dot_or_conv = match::skip(match::name("contiguous"))(
match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)).bind("gemm_based_op"));
return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x")));
}
std::unordered_map<instruction_ref, instruction_ref>
create_param_map_with_literals(module_ref mm, const module* pm, const shape& shape) const
{
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(*pm))
if(ins->name() != "@literal")
{
if(ins->name() != "@literal")
{
continue;
}
literal r = ins->get_literal();
instruction_ref literal = mm->add_literal(r);
instruction_ref mbcast = mm->add_instruction(
make_op("multibroadcast", {{"out_lens", shape.lens()}}), literal);
ins_map[ins] = mbcast;
continue;
}
return ins_map;
literal r = ins->get_literal();
instruction_ref literal = mm->add_literal(r);
instruction_ref mbcast =
mm->add_instruction(make_op("multibroadcast", {{"out_lens", shape.lens()}}), literal);
ins_map[ins] = mbcast;
}
return ins_map;
}
// Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types.
bool is_pointwise_op_supported_by_mlir(const instruction& i) const
std::vector<instruction_ref>
fold_pointwise_mod(instruction_ref pm_ins,
module_ref parent_mod,
const std::unordered_map<instruction_ref, instruction_ref>& ins_map)
{
auto* pm = pm_ins->module_inputs().front();
auto names = pm->get_parameter_names();
std::sort(names.begin(), names.end());
std::unordered_map<instruction_ref, instruction_ref> param_map =
create_param_map_with_literals(parent_mod, pm, pm_ins->get_shape());
std::transform(names.begin(),
names.end(),
pm_ins->inputs().begin(),
std::inserter(param_map, param_map.end()),
[&](auto name, auto input) {
if(ins_map.count(input))
return std::make_pair(pm->get_parameter(name), ins_map.at(input));
return std::make_pair(pm->get_parameter(name),
parent_mod->add_parameter(name, input->get_shape()));
});
return parent_mod->insert_instructions(parent_mod->end(), pm, param_map);
}
// Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types.
bool is_pointwise_op_supported_by_mlir(const instruction& i)
{
using type_t = shape::type_t;
const auto& name = i.name();
const auto result_type = i.get_shape().type();
const std::initializer_list<type_t> allowed_types = {type_t::float_type,
type_t::half_type,
type_t::int8_type,
type_t::int32_type,
type_t::bool_type};
// Preliminary type check.
if(not contains(allowed_types, result_type))
{
using type_t = shape::type_t;
const auto& name = i.name();
const auto result_type = i.get_shape().type();
const std::initializer_list<type_t> allowed_types = {type_t::float_type,
type_t::half_type,
type_t::int8_type,
type_t::int32_type,
type_t::bool_type};
// Preliminary type check.
if(not contains(allowed_types, result_type))
{
return false;
}
const std::initializer_list<std::string> any_type_ops = {"@literal", "@param", "@return"};
const std::initializer_list<std::string> no_bool_ops = {
"convolution",
"quant_convolution",
"dot",
"quant_dot",
"add",
"clip",
"relu",
"sub",
"mul",
"div",
"pow",
"where",
"quantizelinear",
"dequantizelinear",
"abs",
"neg",
};
const std::initializer_list<std::string> fp_only_ops = {
"ceil",
"erf",
"exp",
"floor",
"log",
"recip",
"rsqrt",
"sigmoid",
"softmax",
"tanh",
};
bool is_float = contains({type_t::float_type, type_t::half_type}, result_type);
if(contains(any_type_ops, name))
return true;
if(result_type != type_t::bool_type and contains(no_bool_ops, name))
return true;
if(is_float and contains(fp_only_ops, name))
return true;
// Only conversions between floating types are known to be unambigiously
// supported.
if(is_float and name == "convert")
{
return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) {
return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type());
});
}
return false;
}
const std::initializer_list<std::string> any_type_ops = {"@literal", "@param", "@return"};
const std::initializer_list<std::string> no_bool_ops = {
"convolution",
"quant_convolution",
"dot",
"quant_dot",
"add",
"clip",
"relu",
"sub",
"mul",
"div",
"pow",
"where",
"quantizelinear",
"dequantizelinear",
"abs",
"neg",
};
const std::initializer_list<std::string> fp_only_ops = {
"ceil",
"erf",
"exp",
"floor",
"log",
"recip",
"rsqrt",
"sigmoid",
"softmax",
"tanh",
};
bool is_float = contains({type_t::float_type, type_t::half_type}, result_type);
if(contains(any_type_ops, name))
return true;
if(result_type != type_t::bool_type and contains(no_bool_ops, name))
return true;
if(is_float and contains(fp_only_ops, name))
return true;
// Only conversions between floating types are known to be unambigiously
// supported.
if(is_float and name == "convert")
{
return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) {
return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type());
});
}
return false;
}
MIGRAPHX_PRED_MATCHER(mlir_pointwise, instruction_ref ins)
{
if(ins->name() != "pointwise")
return false;
auto* pm = ins->module_inputs().front();
return std::all_of(pm->begin(), pm->end(), [&](const auto& i) {
return is_pointwise_op_supported_by_mlir(i);
});
}
struct find_mlir_fused_ops
{
mlir_mode conv_mode = mlir_mode::none;
mlir_mode dot_mode = mlir_mode::none;
auto matcher() const
{
auto dot_or_conv = match::skip(match::name("contiguous"))(
match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)).bind("gemm_based_op"));
return mlir_pointwise()(match::any_of[match::inputs()](dot_or_conv.bind("x")));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
......@@ -309,29 +377,12 @@ struct find_mlir_fused_ops
auto x_ins = r.instructions["x"]; // input after contiguous
auto* pm = ins->module_inputs().front();
auto names = pm->get_parameter_names();
// Whitelist pointwise operators.
if(std::any_of(pm->begin(), pm->end(), [&](const auto& i) {
return not is_pointwise_op_supported_by_mlir(i);
}))
return;
std::sort(names.begin(), names.end());
module_ref mm = mpm.create_module("mlir_" + pm->name());
mm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> param_map =
create_param_map_with_literals(mm, pm, gemm_based_op->get_shape());
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, gemm_based_op);
std::transform(names.begin(),
names.end(),
ins->inputs().begin(),
std::inserter(param_map, param_map.end()),
[&, &anchor = anchor_op](auto name, auto input) {
if(input == x_ins)
return std::make_pair(pm->get_parameter(name), anchor);
return std::make_pair(pm->get_parameter(name),
mm->add_parameter(name, input->get_shape()));
});
mm->add_return(mm->insert_instructions(mm->end(), pm, param_map));
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(
mm, gemm_based_op->inputs(), gemm_based_op->get_operator());
mm->add_return(fold_pointwise_mod(ins, mm, {{x_ins, anchor_op}}));
std::vector<instruction_ref> inputs;
std::copy_if(ins->inputs().begin(),
......@@ -349,51 +400,103 @@ struct find_mlir_standalone_op
{
mlir_mode mode = mlir_mode::none;
auto matcher() const { return Matcher(mode); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto conv_based_op = r.result;
auto gemm_based_op = r.result;
//
// enable only for fp32/fp16/i8 types
if(std::any_of(conv_based_op->inputs().begin(), conv_based_op->inputs().end(), [&](auto i) {
if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) {
return not contains(
{shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type},
i->get_shape().type());
}))
return;
static size_t counter = 0;
module_ref mm = mpm.create_module("mlir_" + std::to_string(counter++));
module_ref mm =
mpm.create_module("mlir_" + gemm_based_op->name() + std::to_string(counter++));
mm->set_bypass();
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, conv_based_op);
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(
mm, gemm_based_op->inputs(), gemm_based_op->get_operator());
mm->add_return({anchor_op});
mpm.get_module().replace_instruction(
conv_based_op, mlir_op{conv_based_op->get_operator()}, top_inputs, {mm});
gemm_based_op, mlir_op{gemm_based_op->get_operator()}, top_inputs, {mm});
}
};
using find_mlir_standalone_convolution_op = find_mlir_standalone_op<&is_mlir_conv>;
using find_mlir_standalone_dot_op = find_mlir_standalone_op<&is_mlir_dot>;
/**
* @brief Declares a new MIGraphX environment variable which forces to generate
* only specific MLIR operations.
*
* The variable, if defined, forces MIGraphX to use only specific operations
* with MLIR regardless of the underlying GPU architecture. The variable accepts
* a list of operations separated by comma. The variable recognizes the following
* operations: "fused", "convolution", "dot". If the variable is not defined MIGraphX
* will decide by itself which operations to delegate to MLIR. The variable is
* intended to be primarily used by rocMLIR developers.
*/
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS);
struct find_mlir_standalone_attention_op
{
auto matcher() const
{
return match::name("gpu::pre_gemm_softmax_gemm").bind("gemm_softmax_gemm");
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
static size_t counter = 0;
module_ref mm = mpm.create_module("mlir_" + std::to_string(counter++));
auto gemm_softmax_gemm = r.instructions["gemm_softmax_gemm"];
std::vector<instruction_ref> inputs;
mm->set_bypass();
bool is_requested(std::string_view option, bool fallback = false)
std::unordered_map<instruction_ref, instruction_ref> ins_map;
auto gemm0_inputs = gemm_softmax_gemm->inputs();
gemm0_inputs.pop_back();
auto [gemm0, top_gemm0_inputs] =
fuse_input_ops_and_gemm_based_op(mm, gemm0_inputs, make_op("dot"));
inputs.insert(inputs.begin(), top_gemm0_inputs.begin(), top_gemm0_inputs.end());
// handle scale
auto v = gemm_softmax_gemm->get_operator().to_value();
assert(v.contains("scale"));
auto scale = v.at("scale").to<float>();
auto scale_lit = mm->add_literal(literal{shape{gemm0->get_shape().type()}, {scale}});
instruction_ref scale_lit_mbcast = mm->add_instruction(
make_op("multibroadcast", {{"out_lens", gemm0->get_shape().lens()}}), scale_lit);
auto scaled_gemm0 = mm->add_instruction(make_op("mul"), gemm0, scale_lit_mbcast);
auto softmax = mm->add_instruction(
make_op("softmax", {{"axis", gemm0->get_shape().lens().size() - 1}}), scaled_gemm0);
auto [old_upper_v, upper_v_op_stream] =
get_fusable_input_op_stream(gemm_softmax_gemm->inputs()[2]);
instruction_ref new_upper_v = mm->add_parameter("z", old_upper_v->get_shape());
for(const auto& op : reverse(upper_v_op_stream))
{
new_upper_v = mm->add_instruction(op, {new_upper_v});
}
inputs.push_back(old_upper_v);
auto gemm1 = mm->add_instruction(make_op("dot"), {softmax, new_upper_v});
ins_map[gemm_softmax_gemm] = gemm1;
auto ins_to_replace = gemm1;
auto ins_to_be_replaced = gemm_softmax_gemm;
if(r.instructions.find("trailing_pm") != r.instructions.end())
{
ins_to_replace = fold_pointwise_mod(r.instructions["trailing_pm"], mm, ins_map)[0];
std::copy_if(r.instructions["trailing_pm"]->inputs().begin(),
r.instructions["trailing_pm"]->inputs().end(),
std::back_inserter(inputs),
[&](auto input) { return input != gemm_softmax_gemm; });
ins_to_be_replaced = r.instructions["trailing_pm"];
}
mm->add_return({ins_to_replace});
mpm.get_module().replace_instruction(
ins_to_be_replaced, mlir_op{gemm1->get_operator()}, inputs, {mm});
}
};
struct find_mlir_attention_fused_ops : public find_mlir_standalone_attention_op
{
auto string_value = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "");
if(string_value.empty())
return fallback;
const auto options = split_string(string_value, ',');
return contains(options, option);
}
auto matcher() const
{
auto standalone_matcher = find_mlir_standalone_attention_op::matcher();
return mlir_pointwise()(
match::any_of[match::inputs()](standalone_matcher).bind("trailing_pm"));
;
}
};
} // namespace
#endif // MIGRAPHX_MLIR
......@@ -415,13 +518,20 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
mlir_mode mode =
(enabled(MIGRAPHX_ENABLE_EXTRA_MLIR{}) or enable_extra) ? mlir_mode::fast : mlir_mode::none;
// Attention offloads; default disabled
if(mlir_attention_enabled())
{
match::find_matches(mpm, find_mlir_attention_fused_ops{});
match::find_matches(mpm, find_mlir_standalone_attention_op{});
}
match::find_matches(mpm,
find_mlir_fused_ops{.conv_mode = get_mode("fused", mlir_mode::fast),
.dot_mode = get_mode("fused", mode)});
match::find_matches(
mpm,
find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::int8)},
find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::fast)},
find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::none)});
#else
(void)mpm;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -21,15 +21,37 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <rocblas/internal/rocblas-types.h>
#include <rocblas/rocblas.h>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/time.hpp>
#include <type_traits>
using microseconds = std::chrono::duration<double, std::micro>;
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
/*
Regular rocBLAS API takes compute_type as `rocblas_datatype` enum value v/s "ex3" BETA API takes it
as `rocblas_computetype` enum value. `rb_compute_type` is faciliator to implictly cast integer enum
value to required type that can be used inside `common_args` generator.
*/
struct rb_compute_type
{
int type = 0;
rb_compute_type(rocblas_datatype t) : type(static_cast<int>(t)) {}
rb_compute_type(rocblas_computetype t) : type(static_cast<int>(t)) {}
operator rocblas_datatype() const { return static_cast<rocblas_datatype>(type); }
operator rocblas_computetype() const { return static_cast<rocblas_computetype>(type); }
};
// Convert rocBLAS datatypes to equivalent Migraphx data types
rocblas_datatype get_type(shape::type_t type)
{
switch(type)
......@@ -41,6 +63,7 @@ rocblas_datatype get_type(shape::type_t type)
case shape::uint8_type: return rocblas_datatype_u8_r;
case shape::int32_type: return rocblas_datatype_i32_r;
case shape::uint32_type: return rocblas_datatype_u32_r;
case shape::fp8e4m3fnuz_type: return rocblas_datatype_f8_r;
case shape::tuple_type:
case shape::bool_type:
case shape::uint16_type:
......@@ -81,196 +104,542 @@ shape transpose_batch(const shape& s, unsigned trans_batch)
return shape::from_permutation(s.type(), s.lens(), perm);
}
template <class R, class... Ts, class... Us>
R rocblas_invoke(R (*f)(Ts...), Us... xs)
/**
* Returns results of rocblas_status_success, rocblas_status_perf_degraded,
* or rocblas_status_invalid_value. Caller
* is expected to check for invalid index. Any other result causes an exception.
*
*/
template <class F, class Pack, class... Ts>
auto rocblas_invoke(F f, Pack p, Ts... xs)
{
if constexpr(sizeof...(Ts) == sizeof...(Us))
return f(xs...);
else
return f(xs..., nullptr, nullptr);
return p([=](auto... ws) {
auto status = f(ws..., xs...);
if(status != rocblas_status_success and status != rocblas_status_invalid_value)
{
if(status == rocblas_status_perf_degraded)
{
std::cerr << "WARNING: degraded perf. in rocBLAS call" << std::endl;
}
else
MIGRAPHX_THROW("rocblas_invoke: rocBLAS call failed with status " +
std::to_string(status));
}
return status;
});
}
static bool is_transposed(const shape& s)
{
if(not s.transposed())
return false;
return s.strides().back() != 1;
}
static bool is_transposed(const shape& s) { return s.transposed() and s.strides().back() != 1; }
static rocblas_int get_batch_stride(const argument& a)
static rocblas_int get_batch_stride(const shape& s)
{
return a.get_shape().strides()[a.get_shape().strides().size() - 3];
// This value is not needed for non-strided inputs
if(s.strides().size() < 3)
return 0;
else
return s.strides()[s.strides().size() - 3];
}
template <class T>
void gemm_impl(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
T alpha,
T beta,
bool int8_x4_format,
bool compute_fp32)
/**
* Wrapper for multiple rocBLAS calls. The constructor creates parameters for
* these calls based on data shapes and other values contained in the associated
* instruction and operation.
*
* The template parameter T is not the type of the matrix data but of the weighting
* coefficients alpha and beta (these are float in rocBLAS internals)
*/
template <typename T>
struct gemm_impl
{
const bool is_3inputs = (args.size() == 4);
if(not is_3inputs)
gemm_impl(const shape& output_shape,
const std::vector<shape>& input_shapes,
T alpha_param,
T beta_param,
bool compute_fp32_flag)
: alpha(alpha_param),
beta(beta_param),
is_3inputs(input_shapes.size() == 4),
compute_fp32(compute_fp32_flag)
{
beta = 0;
}
bool transa = is_transposed(args[0].get_shape());
bool transb = is_transposed(args[1].get_shape());
auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2;
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0];
rocblas_int ldd = is_3inputs ? args[3].get_shape().strides()[dim_0] : ldc;
rocblas_datatype arg_type = get_type(args[0].get_shape().type());
auto output_type = arg_type;
if(output_type == rocblas_datatype_i8_r)
{
output_type = rocblas_datatype_i32_r;
}
auto compute_type = output_type;
if(compute_fp32)
{
if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r;
}
rocblas_gemm_flags flag = rocblas_gemm_flags_none;
#if ROCBLAS_VERSION_MAJOR < 3
if(int8_x4_format)
flag = rocblas_gemm_flags_pack_int8x4;
#endif
if(not is_3inputs)
{
beta = 0;
}
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) {
auto alpha_r = as(alpha);
auto beta_r = as(beta);
// Create lambdas that will cast alpha, beta to the output shape's type
// and retain the values being pointed to
output_shape.visit_type([&](auto as) {
auto alpha_r = as(alpha);
auto beta_r = as(beta);
if(compute_fp32)
{
get_alpha = [=] { return &alpha; };
get_beta = [=] { return &beta; };
}
else
{
get_alpha = [=] { return &alpha_r; };
get_beta = [=] { return &beta_r; };
}
});
// use void pointer to select different data type if using fp32 mode
void* alpha_v = &alpha_r;
void* beta_v = &beta_r;
transa = is_transposed(input_shapes[0]);
transb = is_transposed(input_shapes[1]);
auto n_dim = output_shape.lens().size();
auto dim_0 = n_dim - 2;
auto dim_1 = n_dim - 1;
// Leading dimensions of matrices
lda = input_shapes[0].strides()[transa ? dim_1 : dim_0];
ldb = input_shapes[1].strides()[transb ? dim_1 : dim_0];
ldc = input_shapes[2].strides()[dim_0];
ldd = is_3inputs ? input_shapes[3].strides()[dim_0] : ldc;
arg_type = get_type(input_shapes[0].type());
output_type = arg_type;
if(output_type == rocblas_datatype_i8_r)
{
output_type = rocblas_datatype_i32_r;
}
compute_type = rb_compute_type{output_type};
if(compute_fp32)
{
alpha_v = &alpha;
beta_v = &beta;
if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r;
}
auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1];
auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); };
if(args[0].get_shape().type() == shape::int8_type and (k % 4) != 0 and int8_x4_format)
if(arg_type == rocblas_datatype_f8_r)
{
MIGRAPHX_THROW("ROCBLAS_GEMM: k size of int8 type input must be mutlple of 4!");
assert(get_type(input_shapes[1].type()) == rocblas_datatype_f8_r);
compute_type = rocblas_compute_type_f32;
}
auto num_matrices = std::accumulate(
auto a_lens = input_shapes[0].lens();
auto b_lens = input_shapes[1].lens();
auto out_lens = output_shape.lens();
m = out_lens[dim_0];
n = out_lens[dim_1];
k = input_shapes[0].lens()[dim_1];
a_stride = get_batch_stride(input_shapes[0]);
b_stride = get_batch_stride(input_shapes[1]);
c_stride = get_batch_stride(input_shapes[2]);
d_stride = is_3inputs ? get_batch_stride(input_shapes[3]) : c_stride;
num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
if(num_matrices == 1 or (num_matrices > 1 and get_batch_stride(args[1]) == 0))
strided_batched = num_matrices > 1;
if(strided_batched and b_stride == 0 and input_shapes[0].standard())
{
// If the batch dimension of B is broadcasted, then we can
// multiply m by the batch_size and use rocblas_gemm_ex
// instead of rocblas_gemm_strided_batched_ex.
m *= num_matrices;
strided_batched = false;
}
}
// the rocblas_gemm API handles inputs and output matrices as
// column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm.
rocblas_invoke(&rocblas_gemm_ex,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
alpha_v,
to_pointer(args.at(1)),
arg_type,
ldb,
to_pointer(args.at(0)),
arg_type,
lda,
beta_v,
to_pointer(args[2]),
output_type,
ldc,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type,
ldd,
compute_type,
rocblas_gemm_algo_standard,
0,
flag);
void run(context& ctx, const std::vector<argument>& input_args, int32_t solution_idx = 0) const
{
#ifdef MIGRAPHX_USE_ROCBLAS_FP8_API
if(rocblas_fp8_available() and
std::any_of(input_args.begin(), input_args.end(), [](const auto i) {
return i.get_shape().type() == migraphx::shape::fp8e4m3fnuz_type;
}))
{
if(strided_batched)
{
auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex3,
common_args,
rocblas_gemm_algo_standard,
solution_idx,
gemm_flags);
}
else
{
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex3,
common_args,
rocblas_gemm_algo_standard,
solution_idx,
gemm_flags);
}
}
else
#endif
{
auto a_stride = get_batch_stride(args[0]);
auto b_stride = get_batch_stride(args[1]);
auto c_stride = get_batch_stride(args[2]);
auto d_stride = is_3inputs ? get_batch_stride(args[3]) : c_stride;
rocblas_invoke(&rocblas_gemm_strided_batched_ex,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
alpha_v,
to_pointer(args.at(1)),
arg_type,
ldb,
b_stride,
to_pointer(args.at(0)),
arg_type,
lda,
a_stride,
beta_v,
to_pointer(args[2]),
output_type,
ldc,
c_stride,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type,
ldd,
d_stride,
num_matrices,
compute_type,
rocblas_gemm_algo_standard,
0,
flag);
if(strided_batched)
{
auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args,
rocblas_gemm_algo_solution_index,
solution_idx,
gemm_flags);
}
else
{
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex,
common_args,
rocblas_gemm_algo_solution_index,
solution_idx,
gemm_flags);
}
}
});
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
auto validate(context& ctx, const std::vector<shape>& input_shapes, int32_t solution_idx) const
{
// Create dummy arguments for the shapes, and call the overloaded method
std::vector<argument> input_args;
std::transform(input_shapes.begin(),
input_shapes.end(),
std::back_inserter(input_args),
[](const shape& x) { return to_gpu(generate_argument(x)); });
return validate(ctx, input_args, solution_idx);
}
/**
* Checks a particular solution for validity by running it with the flag
* rocblas_gemm_flags_check_solution_index (could be invalid if this model was
* tuned with a different rocBLAS version)
*
* @return Returns either solution_idx if valid, or else the default value 0
* if not. The default does not mean list index 0, but tells the picker
* to choose a solution.
*/
int32_t
validate(context& ctx, const std::vector<argument>& input_args, int32_t solution_idx) const
{
rocblas_status_ check_valid(rocblas_status_success);
if(strided_batched)
{
auto common_args = create_strided_batched_args_common(ctx, input_args);
check_valid = rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args,
rocblas_gemm_algo_solution_index,
solution_idx,
rocblas_gemm_flags_check_solution_index);
}
else
{
auto common_args = create_gemm_ex_args_common(ctx, input_args);
check_valid = rocblas_invoke(&rocblas_gemm_ex,
common_args,
rocblas_gemm_algo_solution_index,
solution_idx,
rocblas_gemm_flags_check_solution_index);
}
if(check_valid == rocblas_status_invalid_value)
{
std::cerr << "WARNING: tuned solution is invalid; reverting to default" << std::endl;
return 0;
}
return solution_idx;
}
#endif
/**
* Helper method to create that subset of a long rocBLAS argument list that is common
* to multiple "...strided_batched..." calls.
*
* The rocblas_gemm API handles inputs and output matrices as
* column-major format. When doing a C = A * B, we actually do
* C^T = (B^T) * (A^T). That is the reason we input args[1] as
* A and args[0] as B in calling the rocblas_gemm.
*
*/
auto create_strided_batched_args_common(context& ctx, const std::vector<argument>& args) const
{
return pack(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
get_alpha(),
args[1].data(),
arg_type,
ldb,
b_stride,
args[0].data(),
arg_type,
lda,
a_stride,
get_beta(),
args[2].data(),
output_type,
ldc,
c_stride,
is_3inputs ? args[3].data() : args[2].data(),
output_type,
ldd,
d_stride,
num_matrices,
compute_type);
}
/**
* Helper method to create that subset of a long rocBLAS argument list that is common
* to multiple "gemm_ex..." calls.
*
* The rocblas_gemm API handles inputs and output matrices as
* column-major format. When doing a C = A * B, we actually do
* C^T = (B^T) * (A^T). That is the reason we input args[1] as
* A and args[0] as B in calling the rocblas_gemm.
*
* */
auto create_gemm_ex_args_common(context& ctx, const std::vector<argument>& args) const
{
return pack(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
get_alpha(),
args[1].data(),
arg_type,
ldb,
args[0].data(),
arg_type,
lda,
get_beta(),
args[2].data(),
output_type,
ldc,
is_3inputs ? args[3].data() : args[2].data(),
output_type,
ldd,
compute_type);
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
/**
* Find best rocBLAS solution: Get list of solutions and try them all, returning the index
* of the fastest one.
*/
int tune(context& ctx, const std::vector<shape>& input_shapes) const
{
// tuning meta parameters
const int hot_calls = 40;
std::vector<argument> input_args;
std::transform(input_shapes.begin(),
input_shapes.end(),
std::back_inserter(input_args),
[](const shape& x) { return to_gpu(generate_argument(x)); });
// Get the solutions list in 2 rocBLAS steps:
// 1. Find out how many solutions there are and allocate the array
// 2. Get the solutions
//
rocblas_int list_size = 0;
std::vector<rocblas_int> solution_indices;
if(strided_batched)
{
auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_args,
rocblas_gemm_algo_solution_index,
gemm_flags,
nullptr,
&list_size);
solution_indices.resize(list_size);
auto common_sol_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_sol_args,
rocblas_gemm_algo_solution_index,
gemm_flags,
solution_indices.data(),
&list_size);
}
else
{
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex_get_solutions,
common_args,
rocblas_gemm_algo_solution_index,
gemm_flags,
nullptr,
&list_size);
solution_indices.resize(list_size);
auto common_sol_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex_get_solutions,
common_sol_args,
rocblas_gemm_algo_solution_index,
gemm_flags,
solution_indices.data(),
&list_size);
}
double best_time = std::numeric_limits<double>::max();
double first_time = -1;
// Initialize to default solution index
rocblas_int best_sol = 0;
for(auto sol : solution_indices)
{
// Warmup: the first call to an op. may not be representative since there is
// more time taken initializing caches, etc. so we won't time it.
run(ctx, input_args, sol);
double host_time = time<milliseconds>([&] {
for([[maybe_unused]] int hc : range(hot_calls))
run(ctx, input_args, sol);
ctx.finish();
});
host_time /= hot_calls;
// dev/evaluation only: track time for first solution.
if(first_time < 0)
first_time = host_time;
// track current best
if(host_time < best_time)
{
best_sol = sol;
best_time = host_time;
}
}
std::cout << "Winning GEMM solution: " << best_sol << " in " << best_time << " ms, beats "
<< first_time << "ms" << std::endl;
return best_sol;
}
#endif
private:
size_t num_matrices = 0;
rocblas_int m = 0;
rocblas_int n = 0;
rocblas_int k = 0;
bool transa = false;
bool transb = false;
T alpha = 0;
T beta = 0;
std::function<const void*()> get_alpha{};
std::function<const void*()> get_beta{};
rocblas_gemm_flags gemm_flags = rocblas_gemm_flags_none;
rocblas_int lda = 0;
rocblas_int ldb = 0;
rocblas_int ldc = 0;
rocblas_int ldd = 0;
rocblas_int a_stride = 0;
rocblas_int b_stride = 0;
rocblas_int c_stride = 0;
rocblas_int d_stride = 0;
rocblas_datatype arg_type = rocblas_datatype_f32_r;
rb_compute_type compute_type = rocblas_datatype_f32_r;
rocblas_datatype output_type = rocblas_datatype_f32_r;
bool strided_batched = true;
bool is_3inputs = true;
bool compute_fp32 = true;
}; // gemm_impl
void gemm_compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
float alpha,
float beta,
bool compute_fp32,
int32_t solution_idx)
{
std::vector<shape> input_shapes;
std::transform(args.begin(),
args.end(),
std::back_inserter(input_shapes),
[](const argument& x) { return x.get_shape(); });
auto gemm_item = gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
gemm_item.run(ctx, args, solution_idx);
}
void gemm(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
float alpha,
float beta,
bool int8_x4_format,
bool compute_fp32)
void gemm_compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
int32_t alpha,
int32_t beta,
bool compute_fp32,
int32_t solution_idx)
{
gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
std::vector<shape> input_shapes;
std::transform(args.begin(),
args.end(),
std::back_inserter(input_shapes),
[](const argument& x) { return x.get_shape(); });
auto gemm_item = gemm_impl<int32_t>(output_shape, input_shapes, alpha, beta, compute_fp32);
gemm_item.run(ctx, args, solution_idx);
}
void gemm(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
int32_t alpha,
int32_t beta,
bool int8_x4_format,
bool compute_fp32)
/**
* Decides if the tune() or validate() method is appropriate and calls it.
* Return value is the chosen solution index, or 0 to let picker choose it.
*/
int32_t gemm_finalize(context& ctx,
const shape& output_shape,
const std::vector<shape>& input_shapes,
float alpha,
float beta,
bool compute_fp32,
int32_t solution_idx)
{
gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
// This code should be called only if either the environment var.
// MIGRAPHX_ENABLE_GEMM_TUNING, or option --exhaustive-tune, is set
if(solution_idx == 0)
{
auto gemm_item = gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.tune(ctx, input_shapes);
}
else
{
// If a tuned solution index is already given, don't tune again but validate
// in case the data was tuned with a different rocBLAS version
auto gemm_item = gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.validate(ctx, input_shapes, solution_idx);
}
#else
(void)ctx, (void)output_shape, (void)input_shapes;
(void)alpha, (void)beta, (void)compute_fp32;
#endif
return solution_idx;
}
/**
* Decides if the tune() or validate() method is appropriate and calls it.
* Return value is the chosen solution index, or 0 to let picker choose it.
*/
int32_t gemm_finalize(context& ctx,
const shape& output_shape,
const std::vector<shape>& input_shapes,
int32_t alpha,
int32_t beta,
bool compute_fp32,
int32_t solution_idx)
{
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
if(solution_idx == 0)
{
auto gemm_item = gemm_impl<int32_t>(output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.tune(ctx, input_shapes);
}
else
{
// If a tuned solution index is already given, don't tune again but validate
// in case the data was tuned with a different rocBLAS version
auto gemm_item = gemm_impl<int32_t>(output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.validate(ctx, input_shapes, solution_idx);
}
#else
(void)ctx, (void)output_shape, (void)input_shapes;
(void)alpha, (void)beta, (void)compute_fp32;
#endif
return solution_idx;
}
} // namespace gpu
......
......@@ -58,10 +58,10 @@ struct hiprtc_src_file
MIGRAPHX_GPU_EXPORT bool hip_has_flags(const std::vector<std::string>& flags);
MIGRAPHX_GPU_EXPORT std::vector<std::vector<char>> compile_hip_src_with_hiprtc(
std::vector<hiprtc_src_file> srcs, std::string params, const std::string& arch);
std::vector<hiprtc_src_file> srcs, const std::string& params, const std::string& arch);
MIGRAPHX_GPU_EXPORT std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch);
MIGRAPHX_GPU_EXPORT std::vector<std::vector<char>> compile_hip_src(
const std::vector<src_file>& srcs, const std::string& params, const std::string& arch);
MIGRAPHX_GPU_EXPORT std::string enum_params(std::size_t count, std::string param);
......
......@@ -42,7 +42,7 @@ struct compile_miopen
context* ctx = nullptr;
std::string name() const { return "gpu::compile_miopen"; }
void apply(module& m) const;
std::size_t compile(operation& op, instruction_ref ins, bool format) const;
std::size_t compile(operation& op, instruction_ref ins) const;
};
} // namespace gpu
......
......@@ -57,7 +57,6 @@ template <class Op>
struct miopen_convolution
{
Op op;
bool int8_x4_format = false;
shared<convolution_descriptor> cd = nullptr;
miopenConvFwdAlgorithm_t algo{};
#ifdef MIGRAPHX_HAS_FIND_2_API
......@@ -74,7 +73,6 @@ struct miopen_convolution
f(self.solution_object, "solution_object"),
#endif
f(self.algo, "algo"),
f(self.int8_x4_format, "int8_x4_format"),
f(self.solution_id, "solution_id"));
}
......@@ -94,9 +92,9 @@ struct miopen_convolution
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{
auto x_desc = make_tensor(reshape_if_1d(args[0].get_shape()), int8_x4_format);
auto w_desc = make_tensor(reshape_if_1d(args[1].get_shape()), int8_x4_format);
auto y_desc = make_tensor(reshape_if_1d(output_shape));
auto x_desc = make_tensor(reshape_if_1d(args[0].get_shape()));
auto w_desc = make_tensor(reshape_if_1d(args[1].get_shape()));
auto y_desc = make_tensor(reshape_if_1d(output_shape));
auto* miopen_stream_handle = ctx.get_stream().get_miopen();
auto workspace_size = args[2].get_shape().bytes();
......@@ -162,8 +160,8 @@ struct miopen_convolution
shape find(context& ctx, const shape& output_shape, const std::vector<shape>& inputs)
{
shape workspace_shape{};
auto x_desc = make_tensor(reshape_if_1d(inputs[0]), int8_x4_format);
auto w_desc = make_tensor(reshape_if_1d(inputs[1]), int8_x4_format);
auto x_desc = make_tensor(reshape_if_1d(inputs[0]));
auto w_desc = make_tensor(reshape_if_1d(inputs[1]));
auto y_desc = make_tensor(reshape_if_1d(output_shape));
auto* miopen_stream_handle = ctx.get_stream().get_miopen();
......@@ -179,13 +177,8 @@ struct miopen_convolution
workspace_shape = shape{shape::int8_type, {workspace_size}};
auto x_shape = inputs[0];
auto w_shape = inputs[1];
if(int8_x4_format)
{
x_shape = pack_int8_shape(x_shape);
w_shape = pack_int8_shape(w_shape);
}
const auto& x_shape = inputs[0];
const auto& w_shape = inputs[1];
#ifdef MIGRAPHX_HAS_FIND_2_API
{
......@@ -327,8 +320,8 @@ struct miopen_convolution
": workspace has changed during finalization.");
}
auto x_desc = make_tensor(reshape_if_1d(inputs[0]), int8_x4_format);
auto w_desc = make_tensor(reshape_if_1d(inputs[1]), int8_x4_format);
auto x_desc = make_tensor(reshape_if_1d(inputs[0]));
auto w_desc = make_tensor(reshape_if_1d(inputs[1]));
auto y_desc = make_tensor(reshape_if_1d(output_shape));
auto status = miopenConvolutionForwardCompileSolution(ctx.get_stream().get_miopen(),
......@@ -347,21 +340,6 @@ struct miopen_convolution
{
return shapes.size() - 1;
}
inline shape pack_int8_shape(const shape& s) const
{
if(s.type() != shape::int8_type)
{
return s;
}
auto lens = s.lens();
auto strides = s.strides();
lens[1] = (lens[1] + 3) / 4 * 4;
strides[0] = strides[1] * lens[1];
return {s.type(), lens, strides};
}
};
} // namespace gpu
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_INT8_GEMM_PACK_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_INT8_GEMM_PACK_HPP
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void MIGRAPHX_DEVICE_EXPORT int8_gemm_pack_a(hipStream_t stream,
const argument& result,
const argument& arg);
void MIGRAPHX_DEVICE_EXPORT int8_gemm_pack_b(hipStream_t stream,
const argument& result,
const argument& arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -34,10 +34,11 @@ struct module_pass_manager;
namespace gpu {
MIGRAPHX_GPU_EXPORT bool mlir_enabled();
MIGRAPHX_GPU_EXPORT bool mlir_attention_enabled();
struct MIGRAPHX_GPU_EXPORT fuse_mlir
{
context* ctx = nullptr;
context* ctx = nullptr;
bool enable_extra = false;
std::string name() const { return "gpu::fuse_mlir"; }
void apply(module_pass_manager& mpm) const;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -40,9 +40,8 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
void blas_shape(const shape& s);
shape transpose_batch(const shape& s, unsigned trans_batch);
void blas_shape(const shape& s);
template <class Op>
struct rocblas_gemm
......@@ -50,9 +49,9 @@ struct rocblas_gemm
Op op;
float alpha = 1;
float beta = 0;
bool int8_x4_format = true;
bool compute_fp32 = false;
unsigned trans_batch = 0;
int32_t solution_idx = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -60,9 +59,9 @@ struct rocblas_gemm
return pack_join(migraphx::reflect(self.op, f),
pack(f(self.alpha, "alpha"),
f(self.beta, "beta"),
f(self.int8_x4_format, "int8_x4_format"),
f(self.compute_fp32, "compute_fp32"),
f(self.trans_batch, "trans_batch")));
f(self.trans_batch, "trans_batch"),
f(self.solution_idx, "solution_idx")));
}
std::string name() const
......@@ -78,6 +77,8 @@ struct rocblas_gemm
{
std::vector<shape> in_shapes(inputs);
in_shapes.pop_back();
// When input shapes are A, B, C the GEMM equation is C  =  α AB+ β C where α, β are
// scalars
check_shapes{in_shapes, *this}.has(2, 3);
blas_shape(inputs[0]);
blas_shape(inputs[1]);
......@@ -113,17 +114,12 @@ struct rocblas_gemm
{
if(this->name() == "gpu::gemm")
{
gemm(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
gemm_compute(ctx, output_shape, args, alpha, beta, compute_fp32, solution_idx);
}
else
{
gemm(ctx,
output_shape,
args,
int32_t(alpha),
int32_t(beta),
int8_x4_format,
compute_fp32);
gemm_compute(
ctx, output_shape, args, int32_t(alpha), int32_t(beta), compute_fp32, solution_idx);
}
return args.back();
}
......@@ -132,6 +128,33 @@ struct rocblas_gemm
{
return shapes.size() - 1;
}
void finalize(context& ctx, const shape& output_shape, const std::vector<shape>& input_shapes)
{
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
if(enabled(MIGRAPHX_ENABLE_GEMM_TUNING{}) or ctx.get_exhaustive_tune_flag())
{
if(this->name() == "gpu::gemm")
{
solution_idx = gemm_finalize(
ctx, output_shape, input_shapes, alpha, beta, compute_fp32, solution_idx);
}
else
{
solution_idx = gemm_finalize(ctx,
output_shape,
input_shapes,
int32_t(alpha),
int32_t(beta),
compute_fp32,
solution_idx);
}
}
#else
// suppress compiler warnings
(void)ctx, (void)output_shape, (void)input_shapes;
#endif
}
};
} // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment