Commit a82ea1d6 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

correction to logic that uses env. var. MIGRAPHX_ENABLE_GEMM_TUNING to trigger GEMM tuning

parent b49f5599
...@@ -32,10 +32,6 @@ ...@@ -32,10 +32,6 @@
#include <migraphx/gpu/gemm_impl.hpp> #include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/time.hpp> #include <migraphx/time.hpp>
// Set this environment variable to "true" to perform GEMM tuning even when the
// --exhaustive-tune option isn't set. Can be used to skip slow convolution tuning.
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_GEMM_TUNING);
using microseconds = std::chrono::duration<double, std::micro>; using microseconds = std::chrono::duration<double, std::micro>;
#if ROCBLAS_VERSION_MAJOR > 2 or (ROCBLAS_VERSION_MAJOR == 2 and ROCBLAS_VERSION_MINOR >= 38) #if ROCBLAS_VERSION_MAJOR > 2 or (ROCBLAS_VERSION_MAJOR == 2 and ROCBLAS_VERSION_MINOR >= 38)
...@@ -141,7 +137,7 @@ static rocblas_int get_batch_stride(const argument& a) ...@@ -141,7 +137,7 @@ static rocblas_int get_batch_stride(const argument& a)
* these calls based on data shapes and other values contained in the associated * these calls based on data shapes and other values contained in the associated
* instruction and operation. * instruction and operation.
* *
* The template parameter T is not the type of the input data but of the weighting * The template parameter T is not the type of the matrix data but of the weighting
* coefficients alpha and beta (these are float in rocBLAS internals) * coefficients alpha and beta (these are float in rocBLAS internals)
*/ */
template <typename T> template <typename T>
...@@ -163,7 +159,7 @@ struct gemm_impl ...@@ -163,7 +159,7 @@ struct gemm_impl
beta = 0; beta = 0;
} }
// Create lambdas that will cast alpha, beta to the output shape type // Create lambdas that will cast alpha, beta to the output shape's type
// and retain the values being pointed to // and retain the values being pointed to
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = as(alpha); auto alpha_r = as(alpha);
...@@ -471,20 +467,19 @@ struct gemm_impl ...@@ -471,20 +467,19 @@ struct gemm_impl
// and increasing cold_calls makes little or no difference. Why? // and increasing cold_calls makes little or no difference. Why?
host_time /= hot_calls; host_time /= hot_calls;
// debug only: track time for first solution. // dev/evaluation only: track time for first solution.
if(first_time < 0) if(first_time < 0)
first_time = host_time; first_time = host_time;
// track current best // track current best
if(host_time < best_time) if(host_time < best_time)
{ {
std::cout << " current best index " << sol << ", time " << host_time << std::endl;
best_sol = sol; best_sol = sol;
best_time = host_time; best_time = host_time;
} }
} }
std::cout << "Winner: " << best_sol << " in " << best_time << " us, beats " << first_time std::cout << "Winning GEMM solution: " << best_sol << " in " << best_time << " us, beats "
<< std::endl; << first_time << std::endl;
return best_sol; return best_sol;
} }
#endif #endif
...@@ -541,6 +536,10 @@ void gemm_compute(context& ctx, ...@@ -541,6 +536,10 @@ void gemm_compute(context& ctx,
gemm_item.run(ctx, args, solution_idx); gemm_item.run(ctx, args, solution_idx);
} }
/**
* Decides if the tune() or validate() method is appropriate and calls it.
* Return value is the chosen solution index, or 0 to let picker choose it.
*/
int32_t gemm_finalize(context& ctx, int32_t gemm_finalize(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<shape>& input_shapes, const std::vector<shape>& input_shapes,
...@@ -551,14 +550,17 @@ int32_t gemm_finalize(context& ctx, ...@@ -551,14 +550,17 @@ int32_t gemm_finalize(context& ctx,
int32_t solution_idx) int32_t solution_idx)
{ {
#ifdef ROCBLAS_BETA_FEATURES_API #ifdef ROCBLAS_BETA_FEATURES_API
if((enabled(MIGRAPHX_ENABLE_GEMM_TUNING{}) or ctx.get_exhaustive_tune_flag()) and
solution_idx == 0) // This code should be called only if either the environment var.
// MIGRAPHX_ENABLE_GEMM_TUNING, or option --exhaustive-tune, is set
if(solution_idx == 0)
{ {
auto gemm_item = auto gemm_item =
gemm_impl<float>(output_shape, input_shapes, alpha, beta, int8_x4_format, compute_fp32); gemm_impl<float>(output_shape, input_shapes, alpha, beta, int8_x4_format, compute_fp32);
solution_idx = gemm_item.tune(ctx, input_shapes); solution_idx = gemm_item.tune(ctx, input_shapes);
} }
else if(solution_idx != 0) else
{ {
// If a tuned solution index is already given, don't tune again but validate // If a tuned solution index is already given, don't tune again but validate
// in case the data was tuned with a different rocBLAS version // in case the data was tuned with a different rocBLAS version
...@@ -576,7 +578,7 @@ int32_t gemm_finalize(context& ctx, ...@@ -576,7 +578,7 @@ int32_t gemm_finalize(context& ctx,
/** /**
* Decides if the tune() or validate() method is appropriate and calls it. * Decides if the tune() or validate() method is appropriate and calls it.
* Return value is the chosen solution index. * Return value is the chosen solution index, or 0 to let picker choose it.
*/ */
int32_t gemm_finalize(context& ctx, int32_t gemm_finalize(context& ctx,
const shape& output_shape, const shape& output_shape,
...@@ -589,14 +591,15 @@ int32_t gemm_finalize(context& ctx, ...@@ -589,14 +591,15 @@ int32_t gemm_finalize(context& ctx,
{ {
#ifdef ROCBLAS_BETA_FEATURES_API #ifdef ROCBLAS_BETA_FEATURES_API
if((enabled(MIGRAPHX_ENABLE_GEMM_TUNING{}) or ctx.get_exhaustive_tune_flag()) and // This code should be called only if either the environment var.
solution_idx == 0) // MIGRAPHX_ENABLE_GEMM_TUNING, or option --exhaustive-tune, is set
if(solution_idx == 0)
{ {
auto gemm_item = gemm_impl<int32_t>( auto gemm_item = gemm_impl<int32_t>(
output_shape, input_shapes, alpha, beta, int8_x4_format, compute_fp32); output_shape, input_shapes, alpha, beta, int8_x4_format, compute_fp32);
solution_idx = gemm_item.tune(ctx, input_shapes); solution_idx = gemm_item.tune(ctx, input_shapes);
} }
else if(solution_idx != 0) else
{ {
// If a tuned solution index is already given, don't tune again but validate // If a tuned solution index is already given, don't tune again but validate
// in case the data was tuned with a different rocBLAS version // in case the data was tuned with a different rocBLAS version
......
...@@ -139,7 +139,7 @@ struct rocblas_gemm ...@@ -139,7 +139,7 @@ struct rocblas_gemm
void finalize(context& ctx, const shape& output_shape, const std::vector<shape>& input_shapes) void finalize(context& ctx, const shape& output_shape, const std::vector<shape>& input_shapes)
{ {
#ifdef ROCBLAS_BETA_FEATURES_API #ifdef ROCBLAS_BETA_FEATURES_API
if(ctx.get_exhaustive_tune_flag() && solution_idx == 0) if(enabled(MIGRAPHX_ENABLE_GEMM_TUNING{}) or ctx.get_exhaustive_tune_flag())
{ {
if(this->name() == "gpu::gemm") if(this->name() == "gpu::gemm")
{ {
......
...@@ -33,6 +33,10 @@ ...@@ -33,6 +33,10 @@
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraphx/time.hpp> #include <migraphx/time.hpp>
// Set this environment variable to "true" to perform GEMM tuning even when the
// --exhaustive-tune option isn't set. Can be used to skip slow convolution tuning.
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_GEMM_TUNING);
using milliseconds = std::chrono::duration<double, std::milli>; using milliseconds = std::chrono::duration<double, std::milli>;
using microseconds = std::chrono::duration<double, std::micro>; using microseconds = std::chrono::duration<double, std::micro>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment