Commit f69d828d authored by Manupa Karunaratne's avatar Manupa Karunaratne
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mlir-attention

parents fe36d210 24148857
......@@ -47,7 +47,7 @@ void apply_quantizelinear(module& m, instruction_ref ins)
ins, make_op("convert", {{"target_type", y_scale->get_shape().type()}}), x);
}
auto div = m.insert_instruction(ins, make_op("div"), x, y_scale);
auto add_zero_point = m.insert_instruction(ins, make_op("round"), div);
auto add_zero_point = m.insert_instruction(ins, make_op("nearbyint"), div);
if(ins->inputs().size() == 3)
{
......
......@@ -941,15 +941,6 @@ struct find_splits
{
auto split = i->inputs()[split_idx];
assert(split->name() == "slice");
// Insert contiguous for reshapes
auto outputs = i->outputs();
for(auto output : outputs)
{
if(output->name() != "reshape")
continue;
auto x = m.insert_instruction(output, make_op("contiguous"), i);
m.replace_instruction(output, output->get_operator(), x);
}
m.replace_instruction(i, split->get_operator(), c);
}
......@@ -1181,13 +1172,6 @@ struct find_conv_dot_horiz_fusion
for(auto arg : range(start, last))
{
auto outputs = arg->outputs();
for(auto output : outputs)
{
if(output->name() != "reshape")
continue;
auto x = m.insert_instruction(output, make_op("contiguous"), arg);
m.replace_instruction(output, output->get_operator(), x);
}
int64_t len = arg->get_shape().lens()[axis];
m.replace_instruction(
......@@ -1487,11 +1471,6 @@ struct find_split_reshape
slc_axis_len;
});
// insert the reshape instruction and add contiguous if needed
if(not input->get_shape().standard())
{
input = m.insert_instruction(std::next(input), make_op("contiguous"), input);
}
auto rsp_ins = m.insert_instruction(
std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input);
......
......@@ -24,6 +24,7 @@
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/literal.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -32,6 +33,10 @@ inline namespace MIGRAPHX_INLINE_NS {
* Convert 2 input static shape broadcast/multibroadcast into 1 input version.
* Some compiler passes (ex. simplify_algebra) only support the 1 input versions
* of the broadcasting operators.
* From:
* broadcast_op(argument_with_static_shape, argument_with_static_shape)
* To:
* broadcast_op(argument_with_static_shape); broadcast_op.out_lens = constant_output_dims
*/
struct find_static_2in_broadcasts
{
......@@ -131,10 +136,85 @@ struct find_const_4in_slice
}
};
/**
* Simplify dimensions_of to a literal when the input arugment has a static shape
* or the dynamic dimensions from `start` to `end` are fixed.
*/
struct find_static_dimensions_of
{
auto matcher() const { return match::name("dimensions_of")(); }
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto input = ins->inputs().at(0);
auto dimensions_of_value = ins->get_operator().to_value();
auto start = dimensions_of_value.at("start").to<std::size_t>();
auto end = dimensions_of_value.at("end").to<std::size_t>();
if(input->get_shape().dynamic())
{
// check if dynamic dimensions from start to end are fixed
auto dds = input->get_shape().dyn_dims();
if(std::any_of(dds.begin() + start, dds.begin() + end, [](auto dd) {
return not dd.is_fixed();
}))
{
return;
}
}
std::size_t output_ndim = end - start;
std::vector<int64_t> vec_shape(output_ndim);
migraphx::shape s(migraphx::shape::int64_type, {output_ndim});
std::vector<std::size_t> input_lens = input->get_shape().to_static(1).lens();
std::transform(input_lens.begin() + start,
input_lens.begin() + end,
vec_shape.begin(),
[](auto i) { return int64_t(i); });
migraphx::shape output_shape{migraphx::shape::int64_type, {end - start}};
auto lit_ins = m.add_literal(migraphx::literal{output_shape, vec_shape});
m.replace_instruction(ins, lit_ins);
}
};
/**
* Simplify allocate into 2 argument reshape that has constant output dimensions into a static 1
* argument reshape. Intended to simplify what ONNX parse_reshape creates for dynamic reshapes.
* From:
* x = allocate(constant_output_dims) -> reshape(data, x)
* To:
* reshape(data); reshape.dims = constant_output_dims
*/
struct find_const_alloc_reshapes
{
auto matcher() const
{
return match::name("reshape")(match::nargs(2),
match::arg(1)(match::name("allocate")(match::is_constant())));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto reshape_ins = mr.result;
auto reshape_inputs = reshape_ins->inputs();
auto alloc_ins = reshape_inputs.at(1);
argument output_dims_arg = alloc_ins->inputs().at(0)->eval(false);
std::vector<int64_t> output_dims_vec;
output_dims_arg.visit(
[&](auto output) { output_dims_vec.assign(output.begin(), output.end()); });
m.replace_instruction(
reshape_ins, make_op("reshape", {{"dims", output_dims_vec}}), reshape_inputs.at(0));
// have dead_code_elimination remove the previous allocate
}
};
void simplify_dyn_ops::apply(module& m) const
{
match::find_matches(
m, find_static_2in_broadcasts{}, find_const_3in_slice{}, find_const_4in_slice{});
match::find_matches(m,
find_static_dimensions_of{},
find_const_alloc_reshapes{},
find_static_2in_broadcasts{},
find_const_3in_slice{},
find_const_4in_slice{});
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -45,77 +45,145 @@ std::unordered_set<std::string> get_quantizable_op_names()
return s;
}
MIGRAPHX_PRED_MATCHER(has_same_value, instruction_ref ins)
struct match_find_quantizable_ops
{
if(ins->name() != "@literal")
return false;
bool all_same = false;
ins->get_literal().visit([&](auto s) {
all_same = std::all_of(s.begin() + 1, s.end(), [&](const auto& scale) {
return float_equal(scale, s.front());
static bool
is_valid_scale(instruction_ref scale, std::vector<std::size_t> lens, std::size_t axis)
{
return scale->get_shape().scalar() or scale->get_shape().elements() == lens.at(axis);
}
static bool is_valid_zero_point(instruction_ref zp)
{
if(not zp->can_eval())
return false;
bool all_zeros = false;
zp->eval().visit([&](auto z) {
all_zeros =
std::all_of(z.begin(), z.end(), [&](auto val) { return float_equal(val, 0); });
});
});
return all_same;
}
return all_zeros;
}
struct match_find_quantizable_ops
{
static auto
scale_broadcast_op(instruction_ref scale, std::vector<std::size_t> lens, std::size_t axis)
{
if(scale->get_shape().scalar())
{
return migraphx::make_op("multibroadcast", {{"out_lens", lens}});
}
else
{
return migraphx::make_op("broadcast", {{"out_lens", lens}, {"axis", axis}});
}
}
static auto dequantizelinear_op(const std::string& name, const std::string& scale)
// Helper function to insert quantized versions of any broadcasts and transpose ops that
// occur between dequantizelinear and the quantized op
static auto
propagate_quantized_ins(module& m, const instruction_ref dqins, const instruction_ref qop)
{
auto qinp = dqins->inputs().front();
auto next_ins = dqins;
while(next_ins != qop)
{
if(next_ins->name() != "dequantizelinear")
{
qinp = m.insert_instruction(qop, next_ins->get_operator(), qinp);
}
next_ins = next_ins->outputs().front();
}
return qinp;
}
static auto dequantizelinear_op(const std::string& scale, const std::string& zp)
{
return match::name("dequantizelinear")(
match::arg(0)(match::skip(match::name("quantizelinear"))(match::any().bind(name))),
match::arg(1)(match::skip_broadcasts(has_same_value().bind(scale))),
match::arg(2)(match::skip_broadcasts(match::all_of(match::has_value(0)))));
match::arg(0)(match::skip(match::name("quantizelinear"))(match::any())),
match::arg(1)(match::skip_broadcasts(match::is_constant().bind(scale))),
match::arg(2)(match::skip_broadcasts(match::is_constant().bind(zp))));
}
auto matcher() const
{
return match::name(get_quantizable_op_names())(
match::arg(0)(dequantizelinear_op("x1", "scale1")),
match::arg(1)(dequantizelinear_op("x2", "scale2")));
match::arg(0)(match::skip_broadcasts_transposes_contiguous(
dequantizelinear_op("scale1", "zp1").bind("dq1"))),
match::arg(1)(match::skip_broadcasts_transposes_contiguous(
dequantizelinear_op("scale2", "zp2").bind("dq2"))));
}
void apply(module& m, const match::matcher_result& r) const
{
auto qop = r.result;
auto q1 = r.instructions["x1"];
auto q2 = r.instructions["x2"];
auto dq1 = r.instructions["dq1"];
auto dq2 = r.instructions["dq2"];
auto scale1 = r.instructions["scale1"];
auto scale2 = r.instructions["scale2"];
auto zp1 = r.instructions["zp1"];
auto zp2 = r.instructions["zp2"];
// Only INT8 type currently supported
if(q1->get_shape().type() != migraphx::shape::int8_type or
q2->get_shape().type() != migraphx::shape::int8_type)
if(dq1->inputs().front()->get_shape().type() != migraphx::shape::int8_type or
dq2->inputs().front()->get_shape().type() != migraphx::shape::int8_type)
return;
double scale;
visit_all(scale1->get_literal(), scale2->get_literal())(
[&](const auto s1, const auto s2) { scale = s1.front() * s2.front(); });
// Only symmetric quantization supported (ie. non-zero zero_points not allowed)
if(not(is_valid_zero_point(zp1) and is_valid_zero_point(zp2)))
return;
// Only support scalar and 1D scales
if(scale1->get_shape().lens().size() != 1 or scale2->get_shape().lens().size() != 1)
return;
// Propagate q1 and q2 through any broadcasts and transposes before qop
auto qop_args = qop->inputs();
qop_args.at(0) = q1;
qop_args.at(1) = q2;
qop_args.at(0) = propagate_quantized_ins(m, dq1, qop);
qop_args.at(1) = propagate_quantized_ins(m, dq2, qop);
instruction_ref dq;
instruction_ref dq_scale;
instruction_ref out_scale;
instruction_ref zero_point;
if(qop->name() == "convolution")
{
auto conv_val = qop->get_operator().to_value();
dq = m.insert_instruction(
qop, migraphx::make_op("quant_convolution", conv_val), qop_args);
auto out_lens = dq->get_shape().lens();
// Input scale should always be scalar and weight scale can be scalar or 1D of the
// same lens as the output channel dim (dim 1 in the output)
if(not(is_valid_scale(scale1, out_lens, 1) and is_valid_scale(scale2, out_lens, 1)))
return;
auto s1_bcast =
m.insert_instruction(qop, scale_broadcast_op(scale1, out_lens, 1), scale1);
auto s2_bcast =
m.insert_instruction(qop, scale_broadcast_op(scale2, out_lens, 1), scale2);
out_scale = m.insert_instruction(qop, migraphx::make_op("mul"), s1_bcast, s2_bcast);
}
else if(qop->name() == "dot")
{
dq = m.insert_instruction(qop, migraphx::make_op("quant_dot"), qop_args);
dq = m.insert_instruction(qop, migraphx::make_op("quant_dot"), qop_args);
auto out_lens = dq->get_shape().lens();
// For (..., M, N) x (..., N, K) dot, only support cases where quantization axis is M
// for input1 and K for input 2
if(not(is_valid_scale(scale1, out_lens, out_lens.size() - 2) and
is_valid_scale(scale2, out_lens, out_lens.size() - 1)))
return;
auto s1_bcast = m.insert_instruction(
qop, scale_broadcast_op(scale1, out_lens, out_lens.size() - 2), scale1);
auto s2_bcast = m.insert_instruction(
qop, scale_broadcast_op(scale2, out_lens, out_lens.size() - 1), scale2);
out_scale = m.insert_instruction(qop, migraphx::make_op("mul"), s1_bcast, s2_bcast);
}
auto ins_type = qop->get_shape().type();
dq_scale = m.add_literal(literal({ins_type}, {scale}));
auto lens = dq->get_shape().lens();
auto scale_mb =
m.insert_instruction(qop, make_op("multibroadcast", {{"out_lens", lens}}), dq_scale);
dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, scale_mb);
dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, out_scale);
m.replace_instruction(qop, dq);
}
};
......
......@@ -103,8 +103,6 @@ struct find_reshaper
auto input = mr.instructions["x"];
auto dims = ins->get_shape().lens();
if(not input->get_shape().standard())
input = m.insert_instruction(ins, make_op("contiguous"), input);
m.replace_instruction(ins, make_op("reshape", {{"dims", dims}}), input);
}
};
......@@ -475,9 +473,8 @@ struct find_resize
ins_rsp, migraphx::make_op("reshape", {{"dims", in_dims}}), in_rsp);
auto mb_rsp = m.insert_instruction(
ins_rsp, migraphx::make_op("multibroadcast", {{"out_lens", out_dims}}), rsp_data);
auto std_mb = m.insert_instruction(ins, migraphx::make_op("contiguous"), mb_rsp);
std::vector<int64_t> rsp_dims(out_lens.begin(), out_lens.end());
m.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), std_mb);
m.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), mb_rsp);
}
};
......@@ -626,9 +623,8 @@ struct find_transpose_contiguous_reshaper_unary
auto cont_ins = r.instructions["cont_ins"];
auto unary_op_name = ins->get_operator().name();
auto unary_ins = m.insert_instruction(cont_ins, make_op(unary_op_name), trans_ins);
auto new_cont_ins = m.insert_instruction(cont_ins, make_op("contiguous"), unary_ins);
// older cont and reshape are removed by deadcode elimination
m.replace_instruction(ins, reshaper_ins->get_operator(), new_cont_ins);
m.replace_instruction(ins, reshaper_ins->get_operator(), unary_ins);
}
};
......@@ -647,8 +643,8 @@ struct find_broadcast_transpose
{
auto transpose = r.result;
auto transpose_lens = transpose->get_shape().lens();
auto bcast_ins = r.instructions["bcast_ins"];
auto input = bcast_ins->inputs().front();
auto bcast_ins = r.instructions["bcast_ins"];
auto input = bcast_ins->inputs().front();
// scalar transformation does not need extra transpose
if(not input->get_shape().scalar())
{
......
......@@ -74,21 +74,27 @@ if(MIGRAPHX_ENABLE_ZENDNN)
target_link_libraries(migraphx_cpu PRIVATE ${BLIS_LIB})
target_link_libraries(migraphx_cpu PRIVATE ${ZENDNN_LIB})
else()
target_link_libraries(migraphx_cpu PRIVATE DNNL::dnnl)
target_link_libraries(migraphx_cpu PUBLIC DNNL::dnnl)
endif()
target_link_libraries(migraphx_cpu PRIVATE migraphx)
migraphx_generate_export_header(migraphx_cpu)
find_package(OpenMP)
target_link_libraries(migraphx_cpu PUBLIC OpenMP::OpenMP_CXX)
# Add library path to rpath to workaround issues with our broken packages
foreach(LIBRARY ${OpenMP_CXX_LIBRARIES})
if(LIBRARY MATCHES "libomp")
get_filename_component(LIBRARY_PATH "${LIBRARY}" PATH)
target_link_libraries(migraphx_cpu PUBLIC -Wl,-rpath=${LIBRARY_PATH} -Wl,-rpath-link=${LIBRARY_PATH})
endif()
endforeach()
if(WIN32)
target_link_libraries(migraphx_cpu PUBLIC libomp)
target_include_directories(migraphx_cpu PUBLIC ${OpenMP_CXX_INCLUDE_DIRS})
target_compile_options(migraphx_cpu PUBLIC ${OpenMP_CXX_FLAGS})
else()
target_link_libraries(migraphx_cpu PUBLIC OpenMP::OpenMP_CXX)
# Add library path to rpath to workaround issues with our broken packages
foreach(LIBRARY ${OpenMP_CXX_LIBRARIES})
if(LIBRARY MATCHES "libomp")
get_filename_component(LIBRARY_PATH "${LIBRARY}" PATH)
target_link_libraries(migraphx_cpu PUBLIC -Wl,-rpath=${LIBRARY_PATH} -Wl,-rpath-link=${LIBRARY_PATH})
endif()
endforeach()
endif()
rocm_install_targets(
TARGETS migraphx_cpu
......
# ####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
......@@ -22,20 +22,20 @@
# THE SOFTWARE.
# ####################################################################################
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
find_package(hip)
find_package(hip REQUIRED)
if(NOT GPU_TARGETS)
message(FATAL_ERROR "HIP package is broken and has no GPU_TARGETS, please pass -DGPU_TARGETS=$(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') to cmake to build for your gpu.")
set(fatal_msg "HIP package is broken and has no GPU_TARGETS. Please pass GPU_TARGETS to cmake.")
if(NOT WIN32)
set(fatal_msg "${fatal_msg}\nUse -DGPU_TARGETS=$(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') to build for your GPU.")
endif()
message(FATAL_ERROR ${fatal_msg})
endif()
find_package(miopen)
find_package(miopen REQUIRED)
message(STATUS "MIGraphX is using MIOpen")
# rocblas
find_package(rocblas REQUIRED PATHS /opt/rocm)
message(STATUS "Build with rocblas")
if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen")
endif()
find_package(rocblas REQUIRED)
message(STATUS "MIGraphX build with rocBLAS")
if(MIGRAPHX_USE_COMPOSABLEKERNEL)
find_package(composable_kernel 1.0.0 REQUIRED COMPONENTS jit_library)
......@@ -49,7 +49,6 @@ endif()
file(GLOB KERNEL_FILES CONFIGURE_DEPENDS
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp)
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
if(NOT MIGRAPHX_USE_COMPOSABLEKERNEL)
list(REMOVE_ITEM KERNEL_FILES
......@@ -66,8 +65,10 @@ file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/*
add_library(migraphx_device ${DEVICE_GPU_SRCS})
add_library(compile_for_gpu INTERFACE)
target_compile_options(compile_for_gpu INTERFACE -std=c++17 -fno-gpu-rdc -Wno-cuda-compat -Wno-unused-command-line-argument -Xclang -fallow-half-arguments-and-returns)
target_link_libraries(compile_for_gpu INTERFACE hip::device -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument -Wno-option-ignored)
target_compile_features(compile_for_gpu INTERFACE cxx_std_17)
target_compile_options(compile_for_gpu INTERFACE -fno-gpu-rdc -Wno-cuda-compat -Wno-unused-command-line-argument -Xclang -fallow-half-arguments-and-returns)
target_link_options(compile_for_gpu INTERFACE -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument -Wno-option-ignored)
target_link_libraries(compile_for_gpu INTERFACE hip::device)
check_cxx_compiler_flag("--cuda-host-only -fhip-lambda-host-device -x hip" HAS_HIP_LAMBDA_HOST_DEVICE)
if(HAS_HIP_LAMBDA_HOST_DEVICE)
......@@ -211,8 +212,10 @@ if(MIGRAPHX_ENABLE_MLIR)
endif()
if(MIGRAPHX_USE_HIPRTC)
find_package(hiprtc REQUIRED)
message(STATUS "MIGraphX is using hipRTC")
target_compile_definitions(migraphx_gpu PRIVATE -DMIGRAPHX_USE_HIPRTC=1)
target_link_libraries(migraphx_gpu PUBLIC hiprtc::hiprtc)
else()
message(STATUS "MIGraphX is using HIP Clang")
......@@ -221,34 +224,45 @@ else()
target_flags(HIP_COMPILER_FLAGS hip::device)
# Remove cuda arch flags
string(REGEX REPLACE --cuda-gpu-arch=[a-z0-9]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
string(REGEX REPLACE --offload-arch=[a-z0-9:+-]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
string(REGEX REPLACE "--cuda-gpu-arch=[a-z0-9]+ ?" "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
string(REGEX REPLACE "--offload-arch=[a-z0-9:+-]+ ?" "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
# Skip library paths since hip will incorrectly treat it as a source file
string(APPEND HIP_COMPILER_FLAGS " ")
if(WIN32)
string(REPLACE "\\" "/" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
endif()
foreach(_unused RANGE 2)
string(REGEX REPLACE " /[^ ]+\\.(a|so) " " " HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
endforeach()
message(STATUS "Hip compiler flags: ${HIP_COMPILER_FLAGS}")
message(STATUS "Hip compiler flags: \"${HIP_COMPILER_FLAGS}\"")
target_compile_definitions(migraphx_gpu PRIVATE
"-DMIGRAPHX_HIP_COMPILER=${CMAKE_CXX_COMPILER}"
"-DMIGRAPHX_HIP_COMPILER_FLAGS=${HIP_COMPILER_FLAGS}"
-DMIGRAPHX_HIP_COMPILER="${CMAKE_CXX_COMPILER}"
-DMIGRAPHX_HIP_COMPILER_FLAGS="${HIP_COMPILER_FLAGS}"
)
if(DEFINED CMAKE_CXX_COMPILER_LAUNCHER)
execute_process(COMMAND which ${CMAKE_CXX_COMPILER_LAUNCHER} OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER)
if(WIN32)
execute_process(COMMAND where ${CMAKE_CXX_COMPILER_LAUNCHER} OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER)
else()
execute_process(COMMAND which ${CMAKE_CXX_COMPILER_LAUNCHER} OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER)
endif()
string(STRIP "${MIGRAPHX_HIP_COMPILER_LAUNCHER}" MIGRAPHX_HIP_COMPILER_LAUNCHER)
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_HIP_COMPILER_LAUNCHER=${MIGRAPHX_HIP_COMPILER_LAUNCHER}")
target_compile_definitions(migraphx_gpu PRIVATE -DMIGRAPHX_HIP_COMPILER_LAUNCHER="${MIGRAPHX_HIP_COMPILER_LAUNCHER}")
endif()
endif()
# Check miopen find mode api
include(CheckLibraryExists)
get_target_property(MIOPEN_LOCATION MIOpen LOCATION)
get_target_property(ROCBLAS_LOCATION roc::rocblas LOCATION)
check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API)
check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API)
# Beta API for automated GEMM tuning
check_library_exists(roc::rocblas "rocblas_gemm_ex_get_solutions" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_TUNING_BETA_FEATURE_API)
set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "")
......@@ -271,6 +285,13 @@ else()
message(STATUS "MIOpen does not have find mode api")
endif()
if(HAS_ROCBLAS_TUNING_BETA_FEATURE_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_TUNING_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS)
message(STATUS "MIGraphx is using Beta API of rocBLAS")
else()
message(STATUS "rocBLAS does not have User Tuning Beta API")
endif()
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
if(MIGRAPHX_USE_COMPOSABLEKERNEL)
......
......@@ -251,10 +251,21 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
std::cout << std::string(src.content) << std::endl;
}
}
auto fname = fs::path{"migraphx-hiprtc-driver"};
#ifdef _WIN32
fname.replace_extension(".exe");
#endif
auto p = dynamic_loader::path(&compile_hip_src_with_hiprtc);
auto driver = p.parent_path().parent_path() / "bin" / "migraphx-hiprtc-driver";
auto driver = p.parent_path() / fname;
bool found = fs::exists(driver);
if(not found)
{
driver = p.parent_path().parent_path() / "bin" / fname;
found = fs::exists(driver);
}
if(fs::exists(driver))
if(found)
{
value v;
v["srcs"] = to_value(hsrcs);
......@@ -284,16 +295,20 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
bool is_hip_clang_compiler()
{
static const auto result = ends_with(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER), "clang++");
static const auto result = fs::path{MIGRAPHX_HIP_COMPILER}.stem() == "clang++";
return result;
}
#ifdef MIGRAPHX_HIP_COMPILER_LAUNCHER
bool has_compiler_launcher()
{
static const auto result = fs::exists(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_LAUNCHER));
static const auto result = fs::exists(MIGRAPHX_HIP_COMPILER_LAUNCHER);
return result;
}
#endif
src_compiler assemble(src_compiler compiler)
{
compiler.out_ext = ".S";
......@@ -306,8 +321,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
{
assert(not srcs.empty());
if(not is_hip_clang_compiler())
MIGRAPHX_THROW("Unknown hip compiler: " +
std::string(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER)));
MIGRAPHX_THROW("Unknown hip compiler: " MIGRAPHX_HIP_COMPILER);
if(params.find("-std=") == std::string::npos)
params += " --std=c++17";
......@@ -323,14 +337,14 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
params += " -DMIGRAPHX_DEBUG";
params += " -Wno-unused-command-line-argument -Wno-cuda-compat ";
params += MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_FLAGS);
params += MIGRAPHX_HIP_COMPILER_FLAGS;
src_compiler compiler;
compiler.flags = params;
compiler.compiler = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER);
compiler.compiler = MIGRAPHX_HIP_COMPILER;
#ifdef MIGRAPHX_HIP_COMPILER_LAUNCHER
if(has_compiler_launcher())
compiler.launcher = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_LAUNCHER);
compiler.launcher = MIGRAPHX_HIP_COMPILER_LAUNCHER;
#endif
if(enabled(MIGRAPHX_GPU_DUMP_SRC{}))
{
......@@ -354,7 +368,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
bool hip_has_flags(const std::vector<std::string>& flags)
{
src_compiler compiler;
compiler.compiler = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER);
compiler.compiler = MIGRAPHX_HIP_COMPILER;
compiler.flags =
join_strings(flags, " ") + " -x hip -c --offload-arch=gfx900 --cuda-device-only";
......
......@@ -168,6 +168,7 @@ struct compile_plan
}
const compiled_result& benchmark(problem_cache& pc) const
{
const auto trace_level = value_of(MIGRAPHX_TRACE_BENCHMARKING{});
if(results.empty())
MIGRAPHX_THROW("No configs to tune");
if(results.size() == 1)
......@@ -178,9 +179,10 @@ struct compile_plan
}
if(not config)
MIGRAPHX_THROW("Multiple kernels without config");
std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs"
<< std::endl;
if(enabled(MIGRAPHX_TRACE_BENCHMARKING{}))
if(trace_level > 0)
std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs"
<< std::endl;
if(trace_level > 1)
std::cout << "Problem: " << config->problem << std::endl;
std::vector<double> times;
times.reserve(results.size());
......@@ -189,22 +191,23 @@ struct compile_plan
config->solutions.begin(),
std::back_inserter(times),
[&](const auto& cr, const auto& solution) {
if(enabled(MIGRAPHX_TRACE_BENCHMARKING{}))
if(trace_level > 1)
std::cout << "Benchmarking solution: " << solution << std::endl;
if(not cr.has_value())
{
if(enabled(MIGRAPHX_TRACE_BENCHMARKING{}))
if(trace_level > 1)
std::cout << "No binary" << std::endl;
return std::numeric_limits<double>::max();
}
auto t = time_op(
*ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20);
if(enabled(MIGRAPHX_TRACE_BENCHMARKING{}))
if(trace_level > 1)
std::cout << t << "ms" << std::endl;
return t;
});
auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end()));
std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl;
if(trace_level > 0)
std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl;
pc.insert(preop.name(), config->problem, config->solutions.at(i));
if(not results[i].has_value())
MIGRAPHX_THROW("No valid tuned compilation.");
......
......@@ -146,20 +146,20 @@ __device__ __host__ T to_hip_type(T x)
// Hip doens't support __fp16
inline __device__ __host__ float to_hip_type(gpu_half x) { return x; }
#define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \
struct trait : std::trait<X> \
{ \
}; \
\
template <> \
struct trait<T> : std::true_type \
{ \
#define MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \
struct trait : std::trait<X> \
{ \
}; \
\
template <> \
struct trait<T> : std::true_type \
{ \
};
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, __fp16)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, __fp16)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, __fp16)
MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, __fp16)
MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_signed, __fp16)
MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, __fp16)
} // namespace device
} // namespace gpu
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -21,15 +21,20 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <rocblas/rocblas.h>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/time.hpp>
using microseconds = std::chrono::duration<double, std::micro>;
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
// Convert rocBLAS datatypes to equivalent Migraphx data types
rocblas_datatype get_type(shape::type_t type)
{
switch(type)
......@@ -41,6 +46,7 @@ rocblas_datatype get_type(shape::type_t type)
case shape::uint8_type: return rocblas_datatype_u8_r;
case shape::int32_type: return rocblas_datatype_i32_r;
case shape::uint32_type: return rocblas_datatype_u32_r;
case shape::fp8e4m3fnuz_type:
case shape::tuple_type:
case shape::bool_type:
case shape::uint16_type:
......@@ -81,184 +87,508 @@ shape transpose_batch(const shape& s, unsigned trans_batch)
return shape::from_permutation(s.type(), s.lens(), perm);
}
template <class R, class... Ts, class... Us>
R rocblas_invoke(R (*f)(Ts...), Us... xs)
/**
* Returns results of rocblas_status_success, rocblas_status_perf_degraded,
* or rocblas_status_invalid_value. Caller
* is expected to check for invalid index. Any other result causes an exception.
*
*/
template <class F, class Pack, class... Ts>
auto rocblas_invoke(F f, Pack p, Ts... xs)
{
if constexpr(sizeof...(Ts) == sizeof...(Us))
return f(xs...);
else
return f(xs..., nullptr, nullptr);
return p([=](auto... ws) {
auto status = f(ws..., xs...);
if(status != rocblas_status_success and status != rocblas_status_invalid_value)
{
if(status == rocblas_status_perf_degraded)
{
std::cerr << "WARNING: degraded perf. in rocBLAS call" << std::endl;
}
else
MIGRAPHX_THROW("rocblas_invoke: rocBLAS call failed with status " +
std::to_string(status));
}
return status;
});
}
static bool is_transposed(const shape& s)
{
if(not s.transposed())
return false;
return s.strides().back() != 1;
}
static bool is_transposed(const shape& s) { return s.transposed() and s.strides().back() != 1; }
static rocblas_int get_batch_stride(const argument& a)
static rocblas_int get_batch_stride(const shape& s)
{
return a.get_shape().strides()[a.get_shape().strides().size() - 3];
// This value is not needed for non-strided inputs
if(s.strides().size() < 3)
return 0;
else
return s.strides()[s.strides().size() - 3];
}
template <class T>
void gemm_impl(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
T alpha,
T beta,
bool compute_fp32)
/**
* Wrapper for multiple rocBLAS calls. The constructor creates parameters for
* these calls based on data shapes and other values contained in the associated
* instruction and operation.
*
* The template parameter T is not the type of the matrix data but of the weighting
* coefficients alpha and beta (these are float in rocBLAS internals)
*/
template <typename T>
struct gemm_impl
{
const bool is_3inputs = (args.size() == 4);
if(not is_3inputs)
{
beta = 0;
}
bool transa = is_transposed(args[0].get_shape());
bool transb = is_transposed(args[1].get_shape());
auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2;
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0];
rocblas_int ldd = is_3inputs ? args[3].get_shape().strides()[dim_0] : ldc;
rocblas_datatype arg_type = get_type(args[0].get_shape().type());
auto output_type = arg_type;
if(output_type == rocblas_datatype_i8_r)
{
output_type = rocblas_datatype_i32_r;
}
auto compute_type = output_type;
if(compute_fp32)
gemm_impl(const shape& output_shape,
const std::vector<shape>& input_shapes,
T alpha_param,
T beta_param,
bool compute_fp32_flag)
: alpha(alpha_param),
beta(beta_param),
is_3inputs(input_shapes.size() == 4),
compute_fp32(compute_fp32_flag)
{
if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r;
}
if(not is_3inputs)
{
beta = 0;
}
rocblas_gemm_flags flag = rocblas_gemm_flags_none;
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) {
auto alpha_r = as(alpha);
auto beta_r = as(beta);
// Create lambdas that will cast alpha, beta to the output shape's type
// and retain the values being pointed to
output_shape.visit_type([&](auto as) {
auto alpha_r = as(alpha);
auto beta_r = as(beta);
if(compute_fp32)
{
get_alpha = [=] { return &alpha; };
get_beta = [=] { return &beta; };
}
else
{
get_alpha = [=] { return &alpha_r; };
get_beta = [=] { return &beta_r; };
}
});
// use void pointer to select different data type if using fp32 mode
void* alpha_v = &alpha_r;
void* beta_v = &beta_r;
transa = is_transposed(input_shapes[0]);
transb = is_transposed(input_shapes[1]);
auto n_dim = output_shape.lens().size();
auto dim_0 = n_dim - 2;
auto dim_1 = n_dim - 1;
// Leading dimensions of matrices
lda = input_shapes[0].strides()[transa ? dim_1 : dim_0];
ldb = input_shapes[1].strides()[transb ? dim_1 : dim_0];
ldc = input_shapes[2].strides()[dim_0];
ldd = is_3inputs ? input_shapes[3].strides()[dim_0] : ldc;
arg_type = get_type(input_shapes[0].type());
output_type = arg_type;
if(output_type == rocblas_datatype_i8_r)
{
output_type = rocblas_datatype_i32_r;
}
compute_type = output_type;
if(compute_fp32)
{
alpha_v = &alpha;
beta_v = &beta;
if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r;
}
auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1];
auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); };
auto a_lens = input_shapes[0].lens();
auto b_lens = input_shapes[1].lens();
auto num_matrices = std::accumulate(
auto out_lens = output_shape.lens();
m = out_lens[dim_0];
n = out_lens[dim_1];
k = input_shapes[0].lens()[dim_1];
a_stride = get_batch_stride(input_shapes[0]);
b_stride = get_batch_stride(input_shapes[1]);
c_stride = get_batch_stride(input_shapes[2]);
d_stride = is_3inputs ? get_batch_stride(input_shapes[3]) : c_stride;
num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
if(num_matrices == 1 or (num_matrices > 1 and get_batch_stride(args[1]) == 0))
strided_batched = num_matrices > 1;
if(strided_batched and b_stride == 0 and input_shapes[0].standard())
{
// If the batch dimension of B is broadcasted, then we can
// multiply m by the batch_size and use rocblas_gemm_ex
// instead of rocblas_gemm_strided_batched_ex.
m *= num_matrices;
strided_batched = false;
}
}
// the rocblas_gemm API handles inputs and output matrices as
// column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm.
void run(context& ctx, const std::vector<argument>& input_args, int32_t solution_idx = 0) const
{
if(strided_batched)
{
auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args,
rocblas_gemm_algo_solution_index,
solution_idx,
gemm_flags);
}
else
{
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
alpha_v,
to_pointer(args.at(1)),
arg_type,
ldb,
to_pointer(args.at(0)),
arg_type,
lda,
beta_v,
to_pointer(args[2]),
output_type,
ldc,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type,
ldd,
compute_type,
rocblas_gemm_algo_standard,
0,
flag);
common_args,
rocblas_gemm_algo_solution_index,
solution_idx,
gemm_flags);
}
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
auto validate(context& ctx, const std::vector<shape>& input_shapes, int32_t solution_idx) const
{
// Create dummy arguments for the shapes, and call the overloaded method
std::vector<argument> input_args;
std::transform(input_shapes.begin(),
input_shapes.end(),
std::back_inserter(input_args),
[](const shape& x) { return to_gpu(generate_argument(x)); });
return validate(ctx, input_args, solution_idx);
}
/**
* Checks a particular solution for validity by running it with the flag
* rocblas_gemm_flags_check_solution_index (could be invalid if this model was
* tuned with a different rocBLAS version)
*
* @return Returns either solution_idx if valid, or else the default value 0
* if not. The default does not mean list index 0, but tells the picker
* to choose a solution.
*/
int32_t
validate(context& ctx, const std::vector<argument>& input_args, int32_t solution_idx) const
{
rocblas_status_ check_valid(rocblas_status_success);
if(strided_batched)
{
auto common_args = create_strided_batched_args_common(ctx, input_args);
check_valid = rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args,
rocblas_gemm_algo_solution_index,
solution_idx,
rocblas_gemm_flags_check_solution_index);
}
else
{
auto a_stride = get_batch_stride(args[0]);
auto b_stride = get_batch_stride(args[1]);
auto c_stride = get_batch_stride(args[2]);
auto d_stride = is_3inputs ? get_batch_stride(args[3]) : c_stride;
rocblas_invoke(&rocblas_gemm_strided_batched_ex,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
alpha_v,
to_pointer(args.at(1)),
arg_type,
ldb,
b_stride,
to_pointer(args.at(0)),
arg_type,
lda,
a_stride,
beta_v,
to_pointer(args[2]),
output_type,
ldc,
c_stride,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type,
ldd,
d_stride,
num_matrices,
compute_type,
rocblas_gemm_algo_standard,
0,
flag);
auto common_args = create_gemm_ex_args_common(ctx, input_args);
check_valid = rocblas_invoke(&rocblas_gemm_ex,
common_args,
rocblas_gemm_algo_solution_index,
solution_idx,
rocblas_gemm_flags_check_solution_index);
}
});
if(check_valid == rocblas_status_invalid_value)
{
std::cerr << "WARNING: tuned solution is invalid; reverting to default" << std::endl;
return 0;
}
return solution_idx;
}
#endif
/**
* Helper method to create that subset of a long rocBLAS argument list that is common
* to multiple "...strided_batched..." calls.
*
* The rocblas_gemm API handles inputs and output matrices as
* column-major format. When doing a C = A * B, we actually do
* C^T = (B^T) * (A^T). That is the reason we input args[1] as
* A and args[0] as B in calling the rocblas_gemm.
*
*/
auto create_strided_batched_args_common(context& ctx, const std::vector<argument>& args) const
{
return pack(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
get_alpha(),
args[1].data(),
arg_type,
ldb,
b_stride,
args[0].data(),
arg_type,
lda,
a_stride,
get_beta(),
args[2].data(),
output_type,
ldc,
c_stride,
is_3inputs ? args[3].data() : args[2].data(),
output_type,
ldd,
d_stride,
num_matrices,
compute_type);
}
/**
* Helper method to create that subset of a long rocBLAS argument list that is common
* to multiple "gemm_ex..." calls.
*
* The rocblas_gemm API handles inputs and output matrices as
* column-major format. When doing a C = A * B, we actually do
* C^T = (B^T) * (A^T). That is the reason we input args[1] as
* A and args[0] as B in calling the rocblas_gemm.
*
* */
auto create_gemm_ex_args_common(context& ctx, const std::vector<argument>& args) const
{
return pack(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
get_alpha(),
args[1].data(),
arg_type,
ldb,
args[0].data(),
arg_type,
lda,
get_beta(),
args[2].data(),
output_type,
ldc,
is_3inputs ? args[3].data() : args[2].data(),
output_type,
ldd,
compute_type);
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
/**
* Find best rocBLAS solution: Get list of solutions and try them all, returning the index
* of the fastest one.
*/
int tune(context& ctx, const std::vector<shape>& input_shapes) const
{
// tuning meta parameters
const int hot_calls = 40;
std::vector<argument> input_args;
std::transform(input_shapes.begin(),
input_shapes.end(),
std::back_inserter(input_args),
[](const shape& x) { return to_gpu(generate_argument(x)); });
// Get the solutions list in 2 rocBLAS steps:
// 1. Find out how many solutions there are and allocate the array
// 2. Get the solutions
//
rocblas_int list_size = 0;
std::vector<rocblas_int> solution_indices;
if(strided_batched)
{
auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_args,
rocblas_gemm_algo_solution_index,
gemm_flags,
nullptr,
&list_size);
solution_indices.resize(list_size);
auto common_sol_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_sol_args,
rocblas_gemm_algo_solution_index,
gemm_flags,
solution_indices.data(),
&list_size);
}
else
{
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex_get_solutions,
common_args,
rocblas_gemm_algo_solution_index,
gemm_flags,
nullptr,
&list_size);
solution_indices.resize(list_size);
auto common_sol_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex_get_solutions,
common_sol_args,
rocblas_gemm_algo_solution_index,
gemm_flags,
solution_indices.data(),
&list_size);
}
double best_time = std::numeric_limits<double>::max();
double first_time = -1;
// Initialize to default solution index
rocblas_int best_sol = 0;
for(auto sol : solution_indices)
{
// Warmup: the first call to an op. may not be representative since there is
// more time taken initializing caches, etc. so we won't time it.
run(ctx, input_args, sol);
double host_time = time<milliseconds>([&] {
for([[maybe_unused]] int hc : range(hot_calls))
run(ctx, input_args, sol);
ctx.finish();
});
host_time /= hot_calls;
// dev/evaluation only: track time for first solution.
if(first_time < 0)
first_time = host_time;
// track current best
if(host_time < best_time)
{
best_sol = sol;
best_time = host_time;
}
}
std::cout << "Winning GEMM solution: " << best_sol << " in " << best_time << " ms, beats "
<< first_time << "ms" << std::endl;
return best_sol;
}
#endif
private:
size_t num_matrices = 0;
rocblas_int m = 0;
rocblas_int n = 0;
rocblas_int k = 0;
bool transa = false;
bool transb = false;
T alpha = 0;
T beta = 0;
std::function<const void*()> get_alpha{};
std::function<const void*()> get_beta{};
rocblas_gemm_flags gemm_flags = rocblas_gemm_flags_none;
rocblas_int lda = 0;
rocblas_int ldb = 0;
rocblas_int ldc = 0;
rocblas_int ldd = 0;
rocblas_int a_stride = 0;
rocblas_int b_stride = 0;
rocblas_int c_stride = 0;
rocblas_int d_stride = 0;
rocblas_datatype compute_type = rocblas_datatype_f32_r;
rocblas_datatype arg_type = rocblas_datatype_f32_r;
rocblas_datatype output_type = rocblas_datatype_f32_r;
bool strided_batched = true;
bool is_3inputs = true;
bool compute_fp32 = true;
}; // gemm_impl
void gemm_compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
float alpha,
float beta,
bool compute_fp32,
int32_t solution_idx)
{
std::vector<shape> input_shapes;
std::transform(args.begin(),
args.end(),
std::back_inserter(input_shapes),
[](const argument& x) { return x.get_shape(); });
auto gemm_item = gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
gemm_item.run(ctx, args, solution_idx);
}
void gemm(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
float alpha,
float beta,
bool compute_fp32)
void gemm_compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
int32_t alpha,
int32_t beta,
bool compute_fp32,
int32_t solution_idx)
{
gemm_impl(ctx, output_shape, args, alpha, beta, compute_fp32);
std::vector<shape> input_shapes;
std::transform(args.begin(),
args.end(),
std::back_inserter(input_shapes),
[](const argument& x) { return x.get_shape(); });
auto gemm_item = gemm_impl<int32_t>(output_shape, input_shapes, alpha, beta, compute_fp32);
gemm_item.run(ctx, args, solution_idx);
}
/**
* Decides if the tune() or validate() method is appropriate and calls it.
* Return value is the chosen solution index, or 0 to let picker choose it.
*/
int32_t gemm_finalize(context& ctx,
const shape& output_shape,
const std::vector<shape>& input_shapes,
float alpha,
float beta,
bool compute_fp32,
int32_t solution_idx)
{
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
// This code should be called only if either the environment var.
// MIGRAPHX_ENABLE_GEMM_TUNING, or option --exhaustive-tune, is set
if(solution_idx == 0)
{
auto gemm_item = gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.tune(ctx, input_shapes);
}
else
{
// If a tuned solution index is already given, don't tune again but validate
// in case the data was tuned with a different rocBLAS version
auto gemm_item = gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.validate(ctx, input_shapes, solution_idx);
}
#else
(void)ctx, (void)output_shape, (void)input_shapes;
(void)alpha, (void)beta, (void)compute_fp32;
#endif
return solution_idx;
}
void gemm(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
int32_t alpha,
int32_t beta,
bool compute_fp32)
/**
* Decides if the tune() or validate() method is appropriate and calls it.
* Return value is the chosen solution index, or 0 to let picker choose it.
*/
int32_t gemm_finalize(context& ctx,
const shape& output_shape,
const std::vector<shape>& input_shapes,
int32_t alpha,
int32_t beta,
bool compute_fp32,
int32_t solution_idx)
{
gemm_impl(ctx, output_shape, args, alpha, beta, compute_fp32);
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
if(solution_idx == 0)
{
auto gemm_item = gemm_impl<int32_t>(output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.tune(ctx, input_shapes);
}
else
{
// If a tuned solution index is already given, don't tune again but validate
// in case the data was tuned with a different rocBLAS version
auto gemm_item = gemm_impl<int32_t>(output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.validate(ctx, input_shapes, solution_idx);
}
#else
(void)ctx, (void)output_shape, (void)input_shapes;
(void)alpha, (void)beta, (void)compute_fp32;
#endif
return solution_idx;
}
} // namespace gpu
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -40,9 +40,8 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
void blas_shape(const shape& s);
shape transpose_batch(const shape& s, unsigned trans_batch);
void blas_shape(const shape& s);
template <class Op>
struct rocblas_gemm
......@@ -52,6 +51,7 @@ struct rocblas_gemm
float beta = 0;
bool compute_fp32 = false;
unsigned trans_batch = 0;
int32_t solution_idx = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -60,7 +60,8 @@ struct rocblas_gemm
pack(f(self.alpha, "alpha"),
f(self.beta, "beta"),
f(self.compute_fp32, "compute_fp32"),
f(self.trans_batch, "trans_batch")));
f(self.trans_batch, "trans_batch"),
f(self.solution_idx, "solution_idx")));
}
std::string name() const
......@@ -76,6 +77,8 @@ struct rocblas_gemm
{
std::vector<shape> in_shapes(inputs);
in_shapes.pop_back();
// When input shapes are A, B, C the GEMM equation is C  =  α AB+ β C where α, β are
// scalars
check_shapes{in_shapes, *this}.has(2, 3);
blas_shape(inputs[0]);
blas_shape(inputs[1]);
......@@ -111,11 +114,12 @@ struct rocblas_gemm
{
if(this->name() == "gpu::gemm")
{
gemm(ctx, output_shape, args, alpha, beta, compute_fp32);
gemm_compute(ctx, output_shape, args, alpha, beta, compute_fp32, solution_idx);
}
else
{
gemm(ctx, output_shape, args, int32_t(alpha), int32_t(beta), compute_fp32);
gemm_compute(
ctx, output_shape, args, int32_t(alpha), int32_t(beta), compute_fp32, solution_idx);
}
return args.back();
}
......@@ -124,6 +128,33 @@ struct rocblas_gemm
{
return shapes.size() - 1;
}
void finalize(context& ctx, const shape& output_shape, const std::vector<shape>& input_shapes)
{
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
if(enabled(MIGRAPHX_ENABLE_GEMM_TUNING{}) or ctx.get_exhaustive_tune_flag())
{
if(this->name() == "gpu::gemm")
{
solution_idx = gemm_finalize(
ctx, output_shape, input_shapes, alpha, beta, compute_fp32, solution_idx);
}
else
{
solution_idx = gemm_finalize(ctx,
output_shape,
input_shapes,
int32_t(alpha),
int32_t(beta),
compute_fp32,
solution_idx);
}
}
#else
// suppress compiler warnings
(void)ctx, (void)output_shape, (void)input_shapes;
#endif
}
};
} // namespace gpu
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -24,26 +24,64 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_GEMM_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_GEMM_IMPL_HPP
#include <iterator>
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/context.hpp>
// Set this environment variable to "true" to perform GEMM tuning even when the
// --exhaustive-tune option isn't set. Can be used to skip slow convolution tuning.
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_GEMM_TUNING);
using milliseconds = std::chrono::duration<double, std::milli>;
using microseconds = std::chrono::duration<double, std::micro>;
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
void gemm(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
float alpha,
float beta,
bool compute_fp32);
void gemm(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
int32_t alpha,
int32_t beta,
bool compute_fp32);
/**
* @brief Templated implementations of the compute() and finalize() methods of the Gemm operator.
* For each function there are overloads using either float or int32_t for the arguments
* alpha and beta.
*
* @param ctx .
* @param output_shape .
* @param args .
* @param alpha .
* @param beta .
* @param compute_fp32 .
*/
void gemm_compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
float alpha,
float beta,
bool compute_fp32,
int32_t solution_idx);
void gemm_compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
int32_t alpha,
int32_t beta,
bool compute_fp32,
int32_t solution_idx);
int32_t gemm_finalize(context& ctx,
const shape& output_shape,
const std::vector<shape>& input_shapes,
float alpha,
float beta,
bool compute_fp32);
int32_t gemm_finalize(context& ctx,
const shape& output_shape,
const std::vector<shape>& input_shapes,
int32_t alpha,
int32_t beta,
bool compute_fp32,
int32_t solution_idx);
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......
......@@ -101,7 +101,9 @@ MIGRAPHX_DEVICE_MATH(erf, ::erf)
MIGRAPHX_DEVICE_MATH(exp, ::exp)
MIGRAPHX_DEVICE_MATH(floor, ::floor)
MIGRAPHX_DEVICE_MATH(isnan, ::isnan)
MIGRAPHX_DEVICE_MATH(isinf, ::isinf)
MIGRAPHX_DEVICE_MATH(log, ::log)
MIGRAPHX_DEVICE_MATH(nearbyint, ::nearbyint)
MIGRAPHX_DEVICE_MATH(pow, ::pow)
MIGRAPHX_DEVICE_MATH(remainder, ::remainder)
MIGRAPHX_DEVICE_MATH(round, ::round)
......@@ -135,6 +137,7 @@ MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, ceil, ::hceil)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, cos, ::hcos)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, exp, ::hexp)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, floor, ::hfloor)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, isinf, ::__hisinf)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, isnan, ::__hisnan)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, log, ::hlog)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, rsqrt, ::hrsqrt)
......@@ -150,6 +153,7 @@ MIGRAPHX_DEVICE_MATH_HALF(atan, ::atan)
MIGRAPHX_DEVICE_MATH_HALF(atanh, ::atanh)
MIGRAPHX_DEVICE_MATH_HALF(cosh, ::cosh)
MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf)
MIGRAPHX_DEVICE_MATH_HALF(nearbyint, ::nearbyint)
MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow)
MIGRAPHX_DEVICE_MATH_HALF(remainder, ::remainder)
MIGRAPHX_DEVICE_MATH_HALF(round, ::round)
......@@ -229,10 +233,12 @@ MIGRAPHX_DEVICE_MATH_VEC(erf)
MIGRAPHX_DEVICE_MATH_VEC(exp)
MIGRAPHX_DEVICE_MATH_VEC(floor)
MIGRAPHX_DEVICE_MATH_VEC(fmod)
MIGRAPHX_DEVICE_MATH_VEC(isinf)
MIGRAPHX_DEVICE_MATH_VEC(isnan)
MIGRAPHX_DEVICE_MATH_VEC(log)
MIGRAPHX_DEVICE_MATH_VEC(max)
MIGRAPHX_DEVICE_MATH_VEC(min)
MIGRAPHX_DEVICE_MATH_VEC(nearbyint)
MIGRAPHX_DEVICE_MATH_VEC(pow)
MIGRAPHX_DEVICE_MATH_VEC(remainder)
MIGRAPHX_DEVICE_MATH_VEC(round)
......
......@@ -28,7 +28,9 @@
#include <migraphx/register_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
#include <migraphx/gpu/ck.hpp>
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -122,6 +124,8 @@ struct find_add_layernorm
}
};
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
struct pre_gemm_softmax_gemm : gemm_softmax_gemm
{
std::string name() const { return "gpu::pre_gemm_softmax_gemm"; }
......@@ -175,6 +179,8 @@ struct find_gemm_softmax_gemm
}
};
#endif
} // namespace
void prefuse_ops::apply(module_pass_manager& mpm) const
......@@ -182,8 +188,10 @@ void prefuse_ops::apply(module_pass_manager& mpm) const
match::find_matches(mpm.get_module(), find_layernorm{});
mpm.run_pass(dead_code_elimination{});
match::find_matches(mpm.get_module(), find_add_layernorm{});
#ifdef MIHRAPHX_USE_COMPOSABLEKERNEL
if(enabled(MIGRAPHX_ENABLE_CK{}))
match::find_matches(mpm, find_gemm_softmax_gemm{});
#endif
}
} // namespace gpu
......
......@@ -33,8 +33,9 @@ rocm_set_soversion(migraphx_ref ${MIGRAPHX_SO_VERSION})
find_path(BLAZE_INCLUDE blaze/Blaze.h)
rocm_clang_tidy_check(migraphx_ref)
target_link_libraries(migraphx_ref PRIVATE Threads::Threads)
target_link_libraries(migraphx_ref PUBLIC migraphx)
target_include_directories(migraphx_ref PRIVATE ${BLAZE_INCLUDE})
target_include_directories(migraphx_ref SYSTEM PRIVATE ${BLAZE_INCLUDE})
target_compile_definitions(migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS)
migraphx_generate_export_header(migraphx_ref)
......
......@@ -38,7 +38,11 @@ protobuf_generate_cpp(
)
add_library(tf-proto STATIC ${PROTO_SRCS})
target_include_directories(tf-proto SYSTEM PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${PROTOBUF_INCLUDE_DIR})
target_compile_options(tf-proto PRIVATE -w)
if(MSVC)
target_compile_options(tf-proto PRIVATE /w)
else()
target_compile_options(tf-proto PRIVATE -w)
endif()
target_link_libraries(tf-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
......@@ -49,7 +53,10 @@ target_include_directories(migraphx_tf PRIVATE include)
set_target_properties(migraphx_tf PROPERTIES EXPORT_NAME tf)
rocm_set_soversion(migraphx_tf ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_tf)
target_link_libraries(migraphx_tf PRIVATE tf-proto "-Wl,--exclude-libs,ALL")
target_link_libraries(migraphx_tf PRIVATE tf-proto)
if(NOT WIN32)
target_link_libraries(migraphx_tf PRIVATE "-Wl,--exclude-libs,ALL")
endif()
target_link_libraries(migraphx_tf PUBLIC migraphx)
rocm_install_targets(
......
......@@ -31,8 +31,18 @@
#include <sstream>
#include <iostream>
#include <string>
#include <sys/types.h>
#ifdef _WIN32
// cppcheck-suppress definePrefix
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#undef getpid
// cppcheck-suppress [definePrefix, defineUpperCase]
#define getpid _getpid
#else
#include <unistd.h>
#include <sys/types.h>
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -88,7 +88,6 @@ bool verify_args(const std::string& name,
if(target_nan_idx >= 0)
std::cout << "Non finite number found in target at " << target_nan_idx << ": "
<< target[target_nan_idx] << std::endl;
std::cout << "MIGraphX verification passed successfully." << std::endl;
}
});
return passed;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment