Commit 82f98478 authored by Umang Yadav's avatar Umang Yadav
Browse files

add comments

parent cf91c2b1
......@@ -22,6 +22,7 @@
* THE SOFTWARE.
*/
#include <rocblas/internal/rocblas-types.h>
#include <rocblas/rocblas.h>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/gemm_impl.hpp>
......@@ -36,6 +37,20 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
/*
Regular rocBLAS API takes compute_type as `rocblas_datatype` enum value v/s "ex3" BETA API takes it
as `rocblas_computetype` enum value. `rb_compute_type` is faciliator to implictly cast interger enum
value to required type that can be used inside `common_args` generator.
*/
struct rb_compute_type
{
int type = 0;
rb_compute_type(rocblas_datatype t) : type(static_cast<int>(t)) {}
rb_compute_type(rocblas_computetype t) : type(static_cast<int>(t)) {}
operator rocblas_datatype() const { return static_cast<rocblas_datatype>(type); }
operator rocblas_computetype() const { return static_cast<rocblas_computetype>(type); }
};
// Convert rocBLAS datatypes to equivalent Migraphx data types
rocblas_datatype get_type(shape::type_t type)
{
......@@ -185,12 +200,17 @@ struct gemm_impl
{
output_type = rocblas_datatype_i32_r;
}
compute_type = output_type;
compute_type = rb_compute_type{output_type};
if(compute_fp32)
{
if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r;
}
else if(arg_type == rocblas_datatype_f8_r)
{
assert(get_type(input_shapes[1].type()) == rocblas_datatype_f8_r);
compute_type = rocblas_compute_type_f32;
}
auto a_lens = input_shapes[0].lens();
auto b_lens = input_shapes[1].lens();
......@@ -230,7 +250,6 @@ struct gemm_impl
auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex3,
common_args,
rocblas_compute_type_f32,
rocblas_gemm_algo_standard,
solution_idx,
gemm_flags);
......@@ -240,7 +259,6 @@ struct gemm_impl
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex3,
common_args,
rocblas_compute_type_f32,
rocblas_gemm_algo_standard,
solution_idx,
gemm_flags);
......@@ -254,7 +272,6 @@ struct gemm_impl
auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args,
compute_type,
rocblas_gemm_algo_solution_index,
solution_idx,
gemm_flags);
......@@ -264,7 +281,6 @@ struct gemm_impl
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex,
common_args,
compute_type,
rocblas_gemm_algo_solution_index,
solution_idx,
gemm_flags);
......@@ -304,7 +320,6 @@ struct gemm_impl
auto common_args = create_strided_batched_args_common(ctx, input_args);
check_valid = rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args,
compute_type,
rocblas_gemm_algo_solution_index,
solution_idx,
rocblas_gemm_flags_check_solution_index);
......@@ -314,7 +329,6 @@ struct gemm_impl
auto common_args = create_gemm_ex_args_common(ctx, input_args);
check_valid = rocblas_invoke(&rocblas_gemm_ex,
common_args,
compute_type,
rocblas_gemm_algo_solution_index,
solution_idx,
rocblas_gemm_flags_check_solution_index);
......@@ -365,7 +379,8 @@ struct gemm_impl
output_type,
ldd,
d_stride,
num_matrices);
num_matrices,
compute_type);
}
/**
* Helper method to create that subset of a long rocBLAS argument list that is common
......@@ -398,7 +413,8 @@ struct gemm_impl
ldc,
is_3inputs ? args[3].data() : args[2].data(),
output_type,
ldd);
ldd,
compute_type);
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
......@@ -428,7 +444,6 @@ struct gemm_impl
auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_args,
compute_type,
rocblas_gemm_algo_solution_index,
gemm_flags,
nullptr,
......@@ -438,7 +453,6 @@ struct gemm_impl
auto common_sol_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_sol_args,
compute_type,
rocblas_gemm_algo_solution_index,
gemm_flags,
solution_indices.data(),
......@@ -449,7 +463,6 @@ struct gemm_impl
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex_get_solutions,
common_args,
compute_type,
rocblas_gemm_algo_solution_index,
gemm_flags,
nullptr,
......@@ -459,7 +472,6 @@ struct gemm_impl
auto common_sol_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex_get_solutions,
common_sol_args,
compute_type,
rocblas_gemm_algo_solution_index,
gemm_flags,
solution_indices.data(),
......@@ -521,7 +533,7 @@ struct gemm_impl
rocblas_int c_stride = 0;
rocblas_int d_stride = 0;
rocblas_datatype arg_type = rocblas_datatype_f32_r;
rocblas_datatype compute_type = rocblas_datatype_f32_r;
rb_compute_type compute_type = rocblas_datatype_f32_r;
rocblas_datatype output_type = rocblas_datatype_f32_r;
bool strided_batched = true;
bool is_3inputs = true;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment