Commit c75fb295 authored by Paul's avatar Paul
Browse files

Format

parent e69b4a33
......@@ -237,8 +237,11 @@ struct gemm_impl
else
{
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(
&rocblas_gemm_ex, common_args, rocblas_gemm_algo_standard, solution_idx, gemm_flags);
rocblas_invoke(&rocblas_gemm_ex,
common_args,
rocblas_gemm_algo_standard,
solution_idx,
gemm_flags);
}
}
......@@ -480,7 +483,7 @@ struct gemm_impl
std::function<const void*()> get_alpha{};
std::function<const void*()> get_beta{};
rocblas_gemm_flags gemm_flags = rocblas_gemm_flags_none;
rocblas_gemm_flags gemm_flags = rocblas_gemm_flags_none;
rocblas_int lda = 0;
rocblas_int ldb = 0;
rocblas_int ldc = 0;
......@@ -510,8 +513,7 @@ void gemm_compute(context& ctx,
args.end(),
std::back_inserter(input_shapes),
[](const argument& x) { return x.get_shape(); });
auto gemm_item =
gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
auto gemm_item = gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
gemm_item.run(ctx, args, solution_idx);
}
......@@ -528,8 +530,7 @@ void gemm_compute(context& ctx,
args.end(),
std::back_inserter(input_shapes),
[](const argument& x) { return x.get_shape(); });
auto gemm_item =
gemm_impl<int32_t>(output_shape, input_shapes, alpha, beta, compute_fp32);
auto gemm_item = gemm_impl<int32_t>(output_shape, input_shapes, alpha, beta, compute_fp32);
gemm_item.run(ctx, args, solution_idx);
}
......@@ -552,16 +553,14 @@ int32_t gemm_finalize(context& ctx,
if(solution_idx == 0)
{
auto gemm_item =
gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
auto gemm_item = gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.tune(ctx, input_shapes);
}
else
{
// If a tuned solution index is already given, don't tune again but validate
// in case the data was tuned with a different rocBLAS version
auto gemm_item =
gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
auto gemm_item = gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.validate(ctx, input_shapes, solution_idx);
}
#else
......@@ -590,16 +589,14 @@ int32_t gemm_finalize(context& ctx,
// MIGRAPHX_ENABLE_GEMM_TUNING, or option --exhaustive-tune, is set
if(solution_idx == 0)
{
auto gemm_item = gemm_impl<int32_t>(
output_shape, input_shapes, alpha, beta, compute_fp32);
auto gemm_item = gemm_impl<int32_t>(output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.tune(ctx, input_shapes);
}
else
{
// If a tuned solution index is already given, don't tune again but validate
// in case the data was tuned with a different rocBLAS version
auto gemm_item = gemm_impl<int32_t>(
output_shape, input_shapes, alpha, beta, compute_fp32);
auto gemm_item = gemm_impl<int32_t>(output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.validate(ctx, input_shapes, solution_idx);
}
#else
......
......@@ -114,18 +114,12 @@ struct rocblas_gemm
{
if(this->name() == "gpu::gemm")
{
gemm_compute(
ctx, output_shape, args, alpha, beta, compute_fp32, solution_idx);
gemm_compute(ctx, output_shape, args, alpha, beta, compute_fp32, solution_idx);
}
else
{
gemm_compute(ctx,
output_shape,
args,
int32_t(alpha),
int32_t(beta),
compute_fp32,
solution_idx);
gemm_compute(
ctx, output_shape, args, int32_t(alpha), int32_t(beta), compute_fp32, solution_idx);
}
return args.back();
}
......@@ -142,13 +136,8 @@ struct rocblas_gemm
{
if(this->name() == "gpu::gemm")
{
solution_idx = gemm_finalize(ctx,
output_shape,
input_shapes,
alpha,
beta,
compute_fp32,
solution_idx);
solution_idx = gemm_finalize(
ctx, output_shape, input_shapes, alpha, beta, compute_fp32, solution_idx);
}
else
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment