Commit a82ea1d6 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

correction to logic that uses env. var. MIGRAPHX_ENABLE_GEMM_TUNING to trigger GEMM tuning

parent b49f5599
......@@ -32,10 +32,6 @@
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/time.hpp>
// Set this environment variable to "true" to perform GEMM tuning even when the
// --exhaustive-tune option isn't set. Can be used to skip slow convolution tuning.
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_GEMM_TUNING);
using microseconds = std::chrono::duration<double, std::micro>;
#if ROCBLAS_VERSION_MAJOR > 2 or (ROCBLAS_VERSION_MAJOR == 2 and ROCBLAS_VERSION_MINOR >= 38)
......@@ -141,7 +137,7 @@ static rocblas_int get_batch_stride(const argument& a)
* these calls based on data shapes and other values contained in the associated
* instruction and operation.
*
* The template parameter T is not the type of the input data but of the weighting
* The template parameter T is not the type of the matrix data but of the weighting
* coefficients alpha and beta (these are float in rocBLAS internals)
*/
template <typename T>
......@@ -163,7 +159,7 @@ struct gemm_impl
beta = 0;
}
// Create lambdas that will cast alpha, beta to the output shape type
// Create lambdas that will cast alpha, beta to the output shape's type
// and retain the values being pointed to
output_shape.visit_type([&](auto as) {
auto alpha_r = as(alpha);
......@@ -471,20 +467,19 @@ struct gemm_impl
// and increasing cold_calls makes little or no difference. Why?
host_time /= hot_calls;
// debug only: track time for first solution.
// dev/evaluation only: track time for first solution.
if(first_time < 0)
first_time = host_time;
// track current best
if(host_time < best_time)
{
std::cout << " current best index " << sol << ", time " << host_time << std::endl;
best_sol = sol;
best_time = host_time;
}
}
std::cout << "Winner: " << best_sol << " in " << best_time << " us, beats " << first_time
<< std::endl;
std::cout << "Winning GEMM solution: " << best_sol << " in " << best_time << " us, beats "
<< first_time << std::endl;
return best_sol;
}
#endif
......@@ -541,6 +536,10 @@ void gemm_compute(context& ctx,
gemm_item.run(ctx, args, solution_idx);
}
/**
* Decides if the tune() or validate() method is appropriate and calls it.
* Return value is the chosen solution index, or 0 to let picker choose it.
*/
int32_t gemm_finalize(context& ctx,
const shape& output_shape,
const std::vector<shape>& input_shapes,
......@@ -551,14 +550,17 @@ int32_t gemm_finalize(context& ctx,
int32_t solution_idx)
{
#ifdef ROCBLAS_BETA_FEATURES_API
if((enabled(MIGRAPHX_ENABLE_GEMM_TUNING{}) or ctx.get_exhaustive_tune_flag()) and
solution_idx == 0)
// This code should be called only if either the environment var.
// MIGRAPHX_ENABLE_GEMM_TUNING, or option --exhaustive-tune, is set
if(solution_idx == 0)
{
auto gemm_item =
gemm_impl<float>(output_shape, input_shapes, alpha, beta, int8_x4_format, compute_fp32);
solution_idx = gemm_item.tune(ctx, input_shapes);
}
else if(solution_idx != 0)
else
{
// If a tuned solution index is already given, don't tune again but validate
// in case the data was tuned with a different rocBLAS version
......@@ -576,7 +578,7 @@ int32_t gemm_finalize(context& ctx,
/**
* Decides if the tune() or validate() method is appropriate and calls it.
* Return value is the chosen solution index.
* Return value is the chosen solution index, or 0 to let picker choose it.
*/
int32_t gemm_finalize(context& ctx,
const shape& output_shape,
......@@ -589,14 +591,15 @@ int32_t gemm_finalize(context& ctx,
{
#ifdef ROCBLAS_BETA_FEATURES_API
if((enabled(MIGRAPHX_ENABLE_GEMM_TUNING{}) or ctx.get_exhaustive_tune_flag()) and
solution_idx == 0)
// This code should be called only if either the environment var.
// MIGRAPHX_ENABLE_GEMM_TUNING, or option --exhaustive-tune, is set
if(solution_idx == 0)
{
auto gemm_item = gemm_impl<int32_t>(
output_shape, input_shapes, alpha, beta, int8_x4_format, compute_fp32);
solution_idx = gemm_item.tune(ctx, input_shapes);
}
else if(solution_idx != 0)
else
{
// If a tuned solution index is already given, don't tune again but validate
// in case the data was tuned with a different rocBLAS version
......
......@@ -139,7 +139,7 @@ struct rocblas_gemm
void finalize(context& ctx, const shape& output_shape, const std::vector<shape>& input_shapes)
{
#ifdef ROCBLAS_BETA_FEATURES_API
if(ctx.get_exhaustive_tune_flag() && solution_idx == 0)
if(enabled(MIGRAPHX_ENABLE_GEMM_TUNING{}) or ctx.get_exhaustive_tune_flag())
{
if(this->name() == "gpu::gemm")
{
......
......@@ -33,6 +33,10 @@
#include <migraphx/gpu/hip.hpp>
#include <migraphx/time.hpp>
// Set this environment variable to "true" to perform GEMM tuning even when the
// --exhaustive-tune option isn't set. Can be used to skip slow convolution tuning.
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_GEMM_TUNING);
using milliseconds = std::chrono::duration<double, std::milli>;
using microseconds = std::chrono::duration<double, std::micro>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment