Commit 4ea39116 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

manual merge

parents 20128cae d8011adf
...@@ -70,6 +70,10 @@ void quantize_int8(program& prog, ...@@ -70,6 +70,10 @@ void quantize_int8(program& prog,
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation"); MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
} }
// Run optimize_module() before converting to int8 to const eval and fold in FP32 to
// avoid loss of precision.
run_passes(prog, {optimize_module{}});
std::shared_ptr<std::vector<std::pair<float, float>>> int8_quant_params = std::shared_ptr<std::vector<std::pair<float, float>>> int8_quant_params =
std::make_shared<std::vector<std::pair<float, float>>>(); std::make_shared<std::vector<std::pair<float, float>>>();
std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>(); std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>();
...@@ -143,11 +147,8 @@ void quantize_int8(program& prog, ...@@ -143,11 +147,8 @@ void quantize_int8(program& prog,
run_passes(prog, run_passes(prog,
{quantize_int8_pass{ins_names, *int8_quant_params}, {quantize_int8_pass{ins_names, *int8_quant_params},
eliminate_common_subexpression{},
dead_code_elimination{},
simplify_reshapes{},
dead_code_elimination{},
simplify_qdq{}, simplify_qdq{},
optimize_module{},
dead_code_elimination{}}); dead_code_elimination{}});
} }
......
...@@ -33,6 +33,8 @@ ...@@ -33,6 +33,8 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK_WORKAROUNDS);
void apply_quantizelinear(module& m, instruction_ref ins) void apply_quantizelinear(module& m, instruction_ref ins)
{ {
assert(ins->name() == "quantizelinear"); assert(ins->name() == "quantizelinear");
...@@ -45,7 +47,7 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -45,7 +47,7 @@ void apply_quantizelinear(module& m, instruction_ref ins)
ins, make_op("convert", {{"target_type", y_scale->get_shape().type()}}), x); ins, make_op("convert", {{"target_type", y_scale->get_shape().type()}}), x);
} }
auto div = m.insert_instruction(ins, make_op("div"), x, y_scale); auto div = m.insert_instruction(ins, make_op("div"), x, y_scale);
auto add_zero_point = m.insert_instruction(ins, make_op("round"), div); auto add_zero_point = m.insert_instruction(ins, make_op("nearbyint"), div);
if(ins->inputs().size() == 3) if(ins->inputs().size() == 3)
{ {
...@@ -62,9 +64,22 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -62,9 +64,22 @@ void apply_quantizelinear(module& m, instruction_ref ins)
max_quant = qt.max(); max_quant = qt.max();
min_quant = qt.min(); min_quant = qt.min();
}); });
auto s = add_zero_point->get_shape(); auto s = add_zero_point->get_shape();
auto min_arg = m.add_literal(literal{shape{s.type()}, {min_quant}}); instruction_ref min_arg;
auto max_arg = m.add_literal(literal{shape{s.type()}, {max_quant}}); instruction_ref max_arg;
if(enabled(MIGRAPHX_ENABLE_CK_WORKAROUNDS{}))
{
std::vector<int> min_data(s.elements(), min_quant);
std::vector<int> max_data(s.elements(), max_quant);
min_arg = m.add_literal(literal(s, min_data));
max_arg = m.add_literal(literal(s, max_data));
}
else
{
min_arg = m.add_literal(literal{shape{s.type()}, {min_quant}});
max_arg = m.add_literal(literal{shape{s.type()}, {max_quant}});
}
auto saturate = insert_common_op(m, ins, make_op("clip"), {add_zero_point, min_arg, max_arg}); auto saturate = insert_common_op(m, ins, make_op("clip"), {add_zero_point, min_arg, max_arg});
m.replace_instruction( m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate); ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate);
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <migraphx/simplify_dyn_ops.hpp> #include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/literal.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -131,10 +132,53 @@ struct find_const_4in_slice ...@@ -131,10 +132,53 @@ struct find_const_4in_slice
} }
}; };
/**
* Simplify dimensions_of to a literal when the input arugment has a static shape
* or the dynamic dimensions from `start` to `end` are fixed.
*/
struct find_static_dimensions_of
{
auto matcher() const { return match::name("dimensions_of")(); }
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto input = ins->inputs().at(0);
auto dimensions_of_value = ins->get_operator().to_value();
auto start = dimensions_of_value.at("start").to<std::size_t>();
auto end = dimensions_of_value.at("end").to<std::size_t>();
if(input->get_shape().dynamic())
{
// check if dynamic dimensions from start to end are fixed
auto dds = input->get_shape().dyn_dims();
if(std::any_of(dds.begin() + start, dds.begin() + end, [](auto dd) {
return not dd.is_fixed();
}))
{
return;
}
}
std::size_t output_ndim = end - start;
std::vector<int64_t> vec_shape(output_ndim);
migraphx::shape s(migraphx::shape::int64_type, {output_ndim});
std::vector<std::size_t> input_lens = input->get_shape().to_static(1).lens();
std::transform(input_lens.begin() + start,
input_lens.begin() + end,
vec_shape.begin(),
[](auto i) { return int64_t(i); });
migraphx::shape output_shape{migraphx::shape::int64_type, {end - start}};
auto lit_ins = m.add_literal(migraphx::literal{output_shape, vec_shape});
m.replace_instruction(ins, lit_ins);
}
};
void simplify_dyn_ops::apply(module& m) const void simplify_dyn_ops::apply(module& m) const
{ {
match::find_matches( match::find_matches(m,
m, find_static_2in_broadcasts{}, find_const_3in_slice{}, find_const_4in_slice{}); find_static_2in_broadcasts{},
find_static_dimensions_of{},
find_const_3in_slice{},
find_const_4in_slice{});
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -91,6 +91,19 @@ struct post_op : reflect_equality<post_op>, reflect_stream<post_op> ...@@ -91,6 +91,19 @@ struct post_op : reflect_equality<post_op>, reflect_stream<post_op>
} }
}; };
template <class F>
struct execute_wrapper
{
F f;
argument operator()(context&, const std::vector<argument>& args) const { return f(args); }
};
template <class F>
execute_wrapper<F> make_execute_wrapper(F f)
{
return {std::move(f)};
}
template <class Derived, class Primitive> template <class Derived, class Primitive>
struct dnnl_op : auto_register_op<Derived> struct dnnl_op : auto_register_op<Derived>
{ {
...@@ -308,7 +321,7 @@ struct dnnl_op : auto_register_op<Derived> ...@@ -308,7 +321,7 @@ struct dnnl_op : auto_register_op<Derived>
#ifndef NDEBUG #ifndef NDEBUG
auto prim_attr = get_primitive_attr(md); auto prim_attr = get_primitive_attr(md);
#endif #endif
execute = [=](context&, const std::vector<argument>& args) { execute = make_execute_wrapper([=](const std::vector<argument>& args) {
#ifndef NDEBUG #ifndef NDEBUG
// Check that the memory descriptors have not changed // Check that the memory descriptors have not changed
auto debug_args = args; auto debug_args = args;
...@@ -379,7 +392,7 @@ struct dnnl_op : auto_register_op<Derived> ...@@ -379,7 +392,7 @@ struct dnnl_op : auto_register_op<Derived>
m[arg_lookup[i]] = to_dnnl_memory(md.at(arg_lookup[i]), args[i]); m[arg_lookup[i]] = to_dnnl_memory(md.at(arg_lookup[i]), args[i]);
prim.execute(get_dnnl_context().stream, m); prim.execute(get_dnnl_context().stream, m);
return args.back(); return args.back();
}; });
} }
std::vector<shape> trim_post_op_inputs(const std::vector<shape>& inputs) const std::vector<shape> trim_post_op_inputs(const std::vector<shape>& inputs) const
{ {
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP #ifndef MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP
#define MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP #define MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP
#include <migraphx/config.hpp> #include <migraphx/cpu/context.hpp>
#include <string> #include <string>
namespace migraphx { namespace migraphx {
...@@ -34,9 +34,7 @@ struct module; ...@@ -34,9 +34,7 @@ struct module;
namespace cpu { namespace cpu {
struct context; struct MIGRAPHX_CPU_EXPORT fuse_ops
struct fuse_ops
{ {
context* ctx = nullptr; context* ctx = nullptr;
std::string name() const { return "cpu::fuse_ops"; } std::string name() const { return "cpu::fuse_ops"; }
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_POINTWISE_HPP #ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_POINTWISE_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_POINTWISE_HPP #define MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_POINTWISE_HPP
#include <array>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/context.hpp> #include <migraphx/context.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
......
# #################################################################################### # ####################################################################################
# The MIT License (MIT) # The MIT License (MIT)
# #
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
# #
# Permission is hereby granted, free of charge, to any person obtaining a copy # Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal # of this software and associated documentation files (the "Software"), to deal
...@@ -37,8 +37,7 @@ if(NOT TARGET MIOpen) ...@@ -37,8 +37,7 @@ if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen") message(SEND_ERROR "Cant find miopen")
endif() endif()
if(NOT WIN32) if(MIGRAPHX_USE_COMPOSABLEKERNEL)
# TODO: re-enable when CK is ported to Windows
find_package(composable_kernel 1.0.0 REQUIRED COMPONENTS jit_library) find_package(composable_kernel 1.0.0 REQUIRED COMPONENTS jit_library)
endif() endif()
...@@ -48,10 +47,18 @@ else() ...@@ -48,10 +47,18 @@ else()
set(MIGRAPHX_USE_HIPRTC ON CACHE BOOL "Use hipRTC APIs") set(MIGRAPHX_USE_HIPRTC ON CACHE BOOL "Use hipRTC APIs")
endif() endif()
include(Embed)
file(GLOB KERNEL_FILES CONFIGURE_DEPENDS file(GLOB KERNEL_FILES CONFIGURE_DEPENDS
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp) ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp)
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
if(NOT MIGRAPHX_USE_COMPOSABLEKERNEL)
list(REMOVE_ITEM KERNEL_FILES
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck_gemm.hpp
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck_gemm_softmax_gemm.hpp
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck.hpp)
endif()
include(Embed)
add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/) add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/)
configure_file(device/targets.hpp.in include/migraphx/gpu/device/targets.hpp) configure_file(device/targets.hpp.in include/migraphx/gpu/device/targets.hpp)
...@@ -95,9 +102,10 @@ rocm_clang_tidy_check(kernel_file_check) ...@@ -95,9 +102,10 @@ rocm_clang_tidy_check(kernel_file_check)
file(GLOB JIT_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/jit/*.cpp) file(GLOB JIT_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/jit/*.cpp)
if(WIN32) if(NOT MIGRAPHX_USE_COMPOSABLEKERNEL)
# TODO: re-enable when CK is ported to Windows list(REMOVE_ITEM JIT_GPU_SRCS
list(REMOVE_ITEM JIT_GPU_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm.cpp) ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm_softmax_gemm.cpp)
endif() endif()
add_library(migraphx_gpu add_library(migraphx_gpu
...@@ -120,8 +128,6 @@ add_library(migraphx_gpu ...@@ -120,8 +128,6 @@ add_library(migraphx_gpu
gather.cpp gather.cpp
gemm_impl.cpp gemm_impl.cpp
hip.cpp hip.cpp
int8_conv_pack.cpp
int8_gemm_pack.cpp
kernel.cpp kernel.cpp
lowering.cpp lowering.cpp
logsoftmax.cpp logsoftmax.cpp
...@@ -132,7 +138,6 @@ add_library(migraphx_gpu ...@@ -132,7 +138,6 @@ add_library(migraphx_gpu
no_device.cpp no_device.cpp
nonzero.cpp nonzero.cpp
pack_args.cpp pack_args.cpp
pack_int8_args.cpp
prefuse_ops.cpp prefuse_ops.cpp
pad.cpp pad.cpp
perfdb.cpp perfdb.cpp
...@@ -176,7 +181,6 @@ register_migraphx_gpu_ops(hip_ ...@@ -176,7 +181,6 @@ register_migraphx_gpu_ops(hip_
register_migraphx_gpu_ops(miopen_ register_migraphx_gpu_ops(miopen_
abs abs
contiguous contiguous
int8_conv_pack
lrn lrn
pooling pooling
) )
...@@ -184,10 +188,6 @@ register_op(migraphx_gpu ...@@ -184,10 +188,6 @@ register_op(migraphx_gpu
HEADER migraphx/gpu/rnn_variable_seq_lens.hpp HEADER migraphx/gpu/rnn_variable_seq_lens.hpp
OPERATORS gpu::hip_rnn_var_sl_shift_sequence gpu::hip_rnn_var_sl_shift_output gpu::hip_rnn_var_sl_last_output OPERATORS gpu::hip_rnn_var_sl_shift_sequence gpu::hip_rnn_var_sl_shift_output gpu::hip_rnn_var_sl_last_output
INCLUDES migraphx/gpu/context.hpp) INCLUDES migraphx/gpu/context.hpp)
register_op(migraphx_gpu
HEADER migraphx/gpu/int8_gemm_pack.hpp
OPERATORS gpu::hip_int8_gemm_pack_a gpu::hip_int8_gemm_pack_b
INCLUDES migraphx/gpu/context.hpp)
register_op(migraphx_gpu register_op(migraphx_gpu
HEADER migraphx/gpu/gemm.hpp HEADER migraphx/gpu/gemm.hpp
OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot> OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot>
...@@ -231,24 +231,28 @@ else() ...@@ -231,24 +231,28 @@ else()
string(REGEX REPLACE " /[^ ]+\\.(a|so) " " " HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}") string(REGEX REPLACE " /[^ ]+\\.(a|so) " " " HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
endforeach() endforeach()
message(STATUS "Hip compiler flags: ${HIP_COMPILER_FLAGS}") message(STATUS "Hip compiler flags: \"${HIP_COMPILER_FLAGS}\"")
target_compile_definitions(migraphx_gpu PRIVATE target_compile_definitions(migraphx_gpu PRIVATE
"-DMIGRAPHX_HIP_COMPILER=${CMAKE_CXX_COMPILER}" -DMIGRAPHX_HIP_COMPILER="${CMAKE_CXX_COMPILER}"
"-DMIGRAPHX_HIP_COMPILER_FLAGS=${HIP_COMPILER_FLAGS}" -DMIGRAPHX_HIP_COMPILER_FLAGS="${HIP_COMPILER_FLAGS}"
) )
if(DEFINED CMAKE_CXX_COMPILER_LAUNCHER) if(DEFINED CMAKE_CXX_COMPILER_LAUNCHER)
execute_process(COMMAND which ${CMAKE_CXX_COMPILER_LAUNCHER} OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER) execute_process(COMMAND which ${CMAKE_CXX_COMPILER_LAUNCHER} OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER)
string(STRIP "${MIGRAPHX_HIP_COMPILER_LAUNCHER}" MIGRAPHX_HIP_COMPILER_LAUNCHER) string(STRIP "${MIGRAPHX_HIP_COMPILER_LAUNCHER}" MIGRAPHX_HIP_COMPILER_LAUNCHER)
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_HIP_COMPILER_LAUNCHER=${MIGRAPHX_HIP_COMPILER_LAUNCHER}") target_compile_definitions(migraphx_gpu PRIVATE -DMIGRAPHX_HIP_COMPILER_LAUNCHER="${MIGRAPHX_HIP_COMPILER_LAUNCHER}")
endif() endif()
endif() endif()
# Check miopen find mode api # Check miopen find mode api
include(CheckLibraryExists) include(CheckLibraryExists)
get_target_property(MIOPEN_LOCATION MIOpen LOCATION) get_target_property(MIOPEN_LOCATION MIOpen LOCATION)
get_target_property(ROCBLAS_LOCATION roc::rocblas LOCATION)
check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API) check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API)
check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API) check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API)
# Beta API for automated GEMM tuning
check_library_exists(roc::rocblas "rocblas_gemm_ex_get_solutions" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_TUNING_BETA_FEATURE_API)
set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "") set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "")
...@@ -271,10 +275,16 @@ else() ...@@ -271,10 +275,16 @@ else()
message(STATUS "MIOpen does not have find mode api") message(STATUS "MIOpen does not have find mode api")
endif() endif()
if(HAS_ROCBLAS_TUNING_BETA_FEATURE_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_TUNING_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS)
message(STATUS "MIGraphx is using Beta API of rocBLAS")
else()
message(STATUS "rocBLAS does not have User Tuning Beta API")
endif()
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas) target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels) target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
if(NOT WIN32) if(MIGRAPHX_USE_COMPOSABLEKERNEL)
# TODO: re-enable when CK is ported to Windows
target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library) target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library)
endif() endif()
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -40,7 +40,8 @@ argument hip_argmax::compute(context& ctx, const shape&, const std::vector<argum ...@@ -40,7 +40,8 @@ argument hip_argmax::compute(context& ctx, const shape&, const std::vector<argum
{ {
auto n_dim = args.front().get_shape().lens().size(); auto n_dim = args.front().get_shape().lens().size();
int64_t tuned_axis = tune_axis(n_dim, op.axis, op.name()); int64_t tuned_axis = tune_axis(n_dim, op.axis, op.name());
device::argmax(ctx.get_stream().get(), args.back(), args.front(), tuned_axis); device::argmax(
ctx.get_stream().get(), args.back(), args.front(), tuned_axis, op.select_last_index);
return args.back(); return args.back();
} }
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -40,7 +40,8 @@ argument hip_argmin::compute(context& ctx, const shape&, const std::vector<argum ...@@ -40,7 +40,8 @@ argument hip_argmin::compute(context& ctx, const shape&, const std::vector<argum
{ {
auto n_dim = args.front().get_shape().lens().size(); auto n_dim = args.front().get_shape().lens().size();
int64_t tuned_axis = tune_axis(n_dim, op.axis, op.name()); int64_t tuned_axis = tune_axis(n_dim, op.axis, op.name());
device::argmin(ctx.get_stream().get(), args.back(), args.front(), tuned_axis); device::argmin(
ctx.get_stream().get(), args.back(), args.front(), tuned_axis, op.select_last_index);
return args.back(); return args.back();
} }
......
...@@ -248,7 +248,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -248,7 +248,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
{ {
if(src.path.extension() != ".cpp") if(src.path.extension() != ".cpp")
continue; continue;
std::cout << std::string(src.content.first, src.len()) << std::endl; std::cout << std::string(src.content) << std::endl;
} }
} }
auto p = dynamic_loader::path(&compile_hip_src_with_hiprtc); auto p = dynamic_loader::path(&compile_hip_src_with_hiprtc);
...@@ -284,16 +284,20 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr ...@@ -284,16 +284,20 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
bool is_hip_clang_compiler() bool is_hip_clang_compiler()
{ {
static const auto result = ends_with(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER), "clang++"); static const auto result = fs::path{MIGRAPHX_HIP_COMPILER}.stem() == "clang++";
return result; return result;
} }
#ifdef MIGRAPHX_HIP_COMPILER_LAUNCHER
bool has_compiler_launcher() bool has_compiler_launcher()
{ {
static const auto result = fs::exists(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_LAUNCHER)); static const auto result = fs::exists(MIGRAPHX_HIP_COMPILER_LAUNCHER);
return result; return result;
} }
#endif
src_compiler assemble(src_compiler compiler) src_compiler assemble(src_compiler compiler)
{ {
compiler.out_ext = ".S"; compiler.out_ext = ".S";
...@@ -306,8 +310,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -306,8 +310,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
{ {
assert(not srcs.empty()); assert(not srcs.empty());
if(not is_hip_clang_compiler()) if(not is_hip_clang_compiler())
MIGRAPHX_THROW("Unknown hip compiler: " + MIGRAPHX_THROW("Unknown hip compiler: " MIGRAPHX_HIP_COMPILER);
std::string(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER)));
if(params.find("-std=") == std::string::npos) if(params.find("-std=") == std::string::npos)
params += " --std=c++17"; params += " --std=c++17";
...@@ -323,14 +326,14 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -323,14 +326,14 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
params += " -DMIGRAPHX_DEBUG"; params += " -DMIGRAPHX_DEBUG";
params += " -Wno-unused-command-line-argument -Wno-cuda-compat "; params += " -Wno-unused-command-line-argument -Wno-cuda-compat ";
params += MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_FLAGS); params += MIGRAPHX_HIP_COMPILER_FLAGS;
src_compiler compiler; src_compiler compiler;
compiler.flags = params; compiler.flags = params;
compiler.compiler = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER); compiler.compiler = MIGRAPHX_HIP_COMPILER;
#ifdef MIGRAPHX_HIP_COMPILER_LAUNCHER #ifdef MIGRAPHX_HIP_COMPILER_LAUNCHER
if(has_compiler_launcher()) if(has_compiler_launcher())
compiler.launcher = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_LAUNCHER); compiler.launcher = MIGRAPHX_HIP_COMPILER_LAUNCHER;
#endif #endif
if(enabled(MIGRAPHX_GPU_DUMP_SRC{})) if(enabled(MIGRAPHX_GPU_DUMP_SRC{}))
{ {
...@@ -338,7 +341,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -338,7 +341,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
{ {
if(src.path.extension() != ".cpp") if(src.path.extension() != ".cpp")
continue; continue;
std::cout << std::string(src.content.first, src.len()) << std::endl; std::cout << std::string(src.content) << std::endl;
} }
} }
...@@ -354,14 +357,12 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -354,14 +357,12 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
bool hip_has_flags(const std::vector<std::string>& flags) bool hip_has_flags(const std::vector<std::string>& flags)
{ {
src_compiler compiler; src_compiler compiler;
compiler.compiler = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER); compiler.compiler = MIGRAPHX_HIP_COMPILER;
compiler.flags = compiler.flags =
join_strings(flags, " ") + " -x hip -c --offload-arch=gfx900 --cuda-device-only"; join_strings(flags, " ") + " -x hip -c --offload-arch=gfx900 --cuda-device-only";
std::string src; std::string src;
src_file input; src_file input{"main.cpp", src};
input.path = "main.cpp";
input.content = std::make_pair(src.data(), src.data() + src.size());
try try
{ {
......
...@@ -139,6 +139,12 @@ void hip_compile_options::set_launch_params( ...@@ -139,6 +139,12 @@ void hip_compile_options::set_launch_params(
global = compute_global(local); global = compute_global(local);
} }
static bool hip_accept_non_uniform_wg()
{
static bool non_uniform_wg = hip_has_flags({"-fno-offload-uniform-block"});
return non_uniform_wg;
}
std::function<std::size_t(std::size_t local)> std::function<std::size_t(std::size_t local)>
compute_global_for(context& ctx, std::size_t n, std::size_t over) compute_global_for(context& ctx, std::size_t n, std::size_t over)
{ {
...@@ -146,13 +152,14 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over) ...@@ -146,13 +152,14 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
std::size_t max_global = ctx.get_current_device().get_cu_count() * std::size_t max_global = ctx.get_current_device().get_cu_count() *
ctx.get_current_device().get_max_workitems_per_cu(); ctx.get_current_device().get_max_workitems_per_cu();
return [n, over, max_global](std::size_t local) { return [n, over, max_global](std::size_t local) {
// hip require global workitems multiple of local workitems. It may degrade performance. std::size_t num_elements = n;
// [TODO]: consider adding "fno-hip-uniform-block" flag when it becomes available. if(not hip_accept_non_uniform_wg())
// https://reviews.llvm.org/D155213 {
std::size_t num_elements = ((n + local - 1) / local) * local; num_elements = (1 + (n - 1) / local) * local;
std::size_t groups = (num_elements + local - 1) / local; }
std::size_t max_blocks = max_global / local; std::size_t groups = 1 + (num_elements - 1) / local;
std::size_t nglobal = std::min(max_blocks * over, groups) * local; std::size_t max_blocks = max_global / local;
std::size_t nglobal = std::min(max_blocks * over, groups) * local;
return std::min(nglobal, num_elements); return std::min(nglobal, num_elements);
}; };
} }
...@@ -172,21 +179,22 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option ...@@ -172,21 +179,22 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
assert(options.inputs.size() == options.virtual_inputs.size() or assert(options.inputs.size() == options.virtual_inputs.size() or
options.virtual_inputs.empty()); options.virtual_inputs.empty());
std::vector<src_file> srcs = options.additional_src_files; std::vector<src_file> srcs = options.additional_src_files;
std::transform(migraphx_kernels().begin(), static auto kernels{::migraphx_kernels()};
migraphx_kernels().end(), std::transform(
std::back_inserter(srcs), kernels.begin(),
[](auto&& p) { kernels.end(),
auto&& name = p.first; std::back_inserter(srcs),
auto&& c = p.second; [](const std::pair<std::string_view, std::string_view>& elem) { return src_file{elem}; });
auto path = name; srcs.emplace_back("main.cpp", content);
return src_file{path, c};
});
srcs.push_back(src_file{fs::path{"main.cpp"},
std::make_pair(content.data(), content.data() + content.size())});
auto args_hpp = auto args_hpp =
generate_args_hpp(options.virtual_inputs.empty() ? options.inputs : options.virtual_inputs); generate_args_hpp(options.virtual_inputs.empty() ? options.inputs : options.virtual_inputs);
srcs.push_back(src_file{fs::path{"args.hpp"}, srcs.emplace_back("args.hpp", args_hpp);
std::make_pair(args_hpp.data(), args_hpp.data() + args_hpp.size())});
if(options.global % options.local != 0 and hip_accept_non_uniform_wg())
options.params += " -fno-offload-uniform-block";
else
assert(options.global % options.local == 0);
options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global); options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global);
options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local); options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local);
options.params += " " + join_strings(compiler_warnings(), " "); options.params += " " + join_strings(compiler_warnings(), " ");
......
...@@ -60,9 +60,8 @@ struct miopen_op ...@@ -60,9 +60,8 @@ struct miopen_op
}; };
MIGRAPHX_REGISTER_OP(miopen_op); MIGRAPHX_REGISTER_OP(miopen_op);
std::size_t compile_miopen::compile(operation& op, instruction_ref ins, bool format) const std::size_t compile_miopen::compile(operation& op, instruction_ref ins) const
{ {
op.from_value({{"int8_x4_format", format}});
auto v = op.compile(*ctx, ins->get_shape(), to_shapes(ins->inputs())); auto v = op.compile(*ctx, ins->get_shape(), to_shapes(ins->inputs()));
return v.get<std::size_t>("workspace", 0); return v.get<std::size_t>("workspace", 0);
} }
...@@ -70,25 +69,15 @@ std::size_t compile_miopen::compile(operation& op, instruction_ref ins, bool for ...@@ -70,25 +69,15 @@ std::size_t compile_miopen::compile(operation& op, instruction_ref ins, bool for
void compile_miopen::apply(module& m) const void compile_miopen::apply(module& m) const
{ {
assert(ctx); assert(ctx);
const bool int8_x4_format = get_int8_x4_format(any_cast<migraphx::gpu::context>(*ctx));
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if(ins->name() != "gpu::miopen_op") if(ins->name() != "gpu::miopen_op")
continue; continue;
auto op = any_cast<miopen_op>(ins->get_operator()).op; auto op = any_cast<miopen_op>(ins->get_operator()).op;
std::size_t ws = 0; std::size_t ws = 0;
try ws = compile(op, ins);
{ auto inputs = ins->inputs();
// for the regular convolution and convolution_backwards, this try would always succeed auto alloc = m.insert_instruction(
ws = compile(op, ins, int8_x4_format);
}
catch(migraphx::exception&)
{
// In case no solver supports the default format, retry using the other format.
ws = compile(op, ins, not int8_x4_format);
}
auto inputs = ins->inputs();
auto alloc = m.insert_instruction(
ins, make_op("allocate", {{"shape", to_value(shape{shape::int8_type, {ws}})}})); ins, make_op("allocate", {{"shape", to_value(shape{shape::int8_type, {ws}})}}));
inputs.insert(std::prev(inputs.end()), alloc); inputs.insert(std::prev(inputs.end()), alloc);
......
...@@ -37,6 +37,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -37,6 +37,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_BENCHMARKING);
struct precompile_op struct precompile_op
{ {
...@@ -167,6 +168,7 @@ struct compile_plan ...@@ -167,6 +168,7 @@ struct compile_plan
} }
const compiled_result& benchmark(problem_cache& pc) const const compiled_result& benchmark(problem_cache& pc) const
{ {
const auto trace_level = value_of(MIGRAPHX_TRACE_BENCHMARKING{});
if(results.empty()) if(results.empty())
MIGRAPHX_THROW("No configs to tune"); MIGRAPHX_THROW("No configs to tune");
if(results.size() == 1) if(results.size() == 1)
...@@ -177,19 +179,35 @@ struct compile_plan ...@@ -177,19 +179,35 @@ struct compile_plan
} }
if(not config) if(not config)
MIGRAPHX_THROW("Multiple kernels without config"); MIGRAPHX_THROW("Multiple kernels without config");
std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs" if(trace_level > 0)
<< std::endl; std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs"
<< std::endl;
if(trace_level > 1)
std::cout << "Problem: " << config->problem << std::endl;
std::vector<double> times; std::vector<double> times;
times.reserve(results.size()); times.reserve(results.size());
std::transform( std::transform(results.begin(),
results.begin(), results.end(), std::back_inserter(times), [&](const auto& cr) { results.end(),
if(not cr.has_value()) config->solutions.begin(),
return std::numeric_limits<double>::max(); std::back_inserter(times),
return time_op(*ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20) [&](const auto& cr, const auto& solution) {
.first; if(trace_level > 1)
}); std::cout << "Benchmarking solution: " << solution << std::endl;
if(not cr.has_value())
{
if(trace_level > 1)
std::cout << "No binary" << std::endl;
return std::numeric_limits<double>::max();
}
auto t = time_op(
*ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20);
if(trace_level > 1)
std::cout << t << "ms" << std::endl;
return t;
});
auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end())); auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end()));
std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl; if(trace_level > 0)
std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl;
pc.insert(preop.name(), config->problem, config->solutions.at(i)); pc.insert(preop.name(), config->problem, config->solutions.at(i));
if(not results[i].has_value()) if(not results[i].has_value())
MIGRAPHX_THROW("No valid tuned compilation."); MIGRAPHX_THROW("No valid tuned compilation.");
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -34,9 +34,16 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -34,9 +34,16 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void argmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis) void argmax(hipStream_t stream,
const argument& result,
const argument& arg,
int64_t axis,
bool select_last_index)
{ {
arg_op(argmax_op{}, stream, result, arg, axis); if(select_last_index)
arg_op(argmax_op_last_index{}, stream, result, arg, axis);
else
arg_op(argmax_op_first_index{}, stream, result, arg, axis);
} }
} // namespace device } // namespace device
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -34,9 +34,16 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -34,9 +34,16 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void argmin(hipStream_t stream, const argument& result, const argument& arg, int64_t axis) void argmin(hipStream_t stream,
const argument& result,
const argument& arg,
int64_t axis,
bool select_last_index)
{ {
arg_op(argmin_op{}, stream, result, arg, axis); if(select_last_index)
arg_op(argmin_op_last_index{}, stream, result, arg, axis);
else
arg_op(argmin_op_first_index{}, stream, result, arg, axis);
} }
} // namespace device } // namespace device
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/int8_gemm_pack.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/tensor.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void int8_gemm_pack_a(hipStream_t stream, const argument& result, const argument& arg)
{
auto comp_shape = arg.get_shape();
auto out_lens = comp_shape.lens();
auto dim_0 = out_lens.size() - 2;
auto dim_1 = out_lens.size() - 1;
std::size_t lda = comp_shape.strides()[dim_0];
std::size_t m_size = out_lens[dim_0] * out_lens[dim_1];
visit_all(result, arg)([&](auto output, auto input) {
std::size_t nelements = comp_shape.elements();
auto* out_ptr = device_cast(output.data());
auto* in_ptr = device_cast(input.data());
visit_tensor_size(out_lens.size(), [&](auto out_dim) {
hip_tensor_descriptor<out_dim> desc(comp_shape);
gs_launch(stream, nelements, 256)([=](auto ii) __device__ {
const size_t nb = 4;
auto idx = desc.multi(ii);
std::size_t i_m = idx[dim_1];
std::size_t i_k = idx[dim_0];
std::size_t offset = ii / m_size * m_size;
out_ptr[i_k % nb + (i_m + (i_k / nb) * lda) * nb + offset] =
in_ptr[i_m + i_k * lda + offset];
});
});
});
}
void int8_gemm_pack_b(hipStream_t stream, const argument& result, const argument& arg)
{
auto trans_shape = arg.get_shape();
auto out_lens = trans_shape.lens();
auto dim_0 = trans_shape.lens().size() - 2;
auto dim_1 = trans_shape.lens().size() - 1;
std::size_t ldb = trans_shape.strides()[dim_1];
auto wrap_lens = out_lens;
std::swap(wrap_lens[dim_0], wrap_lens[dim_1]);
shape comp_shape{trans_shape.type(), wrap_lens};
std::size_t m_size = out_lens[dim_0] * out_lens[dim_1];
visit_all(result, arg)([&](auto output, auto input) {
std::size_t nelements = comp_shape.elements();
auto* out_ptr = device_cast(output.data());
auto* in_ptr = device_cast(input.data());
visit_tensor_size(out_lens.size(), [&](auto out_dim) {
hip_tensor_descriptor<out_dim> desc(comp_shape);
gs_launch(stream, nelements, 256)([=](auto ii) __device__ {
const size_t nb = 4;
auto idx = desc.multi(ii);
std::size_t i_n = idx[dim_1];
std::size_t i_k = idx[dim_0];
std::size_t offset = ii / m_size * m_size;
out_ptr[i_k % nb + (i_n + (i_k / nb) * ldb) * nb + offset] =
in_ptr[i_n + i_k * ldb + offset];
});
});
});
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_DEVICE_TARGETS_CPP #ifndef MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
#define MIGRAPHX_GUARD_DEVICE_TARGETS_CPP #define MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
#include <migraphx/config.hpp> #include <migraphx/gpu/device/config.hpp>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -34,9 +34,13 @@ namespace gpu { ...@@ -34,9 +34,13 @@ namespace gpu {
namespace device { namespace device {
#define MIGRAPHX_GPU_TARGETS "@GPU_TARGETS@" // NOLINT #define MIGRAPHX_GPU_TARGETS "@GPU_TARGETS@" // NOLINT
MIGRAPHX_DEVICE_EXPORT
const std::vector<std::string>& get_targets(); const std::vector<std::string>& get_targets();
MIGRAPHX_DEVICE_EXPORT
std::string get_targets_as_string(); std::string get_targets_as_string();
MIGRAPHX_DEVICE_EXPORT
std::string get_device_name(); std::string get_device_name();
} // namespace device } // namespace device
......
...@@ -38,10 +38,8 @@ struct compile_op : action<compile_op> ...@@ -38,10 +38,8 @@ struct compile_op : action<compile_op>
context ctx; context ctx;
auto inputs = p.parse_shapes(v.at("inputs")); auto inputs = p.parse_shapes(v.at("inputs"));
auto op = gpu::compile_op(v.at("name").to<std::string>(), ctx, inputs, v); auto op = gpu::compile_op(v.at("name").to<std::string>(), ctx, inputs, v);
auto [host_time, device_time] = time_op(ctx, op, inputs, p.get(v, "iterations", 100)); auto t = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << host_time << "ms"; std::cout << op << ": " << t << "ms";
if(device_time > 0)
std::cout << ", " << device_time << "ms";
std::cout << std::endl; std::cout << std::endl;
} }
}; };
......
...@@ -43,8 +43,8 @@ struct run_op : action<run_op> ...@@ -43,8 +43,8 @@ struct run_op : action<run_op>
auto op = make_op(name); auto op = make_op(name);
if(v.contains("fields")) if(v.contains("fields"))
op.from_value(v.at("fields")); op.from_value(v.at("fields"));
auto [host_time, device_time] = time_op(ctx, op, inputs, p.get(v, "iterations", 100)); auto t = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << host_time << "ms" << std::endl; std::cout << op << ": " << t << "ms" << std::endl;
} }
}; };
......
...@@ -22,10 +22,11 @@ ...@@ -22,10 +22,11 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/gpu/fuse_ck.hpp> #include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/gpu/gemm_softmax_gemm.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/gpu/device_name.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -55,7 +56,7 @@ struct ck_gemm ...@@ -55,7 +56,7 @@ struct ck_gemm
{ {
check_shapes{inputs, *this}.same_ndims(); check_shapes{inputs, *this}.same_ndims();
if(inputs.size() < 2) if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs."); MIGRAPHX_THROW(name() + ": should have at least two inputs.");
auto a = inputs[0]; auto a = inputs[0];
auto b = inputs[1]; auto b = inputs[1];
for(const auto& input : inputs) for(const auto& input : inputs)
...@@ -65,27 +66,35 @@ struct ck_gemm ...@@ -65,27 +66,35 @@ struct ck_gemm
return r; return r;
return r.with_type(mods.front()->get_output_shapes().front().type()); return r.with_type(mods.front()->get_output_shapes().front().type());
} }
static bool is_ck_supported_type(shape::type_t t)
{
return contains({shape::half_type, shape::int8_type, shape::int32_type}, t);
}
}; };
MIGRAPHX_REGISTER_OP(ck_gemm); MIGRAPHX_REGISTER_OP(ck_gemm);
namespace { struct ck_gemm_softmax_gemm : gemm_softmax_gemm
bool is_ck_supported_type(shape::type_t t)
{ {
return contains({shape::half_type, shape::int8_type, shape::int32_type}, t); std::string name() const { return "gpu::ck_gemm_softmax_gemm"; }
} };
MIGRAPHX_REGISTER_OP(ck_gemm_softmax_gemm);
namespace {
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
{ {
if(ins->name() != "dot" and ins->name() != "quant_dot") if(ins->name() != "dot" and ins->name() != "quant_dot")
return false; return false;
if(not is_ck_supported_type(ins->get_shape().type())) if(not ck_gemm::is_ck_supported_type(ins->get_shape().type()))
return false; return false;
auto a = ins->inputs().front()->get_shape(); auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape(); auto b = ins->inputs().back()->get_shape();
auto m = a.lens()[a.lens().size() - 2]; auto m = a.lens()[a.lens().size() - 2];
auto n = b.lens().back(); auto n = b.lens().back();
auto k = a.lens().back(); auto k = a.lens().back();
auto batch_size = std::accumulate(
a.lens().rbegin() + 2, a.lens().rend(), std::size_t{1}, std::multiplies<std::size_t>());
// Integer gemms must be divisible by 4 in ck // Integer gemms must be divisible by 4 in ck
if(contains({shape::int8_type, shape::int32_type}, ins->get_shape().type())) if(contains({shape::int8_type, shape::int32_type}, ins->get_shape().type()))
{ {
...@@ -96,9 +105,17 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) ...@@ -96,9 +105,17 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
if(k % 4 != 0) if(k % 4 != 0)
return false; return false;
} }
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy auto device_name = trim(split_string(get_device_name(), ':').front());
// to avoid poor-performing GEMM kernels from CK if(device_name == "gfx940")
// To-do: Investigate a more precise strategy {
if(ins->get_shape().type() == shape::half_type)
{
if(batch_size >= 64)
return m < 2048 or k <= 64 or n <= 384 or n >= 2048;
return true;
}
return true;
}
return k <= 2048; return k <= 2048;
} }
...@@ -127,7 +144,15 @@ struct find_ck_gemm_pointwise ...@@ -127,7 +144,15 @@ struct find_ck_gemm_pointwise
ins->get_shape().type() != gemm_ins->get_shape().type()) ins->get_shape().type() != gemm_ins->get_shape().type())
return; return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) { if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not is_ck_supported_type(input->get_shape().type()); return not ck_gemm::is_ck_supported_type(input->get_shape().type());
}))
return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not input->inputs().empty() and input->inputs().front()->name() == "capture";
}))
return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not input->inputs().empty() and input->inputs().front()->name() == "capture";
})) }))
return; return;
assert(gemm_it != inputs.end()); assert(gemm_it != inputs.end());
...@@ -152,7 +177,7 @@ struct find_ck_gemm_pointwise ...@@ -152,7 +177,7 @@ struct find_ck_gemm_pointwise
struct find_ck_gemm struct find_ck_gemm
{ {
auto matcher() const { return match::name("dot")(is_ck_gemm().bind("gemm")); } auto matcher() const { return match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm")); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
...@@ -161,11 +186,26 @@ struct find_ck_gemm ...@@ -161,11 +186,26 @@ struct find_ck_gemm
} }
}; };
struct find_ck_gemm_softmax_gemm
{
auto matcher() const { return match::name("gpu::pre_gemm_softmax_gemm"); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto v = ins->get_operator().to_value();
assert(v.contains("scale"));
auto scale = v.at("scale").to<float>();
mpm.get_module().replace_instruction(
ins, ck_gemm_softmax_gemm{migraphx::make_op("dot"), scale}, ins->inputs());
}
};
} // namespace } // namespace
void fuse_ck::apply(module_pass_manager& mpm) const void fuse_ck::apply(module_pass_manager& mpm) const
{ {
match::find_matches(mpm, find_ck_gemm_pointwise{}); match::find_matches(mpm, find_ck_gemm_softmax_gemm{}, find_ck_gemm_pointwise{});
match::find_matches(mpm, find_ck_gemm{}); match::find_matches(mpm, find_ck_gemm{});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment