Commit 345dd037 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'rocblas_api_opt' of github.com:ROCmSoftwarePlatform/AMDMIGraphX...

Merge branch 'rocblas_api_opt' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into print_matmul_perf_flops
parents 8d6f2370 d45bd3ba
...@@ -42,7 +42,8 @@ void gemm_impl(context& ctx, ...@@ -42,7 +42,8 @@ void gemm_impl(context& ctx,
const std::vector<argument>& args, const std::vector<argument>& args,
T alpha, T alpha,
T beta, T beta,
bool int8_x4_format) bool int8_x4_format,
bool compute_fp32)
{ {
bool transa = args[0].get_shape().transposed(); bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed(); bool transb = args[1].get_shape().transposed();
...@@ -65,6 +66,11 @@ void gemm_impl(context& ctx, ...@@ -65,6 +66,11 @@ void gemm_impl(context& ctx,
output_type = rocblas_datatype_i32_r; output_type = rocblas_datatype_i32_r;
} }
auto compute_type = output_type; auto compute_type = output_type;
if(compute_fp32)
{
if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r;
}
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38 #if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags flag = rocblas_gemm_flags flag =
...@@ -79,6 +85,7 @@ void gemm_impl(context& ctx, ...@@ -79,6 +85,7 @@ void gemm_impl(context& ctx,
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = as(alpha); auto alpha_r = as(alpha);
auto beta_r = as(beta); auto beta_r = as(beta);
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0]; rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1]; rocblas_int n = out_lens[dim_1];
...@@ -97,6 +104,34 @@ void gemm_impl(context& ctx, ...@@ -97,6 +104,34 @@ void gemm_impl(context& ctx,
// column-major format. When doing a C = A * B, we actually do // column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as // C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm. // A and args[0] as B in calling the rocblas_gemm.
if(compute_fp32)
rocblas_invoke(&rocblas_gemm_ex,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha,
to_pointer(args.at(1)),
arg_type,
ldb,
to_pointer(args.at(0)),
arg_type,
lda,
&beta,
to_pointer(args[2]),
output_type,
ldc,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type,
ldc,
compute_type,
rocblas_gemm_algo_standard,
0,
flag);
else
rocblas_invoke(&rocblas_gemm_ex, rocblas_invoke(&rocblas_gemm_ex,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
...@@ -125,6 +160,38 @@ void gemm_impl(context& ctx, ...@@ -125,6 +160,38 @@ void gemm_impl(context& ctx,
} }
else else
{ {
if(compute_fp32)
rocblas_invoke(&rocblas_gemm_strided_batched_ex,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha,
to_pointer(args.at(1)),
arg_type,
ldb,
k * n,
to_pointer(args.at(0)),
arg_type,
lda,
m * k,
&beta,
to_pointer(args[2]),
output_type,
ldc,
m * n,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type,
ldc,
m * n,
num_matrices,
compute_type,
rocblas_gemm_algo_standard,
0,
flag);
else
rocblas_invoke(&rocblas_gemm_strided_batched_ex, rocblas_invoke(&rocblas_gemm_strided_batched_ex,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
...@@ -164,9 +231,10 @@ void gemm(context& ctx, ...@@ -164,9 +231,10 @@ void gemm(context& ctx,
const std::vector<argument>& args, const std::vector<argument>& args,
float alpha, float alpha,
float beta, float beta,
bool int8_x4_format) bool int8_x4_format,
bool compute_fp32)
{ {
gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format); gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
} }
void gemm(context& ctx, void gemm(context& ctx,
...@@ -174,9 +242,10 @@ void gemm(context& ctx, ...@@ -174,9 +242,10 @@ void gemm(context& ctx,
const std::vector<argument>& args, const std::vector<argument>& args,
int32_t alpha, int32_t alpha,
int32_t beta, int32_t beta,
bool int8_x4_format) bool int8_x4_format,
bool compute_fp32)
{ {
gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format); gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
} }
} // namespace gpu } // namespace gpu
......
...@@ -25,6 +25,7 @@ struct rocblas_gemm ...@@ -25,6 +25,7 @@ struct rocblas_gemm
float alpha = 1; float alpha = 1;
float beta = 0; float beta = 0;
bool int8_x4_format = true; bool int8_x4_format = true;
bool compute_fp32 = false;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -80,11 +81,17 @@ struct rocblas_gemm ...@@ -80,11 +81,17 @@ struct rocblas_gemm
{ {
if(this->name() == "gpu::gemm") if(this->name() == "gpu::gemm")
{ {
gemm(ctx, output_shape, args, alpha, beta, int8_x4_format); gemm(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
} }
else else
{ {
gemm(ctx, output_shape, args, int32_t(alpha), int32_t(beta), int8_x4_format); gemm(ctx,
output_shape,
args,
int32_t(alpha),
int32_t(beta),
int8_x4_format,
compute_fp32);
} }
return args.back(); return args.back();
} }
......
...@@ -14,13 +14,15 @@ void gemm(context& ctx, ...@@ -14,13 +14,15 @@ void gemm(context& ctx,
const std::vector<argument>& args, const std::vector<argument>& args,
float alpha, float alpha,
float beta, float beta,
bool int8_x4_format); bool int8_x4_format,
bool compute_fp32);
void gemm(context& ctx, void gemm(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args, const std::vector<argument>& args,
int32_t alpha, int32_t alpha,
int32_t beta, int32_t beta,
bool int8_x4_format); bool int8_x4_format,
bool compute_fp32);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp> #include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/deconvolution.hpp> #include <migraphx/gpu/deconvolution.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/elu.hpp> #include <migraphx/gpu/elu.hpp>
#include <migraphx/gpu/equal.hpp> #include <migraphx/gpu/equal.hpp>
#include <migraphx/gpu/gemm.hpp> #include <migraphx/gpu/gemm.hpp>
...@@ -60,6 +61,7 @@ struct miopen_apply ...@@ -60,6 +61,7 @@ struct miopen_apply
std::unordered_map<instruction_ref, std::string> prog_output_names{}; std::unordered_map<instruction_ref, std::string> prog_output_names{};
bool offload_copy = false; bool offload_copy = false;
bool int8_x4_format = true; bool int8_x4_format = true;
bool compute_fp32 = false;
context& get_context() const context& get_context() const
{ {
...@@ -96,6 +98,12 @@ struct miopen_apply ...@@ -96,6 +98,12 @@ struct miopen_apply
} }
} }
const std::unordered_set<std::string>& get_rocblas_fp32_archs()
{
static std::unordered_set<std::string> supported_archs{"gfx908", "gfx90a"};
return supported_archs;
}
void init() void init()
{ {
assert(mod != nullptr); assert(mod != nullptr);
...@@ -103,6 +111,9 @@ struct miopen_apply ...@@ -103,6 +111,9 @@ struct miopen_apply
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38 #if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
auto& ctx = get_context(); auto& ctx = get_context();
const auto device_name = trim(split_string(get_device_name(), ':').front());
if(contains(get_rocblas_fp32_archs(), device_name))
compute_fp32 = true;
rocblas_gemm_flags flag; rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag); rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4); int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4);
...@@ -337,7 +348,7 @@ struct miopen_apply ...@@ -337,7 +348,7 @@ struct miopen_apply
} }
} }
return mod->replace_instruction( return mod->replace_instruction(
ins, rocblas_gemm<Op>{Op{}, 1, 0, int8_x4_format}, refs); ins, rocblas_gemm<Op>{Op{}, 1, 0, int8_x4_format, compute_fp32}, refs);
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment