Commit 4ea39116 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

manual merge

parents 20128cae d8011adf
...@@ -36,24 +36,14 @@ struct module; ...@@ -36,24 +36,14 @@ struct module;
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_EXTRA_MLIR);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR);
bool mlir_enabled() bool mlir_enabled()
{ {
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
const bool mlir_enabled = enabled(MIGRAPHX_ENABLE_MLIR{}); const bool mlir_disabled = enabled(MIGRAPHX_DISABLE_MLIR{});
if(mlir_enabled) return not mlir_disabled;
{
return true;
}
else
{
std::cerr << "WARNING: MIGraphX built with MLIR but it is not enabled. Please set the env "
"var MIGRAPHX_ENABLE_MLIR to use MLIR kernel generator."
<< std::endl;
return false;
}
#else #else
return false; return false;
#endif #endif
...@@ -131,9 +121,16 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op) ...@@ -131,9 +121,16 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
for(instruction_ref input : gemm_based_op->inputs()) for(instruction_ref input : gemm_based_op->inputs())
{ {
std::vector<operation> op_stream; std::vector<operation> op_stream;
while(contains({"slice", "transpose", "contiguous", "reshape"}, input->name())) while(contains(
{"slice", "transpose", "contiguous", "reshape", "squeeze", "flatten", "unsqueeze"},
input->name()))
{ {
op_stream.push_back(input->get_operator()); operation op = input->get_operator();
if(contains({"squeeze", "flatten", "unsqueeze"}, input->name()))
{
op = migraphx::make_op("reshape", {{"dims", input->get_shape().lens()}});
}
op_stream.push_back(op);
input = input->inputs().at(0); input = input->inputs().at(0);
} }
top_inputs.push_back(input); top_inputs.push_back(input);
...@@ -150,27 +147,72 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op) ...@@ -150,27 +147,72 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
return {new_gemm_based_op, top_inputs}; return {new_gemm_based_op, top_inputs};
} }
MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins) enum class mlir_mode
{ {
if(ins->name() != "convolution" and ins->name() != "quant_convolution") all,
return false; fast,
value v = ins->get_operator().to_value(); int8,
auto group = v.at("group").to<int>(); none
if(group != 1) };
return false;
// Avoid MLIR assertion: Index < Length && "Invalid index!" auto is_mlir_dot(mlir_mode mode)
if(ins->get_shape().lens().size() != 4) {
return false; return match::make_basic_pred_matcher([=](instruction_ref ins) {
return true; if(mode == mlir_mode::none)
return false;
if(ins->name() != "dot" and ins->name() != "quant_dot")
return false;
if(mode != mlir_mode::fast)
return true;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
// auto m = a.lens()[a.lens().size() - 2];
// auto n = b.lens().back();
auto k = a.lens().back();
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from MLIR
// To-do: Investigate a more precise strategy
return k <= 2048;
});
}
auto is_mlir_conv(mlir_mode mode)
{
return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(mode == mlir_mode::none)
return false;
if(ins->name() != "convolution" and ins->name() != "quant_convolution")
return false;
value v = ins->get_operator().to_value();
auto group = v.at("group").to<int>();
if(group != 1)
return false;
// Avoid MLIR assertion: Index < Length && "Invalid index!"
if(ins->get_shape().lens().size() != 4)
return false;
if(ins->get_shape().type() == shape::int8_type)
return true;
if(mode == mlir_mode::int8)
return false;
if(mode == mlir_mode::all)
return true;
auto w = ins->inputs().at(1)->get_shape();
if(w.lens().size() != 4)
return true;
if(w.lens()[2] != w.lens()[3])
return true;
return (w.lens()[3] % 3) != 0;
});
} }
struct find_mlir_fused_ops struct find_mlir_fused_ops
{ {
mlir_mode conv_mode = mlir_mode::none;
mlir_mode dot_mode = mlir_mode::none;
auto matcher() const auto matcher() const
{ {
auto dot_or_conv = match::skip(match::name("contiguous"))( auto dot_or_conv = match::skip(match::name("contiguous"))(
match::any_of(match::name("dot"), match::name("quant_dot"), is_mlir_conv()) match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)).bind("gemm_based_op"));
.bind("gemm_based_op"));
return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x"))); return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x")));
} }
...@@ -302,8 +344,11 @@ struct find_mlir_fused_ops ...@@ -302,8 +344,11 @@ struct find_mlir_fused_ops
} }
}; };
template <auto Matcher>
struct find_mlir_standalone_op struct find_mlir_standalone_op
{ {
mlir_mode mode = mlir_mode::none;
auto matcher() const { return Matcher(mode); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto conv_based_op = r.result; auto conv_based_op = r.result;
...@@ -316,7 +361,8 @@ struct find_mlir_standalone_op ...@@ -316,7 +361,8 @@ struct find_mlir_standalone_op
return; return;
static size_t counter = 0; static size_t counter = 0;
module_ref mm = mpm.create_module("mlir_" + std::to_string(counter++)); module_ref mm =
mpm.create_module("mlir_" + conv_based_op->name() + std::to_string(counter++));
mm->set_bypass(); mm->set_bypass();
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, conv_based_op); auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, conv_based_op);
mm->add_return({anchor_op}); mm->add_return({anchor_op});
...@@ -325,15 +371,8 @@ struct find_mlir_standalone_op ...@@ -325,15 +371,8 @@ struct find_mlir_standalone_op
} }
}; };
struct find_mlir_standalone_convolution_op : find_mlir_standalone_op using find_mlir_standalone_convolution_op = find_mlir_standalone_op<&is_mlir_conv>;
{ using find_mlir_standalone_dot_op = find_mlir_standalone_op<&is_mlir_dot>;
auto matcher() const { return is_mlir_conv; }
};
struct find_mlir_standalone_dot_op : find_mlir_standalone_op
{
auto matcher() const { return match::any_of(match::name("dot"), match::name("quant_dot")); }
};
/** /**
* @brief Declares a new MIGraphX environment variable which forces to generate * @brief Declares a new MIGraphX environment variable which forces to generate
...@@ -347,44 +386,15 @@ struct find_mlir_standalone_dot_op : find_mlir_standalone_op ...@@ -347,44 +386,15 @@ struct find_mlir_standalone_dot_op : find_mlir_standalone_op
* intended to be primarily used by rocMLIR developers. * intended to be primarily used by rocMLIR developers.
*/ */
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS);
bool is_self_decide() { return string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "").empty(); }
bool is_requested(std::string_view option) bool is_requested(std::string_view option, bool fallback = false)
{ {
assert(not is_self_decide());
auto string_value = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, ""); auto string_value = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "");
if(string_value.empty())
return fallback;
const auto options = split_string(string_value, ','); const auto options = split_string(string_value, ',');
return contains(options, option); return contains(options, option);
} }
bool is_enabled(std::string_view op_name, context* ctx)
{
if(is_self_decide())
{
if(op_name == "fused")
{
return true;
}
else if(op_name == "convolution" or op_name == "quant_convolution")
{
if(ctx == nullptr)
{
return false;
}
else
{
const auto& device = ctx->get_current_device();
const std::string navi_family{"gfx110"};
return starts_with(device.get_gfx_name(), navi_family);
}
}
else
{
return false;
}
}
return is_requested(op_name);
}
} // namespace } // namespace
#endif // MIGRAPHX_MLIR #endif // MIGRAPHX_MLIR
...@@ -392,20 +402,28 @@ bool is_enabled(std::string_view op_name, context* ctx) ...@@ -392,20 +402,28 @@ bool is_enabled(std::string_view op_name, context* ctx)
void fuse_mlir::apply(module_pass_manager& mpm) const void fuse_mlir::apply(module_pass_manager& mpm) const
{ {
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
if(is_enabled("fused", this->ctx)) const auto& device_name = ctx == nullptr ? "" : ctx->get_current_device().get_gfx_name();
{ const bool is_navi = starts_with(device_name, "gfx110");
match::find_matches(mpm, find_mlir_fused_ops{});
}
if(is_enabled("convolution", this->ctx)) auto get_mode = [&](std::string_view option, mlir_mode m1, mlir_mode m2 = mlir_mode::fast) {
{ if(is_requested(option))
match::find_matches(mpm, find_mlir_standalone_convolution_op{}); return mlir_mode::all;
} if(is_navi)
return mlir_mode::all;
return std::max(m1, m2);
};
if(is_enabled("dot", this->ctx)) mlir_mode mode =
{ (enabled(MIGRAPHX_ENABLE_EXTRA_MLIR{}) or enable_extra) ? mlir_mode::fast : mlir_mode::none;
match::find_matches(mpm, find_mlir_standalone_dot_op{});
} match::find_matches(mpm,
find_mlir_fused_ops{.conv_mode = get_mode("fused", mlir_mode::fast),
.dot_mode = get_mode("fused", mode)});
match::find_matches(
mpm,
find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::int8)},
find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::none)});
#else #else
(void)mpm; (void)mpm;
#endif #endif
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -21,15 +21,20 @@ ...@@ -21,15 +21,20 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <rocblas/rocblas.h> #include <rocblas/rocblas.h>
#include <migraphx/gpu/gemm_impl.hpp> #include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp> #include <migraphx/generate.hpp>
#include <migraphx/time.hpp>
using microseconds = std::chrono::duration<double, std::micro>;
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
// Convert rocBLAS datatypes to equivalent Migraphx data types
rocblas_datatype get_type(shape::type_t type) rocblas_datatype get_type(shape::type_t type)
{ {
switch(type) switch(type)
...@@ -81,196 +86,508 @@ shape transpose_batch(const shape& s, unsigned trans_batch) ...@@ -81,196 +86,508 @@ shape transpose_batch(const shape& s, unsigned trans_batch)
return shape::from_permutation(s.type(), s.lens(), perm); return shape::from_permutation(s.type(), s.lens(), perm);
} }
template <class R, class... Ts, class... Us> /**
R rocblas_invoke(R (*f)(Ts...), Us... xs) * Returns results of rocblas_status_success, rocblas_status_perf_degraded,
* or rocblas_status_invalid_value. Caller
* is expected to check for invalid index. Any other result causes an exception.
*
*/
template <class F, class Pack, class... Ts>
auto rocblas_invoke(F f, Pack p, Ts... xs)
{ {
if constexpr(sizeof...(Ts) == sizeof...(Us)) return p([=](auto... ws) {
return f(xs...); auto status = f(ws..., xs...);
else if(status != rocblas_status_success and status != rocblas_status_invalid_value)
return f(xs..., nullptr, nullptr); {
if(status == rocblas_status_perf_degraded)
{
std::cerr << "WARNING: degraded perf. in rocBLAS call" << std::endl;
}
else
MIGRAPHX_THROW("rocblas_invoke: rocBLAS call failed with status " +
std::to_string(status));
}
return status;
});
} }
static bool is_transposed(const shape& s) static bool is_transposed(const shape& s) { return s.transposed() and s.strides().back() != 1; }
{
if(not s.transposed())
return false;
return s.strides().back() != 1;
}
static rocblas_int get_batch_stride(const argument& a) static rocblas_int get_batch_stride(const shape& s)
{ {
return a.get_shape().strides()[a.get_shape().strides().size() - 3]; // This value is not needed for non-strided inputs
if(s.strides().size() < 3)
return 0;
else
return s.strides()[s.strides().size() - 3];
} }
template <class T> /**
void gemm_impl(context& ctx, * Wrapper for multiple rocBLAS calls. The constructor creates parameters for
const shape& output_shape, * these calls based on data shapes and other values contained in the associated
const std::vector<argument>& args, * instruction and operation.
T alpha, *
T beta, * The template parameter T is not the type of the matrix data but of the weighting
bool int8_x4_format, * coefficients alpha and beta (these are float in rocBLAS internals)
bool compute_fp32) */
template <typename T>
struct gemm_impl
{ {
const bool is_3inputs = (args.size() == 4); gemm_impl(const shape& output_shape,
if(not is_3inputs) const std::vector<shape>& input_shapes,
T alpha_param,
T beta_param,
bool compute_fp32_flag)
: alpha(alpha_param),
beta(beta_param),
is_3inputs(input_shapes.size() == 4),
compute_fp32(compute_fp32_flag)
{ {
beta = 0; if(not is_3inputs)
} {
beta = 0;
bool transa = is_transposed(args[0].get_shape()); }
bool transb = is_transposed(args[1].get_shape());
auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2;
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0];
rocblas_int ldd = is_3inputs ? args[3].get_shape().strides()[dim_0] : ldc;
rocblas_datatype arg_type = get_type(args[0].get_shape().type());
auto output_type = arg_type;
if(output_type == rocblas_datatype_i8_r)
{
output_type = rocblas_datatype_i32_r;
}
auto compute_type = output_type;
if(compute_fp32)
{
if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r;
}
rocblas_gemm_flags flag = rocblas_gemm_flags_none;
#if ROCBLAS_VERSION_MAJOR < 3
if(int8_x4_format)
flag = rocblas_gemm_flags_pack_int8x4;
#endif
auto a_lens = args[0].get_shape().lens(); // Create lambdas that will cast alpha, beta to the output shape's type
auto b_lens = args[1].get_shape().lens(); // and retain the values being pointed to
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = as(alpha); auto alpha_r = as(alpha);
auto beta_r = as(beta); auto beta_r = as(beta);
if(compute_fp32)
{
get_alpha = [=] { return &alpha; };
get_beta = [=] { return &beta; };
}
else
{
get_alpha = [=] { return &alpha_r; };
get_beta = [=] { return &beta_r; };
}
});
// use void pointer to select different data type if using fp32 mode transa = is_transposed(input_shapes[0]);
void* alpha_v = &alpha_r; transb = is_transposed(input_shapes[1]);
void* beta_v = &beta_r; auto n_dim = output_shape.lens().size();
auto dim_0 = n_dim - 2;
auto dim_1 = n_dim - 1;
// Leading dimensions of matrices
lda = input_shapes[0].strides()[transa ? dim_1 : dim_0];
ldb = input_shapes[1].strides()[transb ? dim_1 : dim_0];
ldc = input_shapes[2].strides()[dim_0];
ldd = is_3inputs ? input_shapes[3].strides()[dim_0] : ldc;
if(compute_fp32) arg_type = get_type(input_shapes[0].type());
output_type = arg_type;
if(output_type == rocblas_datatype_i8_r)
{ {
alpha_v = &alpha; output_type = rocblas_datatype_i32_r;
beta_v = &beta;
} }
compute_type = output_type;
auto out_lens = output_shape.lens(); if(compute_fp32)
rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1];
auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); };
if(args[0].get_shape().type() == shape::int8_type and (k % 4) != 0 and int8_x4_format)
{ {
MIGRAPHX_THROW("ROCBLAS_GEMM: k size of int8 type input must be mutlple of 4!"); if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r;
} }
auto num_matrices = std::accumulate( auto a_lens = input_shapes[0].lens();
auto b_lens = input_shapes[1].lens();
auto out_lens = output_shape.lens();
m = out_lens[dim_0];
n = out_lens[dim_1];
k = input_shapes[0].lens()[dim_1];
a_stride = get_batch_stride(input_shapes[0]);
b_stride = get_batch_stride(input_shapes[1]);
c_stride = get_batch_stride(input_shapes[2]);
d_stride = is_3inputs ? get_batch_stride(input_shapes[3]) : c_stride;
num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
if(num_matrices == 1 or (num_matrices > 1 and get_batch_stride(args[1]) == 0)) strided_batched = num_matrices > 1;
if(strided_batched and b_stride == 0 and input_shapes[0].standard())
{ {
// If the batch dimension of B is broadcasted, then we can // If the batch dimension of B is broadcasted, then we can
// multiply m by the batch_size and use rocblas_gemm_ex // multiply m by the batch_size and use rocblas_gemm_ex
// instead of rocblas_gemm_strided_batched_ex. // instead of rocblas_gemm_strided_batched_ex.
m *= num_matrices; m *= num_matrices;
strided_batched = false;
}
}
// the rocblas_gemm API handles inputs and output matrices as void run(context& ctx, const std::vector<argument>& input_args, int32_t solution_idx = 0) const
// column-major format. When doing a C = A * B, we actually do {
// C^T = (B^T) * (A^T). That is the reason we input args[1] as if(strided_batched)
// A and args[0] as B in calling the rocblas_gemm. {
auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args,
rocblas_gemm_algo_solution_index,
solution_idx,
gemm_flags);
}
else
{
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex, rocblas_invoke(&rocblas_gemm_ex,
ctx.get_stream().get_rocblas(), common_args,
transb ? rocblas_operation_transpose : rocblas_operation_none, rocblas_gemm_algo_solution_index,
transa ? rocblas_operation_transpose : rocblas_operation_none, solution_idx,
n, gemm_flags);
m, }
k, }
alpha_v,
to_pointer(args.at(1)), #ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
arg_type, auto validate(context& ctx, const std::vector<shape>& input_shapes, int32_t solution_idx) const
ldb, {
to_pointer(args.at(0)), // Create dummy arguments for the shapes, and call the overloaded method
arg_type, std::vector<argument> input_args;
lda, std::transform(input_shapes.begin(),
beta_v, input_shapes.end(),
to_pointer(args[2]), std::back_inserter(input_args),
output_type, [](const shape& x) { return to_gpu(generate_argument(x)); });
ldc,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), return validate(ctx, input_args, solution_idx);
output_type, }
ldd,
compute_type, /**
rocblas_gemm_algo_standard, * Checks a particular solution for validity by running it with the flag
0, * rocblas_gemm_flags_check_solution_index (could be invalid if this model was
flag); * tuned with a different rocBLAS version)
*
* @return Returns either solution_idx if valid, or else the default value 0
* if not. The default does not mean list index 0, but tells the picker
* to choose a solution.
*/
int32_t
validate(context& ctx, const std::vector<argument>& input_args, int32_t solution_idx) const
{
rocblas_status_ check_valid(rocblas_status_success);
if(strided_batched)
{
auto common_args = create_strided_batched_args_common(ctx, input_args);
check_valid = rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args,
rocblas_gemm_algo_solution_index,
solution_idx,
rocblas_gemm_flags_check_solution_index);
} }
else else
{ {
auto a_stride = get_batch_stride(args[0]); auto common_args = create_gemm_ex_args_common(ctx, input_args);
auto b_stride = get_batch_stride(args[1]); check_valid = rocblas_invoke(&rocblas_gemm_ex,
auto c_stride = get_batch_stride(args[2]); common_args,
auto d_stride = is_3inputs ? get_batch_stride(args[3]) : c_stride; rocblas_gemm_algo_solution_index,
rocblas_invoke(&rocblas_gemm_strided_batched_ex, solution_idx,
ctx.get_stream().get_rocblas(), rocblas_gemm_flags_check_solution_index);
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
alpha_v,
to_pointer(args.at(1)),
arg_type,
ldb,
b_stride,
to_pointer(args.at(0)),
arg_type,
lda,
a_stride,
beta_v,
to_pointer(args[2]),
output_type,
ldc,
c_stride,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type,
ldd,
d_stride,
num_matrices,
compute_type,
rocblas_gemm_algo_standard,
0,
flag);
} }
});
if(check_valid == rocblas_status_invalid_value)
{
std::cerr << "WARNING: tuned solution is invalid; reverting to default" << std::endl;
return 0;
}
return solution_idx;
}
#endif
/**
* Helper method to create that subset of a long rocBLAS argument list that is common
* to multiple "...strided_batched..." calls.
*
* The rocblas_gemm API handles inputs and output matrices as
* column-major format. When doing a C = A * B, we actually do
* C^T = (B^T) * (A^T). That is the reason we input args[1] as
* A and args[0] as B in calling the rocblas_gemm.
*
*/
auto create_strided_batched_args_common(context& ctx, const std::vector<argument>& args) const
{
return pack(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
get_alpha(),
args[1].data(),
arg_type,
ldb,
b_stride,
args[0].data(),
arg_type,
lda,
a_stride,
get_beta(),
args[2].data(),
output_type,
ldc,
c_stride,
is_3inputs ? args[3].data() : args[2].data(),
output_type,
ldd,
d_stride,
num_matrices,
compute_type);
}
/**
* Helper method to create that subset of a long rocBLAS argument list that is common
* to multiple "gemm_ex..." calls.
*
* The rocblas_gemm API handles inputs and output matrices as
* column-major format. When doing a C = A * B, we actually do
* C^T = (B^T) * (A^T). That is the reason we input args[1] as
* A and args[0] as B in calling the rocblas_gemm.
*
* */
auto create_gemm_ex_args_common(context& ctx, const std::vector<argument>& args) const
{
return pack(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
get_alpha(),
args[1].data(),
arg_type,
ldb,
args[0].data(),
arg_type,
lda,
get_beta(),
args[2].data(),
output_type,
ldc,
is_3inputs ? args[3].data() : args[2].data(),
output_type,
ldd,
compute_type);
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
/**
* Find best rocBLAS solution: Get list of solutions and try them all, returning the index
* of the fastest one.
*/
int tune(context& ctx, const std::vector<shape>& input_shapes) const
{
// tuning meta parameters
const int hot_calls = 40;
std::vector<argument> input_args;
std::transform(input_shapes.begin(),
input_shapes.end(),
std::back_inserter(input_args),
[](const shape& x) { return to_gpu(generate_argument(x)); });
// Get the solutions list in 2 rocBLAS steps:
// 1. Find out how many solutions there are and allocate the array
// 2. Get the solutions
//
rocblas_int list_size = 0;
std::vector<rocblas_int> solution_indices;
if(strided_batched)
{
auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_args,
rocblas_gemm_algo_solution_index,
gemm_flags,
nullptr,
&list_size);
solution_indices.resize(list_size);
auto common_sol_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_sol_args,
rocblas_gemm_algo_solution_index,
gemm_flags,
solution_indices.data(),
&list_size);
}
else
{
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex_get_solutions,
common_args,
rocblas_gemm_algo_solution_index,
gemm_flags,
nullptr,
&list_size);
solution_indices.resize(list_size);
auto common_sol_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex_get_solutions,
common_sol_args,
rocblas_gemm_algo_solution_index,
gemm_flags,
solution_indices.data(),
&list_size);
}
double best_time = std::numeric_limits<double>::max();
double first_time = -1;
// Initialize to default solution index
rocblas_int best_sol = 0;
for(auto sol : solution_indices)
{
// Warmup: the first call to an op. may not be representative since there is
// more time taken initializing caches, etc. so we won't time it.
run(ctx, input_args, sol);
double host_time = time<milliseconds>([&] {
for([[maybe_unused]] int hc : range(hot_calls))
run(ctx, input_args, sol);
ctx.finish();
});
host_time /= hot_calls;
// dev/evaluation only: track time for first solution.
if(first_time < 0)
first_time = host_time;
// track current best
if(host_time < best_time)
{
best_sol = sol;
best_time = host_time;
}
}
std::cout << "Winning GEMM solution: " << best_sol << " in " << best_time << " ms, beats "
<< first_time << "ms" << std::endl;
return best_sol;
}
#endif
private:
size_t num_matrices = 0;
rocblas_int m = 0;
rocblas_int n = 0;
rocblas_int k = 0;
bool transa = false;
bool transb = false;
T alpha = 0;
T beta = 0;
std::function<const void*()> get_alpha{};
std::function<const void*()> get_beta{};
rocblas_gemm_flags gemm_flags = rocblas_gemm_flags_none;
rocblas_int lda = 0;
rocblas_int ldb = 0;
rocblas_int ldc = 0;
rocblas_int ldd = 0;
rocblas_int a_stride = 0;
rocblas_int b_stride = 0;
rocblas_int c_stride = 0;
rocblas_int d_stride = 0;
rocblas_datatype compute_type = rocblas_datatype_f32_r;
rocblas_datatype arg_type = rocblas_datatype_f32_r;
rocblas_datatype output_type = rocblas_datatype_f32_r;
bool strided_batched = true;
bool is_3inputs = true;
bool compute_fp32 = true;
}; // gemm_impl
void gemm_compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
float alpha,
float beta,
bool compute_fp32,
int32_t solution_idx)
{
std::vector<shape> input_shapes;
std::transform(args.begin(),
args.end(),
std::back_inserter(input_shapes),
[](const argument& x) { return x.get_shape(); });
auto gemm_item = gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
gemm_item.run(ctx, args, solution_idx);
}
void gemm_compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
int32_t alpha,
int32_t beta,
bool compute_fp32,
int32_t solution_idx)
{
std::vector<shape> input_shapes;
std::transform(args.begin(),
args.end(),
std::back_inserter(input_shapes),
[](const argument& x) { return x.get_shape(); });
auto gemm_item = gemm_impl<int32_t>(output_shape, input_shapes, alpha, beta, compute_fp32);
gemm_item.run(ctx, args, solution_idx);
} }
void gemm(context& ctx, /**
const shape& output_shape, * Decides if the tune() or validate() method is appropriate and calls it.
const std::vector<argument>& args, * Return value is the chosen solution index, or 0 to let picker choose it.
float alpha, */
float beta, int32_t gemm_finalize(context& ctx,
bool int8_x4_format, const shape& output_shape,
bool compute_fp32) const std::vector<shape>& input_shapes,
float alpha,
float beta,
bool compute_fp32,
int32_t solution_idx)
{ {
gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32); #ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
// This code should be called only if either the environment var.
// MIGRAPHX_ENABLE_GEMM_TUNING, or option --exhaustive-tune, is set
if(solution_idx == 0)
{
auto gemm_item = gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.tune(ctx, input_shapes);
}
else
{
// If a tuned solution index is already given, don't tune again but validate
// in case the data was tuned with a different rocBLAS version
auto gemm_item = gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.validate(ctx, input_shapes, solution_idx);
}
#else
(void)ctx, (void)output_shape, (void)input_shapes;
(void)alpha, (void)beta, (void)compute_fp32;
#endif
return solution_idx;
} }
void gemm(context& ctx, /**
const shape& output_shape, * Decides if the tune() or validate() method is appropriate and calls it.
const std::vector<argument>& args, * Return value is the chosen solution index, or 0 to let picker choose it.
int32_t alpha, */
int32_t beta, int32_t gemm_finalize(context& ctx,
bool int8_x4_format, const shape& output_shape,
bool compute_fp32) const std::vector<shape>& input_shapes,
int32_t alpha,
int32_t beta,
bool compute_fp32,
int32_t solution_idx)
{ {
gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32); #ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
if(solution_idx == 0)
{
auto gemm_item = gemm_impl<int32_t>(output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.tune(ctx, input_shapes);
}
else
{
// If a tuned solution index is already given, don't tune again but validate
// in case the data was tuned with a different rocBLAS version
auto gemm_item = gemm_impl<int32_t>(output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.validate(ctx, input_shapes, solution_idx);
}
#else
(void)ctx, (void)output_shape, (void)input_shapes;
(void)alpha, (void)beta, (void)compute_fp32;
#endif
return solution_idx;
} }
} // namespace gpu } // namespace gpu
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/msgpack.hpp> #include <migraphx/msgpack.hpp>
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <array>
#include <iostream> #include <iostream>
#include <cstring> #include <cstring>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_GPU_CK_HPP
#define MIGRAPHX_GUARD_GPU_CK_HPP
#include <migraphx/compile_src.hpp>
#include <migraphx/env.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/stringutils.hpp>
#include <string_view>
#include "ck/host/device_gemm_multiple_d.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm.hpp"
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
#ifndef _WIN32
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_DEBUG);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TUNE_CK);
#endif
// NOLINTNEXTLINE
const char* const disable_warning_pragma = R"__migraphx__(
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
${content}
#pragma clang diagnostic pop
)__migraphx__";
template <class P>
std::string ck_disable_warnings(P p)
{
return interpolate_string(disable_warning_pragma,
{{"content", std::string{p.data(), p.size()}}});
}
static std::unordered_map<std::string, std::string> create_ck_header_strings()
{
std::unordered_map<std::string, std::string> result;
auto ck_headers = ck::host::GetHeaders();
std::transform(
ck_headers.begin(), ck_headers.end(), std::inserter(result, result.begin()), [&](auto& p) {
return std::pair<std::string, std::string>(p.first, ck_disable_warnings(p.second));
});
return result;
}
static std::vector<src_file> create_ck_headers()
{
static const auto& header_strings = create_ck_header_strings();
std::vector<src_file> srcs;
std::transform(header_strings.begin(),
header_strings.end(),
std::back_inserter(srcs),
[&](auto& p) { return src_file{p}; });
return srcs;
}
static inline const std::vector<src_file>& ck_headers()
{
static const auto& headers = create_ck_headers();
return headers;
}
inline bool transposed_matrix(const shape& s) { return s.strides().back() != 1; }
inline ck::host::DataType get_type(const shape& s)
{
if(s.type() == shape::half_type)
return ck::host::DataType::Half;
else if(s.type() == shape::float_type)
return ck::host::DataType::Float;
else if(s.type() == shape::int8_type)
return ck::host::DataType::Int8;
else if(s.type() == shape::int32_type)
return ck::host::DataType::Int32;
MIGRAPHX_THROW("Unsupported ck type");
}
inline std::size_t get_batch_count(const shape& s)
{
return std::accumulate(
s.lens().rbegin() + 2, s.lens().rend(), std::size_t{1}, std::multiplies<std::size_t>());
}
inline void fold_batch_dims(shape& s)
{
auto lens = s.lens();
if(lens.size() <= 2)
return;
auto batch_count = get_batch_count(s);
auto m1 = lens.at(lens.size() - 2);
auto m2 = lens.at(lens.size() - 1);
if(transposed_matrix(s))
s = shape{s.type(), {m1, m2 * batch_count}};
else
s = shape{s.type(), {m1 * batch_count, m2}};
}
inline void remove_batch_dims(shape& s)
{
auto lens = s.lens();
if(lens.size() <= 2)
return;
auto m1 = lens.at(lens.size() - 2);
auto m2 = lens.at(lens.size() - 1);
s = shape{s.type(), {m1, m2}};
}
inline bool standard_batch(const shape& s)
{
if(s.lens().size() < 3)
return true;
std::vector<std::size_t> lens(s.lens().begin(), s.lens().end() - 2);
std::vector<std::size_t> strides(s.strides().begin(), s.strides().end() - 2);
auto base = *(s.lens().end() - 2) * *(s.lens().end() - 1);
std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto stride) {
return stride / base;
});
return shape{s.type(), lens, strides}.standard();
}
inline bool can_fold_batch(const std::vector<shape>& inputs)
{
const auto& b_shape = inputs[1];
if(std::any_of(inputs.begin() + 2, inputs.end() - 1, [](auto input) {
return not standard_batch(input);
}))
return false;
const auto& b_strides = b_shape.strides();
return std::all_of(
b_strides.begin(), b_strides.end() - 2, [](auto stride) { return stride == 0; });
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_CK_HPP
...@@ -45,10 +45,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS); ...@@ -45,10 +45,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS);
struct hiprtc_src_file struct hiprtc_src_file
{ {
hiprtc_src_file() = default; hiprtc_src_file() = default;
hiprtc_src_file(const src_file& s) hiprtc_src_file(const src_file& s) : path(s.path.string()), content(s.content) {}
: path(s.path.string()), content(s.content.first, s.content.second)
{
}
std::string path; std::string path;
std::string content; std::string content;
template <class Self, class F> template <class Self, class F>
......
...@@ -42,7 +42,7 @@ struct compile_miopen ...@@ -42,7 +42,7 @@ struct compile_miopen
context* ctx = nullptr; context* ctx = nullptr;
std::string name() const { return "gpu::compile_miopen"; } std::string name() const { return "gpu::compile_miopen"; }
void apply(module& m) const; void apply(module& m) const;
std::size_t compile(operation& op, instruction_ref ins, bool format) const; std::size_t compile(operation& op, instruction_ref ins) const;
}; };
} // namespace gpu } // namespace gpu
......
...@@ -299,23 +299,6 @@ struct context ...@@ -299,23 +299,6 @@ struct context
any_ptr get_queue() { return get_stream().get(); } any_ptr get_queue() { return get_stream().get(); }
void enable_perf_measurement(bool b = true)
{
if(b)
{
start_event = create_event_for_timing();
stop_event = create_event_for_timing();
get_stream().record(start_event.get());
get_stream().record(stop_event.get());
}
else
{
start_event = nullptr;
stop_event = nullptr;
}
measure_perf = b;
}
std::pair<hipEvent_t, hipEvent_t> get_perf_events() const std::pair<hipEvent_t, hipEvent_t> get_perf_events() const
{ {
if(measure_perf) if(measure_perf)
...@@ -323,12 +306,12 @@ struct context ...@@ -323,12 +306,12 @@ struct context
return std::make_pair(nullptr, nullptr); return std::make_pair(nullptr, nullptr);
} }
float get_elapsed_ms() const static float get_elapsed_ms(hipEvent_t start, hipEvent_t stop)
{ {
float result = 0; float result = 0;
if(start_event != nullptr and stop_event != nullptr) if(start != nullptr and stop != nullptr)
{ {
auto status = hipEventElapsedTime(&result, start_event.get(), stop_event.get()); auto status = hipEventElapsedTime(&result, start, stop);
if(status != hipSuccess) if(status != hipSuccess)
MIGRAPHX_THROW("Failed hipEventElapsedTime: " + hip_error(status)); MIGRAPHX_THROW("Failed hipEventElapsedTime: " + hip_error(status));
} }
......
...@@ -57,7 +57,6 @@ template <class Op> ...@@ -57,7 +57,6 @@ template <class Op>
struct miopen_convolution struct miopen_convolution
{ {
Op op; Op op;
bool int8_x4_format = false;
shared<convolution_descriptor> cd = nullptr; shared<convolution_descriptor> cd = nullptr;
miopenConvFwdAlgorithm_t algo{}; miopenConvFwdAlgorithm_t algo{};
#ifdef MIGRAPHX_HAS_FIND_2_API #ifdef MIGRAPHX_HAS_FIND_2_API
...@@ -74,7 +73,6 @@ struct miopen_convolution ...@@ -74,7 +73,6 @@ struct miopen_convolution
f(self.solution_object, "solution_object"), f(self.solution_object, "solution_object"),
#endif #endif
f(self.algo, "algo"), f(self.algo, "algo"),
f(self.int8_x4_format, "int8_x4_format"),
f(self.solution_id, "solution_id")); f(self.solution_id, "solution_id"));
} }
...@@ -94,9 +92,9 @@ struct miopen_convolution ...@@ -94,9 +92,9 @@ struct miopen_convolution
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{ {
auto x_desc = make_tensor(reshape_if_1d(args[0].get_shape()), int8_x4_format); auto x_desc = make_tensor(reshape_if_1d(args[0].get_shape()));
auto w_desc = make_tensor(reshape_if_1d(args[1].get_shape()), int8_x4_format); auto w_desc = make_tensor(reshape_if_1d(args[1].get_shape()));
auto y_desc = make_tensor(reshape_if_1d(output_shape)); auto y_desc = make_tensor(reshape_if_1d(output_shape));
auto* miopen_stream_handle = ctx.get_stream().get_miopen(); auto* miopen_stream_handle = ctx.get_stream().get_miopen();
auto workspace_size = args[2].get_shape().bytes(); auto workspace_size = args[2].get_shape().bytes();
...@@ -162,8 +160,8 @@ struct miopen_convolution ...@@ -162,8 +160,8 @@ struct miopen_convolution
shape find(context& ctx, const shape& output_shape, const std::vector<shape>& inputs) shape find(context& ctx, const shape& output_shape, const std::vector<shape>& inputs)
{ {
shape workspace_shape{}; shape workspace_shape{};
auto x_desc = make_tensor(reshape_if_1d(inputs[0]), int8_x4_format); auto x_desc = make_tensor(reshape_if_1d(inputs[0]));
auto w_desc = make_tensor(reshape_if_1d(inputs[1]), int8_x4_format); auto w_desc = make_tensor(reshape_if_1d(inputs[1]));
auto y_desc = make_tensor(reshape_if_1d(output_shape)); auto y_desc = make_tensor(reshape_if_1d(output_shape));
auto* miopen_stream_handle = ctx.get_stream().get_miopen(); auto* miopen_stream_handle = ctx.get_stream().get_miopen();
...@@ -179,13 +177,8 @@ struct miopen_convolution ...@@ -179,13 +177,8 @@ struct miopen_convolution
workspace_shape = shape{shape::int8_type, {workspace_size}}; workspace_shape = shape{shape::int8_type, {workspace_size}};
auto x_shape = inputs[0]; const auto& x_shape = inputs[0];
auto w_shape = inputs[1]; const auto& w_shape = inputs[1];
if(int8_x4_format)
{
x_shape = pack_int8_shape(x_shape);
w_shape = pack_int8_shape(w_shape);
}
#ifdef MIGRAPHX_HAS_FIND_2_API #ifdef MIGRAPHX_HAS_FIND_2_API
{ {
...@@ -199,9 +192,9 @@ struct miopen_convolution ...@@ -199,9 +192,9 @@ struct miopen_convolution
// MIOpen has APIs to pass pre-allocated buffers starting from rocm-5.6 // MIOpen has APIs to pass pre-allocated buffers starting from rocm-5.6
preallocate = true; preallocate = true;
#endif #endif
auto x = preallocate ? to_gpu(generate_argument(x_shape)) : inputs[0]; auto x = preallocate ? to_gpu(generate_argument(x_shape)) : argument{inputs[0]};
auto w = preallocate ? to_gpu(generate_argument(w_shape)) : inputs[1]; auto w = preallocate ? to_gpu(generate_argument(w_shape)) : argument{inputs[1]};
auto y = preallocate ? allocate_gpu(output_shape) : inputs[2]; auto y = preallocate ? allocate_gpu(output_shape) : argument{inputs[2]};
auto workspace = auto workspace =
preallocate ? allocate_gpu(workspace_shape) : migraphx::argument(workspace_shape); preallocate ? allocate_gpu(workspace_shape) : migraphx::argument(workspace_shape);
...@@ -327,8 +320,8 @@ struct miopen_convolution ...@@ -327,8 +320,8 @@ struct miopen_convolution
": workspace has changed during finalization."); ": workspace has changed during finalization.");
} }
auto x_desc = make_tensor(reshape_if_1d(inputs[0]), int8_x4_format); auto x_desc = make_tensor(reshape_if_1d(inputs[0]));
auto w_desc = make_tensor(reshape_if_1d(inputs[1]), int8_x4_format); auto w_desc = make_tensor(reshape_if_1d(inputs[1]));
auto y_desc = make_tensor(reshape_if_1d(output_shape)); auto y_desc = make_tensor(reshape_if_1d(output_shape));
auto status = miopenConvolutionForwardCompileSolution(ctx.get_stream().get_miopen(), auto status = miopenConvolutionForwardCompileSolution(ctx.get_stream().get_miopen(),
...@@ -347,21 +340,6 @@ struct miopen_convolution ...@@ -347,21 +340,6 @@ struct miopen_convolution
{ {
return shapes.size() - 1; return shapes.size() - 1;
} }
inline shape pack_int8_shape(const shape& s) const
{
if(s.type() != shape::int8_type)
{
return s;
}
auto lens = s.lens();
auto strides = s.strides();
lens[1] = (lens[1] + 3) / 4 * 4;
strides[0] = strides[1] * lens[1];
return {s.type(), lens, strides};
}
}; };
} // namespace gpu } // namespace gpu
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -55,7 +55,7 @@ MIGRAPHX_DEVICE_CONSTEXPR val_index<T> make_val_index(T v, int64_t i) ...@@ -55,7 +55,7 @@ MIGRAPHX_DEVICE_CONSTEXPR val_index<T> make_val_index(T v, int64_t i)
return {v, i}; return {v, i};
} }
struct argmax_op struct argmax_op_first_index
{ {
template <class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
...@@ -73,7 +73,25 @@ struct argmax_op ...@@ -73,7 +73,25 @@ struct argmax_op
MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return lowest(); } MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return lowest(); }
}; };
struct argmin_op struct argmax_op_last_index
{
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
{
if(x.val > y.val)
return x;
else if(x.val < y.val)
return y;
else
{
return (x.index > y.index) ? x : y;
}
}
MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return lowest(); }
};
struct argmin_op_first_index
{ {
template <class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
...@@ -91,6 +109,24 @@ struct argmin_op ...@@ -91,6 +109,24 @@ struct argmin_op
MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return highest(); } MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return highest(); }
}; };
struct argmin_op_last_index
{
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
{
if(x.val < y.val)
return x;
else if(x.val > y.val)
return y;
else
{
return (x.index > y.index) ? x : y;
}
}
MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return highest(); }
};
template <class Op> template <class Op>
void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int64_t axis) void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{ {
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -36,7 +36,8 @@ namespace device { ...@@ -36,7 +36,8 @@ namespace device {
void MIGRAPHX_DEVICE_EXPORT argmax(hipStream_t stream, void MIGRAPHX_DEVICE_EXPORT argmax(hipStream_t stream,
const argument& result, const argument& result,
const argument& arg, const argument& arg,
int64_t axis); int64_t axis,
bool select_last_index);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -36,7 +36,8 @@ namespace device { ...@@ -36,7 +36,8 @@ namespace device {
void MIGRAPHX_DEVICE_EXPORT argmin(hipStream_t stream, void MIGRAPHX_DEVICE_EXPORT argmin(hipStream_t stream,
const argument& result, const argument& result,
const argument& arg, const argument& arg,
int64_t axis); int64_t axis,
bool select_last_index);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_INT8_GEMM_PACK_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_INT8_GEMM_PACK_HPP
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void MIGRAPHX_DEVICE_EXPORT int8_gemm_pack_a(hipStream_t stream,
const argument& result,
const argument& arg);
void MIGRAPHX_DEVICE_EXPORT int8_gemm_pack_b(hipStream_t stream,
const argument& result,
const argument& arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -38,6 +38,7 @@ MIGRAPHX_GPU_EXPORT bool mlir_enabled(); ...@@ -38,6 +38,7 @@ MIGRAPHX_GPU_EXPORT bool mlir_enabled();
struct MIGRAPHX_GPU_EXPORT fuse_mlir struct MIGRAPHX_GPU_EXPORT fuse_mlir
{ {
context* ctx = nullptr; context* ctx = nullptr;
bool enable_extra = false;
std::string name() const { return "gpu::fuse_mlir"; } std::string name() const { return "gpu::fuse_mlir"; }
void apply(module_pass_manager& mpm) const; void apply(module_pass_manager& mpm) const;
}; };
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_FUSE_OPS_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_FUSE_OPS_HPP
#define MIGRAPHX_GUARD_RTGLIB_FUSE_OPS_HPP #define MIGRAPHX_GUARD_RTGLIB_FUSE_OPS_HPP
#include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
namespace migraphx { namespace migraphx {
...@@ -34,7 +33,7 @@ struct module; ...@@ -34,7 +33,7 @@ struct module;
namespace gpu { namespace gpu {
struct fuse_ops struct MIGRAPHX_GPU_EXPORT fuse_ops
{ {
context* ctx = nullptr; context* ctx = nullptr;
bool fast_math = true; bool fast_math = true;
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -40,9 +40,8 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -40,9 +40,8 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
struct context; struct context;
void blas_shape(const shape& s);
shape transpose_batch(const shape& s, unsigned trans_batch); shape transpose_batch(const shape& s, unsigned trans_batch);
void blas_shape(const shape& s);
template <class Op> template <class Op>
struct rocblas_gemm struct rocblas_gemm
...@@ -50,9 +49,9 @@ struct rocblas_gemm ...@@ -50,9 +49,9 @@ struct rocblas_gemm
Op op; Op op;
float alpha = 1; float alpha = 1;
float beta = 0; float beta = 0;
bool int8_x4_format = true;
bool compute_fp32 = false; bool compute_fp32 = false;
unsigned trans_batch = 0; unsigned trans_batch = 0;
int32_t solution_idx = 0;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -60,9 +59,9 @@ struct rocblas_gemm ...@@ -60,9 +59,9 @@ struct rocblas_gemm
return pack_join(migraphx::reflect(self.op, f), return pack_join(migraphx::reflect(self.op, f),
pack(f(self.alpha, "alpha"), pack(f(self.alpha, "alpha"),
f(self.beta, "beta"), f(self.beta, "beta"),
f(self.int8_x4_format, "int8_x4_format"),
f(self.compute_fp32, "compute_fp32"), f(self.compute_fp32, "compute_fp32"),
f(self.trans_batch, "trans_batch"))); f(self.trans_batch, "trans_batch"),
f(self.solution_idx, "solution_idx")));
} }
std::string name() const std::string name() const
...@@ -78,6 +77,8 @@ struct rocblas_gemm ...@@ -78,6 +77,8 @@ struct rocblas_gemm
{ {
std::vector<shape> in_shapes(inputs); std::vector<shape> in_shapes(inputs);
in_shapes.pop_back(); in_shapes.pop_back();
// When input shapes are A, B, C the GEMM equation is C  =  α AB+ β C where α, β are
// scalars
check_shapes{in_shapes, *this}.has(2, 3); check_shapes{in_shapes, *this}.has(2, 3);
blas_shape(inputs[0]); blas_shape(inputs[0]);
blas_shape(inputs[1]); blas_shape(inputs[1]);
...@@ -113,17 +114,12 @@ struct rocblas_gemm ...@@ -113,17 +114,12 @@ struct rocblas_gemm
{ {
if(this->name() == "gpu::gemm") if(this->name() == "gpu::gemm")
{ {
gemm(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32); gemm_compute(ctx, output_shape, args, alpha, beta, compute_fp32, solution_idx);
} }
else else
{ {
gemm(ctx, gemm_compute(
output_shape, ctx, output_shape, args, int32_t(alpha), int32_t(beta), compute_fp32, solution_idx);
args,
int32_t(alpha),
int32_t(beta),
int8_x4_format,
compute_fp32);
} }
return args.back(); return args.back();
} }
...@@ -132,6 +128,33 @@ struct rocblas_gemm ...@@ -132,6 +128,33 @@ struct rocblas_gemm
{ {
return shapes.size() - 1; return shapes.size() - 1;
} }
void finalize(context& ctx, const shape& output_shape, const std::vector<shape>& input_shapes)
{
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
if(enabled(MIGRAPHX_ENABLE_GEMM_TUNING{}) or ctx.get_exhaustive_tune_flag())
{
if(this->name() == "gpu::gemm")
{
solution_idx = gemm_finalize(
ctx, output_shape, input_shapes, alpha, beta, compute_fp32, solution_idx);
}
else
{
solution_idx = gemm_finalize(ctx,
output_shape,
input_shapes,
int32_t(alpha),
int32_t(beta),
compute_fp32,
solution_idx);
}
}
#else
// suppress compiler warnings
(void)ctx, (void)output_shape, (void)input_shapes;
#endif
}
}; };
} // namespace gpu } // namespace gpu
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -24,28 +24,64 @@ ...@@ -24,28 +24,64 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_GEMM_IMPL_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_GEMM_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_GEMM_IMPL_HPP #define MIGRAPHX_GUARD_RTGLIB_GEMM_IMPL_HPP
#include <iterator>
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
// Set this environment variable to "true" to perform GEMM tuning even when the
// --exhaustive-tune option isn't set. Can be used to skip slow convolution tuning.
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_GEMM_TUNING);
using milliseconds = std::chrono::duration<double, std::milli>;
using microseconds = std::chrono::duration<double, std::micro>;
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
void gemm(context& ctx, /**
const shape& output_shape, * @brief Templated implementations of the compute() and finalize() methods of the Gemm operator.
const std::vector<argument>& args, * For each function there are overloads using either float or int32_t for the arguments
float alpha, * alpha and beta.
float beta, *
bool int8_x4_format, * @param ctx .
bool compute_fp32); * @param output_shape .
void gemm(context& ctx, * @param args .
const shape& output_shape, * @param alpha .
const std::vector<argument>& args, * @param beta .
int32_t alpha, * @param compute_fp32 .
int32_t beta, */
bool int8_x4_format, void gemm_compute(context& ctx,
bool compute_fp32); const shape& output_shape,
const std::vector<argument>& args,
float alpha,
float beta,
bool compute_fp32,
int32_t solution_idx);
void gemm_compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
int32_t alpha,
int32_t beta,
bool compute_fp32,
int32_t solution_idx);
int32_t gemm_finalize(context& ctx,
const shape& output_shape,
const std::vector<shape>& input_shapes,
float alpha,
float beta,
bool compute_fp32);
int32_t gemm_finalize(context& ctx,
const shape& output_shape,
const std::vector<shape>& input_shapes,
int32_t alpha,
int32_t beta,
bool compute_fp32,
int32_t solution_idx);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -21,32 +21,55 @@ ...@@ -21,32 +21,55 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_RTGLIB_INT8_CONV_PACK_HPP #ifndef MIGRAPHX_GUARD_GPU_GEMM_SOFTMAX_GEMM_HPP
#define MIGRAPHX_GUARD_RTGLIB_INT8_CONV_PACK_HPP #define MIGRAPHX_GUARD_GPU_GEMM_SOFTMAX_GEMM_HPP
#include <migraphx/argument.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/config.hpp> #include <migraphx/check_shapes.hpp>
#include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
struct context; struct gemm_softmax_gemm
struct miopen_int8_conv_pack
{ {
std::string name() const { return "gpu::int8_conv_pack"; } operation op = make_op("dot");
shape compute_shape(const std::vector<shape>& inputs) const; float scale = 1.0;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"), f(self.scale, "scale"));
}
std::string name() const { return "gpu::gemm_softmax_gemm"; }
void check_gemm_shape(const shape& s) const
{ {
return shapes.size() - 1; if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1))
MIGRAPHX_THROW("Invalid shape for " + name());
} }
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>&) const
{
check_shapes{inputs, *this}.same_ndims();
if(inputs.size() < 3)
MIGRAPHX_THROW(name() + ": Expected 3 inputs but got " + to_string(inputs.size()));
auto a = inputs[0];
auto b = inputs[1];
auto b1 = inputs[2];
for(const auto& input : inputs)
{
check_gemm_shape(input);
}
return op.compute_shape({op.compute_shape({a, b}), b1});
}
static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); }
}; };
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_GEMM_SOFTMAX_GEMM_HPP
#endif
...@@ -127,7 +127,7 @@ inline void set_tensor_descriptor(miopenTensorArgumentId_t name, ...@@ -127,7 +127,7 @@ inline void set_tensor_descriptor(miopenTensorArgumentId_t name,
} }
#endif #endif
inline tensor_descriptor make_tensor(const migraphx::shape& os, bool pack = false) inline tensor_descriptor make_tensor(const migraphx::shape& os)
{ {
auto s = os.normalize_standard(); auto s = os.normalize_standard();
auto t = make_obj<tensor_descriptor>(&miopenCreateTensorDescriptor); auto t = make_obj<tensor_descriptor>(&miopenCreateTensorDescriptor);
...@@ -142,23 +142,9 @@ inline tensor_descriptor make_tensor(const migraphx::shape& os, bool pack = fals ...@@ -142,23 +142,9 @@ inline tensor_descriptor make_tensor(const migraphx::shape& os, bool pack = fals
else if(s.type() == shape::int32_type) else if(s.type() == shape::int32_type)
d = miopenInt32; d = miopenInt32;
else if(s.type() == shape::int8_type) else if(s.type() == shape::int8_type)
{ d = miopenInt8;
if(pack)
{
// update the lens and corresponding strides
d = miopenInt8x4;
lens[1] = ((lens[1] + 3) / 4) * 4;
strides[0] = strides[1] * lens[1];
}
else
{
d = miopenInt8;
}
}
else else
{
MIGRAPHX_THROW("MAKE_TENSOR: unsupported type"); MIGRAPHX_THROW("MAKE_TENSOR: unsupported type");
}
miopenSetTensorDescriptor(t.get(), d, s.lens().size(), lens.data(), strides.data()); miopenSetTensorDescriptor(t.get(), d, s.lens().size(), lens.data(), strides.data());
return t; return t;
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP #ifndef MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP
#define MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP #define MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP
#include <migraphx/config.hpp> #include <migraphx/gpu/config.hpp>
#include <string> #include <string>
namespace migraphx { namespace migraphx {
...@@ -34,7 +34,7 @@ struct module_pass_manager; ...@@ -34,7 +34,7 @@ struct module_pass_manager;
namespace gpu { namespace gpu {
struct prefuse_ops struct MIGRAPHX_GPU_EXPORT prefuse_ops
{ {
std::string name() const { return "gpu::prefuse_ops"; } std::string name() const { return "gpu::prefuse_ops"; }
void apply(module_pass_manager& mpm) const; void apply(module_pass_manager& mpm) const;
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -40,8 +40,6 @@ struct context; ...@@ -40,8 +40,6 @@ struct context;
MIGRAPHX_GPU_EXPORT bool get_compute_fp32_flag(); MIGRAPHX_GPU_EXPORT bool get_compute_fp32_flag();
MIGRAPHX_GPU_EXPORT bool get_int8_x4_format(context& ctx);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment