Commit 3eaadd61 authored by letaoqin's avatar letaoqin
Browse files

first

parent a0ae1c61
...@@ -17,3 +17,12 @@ struct StreamConfig ...@@ -17,3 +17,12 @@ struct StreamConfig
bool flush_cache = false; bool flush_cache = false;
int rotating_count = 1; int rotating_count = 1;
}; };
struct GemmConfig
{
int tile_m = 1;
int tile_n = 1;
int split_k = 1;
int stages = 1;
std::string op_name = "";
};
...@@ -33,7 +33,7 @@ struct BaseInvoker ...@@ -33,7 +33,7 @@ struct BaseInvoker
{ {
return float{0}; return float{0};
} }
virtual int GetOccupancy(const BaseArgument*) { return 1; }
virtual ~BaseInvoker() {} virtual ~BaseInvoker() {}
}; };
...@@ -67,6 +67,8 @@ struct BaseOperator ...@@ -67,6 +67,8 @@ struct BaseOperator
p_arg->p_workspace_ = p_workspace; p_arg->p_workspace_ = p_workspace;
} }
//virtual int GetOccupancy() { return 1; }
virtual GemmConfig GetConfig() { return GemmConfig{1, 1, 1, 1, ""}; }
virtual ~BaseOperator() {} virtual ~BaseOperator() {}
}; };
......
...@@ -129,6 +129,430 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -129,6 +129,430 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
int GetOccupancy(const BaseArgument* p_arg) override
{
int occupancy = 0;
auto arg = *dynamic_cast<const Argument*>(p_arg);
ignore = arg;
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
index_t k_grain = arg.KBatch * KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto Run = [&](const auto& kernel) {
hipError_t rtn;
rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte());
hip_check_error(rtn);
};
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
// Tail number could be One to Seven
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Two)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Six)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
}
// Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
}
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
}
return occupancy > 0 ?occupancy : 1;
}
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
if(stream_config.log_level_ > 0) if(stream_config.log_level_ > 0)
...@@ -741,6 +1165,49 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -741,6 +1165,49 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
return str.str(); return str.str();
} }
// static int GetOccupancy2()
// {
// int occupancy = 1;
// constexpr index_t minimum_occupancy =
// BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
// auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
// true,
// InMemoryDataOperationEnum::Set,
// minimum_occupancy>;
// hipError_t rtn;
// rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
// &occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte());
// hip_check_error(rtn);
// return ++occupancy;
// }
// int GetOccupancy() override
// {
// int occupancy = 3;
// // constexpr index_t minimum_occupancy =
// // BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
// // const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
// // true,
// // InMemoryDataOperationEnum::Set,
// // minimum_occupancy>;
// // hipError_t rtn;
// // rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
// // &occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte());
// // hip_check_error(rtn);
// return ++occupancy;
// }
GemmConfig GetConfig() override
{
return GemmConfig{MPerBlock,
NPerBlock,
1,
GridwiseGemm::BlockwiseGemmPipe::PrefetchStages,
GetTypeString()};
}
}; };
} // namespace device } // namespace device
......
...@@ -146,7 +146,27 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -146,7 +146,27 @@ bool profile_gemm_universal_impl(int do_verification,
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
float best_kbatch = 0; float best_kbatch = 0;
int best_occupancy = 0;
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hipError_t rtn;
rtn = hipGetDevice(&dev);
hip_check_error(rtn);
rtn = hipGetDeviceProperties(&dev_prop, dev);
hip_check_error(rtn);
int num_cu = dev_prop.multiProcessorCount;
float config_score = 1;
int config_waves = INT_MAX;
int current_tile_m = 0;
int current_occupancy = 0;
float current_tflops = 0;
GemmConfig best_config;
ignore = config_score;
ignore = config_waves;
ignore = current_tile_m;
ignore = best_config;
// profile device GEMM instances // profile device GEMM instances
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
...@@ -157,6 +177,10 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -157,6 +177,10 @@ bool profile_gemm_universal_impl(int do_verification,
kbatch_list = {KBatch}; kbatch_list = {KBatch};
} }
auto candidate_config = op_ptr->GetConfig();
int num_tile_m = (M + candidate_config.tile_m - 1) / candidate_config.tile_m;
int num_tile_n = (N + candidate_config.tile_n - 1) / candidate_config.tile_n;
for(std::size_t i = 0; i < kbatch_list.size(); i++) for(std::size_t i = 0; i < kbatch_list.size(); i++)
{ {
auto kbatch_curr = kbatch_list[i]; auto kbatch_curr = kbatch_list[i];
...@@ -180,7 +204,22 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -180,7 +204,22 @@ bool profile_gemm_universal_impl(int do_verification,
if(op_ptr->IsSupportedArgument(argument_ptr.get())) if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
int occupancy = invoker_ptr->GetOccupancy(argument_ptr.get());
if(occupancy == 0)
continue;
int ctas_per_wave = occupancy * num_cu;
int ctas_for_problem = num_tile_m * num_tile_n * kbatch_curr;
const int num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave;
const float num_waves_fractional = ctas_for_problem / float(ctas_per_wave);
const float current_score = float(num_waves_total) - num_waves_fractional;
std::cout << "tile_m: " << num_tile_m << " tile_n: " << num_tile_n
<< " occupancy: " << occupancy << " current_score:" << current_score
<< " ctas_per_wave: " << ctas_per_wave
<< " ctas_for_problem: " << ctas_for_problem
<< " num_waves_total: " << num_waves_total
<< " num_waves_fractional: " << num_waves_fractional
<< " kbatch_curr: " << kbatch_curr << std::endl;
// re-init C to zero before profiling next kernel // re-init C to zero before profiling next kernel
c_device_buf.SetZero(); c_device_buf.SetZero();
...@@ -227,7 +266,8 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -227,7 +266,8 @@ bool profile_gemm_universal_impl(int do_verification,
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch " << " TFlops, " << gb_per_sec << " GB/s, "
<< ", occupancy: " << occupancy << " " << op_name << ", KBatch "
<< kbatch_curr << std::endl; << kbatch_curr << std::endl;
#if defined CK_ENABLE_FP8 #if defined CK_ENABLE_FP8
...@@ -256,6 +296,42 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -256,6 +296,42 @@ bool profile_gemm_universal_impl(int do_verification,
best_ave_time = ave_time; best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec; best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr; best_kbatch = kbatch_curr;
best_occupancy = occupancy;
}
if(num_waves_total > 1 && num_waves_total < 10)
{
if((current_score < config_score) ||
((config_waves > num_waves_total) && (current_score < config_score + 0.1f)))
{
best_config.tile_m = candidate_config.tile_m;
best_config.tile_n = candidate_config.tile_n;
best_config.stages = candidate_config.stages;
best_config.split_k = kbatch_curr;
best_config.op_name = op_name;
config_score = current_score;
current_tile_m = candidate_config.tile_m;
config_waves = num_waves_total;
current_occupancy = occupancy;
current_tflops = tflops;
}
// else if(abs(current_score - config_score) < 0.001f &&
// (best_config.stages < candidate_config.stages ||
// kbatch_curr < best_config.split_k ||
// current_tile_m < candidate_config.tile_m))
// {
// best_config.tile_m = candidate_config.tile_m;
// best_config.tile_n = candidate_config.tile_n;
// best_config.stages = candidate_config.stages;
// best_config.split_k = kbatch_curr;
// best_config.op_name = op_name;
// current_tile_m = candidate_config.tile_m;
// config_waves = num_waves_total;
// current_occupancy = occupancy;
// current_tflops = tflops;
// }
} }
} }
else else
...@@ -303,8 +379,14 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -303,8 +379,14 @@ bool profile_gemm_universal_impl(int do_verification,
std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA
<< " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << best_kbatch << " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << best_kbatch
<< " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " : "
<< " GB/s, " << best_op_name << std::endl; << " occupancy: " << best_occupancy << " " << best_ave_time << " ms, " << best_tflops
<< " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
std::cout << "tile_m: " << best_config.tile_m << " tile_n: " << best_config.tile_n
<< " split_k: " << best_config.split_k << " stages: " << best_config.stages
<< ", config_score: " << config_score << ", tflops: " << current_tflops
<< ", current_occupancy: " << current_occupancy << " name: " << best_config.op_name
<< ", KBatch " << best_config.split_k << std::endl;
return pass; return pass;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment