Unverified Commit 36b01ba5 authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Use fp32 compute_type when calling rocBLAS API (#1085)

better performance doing it this way
parent 832f28c6
...@@ -42,7 +42,8 @@ void gemm_impl(context& ctx, ...@@ -42,7 +42,8 @@ void gemm_impl(context& ctx,
const std::vector<argument>& args, const std::vector<argument>& args,
T alpha, T alpha,
T beta, T beta,
bool int8_x4_format) bool int8_x4_format,
bool compute_fp32)
{ {
bool transa = args[0].get_shape().transposed(); bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed(); bool transb = args[1].get_shape().transposed();
...@@ -65,6 +66,11 @@ void gemm_impl(context& ctx, ...@@ -65,6 +66,11 @@ void gemm_impl(context& ctx,
output_type = rocblas_datatype_i32_r; output_type = rocblas_datatype_i32_r;
} }
auto compute_type = output_type; auto compute_type = output_type;
if(compute_fp32)
{
if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r;
}
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38 #if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags flag = rocblas_gemm_flags flag =
...@@ -77,8 +83,19 @@ void gemm_impl(context& ctx, ...@@ -77,8 +83,19 @@ void gemm_impl(context& ctx,
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens(); auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = as(alpha); auto alpha_r = as(alpha);
auto beta_r = as(beta); auto beta_r = as(beta);
// use void pointer to select different data type if using fp32 mode
void* alpha_v = &alpha_r;
void* beta_v = &beta_r;
if(compute_fp32)
{
alpha_v = &alpha;
beta_v = &beta;
}
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0]; rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1]; rocblas_int n = out_lens[dim_1];
...@@ -104,14 +121,14 @@ void gemm_impl(context& ctx, ...@@ -104,14 +121,14 @@ void gemm_impl(context& ctx,
n, n,
m, m,
k, k,
&alpha_r, alpha_v,
to_pointer(args.at(1)), to_pointer(args.at(1)),
arg_type, arg_type,
ldb, ldb,
to_pointer(args.at(0)), to_pointer(args.at(0)),
arg_type, arg_type,
lda, lda,
&beta_r, beta_v,
to_pointer(args[2]), to_pointer(args[2]),
output_type, output_type,
ldc, ldc,
...@@ -132,7 +149,7 @@ void gemm_impl(context& ctx, ...@@ -132,7 +149,7 @@ void gemm_impl(context& ctx,
n, n,
m, m,
k, k,
&alpha_r, alpha_v,
to_pointer(args.at(1)), to_pointer(args.at(1)),
arg_type, arg_type,
ldb, ldb,
...@@ -141,7 +158,7 @@ void gemm_impl(context& ctx, ...@@ -141,7 +158,7 @@ void gemm_impl(context& ctx,
arg_type, arg_type,
lda, lda,
m * k, m * k,
&beta_r, beta_v,
to_pointer(args[2]), to_pointer(args[2]),
output_type, output_type,
ldc, ldc,
...@@ -164,9 +181,10 @@ void gemm(context& ctx, ...@@ -164,9 +181,10 @@ void gemm(context& ctx,
const std::vector<argument>& args, const std::vector<argument>& args,
float alpha, float alpha,
float beta, float beta,
bool int8_x4_format) bool int8_x4_format,
bool compute_fp32)
{ {
gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format); gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
} }
void gemm(context& ctx, void gemm(context& ctx,
...@@ -174,9 +192,10 @@ void gemm(context& ctx, ...@@ -174,9 +192,10 @@ void gemm(context& ctx,
const std::vector<argument>& args, const std::vector<argument>& args,
int32_t alpha, int32_t alpha,
int32_t beta, int32_t beta,
bool int8_x4_format) bool int8_x4_format,
bool compute_fp32)
{ {
gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format); gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
} }
} // namespace gpu } // namespace gpu
......
...@@ -25,6 +25,7 @@ struct rocblas_gemm ...@@ -25,6 +25,7 @@ struct rocblas_gemm
float alpha = 1; float alpha = 1;
float beta = 0; float beta = 0;
bool int8_x4_format = true; bool int8_x4_format = true;
bool compute_fp32 = false;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -80,11 +81,17 @@ struct rocblas_gemm ...@@ -80,11 +81,17 @@ struct rocblas_gemm
{ {
if(this->name() == "gpu::gemm") if(this->name() == "gpu::gemm")
{ {
gemm(ctx, output_shape, args, alpha, beta, int8_x4_format); gemm(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
} }
else else
{ {
gemm(ctx, output_shape, args, int32_t(alpha), int32_t(beta), int8_x4_format); gemm(ctx,
output_shape,
args,
int32_t(alpha),
int32_t(beta),
int8_x4_format,
compute_fp32);
} }
return args.back(); return args.back();
} }
......
...@@ -14,13 +14,15 @@ void gemm(context& ctx, ...@@ -14,13 +14,15 @@ void gemm(context& ctx,
const std::vector<argument>& args, const std::vector<argument>& args,
float alpha, float alpha,
float beta, float beta,
bool int8_x4_format); bool int8_x4_format,
bool compute_fp32);
void gemm(context& ctx, void gemm(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args, const std::vector<argument>& args,
int32_t alpha, int32_t alpha,
int32_t beta, int32_t beta,
bool int8_x4_format); bool int8_x4_format,
bool compute_fp32);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp> #include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/deconvolution.hpp> #include <migraphx/gpu/deconvolution.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/elu.hpp> #include <migraphx/gpu/elu.hpp>
#include <migraphx/gpu/equal.hpp> #include <migraphx/gpu/equal.hpp>
#include <migraphx/gpu/gemm.hpp> #include <migraphx/gpu/gemm.hpp>
...@@ -61,6 +62,7 @@ struct miopen_apply ...@@ -61,6 +62,7 @@ struct miopen_apply
std::unordered_map<instruction_ref, std::string> prog_output_names{}; std::unordered_map<instruction_ref, std::string> prog_output_names{};
bool offload_copy = false; bool offload_copy = false;
bool int8_x4_format = true; bool int8_x4_format = true;
bool compute_fp32 = false;
context& get_context() const context& get_context() const
{ {
...@@ -97,13 +99,22 @@ struct miopen_apply ...@@ -97,13 +99,22 @@ struct miopen_apply
} }
} }
const std::unordered_set<std::string>& get_rocblas_fp32_archs()
{
static std::unordered_set<std::string> supported_archs{"gfx908", "gfx90a"};
return supported_archs;
}
void init() void init()
{ {
assert(mod != nullptr); assert(mod != nullptr);
assert(pass != nullptr); assert(pass != nullptr);
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38 #if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
auto& ctx = get_context(); auto& ctx = get_context();
const auto device_name = trim(split_string(get_device_name(), ':').front());
if(contains(get_rocblas_fp32_archs(), device_name))
compute_fp32 = true;
rocblas_gemm_flags flag; rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag); rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4); int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4);
...@@ -339,7 +350,7 @@ struct miopen_apply ...@@ -339,7 +350,7 @@ struct miopen_apply
} }
} }
return mod->replace_instruction( return mod->replace_instruction(
ins, rocblas_gemm<Op>{Op{}, 1, 0, int8_x4_format}, refs); ins, rocblas_gemm<Op>{Op{}, 1, 0, int8_x4_format, compute_fp32}, refs);
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment