"vscode:/vscode.git/clone" did not exist on "8ff6431c844ba4971889ddbc0e73d802a05bff4c"
Commit 345dd037 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'rocblas_api_opt' of github.com:ROCmSoftwarePlatform/AMDMIGraphX...

Merge branch 'rocblas_api_opt' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into print_matmul_perf_flops
parents 8d6f2370 d45bd3ba
......@@ -42,7 +42,8 @@ void gemm_impl(context& ctx,
const std::vector<argument>& args,
T alpha,
T beta,
bool int8_x4_format)
bool int8_x4_format,
bool compute_fp32)
{
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
......@@ -65,6 +66,11 @@ void gemm_impl(context& ctx,
output_type = rocblas_datatype_i32_r;
}
auto compute_type = output_type;
if(compute_fp32)
{
if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r;
}
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags flag =
......@@ -79,6 +85,7 @@ void gemm_impl(context& ctx,
output_shape.visit_type([&](auto as) {
auto alpha_r = as(alpha);
auto beta_r = as(beta);
auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1];
......@@ -97,6 +104,34 @@ void gemm_impl(context& ctx,
// column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm.
if(compute_fp32)
rocblas_invoke(&rocblas_gemm_ex,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha,
to_pointer(args.at(1)),
arg_type,
ldb,
to_pointer(args.at(0)),
arg_type,
lda,
&beta,
to_pointer(args[2]),
output_type,
ldc,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type,
ldc,
compute_type,
rocblas_gemm_algo_standard,
0,
flag);
else
rocblas_invoke(&rocblas_gemm_ex,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
......@@ -125,6 +160,38 @@ void gemm_impl(context& ctx,
}
else
{
if(compute_fp32)
rocblas_invoke(&rocblas_gemm_strided_batched_ex,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha,
to_pointer(args.at(1)),
arg_type,
ldb,
k * n,
to_pointer(args.at(0)),
arg_type,
lda,
m * k,
&beta,
to_pointer(args[2]),
output_type,
ldc,
m * n,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type,
ldc,
m * n,
num_matrices,
compute_type,
rocblas_gemm_algo_standard,
0,
flag);
else
rocblas_invoke(&rocblas_gemm_strided_batched_ex,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
......@@ -164,9 +231,10 @@ void gemm(context& ctx,
const std::vector<argument>& args,
float alpha,
float beta,
bool int8_x4_format)
bool int8_x4_format,
bool compute_fp32)
{
gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format);
gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
}
void gemm(context& ctx,
......@@ -174,9 +242,10 @@ void gemm(context& ctx,
const std::vector<argument>& args,
int32_t alpha,
int32_t beta,
bool int8_x4_format)
bool int8_x4_format,
bool compute_fp32)
{
gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format);
gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
}
} // namespace gpu
......
......@@ -25,6 +25,7 @@ struct rocblas_gemm
float alpha = 1;
float beta = 0;
bool int8_x4_format = true;
bool compute_fp32 = false;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -80,11 +81,17 @@ struct rocblas_gemm
{
if(this->name() == "gpu::gemm")
{
gemm(ctx, output_shape, args, alpha, beta, int8_x4_format);
gemm(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
}
else
{
gemm(ctx, output_shape, args, int32_t(alpha), int32_t(beta), int8_x4_format);
gemm(ctx,
output_shape,
args,
int32_t(alpha),
int32_t(beta),
int8_x4_format,
compute_fp32);
}
return args.back();
}
......
......@@ -14,13 +14,15 @@ void gemm(context& ctx,
const std::vector<argument>& args,
float alpha,
float beta,
bool int8_x4_format);
bool int8_x4_format,
bool compute_fp32);
void gemm(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
int32_t alpha,
int32_t beta,
bool int8_x4_format);
bool int8_x4_format,
bool compute_fp32);
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -24,6 +24,7 @@
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/deconvolution.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/elu.hpp>
#include <migraphx/gpu/equal.hpp>
#include <migraphx/gpu/gemm.hpp>
......@@ -60,6 +61,7 @@ struct miopen_apply
std::unordered_map<instruction_ref, std::string> prog_output_names{};
bool offload_copy = false;
bool int8_x4_format = true;
bool compute_fp32 = false;
context& get_context() const
{
......@@ -96,6 +98,12 @@ struct miopen_apply
}
}
const std::unordered_set<std::string>& get_rocblas_fp32_archs()
{
static std::unordered_set<std::string> supported_archs{"gfx908", "gfx90a"};
return supported_archs;
}
void init()
{
assert(mod != nullptr);
......@@ -103,6 +111,9 @@ struct miopen_apply
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
auto& ctx = get_context();
const auto device_name = trim(split_string(get_device_name(), ':').front());
if(contains(get_rocblas_fp32_archs(), device_name))
compute_fp32 = true;
rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4);
......@@ -337,7 +348,7 @@ struct miopen_apply
}
}
return mod->replace_instruction(
ins, rocblas_gemm<Op>{Op{}, 1, 0, int8_x4_format}, refs);
ins, rocblas_gemm<Op>{Op{}, 1, 0, int8_x4_format, compute_fp32}, refs);
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment