Unverified Commit ed6542ee authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

update rocBLAS version check to support 3.0 and above (#1716)

update rocBLAS version check to support 3.0 and above with simplified logic
parent b4cba0b8
......@@ -140,13 +140,8 @@ void gemm_impl(context& ctx,
compute_type = rocblas_datatype_f32_r;
}
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags flag =
int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none;
#else
(void)int8_x4_format;
int flag = 0;
#endif
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
......
......@@ -55,24 +55,15 @@ const std::unordered_set<std::string>& get_rocblas_fp32_archs()
bool get_compute_fp32_flag()
{
bool compute_fp32 = false;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
const auto device_name = trim(split_string(get_device_name(), ':').front());
if(contains(get_rocblas_fp32_archs(), device_name))
compute_fp32 = true;
#endif
return compute_fp32;
return contains(get_rocblas_fp32_archs(), device_name);
}
bool get_int8_x4_format(context& ctx)
{
bool int8_x4_format = true;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4);
#endif
return int8_x4_format;
return flag == rocblas_gemm_flags_pack_int8x4;
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment