Commit 333176c5 authored by Adam Osewski's avatar Adam Osewski
Browse files

Draft changes to run gridwise gemm through multiple SplitK tiles

parent be48abdb
...@@ -37,10 +37,11 @@ using BDataType = F16; ...@@ -37,10 +37,11 @@ using BDataType = F16;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>; using DsDataType = ck::Tuple<>;
using EDataType = F32; using EDataType = F16;
using ALayout = Row; using ALayout = Row;
using BLayout = Col; using BLayout = Col;
// using BLayout = Row;
using DsLayout = ck::Tuple<>; using DsLayout = ck::Tuple<>;
using ELayout = Row; using ELayout = Row;
...@@ -56,7 +57,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmMultip ...@@ -56,7 +57,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmMultip
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, ck::PipelineVersion::v1>;
// < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
// < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>;
// clang-format on // clang-format on
struct ProblemSize final struct ProblemSize final
...@@ -76,8 +79,9 @@ struct ExecutionConfig final ...@@ -76,8 +79,9 @@ struct ExecutionConfig final
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 1;
int k_batch = 128; // int k_batch = 128;
bool time_kernel = false; int k_batch = 1;
bool time_kernel = false;
}; };
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...@@ -158,9 +162,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -158,9 +162,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break; break;
case 3:
ck::utils::FillConstant<ADataType>{1}(a_tensors[i]);
ck::utils::FillConstant<BDataType>{1}(b_tensors[i]);
break;
default: default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); // a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); // b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
ck::utils::FillMonotonicSeq<ADataType>{0, 1}(a_tensors[i]);
ck::utils::FillMonotonicSeq<BDataType>{1, 1}(b_tensors[i]);
} }
} }
...@@ -309,17 +319,20 @@ int main(int argc, char* argv[]) ...@@ -309,17 +319,20 @@ int main(int argc, char* argv[])
if(argc < 11) if(argc < 11)
{ {
std::vector<ck::index_t> Ms{64, 127, 255, 129, 260, 190, 77}; // std::vector<ck::index_t> Ms{64, 127, 255, 129, 260, 190, 77};
std::vector<ck::index_t> Ms{64};
problem_size.group_count = Ms.size(); problem_size.group_count = Ms.size();
for(int i = 0; i < problem_size.group_count; i++) for(int i = 0; i < problem_size.group_count; i++)
{ {
problem_size.Ms.push_back(Ms[i]); problem_size.Ms.push_back(Ms[i]);
problem_size.Ns.push_back(252); // problem_size.Ns.push_back(252);
problem_size.Ns.push_back(256);
problem_size.Ks.push_back(4608); problem_size.Ks.push_back(4608);
problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]);
// problem_size.stride_Bs.push_back(problem_size.Ns[i]);
problem_size.stride_Cs.push_back(problem_size.Ns[i]); problem_size.stride_Cs.push_back(problem_size.Ns[i]);
} }
......
...@@ -131,30 +131,40 @@ __global__ void ...@@ -131,30 +131,40 @@ __global__ void
const auto StrideA = gemm_desc_ptr[group_id].StrideA; const auto StrideA = gemm_desc_ptr[group_id].StrideA;
const auto StrideB = gemm_desc_ptr[group_id].StrideB; const auto StrideB = gemm_desc_ptr[group_id].StrideB;
results_buffer.Clear(); // results_buffer.Clear();
b2c_tile_map.CalculateBottomIndex(work_scheduler.tile_id_ - offset); b2c_tile_map.CalculateBottomIndex(work_scheduler.tile_id_ - offset);
// Iterate over K dimension for this [M,N] tile // Iterate over K dimension for this [M,N] tile
// still in the same GEMM && the same [M,N] tile // still in the same GEMM && the same [M,N] tile
// TODO: change desc so that few K-tiles will be done in single GEMM. // TODO: change desc so that few K-tiles will be done in single GEMM.
do // do
{ // {
// just accumulate results in registers! auto k_tiles = work_scheduler.GetNextKTiles(k_batch, b2c_tile_map.GetTileKIdx());
GridwiseGemm::template RunGEMM<HasMainKBlockLoop>(p_a_grid, // if (blockIdx.x < 4 && ck::debug::is_thread_local_1d_id_idx<0>())
p_b_grid, // {
static_cast<void*>(p_shared), // printf("bid: %d, k_tiles: %d\n",
a_element_op, // static_cast<index_t>(blockIdx.x),
b_element_op, // k_tiles);
M, // }
N, // just accumulate results in registers!
K, GridwiseGemm::template RunGEMM<HasMainKBlockLoop>(p_a_grid,
StrideA, p_b_grid,
StrideB, static_cast<void*>(p_shared),
k_batch, a_element_op,
b2c_tile_map, b_element_op,
results_buffer); M,
N,
} while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx()); K,
StrideA,
StrideB,
k_batch,
b2c_tile_map,
results_buffer,
k_tiles);
// Move to the last processed k-tile
b2c_tile_map.AdvanceTileKIdx(k_tiles - 1);
// } while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx());
// if (changed group_id || next [M,N] tile) // if (changed group_id || next [M,N] tile)
// With cshuffle at store partials all workgroups have to store // With cshuffle at store partials all workgroups have to store
...@@ -164,7 +174,7 @@ __global__ void ...@@ -164,7 +174,7 @@ __global__ void
// do CShuffle in flight with loading partials products of other peer workgroups. // do CShuffle in flight with loading partials products of other peer workgroups.
GridwiseGemm::StorePartials(p_workspace, static_cast<void*>(p_shared), results_buffer); GridwiseGemm::StorePartials(p_workspace, static_cast<void*>(p_shared), results_buffer);
#if 0 #if 1
// make sure all writes to gmem has finished. // make sure all writes to gmem has finished.
__builtin_amdgcn_s_waitcnt(0x0f70); // s_waitcnt vmcnt(0) __builtin_amdgcn_s_waitcnt(0x0f70); // s_waitcnt vmcnt(0)
// __builtin_amdgcn_s_waitcnt(0x0070); // s_waitcnt vmcnt(0) lgkmcnt(0) // __builtin_amdgcn_s_waitcnt(0x0070); // s_waitcnt vmcnt(0) lgkmcnt(0)
...@@ -212,6 +222,11 @@ __global__ void ...@@ -212,6 +222,11 @@ __global__ void
p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]); p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
}); });
// if (threadIdx.x == 0)
// {
// p_e_grid[blockIdx.x] = 0;
// }
GridwiseGemm::template RunWrite(p_ds_grid, GridwiseGemm::template RunWrite(p_ds_grid,
p_e_grid, p_e_grid,
acc_buff, acc_buff,
...@@ -497,29 +512,29 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -497,29 +512,29 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
{ {
bool all_have_main_k_block_loop; bool all_have_main_k_block_loop;
{ {
const auto a_grid_desc_kbatch_ak0_m_ak1 = const auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_KBatch_AK0_M_AK1(gemm_kernel_args_[0].M, GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(gemm_kernel_args_[0].M,
gemm_kernel_args_[0].K, gemm_kernel_args_[0].K,
gemm_kernel_args_[0].StrideA, gemm_kernel_args_[0].StrideA,
K_BATCH); K_BATCH);
all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop( all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2) /
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)); K_BATCH);
} }
for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i) for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i)
{ {
const auto& gemm_arg = gemm_kernel_args_[i]; const auto& gemm_arg = gemm_kernel_args_[i];
auto kbatch = K_BATCH; auto kbatch = K_BATCH;
const auto a_grid_desc_kbatch_ak0_m_ak1 = const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
GridwiseGemm::MakeAGridDescriptor_KBatch_AK0_M_AK1( gemm_arg.M, gemm_arg.K, gemm_arg.StrideA, kbatch);
gemm_arg.M, gemm_arg.K, gemm_arg.StrideA, kbatch);
bool not_all_have_main_k_block_loop_same = bool not_all_have_main_k_block_loop_same =
all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainKBlockLoop( all_have_main_k_block_loop xor
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(I0) *
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)); a_grid_desc_ak0_m_ak1.GetLength(I2) /
K_BATCH);
if(not_all_have_main_k_block_loop_same) if(not_all_have_main_k_block_loop_same)
{ {
...@@ -616,7 +631,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -616,7 +631,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
void* dev_gemm_workspace, void* dev_gemm_workspace,
const StreamConfig& stream_config = StreamConfig{}) const StreamConfig& stream_config = StreamConfig{})
{ {
auto [all_have_kbatch_gt_one, all_have_main_k_block_loop] = [[maybe_unused]] auto [all_have_kbatch_gt_one, all_have_main_k_block_loop] =
CheckArgument(arg, stream_config); CheckArgument(arg, stream_config);
if(dev_gemm_args == nullptr) if(dev_gemm_args == nullptr)
...@@ -698,17 +713,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -698,17 +713,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
bool all_have_kbatch_gt_one, all_have_main_k_block_loop; bool all_have_kbatch_gt_one, all_have_main_k_block_loop;
{ {
const auto a_grid_desc_kbatch_ak0_m_ak1 = const auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_KBatch_AK0_M_AK1( GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(arg.gemm_kernel_args_[0].M,
arg.gemm_kernel_args_[0].M, arg.gemm_kernel_args_[0].K,
arg.gemm_kernel_args_[0].K, arg.gemm_kernel_args_[0].StrideA,
arg.gemm_kernel_args_[0].StrideA, arg.K_BATCH);
arg.K_BATCH);
all_have_kbatch_gt_one = arg.K_BATCH > 1; all_have_kbatch_gt_one = arg.K_BATCH > 1;
all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop( all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * a_grid_desc_ak0_m_ak1.GetLength(I0) *
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)); a_grid_desc_ak0_m_ak1.GetLength(I2 / kbatch);
} }
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
...@@ -737,14 +751,14 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -737,14 +751,14 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
throw std::runtime_error(err.str()); throw std::runtime_error(err.str());
} }
const auto a_grid_desc_kbatch_ak0_m_ak1 = const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
GridwiseGemm::MakeAGridDescriptor_KBatch_AK0_M_AK1( gemm_arg.M, gemm_arg.K, gemm_arg.StrideA, kbatch);
gemm_arg.M, gemm_arg.K, gemm_arg.StrideA, kbatch);
bool not_all_have_main_k_block_loop_same = bool not_all_have_main_k_block_loop_same =
all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainKBlockLoop( all_have_main_k_block_loop xor
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(I0) *
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)); a_grid_desc_ak0_m_ak1.GetLength(I2) /
kbatch);
bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1); bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1);
if(not_all_have_main_k_block_loop_same) if(not_all_have_main_k_block_loop_same)
...@@ -853,8 +867,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -853,8 +867,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
} }
auto preprocess = [&]() { auto preprocess = [&]() {
// std::cout << "[preprocess] p_flags: " << p_flags
// << ", flag count: " << flag_count
// << ", bytes: " << flag_count * sizeof(uint32_t)
// << ", stream id: " << stream_config.stream_id_
// << std::endl;
hip_check_error(hipMemsetAsync( hip_check_error(hipMemsetAsync(
p_flags, 0, flag_count * sizeof(uint32_t), stream_config.stream_id_)); p_flags, 0, flag_count * sizeof(uint32_t), stream_config.stream_id_));
// TODO: For debug only!
hip_check_error(hipMemsetAsync(
dev_gemm_workspace, 2, acc_workspace_size_bytes, stream_config.stream_id_));
}; };
return launch_and_time_kernel_with_preprocess( return launch_and_time_kernel_with_preprocess(
...@@ -890,11 +912,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -890,11 +912,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) + if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
arg.skipped_group_count_) != arg.group_count_) arg.skipped_group_count_) != arg.group_count_)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "The group count is not equal to sum of skipped groups " {
"and kernel args size!" std::cout << "The group count is not equal to sum of skipped groups "
<< std::endl; "and kernel args size!"
#endif // DEBUG_LOG << std::endl;
}
return false; return false;
} }
...@@ -913,11 +936,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -913,11 +936,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
arg.K_BATCH); arg.K_BATCH);
if(not group_arg_valid) if(not group_arg_valid)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "[" << __func__ << "] group id: " << i {
<< " has invalid GridwiseGemm settings!" << std::endl; std::cout << "[" << __func__ << "] group id: " << i
gemm_arg.Print(); << " has invalid GridwiseGemm settings!" << std::endl;
#endif // DEBUG_LOG gemm_arg.Print();
}
} }
supported = supported && group_arg_valid; supported = supported && group_arg_valid;
} }
...@@ -1043,6 +1067,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -1043,6 +1067,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
size_t size_bytes = size_t size_bytes =
Block2ETileMapKSplit::GetAccWorkspaceSize(sizeof(CShuffleDataType), grid_size) + Block2ETileMapKSplit::GetAccWorkspaceSize(sizeof(CShuffleDataType), grid_size) +
flag_count * sizeof(uint32_t); flag_count * sizeof(uint32_t);
std::cout << "[GetWorkspaceSize]: "
<< "occ_grid_size: " << occ_grid_size << ", grid_size: " << grid_size
<< ", tiles_per_block: " << tiles_per_block << ", flag_count: " << flag_count
<< ", size_bytes: " << size_bytes << std::endl;
return size_bytes; return size_bytes;
} }
......
...@@ -1531,6 +1531,8 @@ struct BlockToCTileMap_LinearKSplit ...@@ -1531,6 +1531,8 @@ struct BlockToCTileMap_LinearKSplit
return false; return false;
} }
__host__ __device__ void AdvanceTileKIdx(index_t k_tiles) { K0_idx_ += k_tiles; }
/// ///
/// @brief Determines whether the current workgroup processed first tile in K dimension /// @brief Determines whether the current workgroup processed first tile in K dimension
/// ///
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment