Unverified Commit a6fa5e4b authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into enable_navi_32_ci

parents b7a7cd3c 7604ecf5
...@@ -48,10 +48,18 @@ else() ...@@ -48,10 +48,18 @@ else()
set(MIGRAPHX_USE_HIPRTC ON CACHE BOOL "Use hipRTC APIs") set(MIGRAPHX_USE_HIPRTC ON CACHE BOOL "Use hipRTC APIs")
endif() endif()
include(Embed)
file(GLOB KERNEL_FILES CONFIGURE_DEPENDS file(GLOB KERNEL_FILES CONFIGURE_DEPENDS
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp) ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp)
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
if(WIN32)
# TODO: re-enable when CK is ported to Windows
list(REMOVE_ITEM KERNEL_FILES
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck_gemm.hpp
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck.hpp)
endif()
include(Embed)
add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/) add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/)
configure_file(device/targets.hpp.in include/migraphx/gpu/device/targets.hpp) configure_file(device/targets.hpp.in include/migraphx/gpu/device/targets.hpp)
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -40,7 +40,8 @@ argument hip_argmax::compute(context& ctx, const shape&, const std::vector<argum ...@@ -40,7 +40,8 @@ argument hip_argmax::compute(context& ctx, const shape&, const std::vector<argum
{ {
auto n_dim = args.front().get_shape().lens().size(); auto n_dim = args.front().get_shape().lens().size();
int64_t tuned_axis = tune_axis(n_dim, op.axis, op.name()); int64_t tuned_axis = tune_axis(n_dim, op.axis, op.name());
device::argmax(ctx.get_stream().get(), args.back(), args.front(), tuned_axis); device::argmax(
ctx.get_stream().get(), args.back(), args.front(), tuned_axis, op.select_last_index);
return args.back(); return args.back();
} }
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -40,7 +40,8 @@ argument hip_argmin::compute(context& ctx, const shape&, const std::vector<argum ...@@ -40,7 +40,8 @@ argument hip_argmin::compute(context& ctx, const shape&, const std::vector<argum
{ {
auto n_dim = args.front().get_shape().lens().size(); auto n_dim = args.front().get_shape().lens().size();
int64_t tuned_axis = tune_axis(n_dim, op.axis, op.name()); int64_t tuned_axis = tune_axis(n_dim, op.axis, op.name());
device::argmin(ctx.get_stream().get(), args.back(), args.front(), tuned_axis); device::argmin(
ctx.get_stream().get(), args.back(), args.front(), tuned_axis, op.select_last_index);
return args.back(); return args.back();
} }
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <cassert> #include <cassert>
#include <iostream> #include <iostream>
#include <deque>
#ifdef MIGRAPHX_USE_HIPRTC #ifdef MIGRAPHX_USE_HIPRTC
#include <hip/hiprtc.h> #include <hip/hiprtc.h>
...@@ -92,7 +93,7 @@ struct hiprtc_program ...@@ -92,7 +93,7 @@ struct hiprtc_program
{ {
struct string_array struct string_array
{ {
std::vector<std::string> strings{}; std::deque<std::string> strings{};
std::vector<const char*> c_strs{}; std::vector<const char*> c_strs{};
string_array() {} string_array() {}
...@@ -209,7 +210,6 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr ...@@ -209,7 +210,6 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
options.push_back("-Wno-gnu-line-marker"); options.push_back("-Wno-gnu-line-marker");
options.push_back("-Wno-old-style-cast"); options.push_back("-Wno-old-style-cast");
} }
if(enabled(MIGRAPHX_GPU_DEBUG{})) if(enabled(MIGRAPHX_GPU_DEBUG{}))
options.push_back("-DMIGRAPHX_DEBUG"); options.push_back("-DMIGRAPHX_DEBUG");
if(std::none_of(options.begin(), options.end(), [](const std::string& s) { if(std::none_of(options.begin(), options.end(), [](const std::string& s) {
...@@ -248,7 +248,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -248,7 +248,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
{ {
if(src.path.extension() != ".cpp") if(src.path.extension() != ".cpp")
continue; continue;
std::cout << std::string(src.content.first, src.len()) << std::endl; std::cout << std::string(src.content) << std::endl;
} }
} }
auto p = dynamic_loader::path(&compile_hip_src_with_hiprtc); auto p = dynamic_loader::path(&compile_hip_src_with_hiprtc);
...@@ -338,7 +338,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -338,7 +338,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
{ {
if(src.path.extension() != ".cpp") if(src.path.extension() != ".cpp")
continue; continue;
std::cout << std::string(src.content.first, src.len()) << std::endl; std::cout << std::string(src.content) << std::endl;
} }
} }
...@@ -359,9 +359,7 @@ bool hip_has_flags(const std::vector<std::string>& flags) ...@@ -359,9 +359,7 @@ bool hip_has_flags(const std::vector<std::string>& flags)
join_strings(flags, " ") + " -x hip -c --offload-arch=gfx900 --cuda-device-only"; join_strings(flags, " ") + " -x hip -c --offload-arch=gfx900 --cuda-device-only";
std::string src; std::string src;
src_file input; src_file input{"main.cpp", src};
input.path = "main.cpp";
input.content = std::make_pair(src.data(), src.data() + src.size());
try try
{ {
......
...@@ -139,6 +139,12 @@ void hip_compile_options::set_launch_params( ...@@ -139,6 +139,12 @@ void hip_compile_options::set_launch_params(
global = compute_global(local); global = compute_global(local);
} }
static bool hip_accept_non_uniform_wg()
{
static bool non_uniform_wg = hip_has_flags({"-fno-offload-uniform-block"});
return non_uniform_wg;
}
std::function<std::size_t(std::size_t local)> std::function<std::size_t(std::size_t local)>
compute_global_for(context& ctx, std::size_t n, std::size_t over) compute_global_for(context& ctx, std::size_t n, std::size_t over)
{ {
...@@ -146,13 +152,14 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over) ...@@ -146,13 +152,14 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
std::size_t max_global = ctx.get_current_device().get_cu_count() * std::size_t max_global = ctx.get_current_device().get_cu_count() *
ctx.get_current_device().get_max_workitems_per_cu(); ctx.get_current_device().get_max_workitems_per_cu();
return [n, over, max_global](std::size_t local) { return [n, over, max_global](std::size_t local) {
// hip require global workitems multiple of local workitems. It may degrade performance. std::size_t num_elements = n;
// [TODO]: consider adding "fno-hip-uniform-block" flag when it becomes available. if(not hip_accept_non_uniform_wg())
// https://reviews.llvm.org/D155213 {
std::size_t num_elements = ((n + local - 1) / local) * local; num_elements = (1 + (n - 1) / local) * local;
std::size_t groups = (num_elements + local - 1) / local; }
std::size_t max_blocks = max_global / local; std::size_t groups = 1 + (num_elements - 1) / local;
std::size_t nglobal = std::min(max_blocks * over, groups) * local; std::size_t max_blocks = max_global / local;
std::size_t nglobal = std::min(max_blocks * over, groups) * local;
return std::min(nglobal, num_elements); return std::min(nglobal, num_elements);
}; };
} }
...@@ -172,21 +179,22 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option ...@@ -172,21 +179,22 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
assert(options.inputs.size() == options.virtual_inputs.size() or assert(options.inputs.size() == options.virtual_inputs.size() or
options.virtual_inputs.empty()); options.virtual_inputs.empty());
std::vector<src_file> srcs = options.additional_src_files; std::vector<src_file> srcs = options.additional_src_files;
std::transform(migraphx_kernels().begin(), static auto kernels{::migraphx_kernels()};
migraphx_kernels().end(), std::transform(
std::back_inserter(srcs), kernels.begin(),
[](auto&& p) { kernels.end(),
auto&& name = p.first; std::back_inserter(srcs),
auto&& c = p.second; [](const std::pair<std::string_view, std::string_view>& elem) { return src_file{elem}; });
auto path = name; srcs.emplace_back("main.cpp", content);
return src_file{path, c};
});
srcs.push_back(src_file{fs::path{"main.cpp"},
std::make_pair(content.data(), content.data() + content.size())});
auto args_hpp = auto args_hpp =
generate_args_hpp(options.virtual_inputs.empty() ? options.inputs : options.virtual_inputs); generate_args_hpp(options.virtual_inputs.empty() ? options.inputs : options.virtual_inputs);
srcs.push_back(src_file{fs::path{"args.hpp"}, srcs.emplace_back("args.hpp", args_hpp);
std::make_pair(args_hpp.data(), args_hpp.data() + args_hpp.size())});
if(options.global % options.local != 0 and hip_accept_non_uniform_wg())
options.params += " -fno-offload-uniform-block";
else
assert(options.global % options.local == 0);
options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global); options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global);
options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local); options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local);
options.params += " " + join_strings(compiler_warnings(), " "); options.params += " " + join_strings(compiler_warnings(), " ");
......
...@@ -37,6 +37,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -37,6 +37,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_BENCHMARKING);
struct precompile_op struct precompile_op
{ {
...@@ -179,15 +180,29 @@ struct compile_plan ...@@ -179,15 +180,29 @@ struct compile_plan
MIGRAPHX_THROW("Multiple kernels without config"); MIGRAPHX_THROW("Multiple kernels without config");
std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs" std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs"
<< std::endl; << std::endl;
if(enabled(MIGRAPHX_TRACE_BENCHMARKING{}))
std::cout << "Problem: " << config->problem << std::endl;
std::vector<double> times; std::vector<double> times;
times.reserve(results.size()); times.reserve(results.size());
std::transform( std::transform(results.begin(),
results.begin(), results.end(), std::back_inserter(times), [&](const auto& cr) { results.end(),
if(not cr.has_value()) config->solutions.begin(),
return std::numeric_limits<double>::max(); std::back_inserter(times),
return time_op(*ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20) [&](const auto& cr, const auto& solution) {
.first; if(enabled(MIGRAPHX_TRACE_BENCHMARKING{}))
}); std::cout << "Benchmarking solution: " << solution << std::endl;
if(not cr.has_value())
{
if(enabled(MIGRAPHX_TRACE_BENCHMARKING{}))
std::cout << "No binary" << std::endl;
return std::numeric_limits<double>::max();
}
auto t = time_op(
*ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20);
if(enabled(MIGRAPHX_TRACE_BENCHMARKING{}))
std::cout << t << "ms" << std::endl;
return t;
});
auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end())); auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end()));
std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl; std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl;
pc.insert(preop.name(), config->problem, config->solutions.at(i)); pc.insert(preop.name(), config->problem, config->solutions.at(i));
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -34,9 +34,16 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -34,9 +34,16 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void argmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis) void argmax(hipStream_t stream,
const argument& result,
const argument& arg,
int64_t axis,
bool select_last_index)
{ {
arg_op(argmax_op{}, stream, result, arg, axis); if(select_last_index)
arg_op(argmax_op_last_index{}, stream, result, arg, axis);
else
arg_op(argmax_op_first_index{}, stream, result, arg, axis);
} }
} // namespace device } // namespace device
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -34,9 +34,16 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -34,9 +34,16 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void argmin(hipStream_t stream, const argument& result, const argument& arg, int64_t axis) void argmin(hipStream_t stream,
const argument& result,
const argument& arg,
int64_t axis,
bool select_last_index)
{ {
arg_op(argmin_op{}, stream, result, arg, axis); if(select_last_index)
arg_op(argmin_op_last_index{}, stream, result, arg, axis);
else
arg_op(argmin_op_first_index{}, stream, result, arg, axis);
} }
} // namespace device } // namespace device
......
...@@ -81,6 +81,14 @@ inline auto launch(hipStream_t stream, index_int global, index_int local) ...@@ -81,6 +81,14 @@ inline auto launch(hipStream_t stream, index_int global, index_int local)
using f_type = decltype(f); using f_type = decltype(f);
dim3 nblocks(global / local); dim3 nblocks(global / local);
dim3 nthreads(local); dim3 nthreads(local);
/*
hipGetLastError() returns error for the first failed HIP call that happened previously.
MIGraphX calls into various backend libraries and failed HIP calls can also happen there.
Calling hipGetLastError() would reset error code to hipSuccess, so that inside MIGraphX
failed call to hipLaunchKernelGGL() can be captured.
*/
hipError_t flush_call = hipGetLastError();
(void)(flush_call);
// cppcheck-suppress UseDeviceLaunch // cppcheck-suppress UseDeviceLaunch
hipLaunchKernelGGL((launcher<f_type>), nblocks, nthreads, 0, stream, f); hipLaunchKernelGGL((launcher<f_type>), nblocks, nthreads, 0, stream, f);
hipError_t kernel_launch_status = hipGetLastError(); hipError_t kernel_launch_status = hipGetLastError();
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_DEVICE_TARGETS_CPP #ifndef MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
#define MIGRAPHX_GUARD_DEVICE_TARGETS_CPP #define MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
#include <migraphx/config.hpp> #include <migraphx/gpu/device/config.hpp>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -34,9 +34,13 @@ namespace gpu { ...@@ -34,9 +34,13 @@ namespace gpu {
namespace device { namespace device {
#define MIGRAPHX_GPU_TARGETS "@GPU_TARGETS@" // NOLINT #define MIGRAPHX_GPU_TARGETS "@GPU_TARGETS@" // NOLINT
MIGRAPHX_DEVICE_EXPORT
const std::vector<std::string>& get_targets(); const std::vector<std::string>& get_targets();
MIGRAPHX_DEVICE_EXPORT
std::string get_targets_as_string(); std::string get_targets_as_string();
MIGRAPHX_DEVICE_EXPORT
std::string get_device_name(); std::string get_device_name();
} // namespace device } // namespace device
......
...@@ -38,10 +38,8 @@ struct compile_op : action<compile_op> ...@@ -38,10 +38,8 @@ struct compile_op : action<compile_op>
context ctx; context ctx;
auto inputs = p.parse_shapes(v.at("inputs")); auto inputs = p.parse_shapes(v.at("inputs"));
auto op = gpu::compile_op(v.at("name").to<std::string>(), ctx, inputs, v); auto op = gpu::compile_op(v.at("name").to<std::string>(), ctx, inputs, v);
auto [host_time, device_time] = time_op(ctx, op, inputs, p.get(v, "iterations", 100)); auto t = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << host_time << "ms"; std::cout << op << ": " << t << "ms";
if(device_time > 0)
std::cout << ", " << device_time << "ms";
std::cout << std::endl; std::cout << std::endl;
} }
}; };
......
...@@ -43,8 +43,8 @@ struct run_op : action<run_op> ...@@ -43,8 +43,8 @@ struct run_op : action<run_op>
auto op = make_op(name); auto op = make_op(name);
if(v.contains("fields")) if(v.contains("fields"))
op.from_value(v.at("fields")); op.from_value(v.at("fields"));
auto [host_time, device_time] = time_op(ctx, op, inputs, p.get(v, "iterations", 100)); auto t = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << host_time << "ms" << std::endl; std::cout << op << ": " << t << "ms" << std::endl;
} }
}; };
......
...@@ -22,10 +22,11 @@ ...@@ -22,10 +22,11 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/gpu/fuse_ck.hpp> #include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/gpu/gemm_softmax_gemm.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/gpu/device_name.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -55,7 +56,7 @@ struct ck_gemm ...@@ -55,7 +56,7 @@ struct ck_gemm
{ {
check_shapes{inputs, *this}.same_ndims(); check_shapes{inputs, *this}.same_ndims();
if(inputs.size() < 2) if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs."); MIGRAPHX_THROW(name() + ": should have at least two inputs.");
auto a = inputs[0]; auto a = inputs[0];
auto b = inputs[1]; auto b = inputs[1];
for(const auto& input : inputs) for(const auto& input : inputs)
...@@ -65,27 +66,35 @@ struct ck_gemm ...@@ -65,27 +66,35 @@ struct ck_gemm
return r; return r;
return r.with_type(mods.front()->get_output_shapes().front().type()); return r.with_type(mods.front()->get_output_shapes().front().type());
} }
static bool is_ck_supported_type(shape::type_t t)
{
return contains({shape::half_type, shape::int8_type, shape::int32_type}, t);
}
}; };
MIGRAPHX_REGISTER_OP(ck_gemm); MIGRAPHX_REGISTER_OP(ck_gemm);
namespace { struct ck_gemm_softmax_gemm : gemm_softmax_gemm
bool is_ck_supported_type(shape::type_t t)
{ {
return contains({shape::half_type, shape::int8_type, shape::int32_type}, t); std::string name() const { return "gpu::ck_gemm_softmax_gemm"; }
} };
MIGRAPHX_REGISTER_OP(ck_gemm_softmax_gemm);
namespace {
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
{ {
if(ins->name() != "dot" and ins->name() != "quant_dot") if(ins->name() != "dot" and ins->name() != "quant_dot")
return false; return false;
if(not is_ck_supported_type(ins->get_shape().type())) if(not ck_gemm::is_ck_supported_type(ins->get_shape().type()))
return false; return false;
auto a = ins->inputs().front()->get_shape(); auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape(); auto b = ins->inputs().back()->get_shape();
auto m = a.lens()[a.lens().size() - 2]; auto m = a.lens()[a.lens().size() - 2];
auto n = b.lens().back(); auto n = b.lens().back();
auto k = a.lens().back(); auto k = a.lens().back();
auto batch_size = std::accumulate(
a.lens().rbegin() + 2, a.lens().rend(), std::size_t{1}, std::multiplies<std::size_t>());
// Integer gemms must be divisible by 4 in ck // Integer gemms must be divisible by 4 in ck
if(contains({shape::int8_type, shape::int32_type}, ins->get_shape().type())) if(contains({shape::int8_type, shape::int32_type}, ins->get_shape().type()))
{ {
...@@ -96,9 +105,17 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) ...@@ -96,9 +105,17 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
if(k % 4 != 0) if(k % 4 != 0)
return false; return false;
} }
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy auto device_name = trim(split_string(get_device_name(), ':').front());
// to avoid poor-performing GEMM kernels from CK if(device_name == "gfx940")
// To-do: Investigate a more precise strategy {
if(ins->get_shape().type() == shape::half_type)
{
if(batch_size >= 64)
return m < 2048 or k <= 64 or n <= 384 or n >= 2048;
return true;
}
return true;
}
return k <= 2048; return k <= 2048;
} }
...@@ -127,7 +144,15 @@ struct find_ck_gemm_pointwise ...@@ -127,7 +144,15 @@ struct find_ck_gemm_pointwise
ins->get_shape().type() != gemm_ins->get_shape().type()) ins->get_shape().type() != gemm_ins->get_shape().type())
return; return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) { if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not is_ck_supported_type(input->get_shape().type()); return not ck_gemm::is_ck_supported_type(input->get_shape().type());
}))
return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not input->inputs().empty() and input->inputs().front()->name() == "capture";
}))
return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not input->inputs().empty() and input->inputs().front()->name() == "capture";
})) }))
return; return;
assert(gemm_it != inputs.end()); assert(gemm_it != inputs.end());
...@@ -152,7 +177,7 @@ struct find_ck_gemm_pointwise ...@@ -152,7 +177,7 @@ struct find_ck_gemm_pointwise
struct find_ck_gemm struct find_ck_gemm
{ {
auto matcher() const { return match::name("dot")(is_ck_gemm().bind("gemm")); } auto matcher() const { return match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm")); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
...@@ -161,11 +186,26 @@ struct find_ck_gemm ...@@ -161,11 +186,26 @@ struct find_ck_gemm
} }
}; };
struct find_ck_gemm_softmax_gemm
{
auto matcher() const { return match::name("gpu::pre_gemm_softmax_gemm"); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto v = ins->get_operator().to_value();
assert(v.contains("scale"));
auto scale = v.at("scale").to<float>();
mpm.get_module().replace_instruction(
ins, ck_gemm_softmax_gemm{migraphx::make_op("dot"), scale}, ins->inputs());
}
};
} // namespace } // namespace
void fuse_ck::apply(module_pass_manager& mpm) const void fuse_ck::apply(module_pass_manager& mpm) const
{ {
match::find_matches(mpm, find_ck_gemm_pointwise{}); match::find_matches(mpm, find_ck_gemm_softmax_gemm{}, find_ck_gemm_pointwise{});
match::find_matches(mpm, find_ck_gemm{}); match::find_matches(mpm, find_ck_gemm{});
} }
......
...@@ -36,24 +36,14 @@ struct module; ...@@ -36,24 +36,14 @@ struct module;
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_EXTRA_MLIR);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR);
bool mlir_enabled() bool mlir_enabled()
{ {
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
const bool mlir_enabled = enabled(MIGRAPHX_ENABLE_MLIR{}); const bool mlir_disabled = enabled(MIGRAPHX_DISABLE_MLIR{});
if(mlir_enabled) return not mlir_disabled;
{
return true;
}
else
{
std::cerr << "WARNING: MIGraphX built with MLIR but it is not enabled. Please set the env "
"var MIGRAPHX_ENABLE_MLIR to use MLIR kernel generator."
<< std::endl;
return false;
}
#else #else
return false; return false;
#endif #endif
...@@ -131,9 +121,16 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op) ...@@ -131,9 +121,16 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
for(instruction_ref input : gemm_based_op->inputs()) for(instruction_ref input : gemm_based_op->inputs())
{ {
std::vector<operation> op_stream; std::vector<operation> op_stream;
while(contains({"slice", "transpose", "contiguous", "reshape"}, input->name())) while(contains(
{"slice", "transpose", "contiguous", "reshape", "squeeze", "flatten", "unsqueeze"},
input->name()))
{ {
op_stream.push_back(input->get_operator()); operation op = input->get_operator();
if(contains({"squeeze", "flatten", "unsqueeze"}, input->name()))
{
op = migraphx::make_op("reshape", {{"dims", input->get_shape().lens()}});
}
op_stream.push_back(op);
input = input->inputs().at(0); input = input->inputs().at(0);
} }
top_inputs.push_back(input); top_inputs.push_back(input);
...@@ -150,27 +147,72 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op) ...@@ -150,27 +147,72 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
return {new_gemm_based_op, top_inputs}; return {new_gemm_based_op, top_inputs};
} }
MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins) enum class mlir_mode
{ {
if(ins->name() != "convolution" and ins->name() != "quant_convolution") all,
return false; fast,
value v = ins->get_operator().to_value(); int8,
auto group = v.at("group").to<int>(); none
if(group != 1) };
return false;
// Avoid MLIR assertion: Index < Length && "Invalid index!" auto is_mlir_dot(mlir_mode mode)
if(ins->get_shape().lens().size() != 4) {
return false; return match::make_basic_pred_matcher([=](instruction_ref ins) {
return true; if(mode == mlir_mode::none)
return false;
if(ins->name() != "dot" and ins->name() != "quant_dot")
return false;
if(mode != mlir_mode::fast)
return true;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
// auto m = a.lens()[a.lens().size() - 2];
// auto n = b.lens().back();
auto k = a.lens().back();
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from MLIR
// To-do: Investigate a more precise strategy
return k <= 2048;
});
}
auto is_mlir_conv(mlir_mode mode)
{
return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(mode == mlir_mode::none)
return false;
if(ins->name() != "convolution" and ins->name() != "quant_convolution")
return false;
value v = ins->get_operator().to_value();
auto group = v.at("group").to<int>();
if(group != 1)
return false;
// Avoid MLIR assertion: Index < Length && "Invalid index!"
if(ins->get_shape().lens().size() != 4)
return false;
if(ins->get_shape().type() == shape::int8_type)
return true;
if(mode == mlir_mode::int8)
return false;
if(mode == mlir_mode::all)
return true;
auto w = ins->inputs().at(1)->get_shape();
if(w.lens().size() != 4)
return true;
if(w.lens()[2] != w.lens()[3])
return true;
return (w.lens()[3] % 3) != 0;
});
} }
struct find_mlir_fused_ops struct find_mlir_fused_ops
{ {
mlir_mode conv_mode = mlir_mode::none;
mlir_mode dot_mode = mlir_mode::none;
auto matcher() const auto matcher() const
{ {
auto dot_or_conv = match::skip(match::name("contiguous"))( auto dot_or_conv = match::skip(match::name("contiguous"))(
match::any_of(match::name("dot"), match::name("quant_dot"), is_mlir_conv()) match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)).bind("gemm_based_op"));
.bind("gemm_based_op"));
return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x"))); return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x")));
} }
...@@ -302,8 +344,11 @@ struct find_mlir_fused_ops ...@@ -302,8 +344,11 @@ struct find_mlir_fused_ops
} }
}; };
template <auto Matcher>
struct find_mlir_standalone_op struct find_mlir_standalone_op
{ {
mlir_mode mode = mlir_mode::none;
auto matcher() const { return Matcher(mode); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto conv_based_op = r.result; auto conv_based_op = r.result;
...@@ -325,15 +370,8 @@ struct find_mlir_standalone_op ...@@ -325,15 +370,8 @@ struct find_mlir_standalone_op
} }
}; };
struct find_mlir_standalone_convolution_op : find_mlir_standalone_op using find_mlir_standalone_convolution_op = find_mlir_standalone_op<&is_mlir_conv>;
{ using find_mlir_standalone_dot_op = find_mlir_standalone_op<&is_mlir_dot>;
auto matcher() const { return is_mlir_conv; }
};
struct find_mlir_standalone_dot_op : find_mlir_standalone_op
{
auto matcher() const { return match::any_of(match::name("dot"), match::name("quant_dot")); }
};
/** /**
* @brief Declares a new MIGraphX environment variable which forces to generate * @brief Declares a new MIGraphX environment variable which forces to generate
...@@ -347,44 +385,15 @@ struct find_mlir_standalone_dot_op : find_mlir_standalone_op ...@@ -347,44 +385,15 @@ struct find_mlir_standalone_dot_op : find_mlir_standalone_op
* intended to be primarily used by rocMLIR developers. * intended to be primarily used by rocMLIR developers.
*/ */
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS);
bool is_self_decide() { return string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "").empty(); }
bool is_requested(std::string_view option) bool is_requested(std::string_view option, bool fallback = false)
{ {
assert(not is_self_decide());
auto string_value = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, ""); auto string_value = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "");
if(string_value.empty())
return fallback;
const auto options = split_string(string_value, ','); const auto options = split_string(string_value, ',');
return contains(options, option); return contains(options, option);
} }
bool is_enabled(std::string_view op_name, context* ctx)
{
if(is_self_decide())
{
if(op_name == "fused")
{
return true;
}
else if(op_name == "convolution" or op_name == "quant_convolution")
{
if(ctx == nullptr)
{
return false;
}
else
{
const auto& device = ctx->get_current_device();
const std::string navi_family{"gfx110"};
return starts_with(device.get_gfx_name(), navi_family);
}
}
else
{
return false;
}
}
return is_requested(op_name);
}
} // namespace } // namespace
#endif // MIGRAPHX_MLIR #endif // MIGRAPHX_MLIR
...@@ -392,20 +401,28 @@ bool is_enabled(std::string_view op_name, context* ctx) ...@@ -392,20 +401,28 @@ bool is_enabled(std::string_view op_name, context* ctx)
void fuse_mlir::apply(module_pass_manager& mpm) const void fuse_mlir::apply(module_pass_manager& mpm) const
{ {
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
if(is_enabled("fused", this->ctx)) const auto& device_name = ctx == nullptr ? "" : ctx->get_current_device().get_gfx_name();
{ const bool is_navi = starts_with(device_name, "gfx110");
match::find_matches(mpm, find_mlir_fused_ops{});
}
if(is_enabled("convolution", this->ctx)) auto get_mode = [&](std::string_view option, mlir_mode m1, mlir_mode m2 = mlir_mode::fast) {
{ if(is_requested(option))
match::find_matches(mpm, find_mlir_standalone_convolution_op{}); return mlir_mode::all;
} if(is_navi)
return mlir_mode::all;
return std::max(m1, m2);
};
if(is_enabled("dot", this->ctx)) mlir_mode mode =
{ (enabled(MIGRAPHX_ENABLE_EXTRA_MLIR{}) or enable_extra) ? mlir_mode::fast : mlir_mode::none;
match::find_matches(mpm, find_mlir_standalone_dot_op{});
} match::find_matches(mpm,
find_mlir_fused_ops{.conv_mode = get_mode("fused", mlir_mode::fast),
.dot_mode = get_mode("fused", mode)});
match::find_matches(
mpm,
find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::int8)},
find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::none)});
#else #else
(void)mpm; (void)mpm;
#endif #endif
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/msgpack.hpp> #include <migraphx/msgpack.hpp>
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <array>
#include <iostream> #include <iostream>
#include <cstring> #include <cstring>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_GPU_CK_HPP
#define MIGRAPHX_GUARD_GPU_CK_HPP
#include <migraphx/compile_src.hpp>
#include <migraphx/env.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/stringutils.hpp>
#include <string_view>
#include "ck/host/device_gemm_multiple_d.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm.hpp"
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
#ifndef _WIN32
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_DEBUG);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TUNE_CK);
#endif
// NOLINTNEXTLINE
const char* const disable_warning_pragma = R"__migraphx__(
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
${content}
#pragma clang diagnostic pop
)__migraphx__";
template <class P>
std::string ck_disable_warnings(P p)
{
return interpolate_string(disable_warning_pragma,
{{"content", std::string{p.data(), p.size()}}});
}
static std::unordered_map<std::string, std::string> create_ck_header_strings()
{
std::unordered_map<std::string, std::string> result;
auto ck_headers = ck::host::GetHeaders();
std::transform(
ck_headers.begin(), ck_headers.end(), std::inserter(result, result.begin()), [&](auto& p) {
return std::pair<std::string, std::string>(p.first, ck_disable_warnings(p.second));
});
return result;
}
static std::vector<src_file> create_ck_headers()
{
static const auto& header_strings = create_ck_header_strings();
std::vector<src_file> srcs;
std::transform(header_strings.begin(),
header_strings.end(),
std::back_inserter(srcs),
[&](auto& p) { return src_file{p}; });
return srcs;
}
static inline const std::vector<src_file>& ck_headers()
{
static const auto& headers = create_ck_headers();
return headers;
}
inline bool transposed_matrix(const shape& s) { return s.strides().back() != 1; }
inline ck::host::DataType get_type(const shape& s)
{
if(s.type() == shape::half_type)
return ck::host::DataType::Half;
else if(s.type() == shape::float_type)
return ck::host::DataType::Float;
else if(s.type() == shape::int8_type)
return ck::host::DataType::Int8;
else if(s.type() == shape::int32_type)
return ck::host::DataType::Int32;
MIGRAPHX_THROW("Unsupported ck type");
}
inline std::size_t get_batch_count(const shape& s)
{
return std::accumulate(
s.lens().rbegin() + 2, s.lens().rend(), std::size_t{1}, std::multiplies<std::size_t>());
}
inline void fold_batch_dims(shape& s)
{
auto lens = s.lens();
if(lens.size() <= 2)
return;
auto batch_count = get_batch_count(s);
auto m1 = lens.at(lens.size() - 2);
auto m2 = lens.at(lens.size() - 1);
if(transposed_matrix(s))
s = shape{s.type(), {m1, m2 * batch_count}};
else
s = shape{s.type(), {m1 * batch_count, m2}};
}
inline void remove_batch_dims(shape& s)
{
auto lens = s.lens();
if(lens.size() <= 2)
return;
auto m1 = lens.at(lens.size() - 2);
auto m2 = lens.at(lens.size() - 1);
s = shape{s.type(), {m1, m2}};
}
inline bool standard_batch(const shape& s)
{
if(s.lens().size() < 3)
return true;
std::vector<std::size_t> lens(s.lens().begin(), s.lens().end() - 2);
std::vector<std::size_t> strides(s.strides().begin(), s.strides().end() - 2);
auto base = *(s.lens().end() - 2) * *(s.lens().end() - 1);
std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto stride) {
return stride / base;
});
return shape{s.type(), lens, strides}.standard();
}
inline bool can_fold_batch(const std::vector<shape>& inputs)
{
const auto& b_shape = inputs[1];
if(std::any_of(inputs.begin() + 2, inputs.end() - 1, [](auto input) {
return not standard_batch(input);
}))
return false;
const auto& b_strides = b_shape.strides();
return std::all_of(
b_strides.begin(), b_strides.end() - 2, [](auto stride) { return stride == 0; });
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_CK_HPP
...@@ -45,10 +45,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS); ...@@ -45,10 +45,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS);
struct hiprtc_src_file struct hiprtc_src_file
{ {
hiprtc_src_file() = default; hiprtc_src_file() = default;
hiprtc_src_file(const src_file& s) hiprtc_src_file(const src_file& s) : path(s.path.string()), content(s.content) {}
: path(s.path.string()), content(s.content.first, s.content.second)
{
}
std::string path; std::string path;
std::string content; std::string content;
template <class Self, class F> template <class Self, class F>
......
...@@ -299,23 +299,6 @@ struct context ...@@ -299,23 +299,6 @@ struct context
any_ptr get_queue() { return get_stream().get(); } any_ptr get_queue() { return get_stream().get(); }
void enable_perf_measurement(bool b = true)
{
if(b)
{
start_event = create_event_for_timing();
stop_event = create_event_for_timing();
get_stream().record(start_event.get());
get_stream().record(stop_event.get());
}
else
{
start_event = nullptr;
stop_event = nullptr;
}
measure_perf = b;
}
std::pair<hipEvent_t, hipEvent_t> get_perf_events() const std::pair<hipEvent_t, hipEvent_t> get_perf_events() const
{ {
if(measure_perf) if(measure_perf)
...@@ -323,12 +306,12 @@ struct context ...@@ -323,12 +306,12 @@ struct context
return std::make_pair(nullptr, nullptr); return std::make_pair(nullptr, nullptr);
} }
float get_elapsed_ms() const static float get_elapsed_ms(hipEvent_t start, hipEvent_t stop)
{ {
float result = 0; float result = 0;
if(start_event != nullptr and stop_event != nullptr) if(start != nullptr and stop != nullptr)
{ {
auto status = hipEventElapsedTime(&result, start_event.get(), stop_event.get()); auto status = hipEventElapsedTime(&result, start, stop);
if(status != hipSuccess) if(status != hipSuccess)
MIGRAPHX_THROW("Failed hipEventElapsedTime: " + hip_error(status)); MIGRAPHX_THROW("Failed hipEventElapsedTime: " + hip_error(status));
} }
......
...@@ -199,9 +199,9 @@ struct miopen_convolution ...@@ -199,9 +199,9 @@ struct miopen_convolution
// MIOpen has APIs to pass pre-allocated buffers starting from rocm-5.6 // MIOpen has APIs to pass pre-allocated buffers starting from rocm-5.6
preallocate = true; preallocate = true;
#endif #endif
auto x = preallocate ? to_gpu(generate_argument(x_shape)) : inputs[0]; auto x = preallocate ? to_gpu(generate_argument(x_shape)) : argument{inputs[0]};
auto w = preallocate ? to_gpu(generate_argument(w_shape)) : inputs[1]; auto w = preallocate ? to_gpu(generate_argument(w_shape)) : argument{inputs[1]};
auto y = preallocate ? allocate_gpu(output_shape) : inputs[2]; auto y = preallocate ? allocate_gpu(output_shape) : argument{inputs[2]};
auto workspace = auto workspace =
preallocate ? allocate_gpu(workspace_shape) : migraphx::argument(workspace_shape); preallocate ? allocate_gpu(workspace_shape) : migraphx::argument(workspace_shape);
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -55,7 +55,7 @@ MIGRAPHX_DEVICE_CONSTEXPR val_index<T> make_val_index(T v, int64_t i) ...@@ -55,7 +55,7 @@ MIGRAPHX_DEVICE_CONSTEXPR val_index<T> make_val_index(T v, int64_t i)
return {v, i}; return {v, i};
} }
struct argmax_op struct argmax_op_first_index
{ {
template <class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
...@@ -73,7 +73,25 @@ struct argmax_op ...@@ -73,7 +73,25 @@ struct argmax_op
MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return lowest(); } MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return lowest(); }
}; };
struct argmin_op struct argmax_op_last_index
{
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
{
if(x.val > y.val)
return x;
else if(x.val < y.val)
return y;
else
{
return (x.index > y.index) ? x : y;
}
}
MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return lowest(); }
};
struct argmin_op_first_index
{ {
template <class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
...@@ -91,6 +109,24 @@ struct argmin_op ...@@ -91,6 +109,24 @@ struct argmin_op
MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return highest(); } MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return highest(); }
}; };
struct argmin_op_last_index
{
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
{
if(x.val < y.val)
return x;
else if(x.val > y.val)
return y;
else
{
return (x.index > y.index) ? x : y;
}
}
MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return highest(); }
};
template <class Op> template <class Op>
void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int64_t axis) void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment