"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "67048d04a8096fc89b9e24452a705a8a59bf5acc"
Commit b49f5599 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

Moved the environment var. that can run one kind of tuning without the other...

Moved the environment var. that can run one kind of tuning without the other from convolution.hpp to gemm_impl.cpp.  Also misc. style cleanup.
parent a00d8b5f
...@@ -32,9 +32,13 @@ ...@@ -32,9 +32,13 @@
#include <migraphx/gpu/gemm_impl.hpp> #include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/time.hpp> #include <migraphx/time.hpp>
// Set this environment variable to "true" to perform GEMM tuning even when the
// --exhaustive-tune option isn't set. Can be used to skip slow convolution tuning.
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_GEMM_TUNING);
using microseconds = std::chrono::duration<double, std::micro>; using microseconds = std::chrono::duration<double, std::micro>;
#if ROCBLAS_VERSION_MAJOR > 2 || (ROCBLAS_VERSION_MAJOR == 2 && ROCBLAS_VERSION_MINOR >= 38) #if ROCBLAS_VERSION_MAJOR > 2 or (ROCBLAS_VERSION_MAJOR == 2 and ROCBLAS_VERSION_MINOR >= 38)
using flag_type = rocblas_gemm_flags; using flag_type = rocblas_gemm_flags;
#else #else
using flag_type = int; using flag_type = int;
...@@ -44,6 +48,7 @@ namespace migraphx { ...@@ -44,6 +48,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
// Convert rocBLAS datatypes to equivalent Migraphx data types
rocblas_datatype get_type(shape::type_t type) rocblas_datatype get_type(shape::type_t type)
{ {
switch(type) switch(type)
...@@ -252,6 +257,7 @@ struct gemm_impl ...@@ -252,6 +257,7 @@ struct gemm_impl
#ifdef ROCBLAS_BETA_FEATURES_API #ifdef ROCBLAS_BETA_FEATURES_API
auto validate(context& ctx, const std::vector<shape>& input_shapes, int32_t solution_idx) const auto validate(context& ctx, const std::vector<shape>& input_shapes, int32_t solution_idx) const
{ {
// Create dummy arguments for the shapes, and call the overloaded method
std::vector<argument> input_args; std::vector<argument> input_args;
std::transform(input_shapes.begin(), std::transform(input_shapes.begin(),
input_shapes.end(), input_shapes.end(),
...@@ -385,8 +391,7 @@ struct gemm_impl ...@@ -385,8 +391,7 @@ struct gemm_impl
int tune(context& ctx, const std::vector<shape>& input_shapes) const int tune(context& ctx, const std::vector<shape>& input_shapes) const
{ {
// tuning meta parameters // tuning meta parameters
const int hot_calls = 40; const int hot_calls = 40;
const int cold_calls = 1;
std::vector<argument> input_args; std::vector<argument> input_args;
std::transform(input_shapes.begin(), std::transform(input_shapes.begin(),
...@@ -439,10 +444,10 @@ struct gemm_impl ...@@ -439,10 +444,10 @@ struct gemm_impl
&list_size); &list_size);
} }
double bestTime = std::numeric_limits<double>::max(); double best_time = std::numeric_limits<double>::max();
double first_time = -1; double first_time = -1;
// Initialize to default solution index // Initialize to default solution index
rocblas_int bestSol = 0; rocblas_int best_sol = 0;
for(auto sol : solution_indices) for(auto sol : solution_indices)
{ {
// Define the function to be timed // Define the function to be timed
...@@ -451,12 +456,9 @@ struct gemm_impl ...@@ -451,12 +456,9 @@ struct gemm_impl
ctx.finish(); ctx.finish();
}; };
// Warmup: the first few calls to an op. may not be representative since there is // Warmup: the first call to an op. may not be representative since there is
// more time taken initializing caches, etc. so we won't time them. // more time taken initializing caches, etc. so we won't time it.
for(int cc = 0; cc < cold_calls; ++cc) run_func();
{
run_func();
}
double host_time = 0.0; double host_time = 0.0;
for(int hc = 0; hc < hot_calls; ++hc) for(int hc = 0; hc < hot_calls; ++hc)
...@@ -474,16 +476,16 @@ struct gemm_impl ...@@ -474,16 +476,16 @@ struct gemm_impl
first_time = host_time; first_time = host_time;
// track current best // track current best
if(host_time < bestTime) if(host_time < best_time)
{ {
std::cout << " current best index " << sol << ", time " << host_time << std::endl; std::cout << " current best index " << sol << ", time " << host_time << std::endl;
bestSol = sol; best_sol = sol;
bestTime = host_time; best_time = host_time;
} }
} }
std::cout << "Winner: " << bestSol << " in " << bestTime << " us, beats " << first_time std::cout << "Winner: " << best_sol << " in " << best_time << " us, beats " << first_time
<< std::endl; << std::endl;
return bestSol; return best_sol;
} }
#endif #endif
private: private:
...@@ -549,9 +551,8 @@ int32_t gemm_finalize(context& ctx, ...@@ -549,9 +551,8 @@ int32_t gemm_finalize(context& ctx,
int32_t solution_idx) int32_t solution_idx)
{ {
#ifdef ROCBLAS_BETA_FEATURES_API #ifdef ROCBLAS_BETA_FEATURES_API
if((enabled(MIGRAPHX_ENABLE_GEMM_TUNING{}) or ctx.get_exhaustive_tune_flag()) and
if(ctx.get_exhaustive_tune_flag() && solution_idx == 0) solution_idx == 0)
// if((true))
{ {
auto gemm_item = auto gemm_item =
gemm_impl<float>(output_shape, input_shapes, alpha, beta, int8_x4_format, compute_fp32); gemm_impl<float>(output_shape, input_shapes, alpha, beta, int8_x4_format, compute_fp32);
...@@ -573,6 +574,10 @@ int32_t gemm_finalize(context& ctx, ...@@ -573,6 +574,10 @@ int32_t gemm_finalize(context& ctx,
return solution_idx; return solution_idx;
} }
/**
* Decides if the tune() or validate() method is appropriate and calls it.
* Return value is the chosen solution index.
*/
int32_t gemm_finalize(context& ctx, int32_t gemm_finalize(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<shape>& input_shapes, const std::vector<shape>& input_shapes,
...@@ -584,7 +589,8 @@ int32_t gemm_finalize(context& ctx, ...@@ -584,7 +589,8 @@ int32_t gemm_finalize(context& ctx,
{ {
#ifdef ROCBLAS_BETA_FEATURES_API #ifdef ROCBLAS_BETA_FEATURES_API
if(ctx.get_exhaustive_tune_flag() && solution_idx == 0) if((enabled(MIGRAPHX_ENABLE_GEMM_TUNING{}) or ctx.get_exhaustive_tune_flag()) and
solution_idx == 0)
{ {
auto gemm_item = gemm_impl<int32_t>( auto gemm_item = gemm_impl<int32_t>(
output_shape, input_shapes, alpha, beta, int8_x4_format, compute_fp32); output_shape, input_shapes, alpha, beta, int8_x4_format, compute_fp32);
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -27,7 +27,6 @@ ...@@ -27,7 +27,6 @@
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
...@@ -36,9 +35,6 @@ ...@@ -36,9 +35,6 @@
#include <unordered_map> #include <unordered_map>
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CONV_TUNING);
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
...@@ -167,20 +163,6 @@ struct miopen_convolution ...@@ -167,20 +163,6 @@ struct miopen_convolution
auto w_desc = make_tensor(reshape_if_1d(inputs[1]), int8_x4_format); auto w_desc = make_tensor(reshape_if_1d(inputs[1]), int8_x4_format);
auto y_desc = make_tensor(reshape_if_1d(output_shape)); auto y_desc = make_tensor(reshape_if_1d(output_shape));
std::size_t workspace_size = 0; std::size_t workspace_size = 0;
// TODO: Using an environment variable to disable convolution tuning while still allowing
// other (i.e. GEMM) tuning with --exhaustive-tune, is a workaround for a problem in which
// convolution tuning takes multiple hours.
const bool convo_tune_enabled = enabled(MIGRAPHX_ENABLE_CONV_TUNING{});
if(ctx.get_exhaustive_tune_flag() and not convo_tune_enabled)
{
std::cerr
<< "WARNING: MIGraphX convolution tuning not enabled. To run tuning, set both the "
"env "
"var MIGRAPHX_ENABLE_CONV_TUNING and the command argument --exhaustive-tune."
<< std::endl;
}
#ifdef MIGRAPHX_HAS_FIND_2_API #ifdef MIGRAPHX_HAS_FIND_2_API
{ {
auto conv_problem = make_obj<miopen_problem>( auto conv_problem = make_obj<miopen_problem>(
...@@ -192,10 +174,9 @@ struct miopen_convolution ...@@ -192,10 +174,9 @@ struct miopen_convolution
auto* miopen_stream_handle = ctx.get_stream().get_miopen(); auto* miopen_stream_handle = ctx.get_stream().get_miopen();
solution_ptr = find_solution(miopen_stream_handle, solution_ptr = find_solution(
conv_problem.get(), miopen_stream_handle, conv_problem.get(), ctx.get_exhaustive_tune_flag());
ctx.get_exhaustive_tune_flag() and convo_tune_enabled); auto status = miopenGetSolutionWorkspaceSize(solution_ptr.get(), &workspace_size);
auto status = miopenGetSolutionWorkspaceSize(solution_ptr.get(), &workspace_size);
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen" + op.name() + " : failed to get solution's workspace size"); MIGRAPHX_THROW("MIOpen" + op.name() + " : failed to get solution's workspace size");
...@@ -252,8 +233,7 @@ struct miopen_convolution ...@@ -252,8 +233,7 @@ struct miopen_convolution
&perf, &perf,
workspace.implicit(), workspace.implicit(),
workspace_size, workspace_size,
ctx.get_exhaustive_tune_flag() and ctx.get_exhaustive_tune_flag());
convo_tune_enabled);
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen " + op.name() + " : find convolution failed"); MIGRAPHX_THROW("MIOpen " + op.name() + " : find convolution failed");
algo = perf.fwd_algo; algo = perf.fwd_algo;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment