Commit 4ea39116 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

manual merge

parents 20128cae d8011adf
......@@ -70,6 +70,10 @@ void quantize_int8(program& prog,
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
}
// Run optimize_module() before converting to int8 to const eval and fold in FP32 to
// avoid loss of precision.
run_passes(prog, {optimize_module{}});
std::shared_ptr<std::vector<std::pair<float, float>>> int8_quant_params =
std::make_shared<std::vector<std::pair<float, float>>>();
std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>();
......@@ -143,11 +147,8 @@ void quantize_int8(program& prog,
run_passes(prog,
{quantize_int8_pass{ins_names, *int8_quant_params},
eliminate_common_subexpression{},
dead_code_elimination{},
simplify_reshapes{},
dead_code_elimination{},
simplify_qdq{},
optimize_module{},
dead_code_elimination{}});
}
......
......@@ -33,6 +33,8 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK_WORKAROUNDS);
void apply_quantizelinear(module& m, instruction_ref ins)
{
assert(ins->name() == "quantizelinear");
......@@ -45,7 +47,7 @@ void apply_quantizelinear(module& m, instruction_ref ins)
ins, make_op("convert", {{"target_type", y_scale->get_shape().type()}}), x);
}
auto div = m.insert_instruction(ins, make_op("div"), x, y_scale);
auto add_zero_point = m.insert_instruction(ins, make_op("round"), div);
auto add_zero_point = m.insert_instruction(ins, make_op("nearbyint"), div);
if(ins->inputs().size() == 3)
{
......@@ -62,9 +64,22 @@ void apply_quantizelinear(module& m, instruction_ref ins)
max_quant = qt.max();
min_quant = qt.min();
});
auto s = add_zero_point->get_shape();
auto min_arg = m.add_literal(literal{shape{s.type()}, {min_quant}});
auto max_arg = m.add_literal(literal{shape{s.type()}, {max_quant}});
auto s = add_zero_point->get_shape();
instruction_ref min_arg;
instruction_ref max_arg;
if(enabled(MIGRAPHX_ENABLE_CK_WORKAROUNDS{}))
{
std::vector<int> min_data(s.elements(), min_quant);
std::vector<int> max_data(s.elements(), max_quant);
min_arg = m.add_literal(literal(s, min_data));
max_arg = m.add_literal(literal(s, max_data));
}
else
{
min_arg = m.add_literal(literal{shape{s.type()}, {min_quant}});
max_arg = m.add_literal(literal{shape{s.type()}, {max_quant}});
}
auto saturate = insert_common_op(m, ins, make_op("clip"), {add_zero_point, min_arg, max_arg});
m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate);
......
......@@ -24,6 +24,7 @@
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/literal.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -131,10 +132,53 @@ struct find_const_4in_slice
}
};
/**
* Simplify dimensions_of to a literal when the input arugment has a static shape
* or the dynamic dimensions from `start` to `end` are fixed.
*/
struct find_static_dimensions_of
{
auto matcher() const { return match::name("dimensions_of")(); }
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto input = ins->inputs().at(0);
auto dimensions_of_value = ins->get_operator().to_value();
auto start = dimensions_of_value.at("start").to<std::size_t>();
auto end = dimensions_of_value.at("end").to<std::size_t>();
if(input->get_shape().dynamic())
{
// check if dynamic dimensions from start to end are fixed
auto dds = input->get_shape().dyn_dims();
if(std::any_of(dds.begin() + start, dds.begin() + end, [](auto dd) {
return not dd.is_fixed();
}))
{
return;
}
}
std::size_t output_ndim = end - start;
std::vector<int64_t> vec_shape(output_ndim);
migraphx::shape s(migraphx::shape::int64_type, {output_ndim});
std::vector<std::size_t> input_lens = input->get_shape().to_static(1).lens();
std::transform(input_lens.begin() + start,
input_lens.begin() + end,
vec_shape.begin(),
[](auto i) { return int64_t(i); });
migraphx::shape output_shape{migraphx::shape::int64_type, {end - start}};
auto lit_ins = m.add_literal(migraphx::literal{output_shape, vec_shape});
m.replace_instruction(ins, lit_ins);
}
};
void simplify_dyn_ops::apply(module& m) const
{
match::find_matches(
m, find_static_2in_broadcasts{}, find_const_3in_slice{}, find_const_4in_slice{});
match::find_matches(m,
find_static_2in_broadcasts{},
find_static_dimensions_of{},
find_const_3in_slice{},
find_const_4in_slice{});
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -91,6 +91,19 @@ struct post_op : reflect_equality<post_op>, reflect_stream<post_op>
}
};
template <class F>
struct execute_wrapper
{
F f;
argument operator()(context&, const std::vector<argument>& args) const { return f(args); }
};
template <class F>
execute_wrapper<F> make_execute_wrapper(F f)
{
return {std::move(f)};
}
template <class Derived, class Primitive>
struct dnnl_op : auto_register_op<Derived>
{
......@@ -308,7 +321,7 @@ struct dnnl_op : auto_register_op<Derived>
#ifndef NDEBUG
auto prim_attr = get_primitive_attr(md);
#endif
execute = [=](context&, const std::vector<argument>& args) {
execute = make_execute_wrapper([=](const std::vector<argument>& args) {
#ifndef NDEBUG
// Check that the memory descriptors have not changed
auto debug_args = args;
......@@ -379,7 +392,7 @@ struct dnnl_op : auto_register_op<Derived>
m[arg_lookup[i]] = to_dnnl_memory(md.at(arg_lookup[i]), args[i]);
prim.execute(get_dnnl_context().stream, m);
return args.back();
};
});
}
std::vector<shape> trim_post_op_inputs(const std::vector<shape>& inputs) const
{
......
......@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP
#define MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP
#include <migraphx/config.hpp>
#include <migraphx/cpu/context.hpp>
#include <string>
namespace migraphx {
......@@ -34,9 +34,7 @@ struct module;
namespace cpu {
struct context;
struct fuse_ops
struct MIGRAPHX_CPU_EXPORT fuse_ops
{
context* ctx = nullptr;
std::string name() const { return "cpu::fuse_ops"; }
......
......@@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_POINTWISE_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_POINTWISE_HPP
#include <array>
#include <migraphx/config.hpp>
#include <migraphx/context.hpp>
#include <migraphx/check_shapes.hpp>
......
# ####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
......@@ -37,8 +37,7 @@ if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen")
endif()
if(NOT WIN32)
# TODO: re-enable when CK is ported to Windows
if(MIGRAPHX_USE_COMPOSABLEKERNEL)
find_package(composable_kernel 1.0.0 REQUIRED COMPONENTS jit_library)
endif()
......@@ -48,10 +47,18 @@ else()
set(MIGRAPHX_USE_HIPRTC ON CACHE BOOL "Use hipRTC APIs")
endif()
include(Embed)
file(GLOB KERNEL_FILES CONFIGURE_DEPENDS
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp)
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
if(NOT MIGRAPHX_USE_COMPOSABLEKERNEL)
list(REMOVE_ITEM KERNEL_FILES
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck_gemm.hpp
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck_gemm_softmax_gemm.hpp
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck.hpp)
endif()
include(Embed)
add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/)
configure_file(device/targets.hpp.in include/migraphx/gpu/device/targets.hpp)
......@@ -95,9 +102,10 @@ rocm_clang_tidy_check(kernel_file_check)
file(GLOB JIT_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/jit/*.cpp)
if(WIN32)
# TODO: re-enable when CK is ported to Windows
list(REMOVE_ITEM JIT_GPU_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm.cpp)
if(NOT MIGRAPHX_USE_COMPOSABLEKERNEL)
list(REMOVE_ITEM JIT_GPU_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm_softmax_gemm.cpp)
endif()
add_library(migraphx_gpu
......@@ -120,8 +128,6 @@ add_library(migraphx_gpu
gather.cpp
gemm_impl.cpp
hip.cpp
int8_conv_pack.cpp
int8_gemm_pack.cpp
kernel.cpp
lowering.cpp
logsoftmax.cpp
......@@ -132,7 +138,6 @@ add_library(migraphx_gpu
no_device.cpp
nonzero.cpp
pack_args.cpp
pack_int8_args.cpp
prefuse_ops.cpp
pad.cpp
perfdb.cpp
......@@ -176,7 +181,6 @@ register_migraphx_gpu_ops(hip_
register_migraphx_gpu_ops(miopen_
abs
contiguous
int8_conv_pack
lrn
pooling
)
......@@ -184,10 +188,6 @@ register_op(migraphx_gpu
HEADER migraphx/gpu/rnn_variable_seq_lens.hpp
OPERATORS gpu::hip_rnn_var_sl_shift_sequence gpu::hip_rnn_var_sl_shift_output gpu::hip_rnn_var_sl_last_output
INCLUDES migraphx/gpu/context.hpp)
register_op(migraphx_gpu
HEADER migraphx/gpu/int8_gemm_pack.hpp
OPERATORS gpu::hip_int8_gemm_pack_a gpu::hip_int8_gemm_pack_b
INCLUDES migraphx/gpu/context.hpp)
register_op(migraphx_gpu
HEADER migraphx/gpu/gemm.hpp
OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot>
......@@ -231,24 +231,28 @@ else()
string(REGEX REPLACE " /[^ ]+\\.(a|so) " " " HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
endforeach()
message(STATUS "Hip compiler flags: ${HIP_COMPILER_FLAGS}")
message(STATUS "Hip compiler flags: \"${HIP_COMPILER_FLAGS}\"")
target_compile_definitions(migraphx_gpu PRIVATE
"-DMIGRAPHX_HIP_COMPILER=${CMAKE_CXX_COMPILER}"
"-DMIGRAPHX_HIP_COMPILER_FLAGS=${HIP_COMPILER_FLAGS}"
-DMIGRAPHX_HIP_COMPILER="${CMAKE_CXX_COMPILER}"
-DMIGRAPHX_HIP_COMPILER_FLAGS="${HIP_COMPILER_FLAGS}"
)
if(DEFINED CMAKE_CXX_COMPILER_LAUNCHER)
execute_process(COMMAND which ${CMAKE_CXX_COMPILER_LAUNCHER} OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER)
string(STRIP "${MIGRAPHX_HIP_COMPILER_LAUNCHER}" MIGRAPHX_HIP_COMPILER_LAUNCHER)
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_HIP_COMPILER_LAUNCHER=${MIGRAPHX_HIP_COMPILER_LAUNCHER}")
target_compile_definitions(migraphx_gpu PRIVATE -DMIGRAPHX_HIP_COMPILER_LAUNCHER="${MIGRAPHX_HIP_COMPILER_LAUNCHER}")
endif()
endif()
# Check miopen find mode api
include(CheckLibraryExists)
get_target_property(MIOPEN_LOCATION MIOpen LOCATION)
get_target_property(ROCBLAS_LOCATION roc::rocblas LOCATION)
check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API)
check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API)
# Beta API for automated GEMM tuning
check_library_exists(roc::rocblas "rocblas_gemm_ex_get_solutions" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_TUNING_BETA_FEATURE_API)
set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "")
......@@ -271,10 +275,16 @@ else()
message(STATUS "MIOpen does not have find mode api")
endif()
if(HAS_ROCBLAS_TUNING_BETA_FEATURE_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_TUNING_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS)
message(STATUS "MIGraphx is using Beta API of rocBLAS")
else()
message(STATUS "rocBLAS does not have User Tuning Beta API")
endif()
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
if(NOT WIN32)
# TODO: re-enable when CK is ported to Windows
if(MIGRAPHX_USE_COMPOSABLEKERNEL)
target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library)
endif()
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -40,7 +40,8 @@ argument hip_argmax::compute(context& ctx, const shape&, const std::vector<argum
{
auto n_dim = args.front().get_shape().lens().size();
int64_t tuned_axis = tune_axis(n_dim, op.axis, op.name());
device::argmax(ctx.get_stream().get(), args.back(), args.front(), tuned_axis);
device::argmax(
ctx.get_stream().get(), args.back(), args.front(), tuned_axis, op.select_last_index);
return args.back();
}
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -40,7 +40,8 @@ argument hip_argmin::compute(context& ctx, const shape&, const std::vector<argum
{
auto n_dim = args.front().get_shape().lens().size();
int64_t tuned_axis = tune_axis(n_dim, op.axis, op.name());
device::argmin(ctx.get_stream().get(), args.back(), args.front(), tuned_axis);
device::argmin(
ctx.get_stream().get(), args.back(), args.front(), tuned_axis, op.select_last_index);
return args.back();
}
......
......@@ -248,7 +248,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
{
if(src.path.extension() != ".cpp")
continue;
std::cout << std::string(src.content.first, src.len()) << std::endl;
std::cout << std::string(src.content) << std::endl;
}
}
auto p = dynamic_loader::path(&compile_hip_src_with_hiprtc);
......@@ -284,16 +284,20 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
bool is_hip_clang_compiler()
{
static const auto result = ends_with(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER), "clang++");
static const auto result = fs::path{MIGRAPHX_HIP_COMPILER}.stem() == "clang++";
return result;
}
#ifdef MIGRAPHX_HIP_COMPILER_LAUNCHER
bool has_compiler_launcher()
{
static const auto result = fs::exists(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_LAUNCHER));
static const auto result = fs::exists(MIGRAPHX_HIP_COMPILER_LAUNCHER);
return result;
}
#endif
src_compiler assemble(src_compiler compiler)
{
compiler.out_ext = ".S";
......@@ -306,8 +310,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
{
assert(not srcs.empty());
if(not is_hip_clang_compiler())
MIGRAPHX_THROW("Unknown hip compiler: " +
std::string(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER)));
MIGRAPHX_THROW("Unknown hip compiler: " MIGRAPHX_HIP_COMPILER);
if(params.find("-std=") == std::string::npos)
params += " --std=c++17";
......@@ -323,14 +326,14 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
params += " -DMIGRAPHX_DEBUG";
params += " -Wno-unused-command-line-argument -Wno-cuda-compat ";
params += MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_FLAGS);
params += MIGRAPHX_HIP_COMPILER_FLAGS;
src_compiler compiler;
compiler.flags = params;
compiler.compiler = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER);
compiler.compiler = MIGRAPHX_HIP_COMPILER;
#ifdef MIGRAPHX_HIP_COMPILER_LAUNCHER
if(has_compiler_launcher())
compiler.launcher = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_LAUNCHER);
compiler.launcher = MIGRAPHX_HIP_COMPILER_LAUNCHER;
#endif
if(enabled(MIGRAPHX_GPU_DUMP_SRC{}))
{
......@@ -338,7 +341,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
{
if(src.path.extension() != ".cpp")
continue;
std::cout << std::string(src.content.first, src.len()) << std::endl;
std::cout << std::string(src.content) << std::endl;
}
}
......@@ -354,14 +357,12 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
bool hip_has_flags(const std::vector<std::string>& flags)
{
src_compiler compiler;
compiler.compiler = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER);
compiler.compiler = MIGRAPHX_HIP_COMPILER;
compiler.flags =
join_strings(flags, " ") + " -x hip -c --offload-arch=gfx900 --cuda-device-only";
std::string src;
src_file input;
input.path = "main.cpp";
input.content = std::make_pair(src.data(), src.data() + src.size());
src_file input{"main.cpp", src};
try
{
......
......@@ -139,6 +139,12 @@ void hip_compile_options::set_launch_params(
global = compute_global(local);
}
static bool hip_accept_non_uniform_wg()
{
static bool non_uniform_wg = hip_has_flags({"-fno-offload-uniform-block"});
return non_uniform_wg;
}
std::function<std::size_t(std::size_t local)>
compute_global_for(context& ctx, std::size_t n, std::size_t over)
{
......@@ -146,13 +152,14 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
std::size_t max_global = ctx.get_current_device().get_cu_count() *
ctx.get_current_device().get_max_workitems_per_cu();
return [n, over, max_global](std::size_t local) {
// hip require global workitems multiple of local workitems. It may degrade performance.
// [TODO]: consider adding "fno-hip-uniform-block" flag when it becomes available.
// https://reviews.llvm.org/D155213
std::size_t num_elements = ((n + local - 1) / local) * local;
std::size_t groups = (num_elements + local - 1) / local;
std::size_t max_blocks = max_global / local;
std::size_t nglobal = std::min(max_blocks * over, groups) * local;
std::size_t num_elements = n;
if(not hip_accept_non_uniform_wg())
{
num_elements = (1 + (n - 1) / local) * local;
}
std::size_t groups = 1 + (num_elements - 1) / local;
std::size_t max_blocks = max_global / local;
std::size_t nglobal = std::min(max_blocks * over, groups) * local;
return std::min(nglobal, num_elements);
};
}
......@@ -172,21 +179,22 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
assert(options.inputs.size() == options.virtual_inputs.size() or
options.virtual_inputs.empty());
std::vector<src_file> srcs = options.additional_src_files;
std::transform(migraphx_kernels().begin(),
migraphx_kernels().end(),
std::back_inserter(srcs),
[](auto&& p) {
auto&& name = p.first;
auto&& c = p.second;
auto path = name;
return src_file{path, c};
});
srcs.push_back(src_file{fs::path{"main.cpp"},
std::make_pair(content.data(), content.data() + content.size())});
static auto kernels{::migraphx_kernels()};
std::transform(
kernels.begin(),
kernels.end(),
std::back_inserter(srcs),
[](const std::pair<std::string_view, std::string_view>& elem) { return src_file{elem}; });
srcs.emplace_back("main.cpp", content);
auto args_hpp =
generate_args_hpp(options.virtual_inputs.empty() ? options.inputs : options.virtual_inputs);
srcs.push_back(src_file{fs::path{"args.hpp"},
std::make_pair(args_hpp.data(), args_hpp.data() + args_hpp.size())});
srcs.emplace_back("args.hpp", args_hpp);
if(options.global % options.local != 0 and hip_accept_non_uniform_wg())
options.params += " -fno-offload-uniform-block";
else
assert(options.global % options.local == 0);
options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global);
options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local);
options.params += " " + join_strings(compiler_warnings(), " ");
......
......@@ -60,9 +60,8 @@ struct miopen_op
};
MIGRAPHX_REGISTER_OP(miopen_op);
std::size_t compile_miopen::compile(operation& op, instruction_ref ins, bool format) const
std::size_t compile_miopen::compile(operation& op, instruction_ref ins) const
{
op.from_value({{"int8_x4_format", format}});
auto v = op.compile(*ctx, ins->get_shape(), to_shapes(ins->inputs()));
return v.get<std::size_t>("workspace", 0);
}
......@@ -70,25 +69,15 @@ std::size_t compile_miopen::compile(operation& op, instruction_ref ins, bool for
void compile_miopen::apply(module& m) const
{
assert(ctx);
const bool int8_x4_format = get_int8_x4_format(any_cast<migraphx::gpu::context>(*ctx));
for(auto ins : iterator_for(m))
{
if(ins->name() != "gpu::miopen_op")
continue;
auto op = any_cast<miopen_op>(ins->get_operator()).op;
std::size_t ws = 0;
try
{
// for the regular convolution and convolution_backwards, this try would always succeed
ws = compile(op, ins, int8_x4_format);
}
catch(migraphx::exception&)
{
// In case no solver supports the default format, retry using the other format.
ws = compile(op, ins, not int8_x4_format);
}
auto inputs = ins->inputs();
auto alloc = m.insert_instruction(
ws = compile(op, ins);
auto inputs = ins->inputs();
auto alloc = m.insert_instruction(
ins, make_op("allocate", {{"shape", to_value(shape{shape::int8_type, {ws}})}}));
inputs.insert(std::prev(inputs.end()), alloc);
......
......@@ -37,6 +37,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_BENCHMARKING);
struct precompile_op
{
......@@ -167,6 +168,7 @@ struct compile_plan
}
const compiled_result& benchmark(problem_cache& pc) const
{
const auto trace_level = value_of(MIGRAPHX_TRACE_BENCHMARKING{});
if(results.empty())
MIGRAPHX_THROW("No configs to tune");
if(results.size() == 1)
......@@ -177,19 +179,35 @@ struct compile_plan
}
if(not config)
MIGRAPHX_THROW("Multiple kernels without config");
std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs"
<< std::endl;
if(trace_level > 0)
std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs"
<< std::endl;
if(trace_level > 1)
std::cout << "Problem: " << config->problem << std::endl;
std::vector<double> times;
times.reserve(results.size());
std::transform(
results.begin(), results.end(), std::back_inserter(times), [&](const auto& cr) {
if(not cr.has_value())
return std::numeric_limits<double>::max();
return time_op(*ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20)
.first;
});
std::transform(results.begin(),
results.end(),
config->solutions.begin(),
std::back_inserter(times),
[&](const auto& cr, const auto& solution) {
if(trace_level > 1)
std::cout << "Benchmarking solution: " << solution << std::endl;
if(not cr.has_value())
{
if(trace_level > 1)
std::cout << "No binary" << std::endl;
return std::numeric_limits<double>::max();
}
auto t = time_op(
*ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20);
if(trace_level > 1)
std::cout << t << "ms" << std::endl;
return t;
});
auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end()));
std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl;
if(trace_level > 0)
std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl;
pc.insert(preop.name(), config->problem, config->solutions.at(i));
if(not results[i].has_value())
MIGRAPHX_THROW("No valid tuned compilation.");
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -34,9 +34,16 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void argmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
void argmax(hipStream_t stream,
const argument& result,
const argument& arg,
int64_t axis,
bool select_last_index)
{
arg_op(argmax_op{}, stream, result, arg, axis);
if(select_last_index)
arg_op(argmax_op_last_index{}, stream, result, arg, axis);
else
arg_op(argmax_op_first_index{}, stream, result, arg, axis);
}
} // namespace device
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -34,9 +34,16 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void argmin(hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
void argmin(hipStream_t stream,
const argument& result,
const argument& arg,
int64_t axis,
bool select_last_index)
{
arg_op(argmin_op{}, stream, result, arg, axis);
if(select_last_index)
arg_op(argmin_op_last_index{}, stream, result, arg, axis);
else
arg_op(argmin_op_first_index{}, stream, result, arg, axis);
}
} // namespace device
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/int8_gemm_pack.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/tensor.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void int8_gemm_pack_a(hipStream_t stream, const argument& result, const argument& arg)
{
auto comp_shape = arg.get_shape();
auto out_lens = comp_shape.lens();
auto dim_0 = out_lens.size() - 2;
auto dim_1 = out_lens.size() - 1;
std::size_t lda = comp_shape.strides()[dim_0];
std::size_t m_size = out_lens[dim_0] * out_lens[dim_1];
visit_all(result, arg)([&](auto output, auto input) {
std::size_t nelements = comp_shape.elements();
auto* out_ptr = device_cast(output.data());
auto* in_ptr = device_cast(input.data());
visit_tensor_size(out_lens.size(), [&](auto out_dim) {
hip_tensor_descriptor<out_dim> desc(comp_shape);
gs_launch(stream, nelements, 256)([=](auto ii) __device__ {
const size_t nb = 4;
auto idx = desc.multi(ii);
std::size_t i_m = idx[dim_1];
std::size_t i_k = idx[dim_0];
std::size_t offset = ii / m_size * m_size;
out_ptr[i_k % nb + (i_m + (i_k / nb) * lda) * nb + offset] =
in_ptr[i_m + i_k * lda + offset];
});
});
});
}
void int8_gemm_pack_b(hipStream_t stream, const argument& result, const argument& arg)
{
auto trans_shape = arg.get_shape();
auto out_lens = trans_shape.lens();
auto dim_0 = trans_shape.lens().size() - 2;
auto dim_1 = trans_shape.lens().size() - 1;
std::size_t ldb = trans_shape.strides()[dim_1];
auto wrap_lens = out_lens;
std::swap(wrap_lens[dim_0], wrap_lens[dim_1]);
shape comp_shape{trans_shape.type(), wrap_lens};
std::size_t m_size = out_lens[dim_0] * out_lens[dim_1];
visit_all(result, arg)([&](auto output, auto input) {
std::size_t nelements = comp_shape.elements();
auto* out_ptr = device_cast(output.data());
auto* in_ptr = device_cast(input.data());
visit_tensor_size(out_lens.size(), [&](auto out_dim) {
hip_tensor_descriptor<out_dim> desc(comp_shape);
gs_launch(stream, nelements, 256)([=](auto ii) __device__ {
const size_t nb = 4;
auto idx = desc.multi(ii);
std::size_t i_n = idx[dim_1];
std::size_t i_k = idx[dim_0];
std::size_t offset = ii / m_size * m_size;
out_ptr[i_k % nb + (i_n + (i_k / nb) * ldb) * nb + offset] =
in_ptr[i_n + i_k * ldb + offset];
});
});
});
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
#define MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
#include <migraphx/config.hpp>
#include <migraphx/gpu/device/config.hpp>
#include <string>
#include <vector>
......@@ -34,9 +34,13 @@ namespace gpu {
namespace device {
#define MIGRAPHX_GPU_TARGETS "@GPU_TARGETS@" // NOLINT
MIGRAPHX_DEVICE_EXPORT
const std::vector<std::string>& get_targets();
MIGRAPHX_DEVICE_EXPORT
std::string get_targets_as_string();
MIGRAPHX_DEVICE_EXPORT
std::string get_device_name();
} // namespace device
......
......@@ -38,10 +38,8 @@ struct compile_op : action<compile_op>
context ctx;
auto inputs = p.parse_shapes(v.at("inputs"));
auto op = gpu::compile_op(v.at("name").to<std::string>(), ctx, inputs, v);
auto [host_time, device_time] = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << host_time << "ms";
if(device_time > 0)
std::cout << ", " << device_time << "ms";
auto t = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << t << "ms";
std::cout << std::endl;
}
};
......
......@@ -43,8 +43,8 @@ struct run_op : action<run_op>
auto op = make_op(name);
if(v.contains("fields"))
op.from_value(v.at("fields"));
auto [host_time, device_time] = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << host_time << "ms" << std::endl;
auto t = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << t << "ms" << std::endl;
}
};
......
......@@ -22,10 +22,11 @@
* THE SOFTWARE.
*/
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/gpu/gemm_softmax_gemm.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/gpu/device_name.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -55,7 +56,7 @@ struct ck_gemm
{
check_shapes{inputs, *this}.same_ndims();
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");
MIGRAPHX_THROW(name() + ": should have at least two inputs.");
auto a = inputs[0];
auto b = inputs[1];
for(const auto& input : inputs)
......@@ -65,27 +66,35 @@ struct ck_gemm
return r;
return r.with_type(mods.front()->get_output_shapes().front().type());
}
static bool is_ck_supported_type(shape::type_t t)
{
return contains({shape::half_type, shape::int8_type, shape::int32_type}, t);
}
};
MIGRAPHX_REGISTER_OP(ck_gemm);
namespace {
bool is_ck_supported_type(shape::type_t t)
struct ck_gemm_softmax_gemm : gemm_softmax_gemm
{
return contains({shape::half_type, shape::int8_type, shape::int32_type}, t);
}
std::string name() const { return "gpu::ck_gemm_softmax_gemm"; }
};
MIGRAPHX_REGISTER_OP(ck_gemm_softmax_gemm);
namespace {
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
{
if(ins->name() != "dot" and ins->name() != "quant_dot")
return false;
if(not is_ck_supported_type(ins->get_shape().type()))
if(not ck_gemm::is_ck_supported_type(ins->get_shape().type()))
return false;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
auto m = a.lens()[a.lens().size() - 2];
auto n = b.lens().back();
auto k = a.lens().back();
auto batch_size = std::accumulate(
a.lens().rbegin() + 2, a.lens().rend(), std::size_t{1}, std::multiplies<std::size_t>());
// Integer gemms must be divisible by 4 in ck
if(contains({shape::int8_type, shape::int32_type}, ins->get_shape().type()))
{
......@@ -96,9 +105,17 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
if(k % 4 != 0)
return false;
}
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from CK
// To-do: Investigate a more precise strategy
auto device_name = trim(split_string(get_device_name(), ':').front());
if(device_name == "gfx940")
{
if(ins->get_shape().type() == shape::half_type)
{
if(batch_size >= 64)
return m < 2048 or k <= 64 or n <= 384 or n >= 2048;
return true;
}
return true;
}
return k <= 2048;
}
......@@ -127,7 +144,15 @@ struct find_ck_gemm_pointwise
ins->get_shape().type() != gemm_ins->get_shape().type())
return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not is_ck_supported_type(input->get_shape().type());
return not ck_gemm::is_ck_supported_type(input->get_shape().type());
}))
return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not input->inputs().empty() and input->inputs().front()->name() == "capture";
}))
return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not input->inputs().empty() and input->inputs().front()->name() == "capture";
}))
return;
assert(gemm_it != inputs.end());
......@@ -152,7 +177,7 @@ struct find_ck_gemm_pointwise
struct find_ck_gemm
{
auto matcher() const { return match::name("dot")(is_ck_gemm().bind("gemm")); }
auto matcher() const { return match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm")); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
......@@ -161,11 +186,26 @@ struct find_ck_gemm
}
};
struct find_ck_gemm_softmax_gemm
{
auto matcher() const { return match::name("gpu::pre_gemm_softmax_gemm"); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto v = ins->get_operator().to_value();
assert(v.contains("scale"));
auto scale = v.at("scale").to<float>();
mpm.get_module().replace_instruction(
ins, ck_gemm_softmax_gemm{migraphx::make_op("dot"), scale}, ins->inputs());
}
};
} // namespace
void fuse_ck::apply(module_pass_manager& mpm) const
{
match::find_matches(mpm, find_ck_gemm_pointwise{});
match::find_matches(mpm, find_ck_gemm_softmax_gemm{}, find_ck_gemm_pointwise{});
match::find_matches(mpm, find_ck_gemm{});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment