Commit 82f98478 authored by Umang Yadav's avatar Umang Yadav
Browse files

add comments

parent cf91c2b1
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <rocblas/internal/rocblas-types.h>
#include <rocblas/rocblas.h> #include <rocblas/rocblas.h>
#include <migraphx/gpu/rocblas.hpp> #include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/gemm_impl.hpp> #include <migraphx/gpu/gemm_impl.hpp>
...@@ -36,6 +37,20 @@ namespace migraphx { ...@@ -36,6 +37,20 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
/*
Regular rocBLAS API takes compute_type as `rocblas_datatype` enum value v/s "ex3" BETA API takes it
as `rocblas_computetype` enum value. `rb_compute_type` is faciliator to implictly cast interger enum
value to required type that can be used inside `common_args` generator.
*/
struct rb_compute_type
{
int type = 0;
rb_compute_type(rocblas_datatype t) : type(static_cast<int>(t)) {}
rb_compute_type(rocblas_computetype t) : type(static_cast<int>(t)) {}
operator rocblas_datatype() const { return static_cast<rocblas_datatype>(type); }
operator rocblas_computetype() const { return static_cast<rocblas_computetype>(type); }
};
// Convert rocBLAS datatypes to equivalent Migraphx data types // Convert rocBLAS datatypes to equivalent Migraphx data types
rocblas_datatype get_type(shape::type_t type) rocblas_datatype get_type(shape::type_t type)
{ {
...@@ -185,12 +200,17 @@ struct gemm_impl ...@@ -185,12 +200,17 @@ struct gemm_impl
{ {
output_type = rocblas_datatype_i32_r; output_type = rocblas_datatype_i32_r;
} }
compute_type = output_type; compute_type = rb_compute_type{output_type};
if(compute_fp32) if(compute_fp32)
{ {
if(arg_type == rocblas_datatype_f16_r) if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r; compute_type = rocblas_datatype_f32_r;
} }
else if(arg_type == rocblas_datatype_f8_r)
{
assert(get_type(input_shapes[1].type()) == rocblas_datatype_f8_r);
compute_type = rocblas_compute_type_f32;
}
auto a_lens = input_shapes[0].lens(); auto a_lens = input_shapes[0].lens();
auto b_lens = input_shapes[1].lens(); auto b_lens = input_shapes[1].lens();
...@@ -230,7 +250,6 @@ struct gemm_impl ...@@ -230,7 +250,6 @@ struct gemm_impl
auto common_args = create_strided_batched_args_common(ctx, input_args); auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex3, rocblas_invoke(&rocblas_gemm_strided_batched_ex3,
common_args, common_args,
rocblas_compute_type_f32,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
solution_idx, solution_idx,
gemm_flags); gemm_flags);
...@@ -240,7 +259,6 @@ struct gemm_impl ...@@ -240,7 +259,6 @@ struct gemm_impl
auto common_args = create_gemm_ex_args_common(ctx, input_args); auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex3, rocblas_invoke(&rocblas_gemm_ex3,
common_args, common_args,
rocblas_compute_type_f32,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
solution_idx, solution_idx,
gemm_flags); gemm_flags);
...@@ -254,7 +272,6 @@ struct gemm_impl ...@@ -254,7 +272,6 @@ struct gemm_impl
auto common_args = create_strided_batched_args_common(ctx, input_args); auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex, rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args, common_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
solution_idx, solution_idx,
gemm_flags); gemm_flags);
...@@ -264,7 +281,6 @@ struct gemm_impl ...@@ -264,7 +281,6 @@ struct gemm_impl
auto common_args = create_gemm_ex_args_common(ctx, input_args); auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex, rocblas_invoke(&rocblas_gemm_ex,
common_args, common_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
solution_idx, solution_idx,
gemm_flags); gemm_flags);
...@@ -304,7 +320,6 @@ struct gemm_impl ...@@ -304,7 +320,6 @@ struct gemm_impl
auto common_args = create_strided_batched_args_common(ctx, input_args); auto common_args = create_strided_batched_args_common(ctx, input_args);
check_valid = rocblas_invoke(&rocblas_gemm_strided_batched_ex, check_valid = rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args, common_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
solution_idx, solution_idx,
rocblas_gemm_flags_check_solution_index); rocblas_gemm_flags_check_solution_index);
...@@ -314,7 +329,6 @@ struct gemm_impl ...@@ -314,7 +329,6 @@ struct gemm_impl
auto common_args = create_gemm_ex_args_common(ctx, input_args); auto common_args = create_gemm_ex_args_common(ctx, input_args);
check_valid = rocblas_invoke(&rocblas_gemm_ex, check_valid = rocblas_invoke(&rocblas_gemm_ex,
common_args, common_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
solution_idx, solution_idx,
rocblas_gemm_flags_check_solution_index); rocblas_gemm_flags_check_solution_index);
...@@ -365,7 +379,8 @@ struct gemm_impl ...@@ -365,7 +379,8 @@ struct gemm_impl
output_type, output_type,
ldd, ldd,
d_stride, d_stride,
num_matrices); num_matrices,
compute_type);
} }
/** /**
* Helper method to create that subset of a long rocBLAS argument list that is common * Helper method to create that subset of a long rocBLAS argument list that is common
...@@ -398,7 +413,8 @@ struct gemm_impl ...@@ -398,7 +413,8 @@ struct gemm_impl
ldc, ldc,
is_3inputs ? args[3].data() : args[2].data(), is_3inputs ? args[3].data() : args[2].data(),
output_type, output_type,
ldd); ldd,
compute_type);
} }
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API #ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
...@@ -428,7 +444,6 @@ struct gemm_impl ...@@ -428,7 +444,6 @@ struct gemm_impl
auto common_args = create_strided_batched_args_common(ctx, input_args); auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions, rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_args, common_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
gemm_flags, gemm_flags,
nullptr, nullptr,
...@@ -438,7 +453,6 @@ struct gemm_impl ...@@ -438,7 +453,6 @@ struct gemm_impl
auto common_sol_args = create_strided_batched_args_common(ctx, input_args); auto common_sol_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions, rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_sol_args, common_sol_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
gemm_flags, gemm_flags,
solution_indices.data(), solution_indices.data(),
...@@ -449,7 +463,6 @@ struct gemm_impl ...@@ -449,7 +463,6 @@ struct gemm_impl
auto common_args = create_gemm_ex_args_common(ctx, input_args); auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex_get_solutions, rocblas_invoke(&rocblas_gemm_ex_get_solutions,
common_args, common_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
gemm_flags, gemm_flags,
nullptr, nullptr,
...@@ -459,7 +472,6 @@ struct gemm_impl ...@@ -459,7 +472,6 @@ struct gemm_impl
auto common_sol_args = create_gemm_ex_args_common(ctx, input_args); auto common_sol_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex_get_solutions, rocblas_invoke(&rocblas_gemm_ex_get_solutions,
common_sol_args, common_sol_args,
compute_type,
rocblas_gemm_algo_solution_index, rocblas_gemm_algo_solution_index,
gemm_flags, gemm_flags,
solution_indices.data(), solution_indices.data(),
...@@ -521,7 +533,7 @@ struct gemm_impl ...@@ -521,7 +533,7 @@ struct gemm_impl
rocblas_int c_stride = 0; rocblas_int c_stride = 0;
rocblas_int d_stride = 0; rocblas_int d_stride = 0;
rocblas_datatype arg_type = rocblas_datatype_f32_r; rocblas_datatype arg_type = rocblas_datatype_f32_r;
rocblas_datatype compute_type = rocblas_datatype_f32_r; rb_compute_type compute_type = rocblas_datatype_f32_r;
rocblas_datatype output_type = rocblas_datatype_f32_r; rocblas_datatype output_type = rocblas_datatype_f32_r;
bool strided_batched = true; bool strided_batched = true;
bool is_3inputs = true; bool is_3inputs = true;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment