Commit 8d32c6b8 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into blas_tuning

parents 23cb7917 f25606f9
......@@ -68,37 +68,6 @@ has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes)
dds_it->max};
}
namespace {
struct find_static_2in_broadcasts
{
// Convert 2 input static shape broadcast/multibroadcast into 1 input version.
// Some compiler passes (ex. simplify_algebra) only support the 1 input versions
// of the broadcasting operators.
auto matcher() const
{
return match::broadcast(match::nargs(2),
match::arg(0)(match::static_shape()),
match::arg(1)(match::static_shape()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto out_lens = ins->get_shape().lens();
auto broadcast_op = ins->get_operator();
if(broadcast_op.name() == "broadcast")
{
broadcast_op.from_value({{"out_lens", out_lens}});
}
else
{
broadcast_op.from_value({{"out_lens", out_lens}, {"out_dyn_dims", {}}});
}
m.replace_instruction(ins, broadcast_op, ins->inputs().at(0));
}
};
} // namespace
/**
* Makes all the shapes in the dynamic_dimension range. Probably won't work for `if`
* and `loop` instructions, depending on how the submodules for those
......@@ -135,7 +104,6 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
dd_check->dyn_param_str, migraphx::shape{dyn_param_shape.type(), static_lens});
auto outputs = submod->add_instructions(mm, map_ins);
submod->add_return({outputs});
match::find_matches(*submod, find_static_2in_broadcasts{});
submodules.push_back(submod);
}
// redirect to select_module operator and return
......
......@@ -91,6 +91,19 @@ struct post_op : reflect_equality<post_op>, reflect_stream<post_op>
}
};
template <class F>
struct execute_wrapper
{
F f;
argument operator()(context&, const std::vector<argument>& args) const { return f(args); }
};
template <class F>
execute_wrapper<F> make_execute_wrapper(F f)
{
return {std::move(f)};
}
template <class Derived, class Primitive>
struct dnnl_op : auto_register_op<Derived>
{
......@@ -308,7 +321,7 @@ struct dnnl_op : auto_register_op<Derived>
#ifndef NDEBUG
auto prim_attr = get_primitive_attr(md);
#endif
execute = [=](context&, const std::vector<argument>& args) {
execute = make_execute_wrapper([=](const std::vector<argument>& args) {
#ifndef NDEBUG
// Check that the memory descriptors have not changed
auto debug_args = args;
......@@ -379,7 +392,7 @@ struct dnnl_op : auto_register_op<Derived>
m[arg_lookup[i]] = to_dnnl_memory(md.at(arg_lookup[i]), args[i]);
prim.execute(get_dnnl_context().stream, m);
return args.back();
};
});
}
std::vector<shape> trim_post_op_inputs(const std::vector<shape>& inputs) const
{
......
......@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP
#define MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP
#include <migraphx/config.hpp>
#include <migraphx/cpu/context.hpp>
#include <string>
namespace migraphx {
......@@ -34,9 +34,7 @@ struct module;
namespace cpu {
struct context;
struct fuse_ops
struct MIGRAPHX_CPU_EXPORT fuse_ops
{
context* ctx = nullptr;
std::string name() const { return "cpu::fuse_ops"; }
......
......@@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_POINTWISE_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_POINTWISE_HPP
#include <array>
#include <migraphx/config.hpp>
#include <migraphx/context.hpp>
#include <migraphx/check_shapes.hpp>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......
......@@ -23,6 +23,10 @@
# ####################################################################################
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
find_package(hip)
if(NOT GPU_TARGETS)
message(FATAL_ERROR "HIP package is broken and has no GPU_TARGETS, please pass -DGPU_TARGETS=$(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') to cmake to build for your gpu.")
endif()
find_package(miopen)
# rocblas
......@@ -44,12 +48,21 @@ else()
set(MIGRAPHX_USE_HIPRTC ON CACHE BOOL "Use hipRTC APIs")
endif()
include(Embed)
file(GLOB KERNEL_FILES CONFIGURE_DEPENDS
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp)
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
if(WIN32)
# TODO: re-enable when CK is ported to Windows
list(REMOVE_ITEM KERNEL_FILES
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck_gemm.hpp
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck.hpp)
endif()
include(Embed)
add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/)
configure_file(device/targets.hpp.in include/migraphx/gpu/device/targets.hpp)
file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp)
add_library(migraphx_device ${DEVICE_GPU_SRCS})
......@@ -69,6 +82,7 @@ rocm_clang_tidy_check(migraphx_device)
target_link_libraries(migraphx_device PUBLIC migraphx)
target_link_libraries(migraphx_device PRIVATE compile_for_gpu)
target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_BINAR_DIR}/include>)
target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>)
target_compile_options(migraphx_device PRIVATE -Wno-ignored-attributes)
migraphx_generate_export_header(migraphx_device DIRECTORY migraphx/gpu/device)
......@@ -123,6 +137,7 @@ add_library(migraphx_gpu
lrn.cpp
mlir.cpp
multinomial.cpp
no_device.cpp
nonzero.cpp
pack_args.cpp
pack_int8_args.cpp
......@@ -191,7 +206,7 @@ register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp
rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_gpu)
set(MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL "")
set(MIGRAPHX_ENABLE_MLIR ON CACHE BOOL "")
if(MIGRAPHX_ENABLE_MLIR)
# Find package rocMLIR
......
......@@ -28,6 +28,7 @@
#include <migraphx/env.hpp>
#include <cassert>
#include <iostream>
#include <deque>
#ifdef MIGRAPHX_USE_HIPRTC
#include <hip/hiprtc.h>
......@@ -92,7 +93,7 @@ struct hiprtc_program
{
struct string_array
{
std::vector<std::string> strings{};
std::deque<std::string> strings{};
std::vector<const char*> c_strs{};
string_array() {}
......@@ -115,6 +116,12 @@ struct hiprtc_program
std::string cpp_src = "";
std::string cpp_name = "";
hiprtc_program(const std::string& src, const std::string& name = "main.cpp")
: cpp_src(src), cpp_name(name)
{
create_program();
}
hiprtc_program(std::vector<hiprtc_src_file> srcs)
{
for(auto&& src : srcs)
......@@ -130,6 +137,14 @@ struct hiprtc_program
include_names.push_back(std::move(src.path));
}
}
create_program();
}
void create_program()
{
assert(not cpp_src.empty());
assert(not cpp_name.empty());
assert(headers.size() == include_names.size());
prog = hiprtc_program_create(cpp_src.c_str(),
cpp_name.c_str(),
headers.size(),
......@@ -137,7 +152,7 @@ struct hiprtc_program
include_names.data());
}
void compile(const std::vector<std::string>& options) const
void compile(const std::vector<std::string>& options, bool quiet = false) const
{
if(enabled(MIGRAPHX_TRACE_HIPRTC{}))
std::cout << "hiprtc " << join_strings(options, " ") << " " << cpp_name << std::endl;
......@@ -148,7 +163,7 @@ struct hiprtc_program
[](const std::string& s) { return s.c_str(); });
auto result = hiprtcCompileProgram(prog.get(), c_options.size(), c_options.data());
auto prog_log = log();
if(not prog_log.empty())
if(not prog_log.empty() and not quiet)
{
std::cerr << prog_log << std::endl;
}
......@@ -195,7 +210,6 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
options.push_back("-Wno-gnu-line-marker");
options.push_back("-Wno-old-style-cast");
}
if(enabled(MIGRAPHX_GPU_DEBUG{}))
options.push_back("-DMIGRAPHX_DEBUG");
if(std::none_of(options.begin(), options.end(), [](const std::string& s) {
......@@ -210,6 +224,20 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
return {prog.get_code_obj()};
}
bool hip_has_flags(const std::vector<std::string>& flags)
{
hiprtc_program prog{" "};
try
{
prog.compile(flags, true);
return true;
}
catch(...)
{
return false;
}
}
std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch)
{
......@@ -220,7 +248,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
{
if(src.path.extension() != ".cpp")
continue;
std::cout << std::string(src.content.first, src.len()) << std::endl;
std::cout << std::string(src.content) << std::endl;
}
}
auto p = dynamic_loader::path(&compile_hip_src_with_hiprtc);
......@@ -310,7 +338,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
{
if(src.path.extension() != ".cpp")
continue;
std::cout << std::string(src.content.first, src.len()) << std::endl;
std::cout << std::string(src.content) << std::endl;
}
}
......@@ -323,6 +351,27 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
return {compiler.compile(srcs)};
}
bool hip_has_flags(const std::vector<std::string>& flags)
{
src_compiler compiler;
compiler.compiler = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER);
compiler.flags =
join_strings(flags, " ") + " -x hip -c --offload-arch=gfx900 --cuda-device-only";
std::string src;
src_file input{"main.cpp", src};
try
{
compiler.compile({input});
return true;
}
catch(...)
{
return false;
}
}
#endif // MIGRAPHX_USE_HIPRTC
std::string enum_params(std::size_t count, std::string param)
......
......@@ -91,28 +91,39 @@ __content__
return replace_string(args_hpp, "__content__", inner);
}
static std::vector<std::string> get_compiler_warnings()
{
std::vector<std::string> warnings = {
"-Weverything",
"-Wno-c++98-compat",
"-Wno-c++98-compat-pedantic",
"-Wno-conversion",
"-Wno-double-promotion",
"-Wno-exit-time-destructors",
"-Wno-extra-semi",
"-Wno-extra-semi-stmt",
"-Wno-float-conversion",
"-Wno-gnu-anonymous-struct",
"-Wno-gnu-zero-variadic-macro-arguments",
"-Wno-missing-prototypes",
"-Wno-nested-anon-types",
"-Wno-padded",
"-Wno-shorten-64-to-32",
"-Wno-sign-conversion",
"-Wno-sign-compare",
"-Wno-unused-command-line-argument",
"-Wno-weak-vtables",
"-Wno-c99-extensions",
};
if(hip_has_flags({"-Werror", "-Wunsafe-buffer-usage"}))
warnings.push_back("-Wno-unsafe-buffer-usage");
return warnings;
}
const std::vector<std::string>& compiler_warnings()
{
static std::vector<std::string> warnings = {"-Weverything",
"-Wno-c++98-compat",
"-Wno-c++98-compat-pedantic",
"-Wno-conversion",
"-Wno-double-promotion",
"-Wno-exit-time-destructors",
"-Wno-extra-semi",
"-Wno-extra-semi-stmt",
"-Wno-float-conversion",
"-Wno-gnu-anonymous-struct",
"-Wno-gnu-zero-variadic-macro-arguments",
"-Wno-missing-prototypes",
"-Wno-nested-anon-types",
"-Wno-padded",
"-Wno-shorten-64-to-32",
"-Wno-sign-conversion",
"-Wno-sign-compare",
"-Wno-unused-command-line-argument",
"-Wno-weak-vtables",
"-Wno-c99-extensions"};
static std::vector<std::string> warnings = get_compiler_warnings();
return warnings;
}
......@@ -161,21 +172,17 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
assert(options.inputs.size() == options.virtual_inputs.size() or
options.virtual_inputs.empty());
std::vector<src_file> srcs = options.additional_src_files;
std::transform(migraphx_kernels().begin(),
migraphx_kernels().end(),
std::back_inserter(srcs),
[](auto&& p) {
auto&& name = p.first;
auto&& c = p.second;
auto path = name;
return src_file{path, c};
});
srcs.push_back(src_file{fs::path{"main.cpp"},
std::make_pair(content.data(), content.data() + content.size())});
static auto kernels{::migraphx_kernels()};
std::transform(
kernels.begin(),
kernels.end(),
std::back_inserter(srcs),
[](const std::pair<std::string_view, std::string_view>& elem) { return src_file{elem}; });
srcs.emplace_back("main.cpp", content);
auto args_hpp =
generate_args_hpp(options.virtual_inputs.empty() ? options.inputs : options.virtual_inputs);
srcs.push_back(src_file{fs::path{"args.hpp"},
std::make_pair(args_hpp.data(), args_hpp.data() + args_hpp.size())});
srcs.emplace_back("args.hpp", args_hpp);
options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global);
options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local);
options.params += " " + join_strings(compiler_warnings(), " ");
......
......@@ -185,8 +185,7 @@ struct compile_plan
results.begin(), results.end(), std::back_inserter(times), [&](const auto& cr) {
if(not cr.has_value())
return std::numeric_limits<double>::max();
return time_op(*ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20)
.first;
return time_op(*ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20);
});
auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end()));
std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl;
......
......@@ -26,7 +26,9 @@
#include <hip/hip_runtime.h>
#include <migraphx/config.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/targets.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -79,13 +81,28 @@ inline auto launch(hipStream_t stream, index_int global, index_int local)
using f_type = decltype(f);
dim3 nblocks(global / local);
dim3 nthreads(local);
/*
hipGetLastError() returns error for the first failed HIP call that happened previously.
MIGraphX calls into various backend libraries and failed HIP calls can also happen there.
Calling hipGetLastError() would reset error code to hipSuccess, so that inside MIGraphX
failed call to hipLaunchKernelGGL() can be captured.
*/
hipError_t flush_call = hipGetLastError();
(void)(flush_call);
// cppcheck-suppress UseDeviceLaunch
hipLaunchKernelGGL((launcher<f_type>), nblocks, nthreads, 0, stream, f);
hipError_t kernel_launch_status = hipGetLastError();
if(kernel_launch_status != hipSuccess)
{
MIGRAPHX_THROW("MIGraphX device kernel failed to launch with error: " +
std::string(hipGetErrorString(kernel_launch_status)));
std::string message = hipGetErrorString(kernel_launch_status);
if(not contains(get_targets(), get_device_name()))
{
message += ". Trying to run a kernel for " + get_device_name() +
" but MIGraphX was built for targets " + get_targets_as_string() +
". Please rebuild MIGraphX with -DGPU_TARGETS='" + get_device_name() +
"'.";
}
MIGRAPHX_THROW("MIGraphX device kernel failed to launch with error: " + message);
}
};
}
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/device/targets.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/errors.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
static std::vector<std::string> parse_targets() { return split_string(MIGRAPHX_GPU_TARGETS, ';'); }
const std::vector<std::string>& get_targets()
{
static auto result = parse_targets();
return result;
}
std::string get_targets_as_string() { return join_strings(get_targets(), ", "); }
static int get_device_id()
{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
MIGRAPHX_THROW("No device");
return device;
}
std::string get_device_name()
{
hipDeviceProp_t props{};
auto status = hipGetDeviceProperties(&props, get_device_id());
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to get device properties");
return props.gcnArchName;
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
#define MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
#include <migraphx/gpu/device/config.hpp>
#include <string>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
#define MIGRAPHX_GPU_TARGETS "@GPU_TARGETS@" // NOLINT
MIGRAPHX_DEVICE_EXPORT
const std::vector<std::string>& get_targets();
MIGRAPHX_DEVICE_EXPORT
std::string get_targets_as_string();
MIGRAPHX_DEVICE_EXPORT
std::string get_device_name();
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
......@@ -38,10 +38,8 @@ struct compile_op : action<compile_op>
context ctx;
auto inputs = p.parse_shapes(v.at("inputs"));
auto op = gpu::compile_op(v.at("name").to<std::string>(), ctx, inputs, v);
auto [host_time, device_time] = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << host_time << "ms";
if(device_time > 0)
std::cout << ", " << device_time << "ms";
auto t = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << t << "ms";
std::cout << std::endl;
}
};
......
......@@ -43,8 +43,8 @@ struct run_op : action<run_op>
auto op = make_op(name);
if(v.contains("fields"))
op.from_value(v.at("fields"));
auto [host_time, device_time] = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << host_time << "ms" << std::endl;
auto t = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << t << "ms" << std::endl;
}
};
......
......@@ -22,9 +22,9 @@
* THE SOFTWARE.
*/
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/gpu/gemm_softmax_gemm.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx {
......@@ -55,7 +55,7 @@ struct ck_gemm
{
check_shapes{inputs, *this}.same_ndims();
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");
MIGRAPHX_THROW(name() + ": should have at least two inputs.");
auto a = inputs[0];
auto b = inputs[1];
for(const auto& input : inputs)
......@@ -65,21 +65,27 @@ struct ck_gemm
return r;
return r.with_type(mods.front()->get_output_shapes().front().type());
}
static bool is_ck_supported_type(shape::type_t t)
{
return contains({shape::half_type, shape::int8_type, shape::int32_type}, t);
}
};
MIGRAPHX_REGISTER_OP(ck_gemm);
namespace {
bool is_ck_supported_type(shape::type_t t)
struct ck_gemm_softmax_gemm : gemm_softmax_gemm
{
return contains({shape::half_type, shape::int8_type, shape::int32_type}, t);
}
std::string name() const { return "gpu::ck_gemm_softmax_gemm"; }
};
MIGRAPHX_REGISTER_OP(ck_gemm_softmax_gemm);
namespace {
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
{
if(ins->name() != "dot" and ins->name() != "quant_dot")
return false;
if(not is_ck_supported_type(ins->get_shape().type()))
if(not ck_gemm::is_ck_supported_type(ins->get_shape().type()))
return false;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
......@@ -127,7 +133,11 @@ struct find_ck_gemm_pointwise
ins->get_shape().type() != gemm_ins->get_shape().type())
return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not is_ck_supported_type(input->get_shape().type());
return not ck_gemm::is_ck_supported_type(input->get_shape().type());
}))
return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not input->inputs().empty() and input->inputs().front()->name() == "capture";
}))
return;
assert(gemm_it != inputs.end());
......@@ -152,7 +162,7 @@ struct find_ck_gemm_pointwise
struct find_ck_gemm
{
auto matcher() const { return match::name("dot")(is_ck_gemm().bind("gemm")); }
auto matcher() const { return match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm")); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
......@@ -161,11 +171,26 @@ struct find_ck_gemm
}
};
struct find_ck_gemm_softmax_gemm
{
auto matcher() const { return match::name("gpu::pre_gemm_softmax_gemm"); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto v = ins->get_operator().to_value();
assert(v.contains("scale"));
auto scale = v.at("scale").to<float>();
mpm.get_module().replace_instruction(
ins, ck_gemm_softmax_gemm{migraphx::make_op("dot"), scale}, ins->inputs());
}
};
} // namespace
void fuse_ck::apply(module_pass_manager& mpm) const
{
match::find_matches(mpm, find_ck_gemm_pointwise{});
match::find_matches(mpm, find_ck_gemm_softmax_gemm{}, find_ck_gemm_pointwise{});
match::find_matches(mpm, find_ck_gemm{});
}
......
......@@ -36,24 +36,14 @@ struct module;
namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_EXTRA_MLIR);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR);
bool mlir_enabled()
{
#ifdef MIGRAPHX_MLIR
const bool mlir_enabled = enabled(MIGRAPHX_ENABLE_MLIR{});
if(mlir_enabled)
{
return true;
}
else
{
std::cerr << "WARNING: MIGraphX built with MLIR but it is not enabled. Please set the env "
"var MIGRAPHX_ENABLE_MLIR to use MLIR kernel generator."
<< std::endl;
return false;
}
const bool mlir_disabled = enabled(MIGRAPHX_DISABLE_MLIR{});
return not mlir_disabled;
#else
return false;
#endif
......@@ -103,7 +93,10 @@ struct mlir_op
}
if(ins->name() == "@return")
{
return ins_shapes[ins->inputs().at(0)].with_type(type);
auto s = ins_shapes[ins->inputs().at(0)].with_type(type);
if(not s.standard())
MIGRAPHX_THROW("MLIR doesnt support non-standard output");
return s;
}
std::vector<shape> input_shapes;
input_shapes.resize(ins->inputs().size());
......@@ -119,28 +112,107 @@ struct mlir_op
MIGRAPHX_REGISTER_OP(mlir_op);
namespace {
std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
{
std::vector<instruction_ref> top_inputs;
std::vector<instruction_ref> imm_inputs;
size_t input_cnt = 0;
for(instruction_ref input : gemm_based_op->inputs())
{
std::vector<operation> op_stream;
while(contains(
{"slice", "transpose", "contiguous", "reshape", "squeeze", "flatten", "unsqueeze"},
input->name()))
{
operation op = input->get_operator();
if(contains({"squeeze", "flatten", "unsqueeze"}, input->name()))
{
op = migraphx::make_op("reshape", {{"dims", input->get_shape().lens()}});
}
op_stream.push_back(op);
input = input->inputs().at(0);
}
top_inputs.push_back(input);
instruction_ref prev_input =
mm->add_parameter("y" + std::to_string(input_cnt++), input->get_shape());
for(const auto& op : reverse(op_stream))
{
prev_input = mm->add_instruction(op, {prev_input});
}
imm_inputs.push_back(prev_input);
}
instruction_ref new_gemm_based_op =
mm->add_instruction(gemm_based_op->get_operator(), imm_inputs);
return {new_gemm_based_op, top_inputs};
}
MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
enum class mlir_mode
{
if(ins->name() != "convolution" and ins->name() != "quant_convolution")
return false;
value v = ins->get_operator().to_value();
auto group = v.at("group").to<int>();
if(group != 1)
return false;
// Avoid MLIR assertion: Index < Length && "Invalid index!"
if(ins->get_shape().lens().size() != 4)
return false;
return true;
all,
fast,
int8,
none
};
auto is_mlir_dot(mlir_mode mode)
{
return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(mode == mlir_mode::none)
return false;
if(ins->name() != "dot" and ins->name() != "quant_dot")
return false;
if(mode != mlir_mode::fast)
return true;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
// auto m = a.lens()[a.lens().size() - 2];
// auto n = b.lens().back();
auto k = a.lens().back();
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from MLIR
// To-do: Investigate a more precise strategy
return k <= 2048;
});
}
struct find_mlir_op
auto is_mlir_conv(mlir_mode mode)
{
return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(mode == mlir_mode::none)
return false;
if(ins->name() != "convolution" and ins->name() != "quant_convolution")
return false;
value v = ins->get_operator().to_value();
auto group = v.at("group").to<int>();
if(group != 1)
return false;
// Avoid MLIR assertion: Index < Length && "Invalid index!"
if(ins->get_shape().lens().size() != 4)
return false;
if(ins->get_shape().type() == shape::int8_type)
return true;
if(mode == mlir_mode::int8)
return false;
if(mode == mlir_mode::all)
return true;
auto w = ins->inputs().at(1)->get_shape();
if(w.lens().size() != 4)
return true;
if(w.lens()[2] != w.lens()[3])
return true;
return (w.lens()[3] % 3) != 0;
});
}
struct find_mlir_fused_ops
{
mlir_mode conv_mode = mlir_mode::none;
mlir_mode dot_mode = mlir_mode::none;
auto matcher() const
{
auto dot_or_conv = match::skip(match::name("contiguous"))(
match::any_of(match::name("dot"), match::name("quant_dot"), is_mlir_conv())
.bind("gemm_based_op"));
match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)).bind("gemm_based_op"));
return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x")));
}
......@@ -163,34 +235,6 @@ struct find_mlir_op
return ins_map;
}
std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op) const
{
std::vector<instruction_ref> top_inputs;
std::vector<instruction_ref> imm_inputs;
size_t input_cnt = 0;
for(instruction_ref input : gemm_based_op->inputs())
{
std::vector<operation> op_stream;
while(contains({"slice", "transpose", "contiguous", "reshape"}, input->name()))
{
op_stream.push_back(input->get_operator());
input = input->inputs().at(0);
}
top_inputs.push_back(input);
instruction_ref prev_input =
mm->add_parameter("y" + std::to_string(input_cnt++), input->get_shape());
for(const auto& op : reverse(op_stream))
{
prev_input = mm->add_instruction(op, {prev_input});
}
imm_inputs.push_back(prev_input);
}
instruction_ref new_gemm_based_op =
mm->add_instruction(gemm_based_op->get_operator(), imm_inputs);
return {new_gemm_based_op, top_inputs};
}
// Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types.
......@@ -236,8 +280,7 @@ struct find_mlir_op
"log",
"recip",
"rsqrt",
// There are bugs in MLIR right now for models using sigmoid so disable it for now
// "sigmoid",
"sigmoid",
"softmax",
"tanh",
};
......@@ -282,9 +325,9 @@ struct find_mlir_op
names.end(),
ins->inputs().begin(),
std::inserter(param_map, param_map.end()),
[&, &anchor_op = anchor_op](auto name, auto input) {
[&, &anchor = anchor_op](auto name, auto input) {
if(input == x_ins)
return std::make_pair(pm->get_parameter(name), anchor_op);
return std::make_pair(pm->get_parameter(name), anchor);
return std::make_pair(pm->get_parameter(name),
mm->add_parameter(name, input->get_shape()));
});
......@@ -301,20 +344,90 @@ struct find_mlir_op
}
};
template <auto Matcher>
struct find_mlir_standalone_op
{
mlir_mode mode = mlir_mode::none;
auto matcher() const { return Matcher(mode); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto conv_based_op = r.result;
// enable only for fp32/fp16/i8 types
if(std::any_of(conv_based_op->inputs().begin(), conv_based_op->inputs().end(), [&](auto i) {
return not contains(
{shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type},
i->get_shape().type());
}))
return;
static size_t counter = 0;
module_ref mm = mpm.create_module("mlir_" + std::to_string(counter++));
mm->set_bypass();
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, conv_based_op);
mm->add_return({anchor_op});
mpm.get_module().replace_instruction(
conv_based_op, mlir_op{conv_based_op->get_operator()}, top_inputs, {mm});
}
};
using find_mlir_standalone_convolution_op = find_mlir_standalone_op<&is_mlir_conv>;
using find_mlir_standalone_dot_op = find_mlir_standalone_op<&is_mlir_dot>;
/**
* @brief Declares a new MIGraphX environment variable which forces to generate
* only specific MLIR operations.
*
* The variable, if defined, forces MIGraphX to use only specific operations
* with MLIR regardless of the underlying GPU architecture. The variable accepts
* a list of operations separated by comma. The variable recognizes the following
* operations: "fused", "convolution", "dot". If the variable is not defined MIGraphX
* will decide by itself which operations to delegate to MLIR. The variable is
* intended to be primarily used by rocMLIR developers.
*/
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS);
bool is_requested(std::string_view option, bool fallback = false)
{
auto string_value = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "");
if(string_value.empty())
return fallback;
const auto options = split_string(string_value, ',');
return contains(options, option);
}
} // namespace
#endif
#endif // MIGRAPHX_MLIR
void fuse_mlir::apply(module_pass_manager& mpm) const
{
#ifdef MIGRAPHX_MLIR
match::find_matches(mpm, find_mlir_op{});
const auto& device_name = ctx == nullptr ? "" : ctx->get_current_device().get_gfx_name();
const bool is_navi = starts_with(device_name, "gfx110");
auto get_mode = [&](std::string_view option, mlir_mode m1, mlir_mode m2 = mlir_mode::fast) {
if(is_requested(option))
return mlir_mode::all;
if(is_navi)
return mlir_mode::all;
return std::max(m1, m2);
};
mlir_mode mode =
(enabled(MIGRAPHX_ENABLE_EXTRA_MLIR{}) or enable_extra) ? mlir_mode::fast : mlir_mode::none;
match::find_matches(mpm,
find_mlir_fused_ops{.conv_mode = get_mode("fused", mlir_mode::fast),
.dot_mode = get_mode("fused", mode)});
match::find_matches(
mpm,
find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::int8)},
find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::none)});
#else
(void)mpm;
#endif
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -790,22 +790,26 @@ struct find_layernorm_pointwise
{
auto matcher() const
{
return precompile_name("pointwise")(match::arg(0)(
return precompile_name("pointwise")(match::any_of[match::inputs()](
precompile_name("gpu::prelayernorm", "gpu::preadd_layernorm").bind("layernorm")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto pw_ins = r.result;
auto layernorm = r.instructions["layernorm"];
if(not layernorm->module_inputs().empty())
return;
auto* pm = ins->module_inputs().front();
auto* pm = pw_ins->module_inputs().front();
auto pw_inputs = pw_ins->inputs();
auto ln_pos = std::find(pw_inputs.begin(), pw_inputs.end(), layernorm);
assert(ln_pos != pw_inputs.end());
pw_inputs.erase(ln_pos);
auto inputs = layernorm->inputs();
inputs.pop_back();
inputs.insert(inputs.end(), ins->inputs().begin() + 1, ins->inputs().end());
inputs.insert(inputs.end(), pw_inputs.begin(), pw_inputs.end());
m.replace_instruction(ins, layernorm->get_operator(), inputs, {pm});
m.replace_instruction(pw_ins, layernorm->get_operator(), inputs, {pm});
}
};
......
......@@ -55,7 +55,7 @@ bool is_device_ptr(const void* ptr)
auto status = hipPointerGetAttributes(&attr, ptr);
if(status != hipSuccess)
return false;
return attr.memoryType == hipMemoryTypeDevice;
return attr.type == hipMemoryTypeDevice;
}
std::size_t get_available_gpu_memory()
......
......@@ -27,6 +27,7 @@
#include <migraphx/msgpack.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/ranges.hpp>
#include <array>
#include <iostream>
#include <cstring>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_GPU_CK_HPP
#define MIGRAPHX_GUARD_GPU_CK_HPP
#include <migraphx/compile_src.hpp>
#include <migraphx/env.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/stringutils.hpp>
#include <string_view>
#include "ck/host/device_gemm_multiple_d.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm.hpp"
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
#ifndef _WIN32
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_DEBUG);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TUNE_CK);
#endif
// NOLINTNEXTLINE
const char* const disable_warning_pragma = R"__migraphx__(
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
${content}
#pragma clang diagnostic pop
)__migraphx__";
template <class P>
std::string ck_disable_warnings(P p)
{
return interpolate_string(disable_warning_pragma,
{{"content", std::string{p.data(), p.size()}}});
}
static std::unordered_map<std::string, std::string> create_ck_header_strings()
{
std::unordered_map<std::string, std::string> result;
auto ck_headers = ck::host::GetHeaders();
std::transform(
ck_headers.begin(), ck_headers.end(), std::inserter(result, result.begin()), [&](auto& p) {
return std::pair<std::string, std::string>(p.first, ck_disable_warnings(p.second));
});
return result;
}
static std::vector<src_file> create_ck_headers()
{
static const auto& header_strings = create_ck_header_strings();
std::vector<src_file> srcs;
std::transform(header_strings.begin(),
header_strings.end(),
std::back_inserter(srcs),
[&](auto& p) { return src_file{p}; });
return srcs;
}
static inline const std::vector<src_file>& ck_headers()
{
static const auto& headers = create_ck_headers();
return headers;
}
inline bool transposed_matrix(const shape& s) { return s.strides().back() != 1; }
inline ck::host::DataType get_type(const shape& s)
{
if(s.type() == shape::half_type)
return ck::host::DataType::Half;
else if(s.type() == shape::float_type)
return ck::host::DataType::Float;
else if(s.type() == shape::int8_type)
return ck::host::DataType::Int8;
else if(s.type() == shape::int32_type)
return ck::host::DataType::Int32;
MIGRAPHX_THROW("Unsupported ck type");
}
inline std::size_t get_batch_count(const shape& s)
{
return std::accumulate(
s.lens().rbegin() + 2, s.lens().rend(), std::size_t{1}, std::multiplies<std::size_t>());
}
inline void fold_batch_dims(shape& s)
{
auto lens = s.lens();
if(lens.size() <= 2)
return;
auto batch_count = get_batch_count(s);
auto m1 = lens.at(lens.size() - 2);
auto m2 = lens.at(lens.size() - 1);
if(transposed_matrix(s))
s = shape{s.type(), {m1, m2 * batch_count}};
else
s = shape{s.type(), {m1 * batch_count, m2}};
}
inline void remove_batch_dims(shape& s)
{
auto lens = s.lens();
if(lens.size() <= 2)
return;
auto m1 = lens.at(lens.size() - 2);
auto m2 = lens.at(lens.size() - 1);
s = shape{s.type(), {m1, m2}};
}
inline bool standard_batch(const shape& s)
{
if(s.lens().size() < 3)
return true;
std::vector<std::size_t> lens(s.lens().begin(), s.lens().end() - 2);
std::vector<std::size_t> strides(s.strides().begin(), s.strides().end() - 2);
auto base = *(s.lens().end() - 2) * *(s.lens().end() - 1);
std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto stride) {
return stride / base;
});
return shape{s.type(), lens, strides}.standard();
}
inline bool can_fold_batch(const std::vector<shape>& inputs)
{
const auto& b_shape = inputs[1];
if(std::any_of(inputs.begin() + 2, inputs.end() - 1, [](auto input) {
return not standard_batch(input);
}))
return false;
const auto& b_strides = b_shape.strides();
return std::all_of(
b_strides.begin(), b_strides.end() - 2, [](auto stride) { return stride == 0; });
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_CK_HPP
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment