Commit 3eaadd61 authored by letaoqin's avatar letaoqin
Browse files

first

parent a0ae1c61
...@@ -17,3 +17,12 @@ struct StreamConfig ...@@ -17,3 +17,12 @@ struct StreamConfig
bool flush_cache = false; bool flush_cache = false;
int rotating_count = 1; int rotating_count = 1;
}; };
struct GemmConfig
{
int tile_m = 1;
int tile_n = 1;
int split_k = 1;
int stages = 1;
std::string op_name = "";
};
...@@ -33,7 +33,7 @@ struct BaseInvoker ...@@ -33,7 +33,7 @@ struct BaseInvoker
{ {
return float{0}; return float{0};
} }
virtual int GetOccupancy(const BaseArgument*) { return 1; }
virtual ~BaseInvoker() {} virtual ~BaseInvoker() {}
}; };
...@@ -67,6 +67,8 @@ struct BaseOperator ...@@ -67,6 +67,8 @@ struct BaseOperator
p_arg->p_workspace_ = p_workspace; p_arg->p_workspace_ = p_workspace;
} }
//virtual int GetOccupancy() { return 1; }
virtual GemmConfig GetConfig() { return GemmConfig{1, 1, 1, 1, ""}; }
virtual ~BaseOperator() {} virtual ~BaseOperator() {}
}; };
......
...@@ -146,7 +146,27 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -146,7 +146,27 @@ bool profile_gemm_universal_impl(int do_verification,
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
float best_kbatch = 0; float best_kbatch = 0;
int best_occupancy = 0;
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hipError_t rtn;
rtn = hipGetDevice(&dev);
hip_check_error(rtn);
rtn = hipGetDeviceProperties(&dev_prop, dev);
hip_check_error(rtn);
int num_cu = dev_prop.multiProcessorCount;
float config_score = 1;
int config_waves = INT_MAX;
int current_tile_m = 0;
int current_occupancy = 0;
float current_tflops = 0;
GemmConfig best_config;
ignore = config_score;
ignore = config_waves;
ignore = current_tile_m;
ignore = best_config;
// profile device GEMM instances // profile device GEMM instances
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
...@@ -157,6 +177,10 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -157,6 +177,10 @@ bool profile_gemm_universal_impl(int do_verification,
kbatch_list = {KBatch}; kbatch_list = {KBatch};
} }
auto candidate_config = op_ptr->GetConfig();
int num_tile_m = (M + candidate_config.tile_m - 1) / candidate_config.tile_m;
int num_tile_n = (N + candidate_config.tile_n - 1) / candidate_config.tile_n;
for(std::size_t i = 0; i < kbatch_list.size(); i++) for(std::size_t i = 0; i < kbatch_list.size(); i++)
{ {
auto kbatch_curr = kbatch_list[i]; auto kbatch_curr = kbatch_list[i];
...@@ -180,7 +204,22 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -180,7 +204,22 @@ bool profile_gemm_universal_impl(int do_verification,
if(op_ptr->IsSupportedArgument(argument_ptr.get())) if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
int occupancy = invoker_ptr->GetOccupancy(argument_ptr.get());
if(occupancy == 0)
continue;
int ctas_per_wave = occupancy * num_cu;
int ctas_for_problem = num_tile_m * num_tile_n * kbatch_curr;
const int num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave;
const float num_waves_fractional = ctas_for_problem / float(ctas_per_wave);
const float current_score = float(num_waves_total) - num_waves_fractional;
std::cout << "tile_m: " << num_tile_m << " tile_n: " << num_tile_n
<< " occupancy: " << occupancy << " current_score:" << current_score
<< " ctas_per_wave: " << ctas_per_wave
<< " ctas_for_problem: " << ctas_for_problem
<< " num_waves_total: " << num_waves_total
<< " num_waves_fractional: " << num_waves_fractional
<< " kbatch_curr: " << kbatch_curr << std::endl;
// re-init C to zero before profiling next kernel // re-init C to zero before profiling next kernel
c_device_buf.SetZero(); c_device_buf.SetZero();
...@@ -227,7 +266,8 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -227,7 +266,8 @@ bool profile_gemm_universal_impl(int do_verification,
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch " << " TFlops, " << gb_per_sec << " GB/s, "
<< ", occupancy: " << occupancy << " " << op_name << ", KBatch "
<< kbatch_curr << std::endl; << kbatch_curr << std::endl;
#if defined CK_ENABLE_FP8 #if defined CK_ENABLE_FP8
...@@ -256,6 +296,42 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -256,6 +296,42 @@ bool profile_gemm_universal_impl(int do_verification,
best_ave_time = ave_time; best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec; best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr; best_kbatch = kbatch_curr;
best_occupancy = occupancy;
}
if(num_waves_total > 1 && num_waves_total < 10)
{
if((current_score < config_score) ||
((config_waves > num_waves_total) && (current_score < config_score + 0.1f)))
{
best_config.tile_m = candidate_config.tile_m;
best_config.tile_n = candidate_config.tile_n;
best_config.stages = candidate_config.stages;
best_config.split_k = kbatch_curr;
best_config.op_name = op_name;
config_score = current_score;
current_tile_m = candidate_config.tile_m;
config_waves = num_waves_total;
current_occupancy = occupancy;
current_tflops = tflops;
}
// else if(abs(current_score - config_score) < 0.001f &&
// (best_config.stages < candidate_config.stages ||
// kbatch_curr < best_config.split_k ||
// current_tile_m < candidate_config.tile_m))
// {
// best_config.tile_m = candidate_config.tile_m;
// best_config.tile_n = candidate_config.tile_n;
// best_config.stages = candidate_config.stages;
// best_config.split_k = kbatch_curr;
// best_config.op_name = op_name;
// current_tile_m = candidate_config.tile_m;
// config_waves = num_waves_total;
// current_occupancy = occupancy;
// current_tflops = tflops;
// }
} }
} }
else else
...@@ -303,8 +379,14 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -303,8 +379,14 @@ bool profile_gemm_universal_impl(int do_verification,
std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA
<< " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << best_kbatch << " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << best_kbatch
<< " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " : "
<< " GB/s, " << best_op_name << std::endl; << " occupancy: " << best_occupancy << " " << best_ave_time << " ms, " << best_tflops
<< " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
std::cout << "tile_m: " << best_config.tile_m << " tile_n: " << best_config.tile_n
<< " split_k: " << best_config.split_k << " stages: " << best_config.stages
<< ", config_score: " << config_score << ", tflops: " << current_tflops
<< ", current_occupancy: " << current_occupancy << " name: " << best_config.op_name
<< ", KBatch " << best_config.split_k << std::endl;
return pass; return pass;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment