Commit c75fb295 authored by Paul's avatar Paul
Browse files

Format

parent e69b4a33
...@@ -237,8 +237,11 @@ struct gemm_impl ...@@ -237,8 +237,11 @@ struct gemm_impl
else else
{ {
auto common_args = create_gemm_ex_args_common(ctx, input_args); auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke( rocblas_invoke(&rocblas_gemm_ex,
&rocblas_gemm_ex, common_args, rocblas_gemm_algo_standard, solution_idx, gemm_flags); common_args,
rocblas_gemm_algo_standard,
solution_idx,
gemm_flags);
} }
} }
...@@ -510,8 +513,7 @@ void gemm_compute(context& ctx, ...@@ -510,8 +513,7 @@ void gemm_compute(context& ctx,
args.end(), args.end(),
std::back_inserter(input_shapes), std::back_inserter(input_shapes),
[](const argument& x) { return x.get_shape(); }); [](const argument& x) { return x.get_shape(); });
auto gemm_item = auto gemm_item = gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
gemm_item.run(ctx, args, solution_idx); gemm_item.run(ctx, args, solution_idx);
} }
...@@ -528,8 +530,7 @@ void gemm_compute(context& ctx, ...@@ -528,8 +530,7 @@ void gemm_compute(context& ctx,
args.end(), args.end(),
std::back_inserter(input_shapes), std::back_inserter(input_shapes),
[](const argument& x) { return x.get_shape(); }); [](const argument& x) { return x.get_shape(); });
auto gemm_item = auto gemm_item = gemm_impl<int32_t>(output_shape, input_shapes, alpha, beta, compute_fp32);
gemm_impl<int32_t>(output_shape, input_shapes, alpha, beta, compute_fp32);
gemm_item.run(ctx, args, solution_idx); gemm_item.run(ctx, args, solution_idx);
} }
...@@ -552,16 +553,14 @@ int32_t gemm_finalize(context& ctx, ...@@ -552,16 +553,14 @@ int32_t gemm_finalize(context& ctx,
if(solution_idx == 0) if(solution_idx == 0)
{ {
auto gemm_item = auto gemm_item = gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.tune(ctx, input_shapes); solution_idx = gemm_item.tune(ctx, input_shapes);
} }
else else
{ {
// If a tuned solution index is already given, don't tune again but validate // If a tuned solution index is already given, don't tune again but validate
// in case the data was tuned with a different rocBLAS version // in case the data was tuned with a different rocBLAS version
auto gemm_item = auto gemm_item = gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.validate(ctx, input_shapes, solution_idx); solution_idx = gemm_item.validate(ctx, input_shapes, solution_idx);
} }
#else #else
...@@ -590,16 +589,14 @@ int32_t gemm_finalize(context& ctx, ...@@ -590,16 +589,14 @@ int32_t gemm_finalize(context& ctx,
// MIGRAPHX_ENABLE_GEMM_TUNING, or option --exhaustive-tune, is set // MIGRAPHX_ENABLE_GEMM_TUNING, or option --exhaustive-tune, is set
if(solution_idx == 0) if(solution_idx == 0)
{ {
auto gemm_item = gemm_impl<int32_t>( auto gemm_item = gemm_impl<int32_t>(output_shape, input_shapes, alpha, beta, compute_fp32);
output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.tune(ctx, input_shapes); solution_idx = gemm_item.tune(ctx, input_shapes);
} }
else else
{ {
// If a tuned solution index is already given, don't tune again but validate // If a tuned solution index is already given, don't tune again but validate
// in case the data was tuned with a different rocBLAS version // in case the data was tuned with a different rocBLAS version
auto gemm_item = gemm_impl<int32_t>( auto gemm_item = gemm_impl<int32_t>(output_shape, input_shapes, alpha, beta, compute_fp32);
output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.validate(ctx, input_shapes, solution_idx); solution_idx = gemm_item.validate(ctx, input_shapes, solution_idx);
} }
#else #else
......
...@@ -114,18 +114,12 @@ struct rocblas_gemm ...@@ -114,18 +114,12 @@ struct rocblas_gemm
{ {
if(this->name() == "gpu::gemm") if(this->name() == "gpu::gemm")
{ {
gemm_compute( gemm_compute(ctx, output_shape, args, alpha, beta, compute_fp32, solution_idx);
ctx, output_shape, args, alpha, beta, compute_fp32, solution_idx);
} }
else else
{ {
gemm_compute(ctx, gemm_compute(
output_shape, ctx, output_shape, args, int32_t(alpha), int32_t(beta), compute_fp32, solution_idx);
args,
int32_t(alpha),
int32_t(beta),
compute_fp32,
solution_idx);
} }
return args.back(); return args.back();
} }
...@@ -142,13 +136,8 @@ struct rocblas_gemm ...@@ -142,13 +136,8 @@ struct rocblas_gemm
{ {
if(this->name() == "gpu::gemm") if(this->name() == "gpu::gemm")
{ {
solution_idx = gemm_finalize(ctx, solution_idx = gemm_finalize(
output_shape, ctx, output_shape, input_shapes, alpha, beta, compute_fp32, solution_idx);
input_shapes,
alpha,
beta,
compute_fp32,
solution_idx);
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment