Commit 333176c5 authored by Adam Osewski's avatar Adam Osewski
Browse files

Draft changes to run gridwise gemm through multiple SplitK tiles

parent be48abdb
...@@ -37,10 +37,11 @@ using BDataType = F16; ...@@ -37,10 +37,11 @@ using BDataType = F16;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>; using DsDataType = ck::Tuple<>;
using EDataType = F32; using EDataType = F16;
using ALayout = Row; using ALayout = Row;
using BLayout = Col; using BLayout = Col;
// using BLayout = Row;
using DsLayout = ck::Tuple<>; using DsLayout = ck::Tuple<>;
using ELayout = Row; using ELayout = Row;
...@@ -56,7 +57,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmMultip ...@@ -56,7 +57,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmMultip
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, ck::PipelineVersion::v1>;
// < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
// < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>;
// clang-format on // clang-format on
struct ProblemSize final struct ProblemSize final
...@@ -76,7 +79,8 @@ struct ExecutionConfig final ...@@ -76,7 +79,8 @@ struct ExecutionConfig final
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 1;
int k_batch = 128; // int k_batch = 128;
int k_batch = 1;
bool time_kernel = false; bool time_kernel = false;
}; };
...@@ -158,9 +162,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -158,9 +162,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break; break;
case 3:
ck::utils::FillConstant<ADataType>{1}(a_tensors[i]);
ck::utils::FillConstant<BDataType>{1}(b_tensors[i]);
break;
default: default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); // a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); // b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
ck::utils::FillMonotonicSeq<ADataType>{0, 1}(a_tensors[i]);
ck::utils::FillMonotonicSeq<BDataType>{1, 1}(b_tensors[i]);
} }
} }
...@@ -309,17 +319,20 @@ int main(int argc, char* argv[]) ...@@ -309,17 +319,20 @@ int main(int argc, char* argv[])
if(argc < 11) if(argc < 11)
{ {
std::vector<ck::index_t> Ms{64, 127, 255, 129, 260, 190, 77}; // std::vector<ck::index_t> Ms{64, 127, 255, 129, 260, 190, 77};
std::vector<ck::index_t> Ms{64};
problem_size.group_count = Ms.size(); problem_size.group_count = Ms.size();
for(int i = 0; i < problem_size.group_count; i++) for(int i = 0; i < problem_size.group_count; i++)
{ {
problem_size.Ms.push_back(Ms[i]); problem_size.Ms.push_back(Ms[i]);
problem_size.Ns.push_back(252); // problem_size.Ns.push_back(252);
problem_size.Ns.push_back(256);
problem_size.Ks.push_back(4608); problem_size.Ks.push_back(4608);
problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]);
// problem_size.stride_Bs.push_back(problem_size.Ns[i]);
problem_size.stride_Cs.push_back(problem_size.Ns[i]); problem_size.stride_Cs.push_back(problem_size.Ns[i]);
} }
......
...@@ -131,14 +131,21 @@ __global__ void ...@@ -131,14 +131,21 @@ __global__ void
const auto StrideA = gemm_desc_ptr[group_id].StrideA; const auto StrideA = gemm_desc_ptr[group_id].StrideA;
const auto StrideB = gemm_desc_ptr[group_id].StrideB; const auto StrideB = gemm_desc_ptr[group_id].StrideB;
results_buffer.Clear(); // results_buffer.Clear();
b2c_tile_map.CalculateBottomIndex(work_scheduler.tile_id_ - offset); b2c_tile_map.CalculateBottomIndex(work_scheduler.tile_id_ - offset);
// Iterate over K dimension for this [M,N] tile // Iterate over K dimension for this [M,N] tile
// still in the same GEMM && the same [M,N] tile // still in the same GEMM && the same [M,N] tile
// TODO: change desc so that few K-tiles will be done in single GEMM. // TODO: change desc so that few K-tiles will be done in single GEMM.
do // do
{ // {
auto k_tiles = work_scheduler.GetNextKTiles(k_batch, b2c_tile_map.GetTileKIdx());
// if (blockIdx.x < 4 && ck::debug::is_thread_local_1d_id_idx<0>())
// {
// printf("bid: %d, k_tiles: %d\n",
// static_cast<index_t>(blockIdx.x),
// k_tiles);
// }
// just accumulate results in registers! // just accumulate results in registers!
GridwiseGemm::template RunGEMM<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template RunGEMM<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
...@@ -152,9 +159,12 @@ __global__ void ...@@ -152,9 +159,12 @@ __global__ void
StrideB, StrideB,
k_batch, k_batch,
b2c_tile_map, b2c_tile_map,
results_buffer); results_buffer,
k_tiles);
// Move to the last processed k-tile
b2c_tile_map.AdvanceTileKIdx(k_tiles - 1);
} while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx()); // } while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx());
// if (changed group_id || next [M,N] tile) // if (changed group_id || next [M,N] tile)
// With cshuffle at store partials all workgroups have to store // With cshuffle at store partials all workgroups have to store
...@@ -164,7 +174,7 @@ __global__ void ...@@ -164,7 +174,7 @@ __global__ void
// do CShuffle in flight with loading partials products of other peer workgroups. // do CShuffle in flight with loading partials products of other peer workgroups.
GridwiseGemm::StorePartials(p_workspace, static_cast<void*>(p_shared), results_buffer); GridwiseGemm::StorePartials(p_workspace, static_cast<void*>(p_shared), results_buffer);
#if 0 #if 1
// make sure all writes to gmem has finished. // make sure all writes to gmem has finished.
__builtin_amdgcn_s_waitcnt(0x0f70); // s_waitcnt vmcnt(0) __builtin_amdgcn_s_waitcnt(0x0f70); // s_waitcnt vmcnt(0)
// __builtin_amdgcn_s_waitcnt(0x0070); // s_waitcnt vmcnt(0) lgkmcnt(0) // __builtin_amdgcn_s_waitcnt(0x0070); // s_waitcnt vmcnt(0) lgkmcnt(0)
...@@ -212,6 +222,11 @@ __global__ void ...@@ -212,6 +222,11 @@ __global__ void
p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]); p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
}); });
// if (threadIdx.x == 0)
// {
// p_e_grid[blockIdx.x] = 0;
// }
GridwiseGemm::template RunWrite(p_ds_grid, GridwiseGemm::template RunWrite(p_ds_grid,
p_e_grid, p_e_grid,
acc_buff, acc_buff,
...@@ -497,29 +512,29 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -497,29 +512,29 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
{ {
bool all_have_main_k_block_loop; bool all_have_main_k_block_loop;
{ {
const auto a_grid_desc_kbatch_ak0_m_ak1 = const auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_KBatch_AK0_M_AK1(gemm_kernel_args_[0].M, GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(gemm_kernel_args_[0].M,
gemm_kernel_args_[0].K, gemm_kernel_args_[0].K,
gemm_kernel_args_[0].StrideA, gemm_kernel_args_[0].StrideA,
K_BATCH); K_BATCH);
all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop( all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2) /
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)); K_BATCH);
} }
for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i) for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i)
{ {
const auto& gemm_arg = gemm_kernel_args_[i]; const auto& gemm_arg = gemm_kernel_args_[i];
auto kbatch = K_BATCH; auto kbatch = K_BATCH;
const auto a_grid_desc_kbatch_ak0_m_ak1 = const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
GridwiseGemm::MakeAGridDescriptor_KBatch_AK0_M_AK1(
gemm_arg.M, gemm_arg.K, gemm_arg.StrideA, kbatch); gemm_arg.M, gemm_arg.K, gemm_arg.StrideA, kbatch);
bool not_all_have_main_k_block_loop_same = bool not_all_have_main_k_block_loop_same =
all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainKBlockLoop( all_have_main_k_block_loop xor
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(I0) *
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)); a_grid_desc_ak0_m_ak1.GetLength(I2) /
K_BATCH);
if(not_all_have_main_k_block_loop_same) if(not_all_have_main_k_block_loop_same)
{ {
...@@ -616,7 +631,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -616,7 +631,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
void* dev_gemm_workspace, void* dev_gemm_workspace,
const StreamConfig& stream_config = StreamConfig{}) const StreamConfig& stream_config = StreamConfig{})
{ {
auto [all_have_kbatch_gt_one, all_have_main_k_block_loop] = [[maybe_unused]] auto [all_have_kbatch_gt_one, all_have_main_k_block_loop] =
CheckArgument(arg, stream_config); CheckArgument(arg, stream_config);
if(dev_gemm_args == nullptr) if(dev_gemm_args == nullptr)
...@@ -698,17 +713,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -698,17 +713,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
bool all_have_kbatch_gt_one, all_have_main_k_block_loop; bool all_have_kbatch_gt_one, all_have_main_k_block_loop;
{ {
const auto a_grid_desc_kbatch_ak0_m_ak1 = const auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_KBatch_AK0_M_AK1( GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(arg.gemm_kernel_args_[0].M,
arg.gemm_kernel_args_[0].M,
arg.gemm_kernel_args_[0].K, arg.gemm_kernel_args_[0].K,
arg.gemm_kernel_args_[0].StrideA, arg.gemm_kernel_args_[0].StrideA,
arg.K_BATCH); arg.K_BATCH);
all_have_kbatch_gt_one = arg.K_BATCH > 1; all_have_kbatch_gt_one = arg.K_BATCH > 1;
all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop( all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * a_grid_desc_ak0_m_ak1.GetLength(I0) *
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)); a_grid_desc_ak0_m_ak1.GetLength(I2 / kbatch);
} }
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
...@@ -737,14 +751,14 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -737,14 +751,14 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
throw std::runtime_error(err.str()); throw std::runtime_error(err.str());
} }
const auto a_grid_desc_kbatch_ak0_m_ak1 = const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
GridwiseGemm::MakeAGridDescriptor_KBatch_AK0_M_AK1(
gemm_arg.M, gemm_arg.K, gemm_arg.StrideA, kbatch); gemm_arg.M, gemm_arg.K, gemm_arg.StrideA, kbatch);
bool not_all_have_main_k_block_loop_same = bool not_all_have_main_k_block_loop_same =
all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainKBlockLoop( all_have_main_k_block_loop xor
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(I0) *
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)); a_grid_desc_ak0_m_ak1.GetLength(I2) /
kbatch);
bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1); bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1);
if(not_all_have_main_k_block_loop_same) if(not_all_have_main_k_block_loop_same)
...@@ -853,8 +867,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -853,8 +867,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
} }
auto preprocess = [&]() { auto preprocess = [&]() {
// std::cout << "[preprocess] p_flags: " << p_flags
// << ", flag count: " << flag_count
// << ", bytes: " << flag_count * sizeof(uint32_t)
// << ", stream id: " << stream_config.stream_id_
// << std::endl;
hip_check_error(hipMemsetAsync( hip_check_error(hipMemsetAsync(
p_flags, 0, flag_count * sizeof(uint32_t), stream_config.stream_id_)); p_flags, 0, flag_count * sizeof(uint32_t), stream_config.stream_id_));
// TODO: For debug only!
hip_check_error(hipMemsetAsync(
dev_gemm_workspace, 2, acc_workspace_size_bytes, stream_config.stream_id_));
}; };
return launch_and_time_kernel_with_preprocess( return launch_and_time_kernel_with_preprocess(
...@@ -890,11 +912,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -890,11 +912,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) + if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
arg.skipped_group_count_) != arg.group_count_) arg.skipped_group_count_) != arg.group_count_)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "The group count is not equal to sum of skipped groups " std::cout << "The group count is not equal to sum of skipped groups "
"and kernel args size!" "and kernel args size!"
<< std::endl; << std::endl;
#endif // DEBUG_LOG }
return false; return false;
} }
...@@ -913,11 +936,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -913,11 +936,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
arg.K_BATCH); arg.K_BATCH);
if(not group_arg_valid) if(not group_arg_valid)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "[" << __func__ << "] group id: " << i std::cout << "[" << __func__ << "] group id: " << i
<< " has invalid GridwiseGemm settings!" << std::endl; << " has invalid GridwiseGemm settings!" << std::endl;
gemm_arg.Print(); gemm_arg.Print();
#endif // DEBUG_LOG }
} }
supported = supported && group_arg_valid; supported = supported && group_arg_valid;
} }
...@@ -1043,6 +1067,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -1043,6 +1067,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
size_t size_bytes = size_t size_bytes =
Block2ETileMapKSplit::GetAccWorkspaceSize(sizeof(CShuffleDataType), grid_size) + Block2ETileMapKSplit::GetAccWorkspaceSize(sizeof(CShuffleDataType), grid_size) +
flag_count * sizeof(uint32_t); flag_count * sizeof(uint32_t);
std::cout << "[GetWorkspaceSize]: "
<< "occ_grid_size: " << occ_grid_size << ", grid_size: " << grid_size
<< ", tiles_per_block: " << tiles_per_block << ", flag_count: " << flag_count
<< ", size_bytes: " << size_bytes << std::endl;
return size_bytes; return size_bytes;
} }
......
...@@ -1531,6 +1531,8 @@ struct BlockToCTileMap_LinearKSplit ...@@ -1531,6 +1531,8 @@ struct BlockToCTileMap_LinearKSplit
return false; return false;
} }
__host__ __device__ void AdvanceTileKIdx(index_t k_tiles) { K0_idx_ += k_tiles; }
/// ///
/// @brief Determines whether the current workgroup processed first tile in K dimension /// @brief Determines whether the current workgroup processed first tile in K dimension
/// ///
......
...@@ -57,7 +57,7 @@ template <typename ADataType, ...@@ -57,7 +57,7 @@ template <typename ADataType,
index_t NPerXdl, index_t NPerXdl,
index_t MXdlPerWave, index_t MXdlPerWave,
index_t NXdlPerWave, index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcVectorDim,
...@@ -65,7 +65,7 @@ template <typename ADataType, ...@@ -65,7 +65,7 @@ template <typename ADataType,
index_t ABlockTransferDstScalarPerVector_AK1, index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun, bool AThreadTransferSrcResetCoordinateAfterRun,
index_t ABlockLdsExtraM, index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder, typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcVectorDim,
...@@ -81,13 +81,6 @@ template <typename ADataType, ...@@ -81,13 +81,6 @@ template <typename ADataType,
PipelineVersion PipelineVer> PipelineVersion PipelineVer>
class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
{ {
template <index_t... Ids>
__device__ static bool is_thread_local_1d_id_idx()
{
const auto tid = get_thread_local_1d_id();
return ((tid == Ids) || ...);
}
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization;
...@@ -132,28 +125,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -132,28 +125,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
return math::integer_least_multiple(K, KPerBlock * K_Batch); return math::integer_least_multiple(K, KPerBlock * K_Batch);
} }
__host__ __device__ static constexpr auto GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(I1, AK0PerBlock, Number<MPerBlock>{}, AK1),
make_tuple(AK0PerBlock * Number<MPerBlock + ABlockLdsExtraM>{} * AK1,
Number<MPerBlock + ABlockLdsExtraM>{} * AK1,
AK1,
I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(I1, BK0PerBlock, Number<NPerBlock>{}, BK1),
make_tuple(BK0PerBlock * Number<NPerBlock + BBlockLdsExtraN>{} * BK1,
Number<NPerBlock + BBlockLdsExtraN>{} * BK1,
BK1,
I1));
}
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
...@@ -171,7 +142,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -171,7 +142,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
} }
__host__ __device__ static auto __host__ __device__ static auto
MakeAGridDescriptor_KBatch_AK0_M_AK1(index_t M, index_t K, index_t StrideA, index_t KBatch) MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t K, index_t StrideA, index_t KBatch)
{ {
const auto a_grid_desc_m_k = [&]() { const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
...@@ -184,7 +155,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -184,7 +155,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
} }
}(); }();
const auto MPad = CalculateMPadded(M);
const auto KPad = CalculateKPadded(K, KBatch); const auto KPad = CalculateKPadded(K, KBatch);
const auto a_grid_desc_m_kpad = transform_tensor_descriptor( const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
...@@ -193,33 +163,34 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -193,33 +163,34 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto AK0 = KPad / (KBatch * AK1); const auto AK0 = KPad / AK1;
if constexpr(GemmSpec == GemmSpecialization::MPadding || if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding || GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MKPadding || GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
const auto MPad = CalculateMPadded(M);
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_kpad, a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, AK1)), make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(M, MPad - M)), make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
else else
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_kpad, a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, AK1)), make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)), make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
} }
__host__ __device__ static auto __host__ __device__ static auto
MakeBGridDescriptor_KBatch_BK0_N_BK1(index_t K, index_t N, index_t StrideB, index_t KBatch) MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t N, index_t StrideB, index_t KBatch)
{ {
const auto b_grid_desc_k_n = [&]() { const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
...@@ -241,7 +212,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -241,7 +212,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto BK0 = KPad / (KBatch * BK1); const auto BK0 = KPad / BK1;
if constexpr(GemmSpec == GemmSpecialization::NPadding || if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding || GemmSpec == GemmSpecialization::MNPadding ||
...@@ -251,32 +222,30 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -251,32 +222,30 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor( return transform_tensor_descriptor(
b_grid_desc_kpad_n, b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, BK1)), make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(N, NPad - N)), make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
else else
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
b_grid_desc_kpad_n, b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, BK1)), make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
} }
private: private:
using AGridDesc_KBatch_AK0_M_AK1 = using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1, 1))>;
remove_cvref_t<decltype(MakeAGridDescriptor_KBatch_AK0_M_AK1(1, 1, 1, 1))>; using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1, 1))>;
using BGridDesc_KBatch_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_KBatch_BK0_N_BK1(1, 1, 1, 1))>;
using ABlockDesc_KBatch_AK0PerB_MPerB_AK1 = using ABlockDesc_AK0PerB_MPerB_AK1 =
remove_cvref_t<decltype(GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1())>; remove_cvref_t<decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())>;
using BBlockDesc_KBatch_BK0PerB_NPerB_BK1 = using BBlockDesc_BK0PerB_NPerB_BK1 =
remove_cvref_t<decltype(GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1())>; remove_cvref_t<decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())>;
public: public:
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -423,10 +392,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -423,10 +392,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
const index_t StrideE, const index_t StrideE,
const index_t KBatch) const index_t KBatch)
{ {
const auto a_grid_desc_kbatch_ak0_m_ak1 = const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(M, K, StrideA, KBatch);
MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch); const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(K, N, StrideB, KBatch);
const auto b_grid_desc_kbatch_bk0_n_bk1 =
MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch);
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE); const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE);
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
...@@ -436,12 +403,12 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -436,12 +403,12 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
{ {
if(!(M % MPerBlock == 0)) if(!(M % MPerBlock == 0))
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "Arg M value is not a multiple of MPerBlock! M: " << M << " " std::cout << "Arg M value is not a multiple of MPerBlock! M: " << M << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl; << std::endl;
}
#endif // DEBUG_LOG
return false; return false;
} }
} }
...@@ -453,12 +420,13 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -453,12 +420,13 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
{ {
if(!(N % NPerBlock == 0)) if(!(N % NPerBlock == 0))
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "Arg N value is not a multiple of NPerBlock! N: " << N << " " std::cout << "Arg N value is not a multiple of NPerBlock! N: " << N << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl; << std::endl;
}
#endif // DEBUG_LOG
return false; return false;
} }
} }
...@@ -471,12 +439,12 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -471,12 +439,12 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
auto K_t = KBatch * KPerBlock; auto K_t = KBatch * KPerBlock;
if(!(K % K_t == 0)) if(!(K % K_t == 0))
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "Arg K value is not a multiple of ! KBatch * KPerBlock: " << K << " " {
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ std::cout << "Arg K value is not a multiple of ! KBatch * KPerBlock: " << K
<< " " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl; << std::endl;
}
#endif // DEBUG_LOG
return false; return false;
} }
} }
...@@ -485,13 +453,13 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -485,13 +453,13 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
{ {
if(K % ABlockTransferSrcScalarPerVector != 0) if(K % ABlockTransferSrcScalarPerVector != 0)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "Arg K (" << K std::cout << "Arg K (" << K
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector (" << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl; << __LINE__ << ", in function: " << __func__ << std::endl;
}
#endif // DEBUG_LOG
return false; return false;
} }
} }
...@@ -499,13 +467,13 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -499,13 +467,13 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
{ {
if(M % ABlockTransferSrcScalarPerVector != 0) if(M % ABlockTransferSrcScalarPerVector != 0)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "Arg M (" << M std::cout << "Arg M (" << M
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector (" << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl; << __LINE__ << ", in function: " << __func__ << std::endl;
}
#endif // DEBUG_LOG
return false; return false;
} }
} }
...@@ -514,13 +482,13 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -514,13 +482,13 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
{ {
if(N % BBlockTransferSrcScalarPerVector != 0) if(N % BBlockTransferSrcScalarPerVector != 0)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "Arg N (" << N std::cout << "Arg N (" << N
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector (" << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl; << __LINE__ << ", in function: " << __func__ << std::endl;
}
#endif // DEBUG_LOG
return false; return false;
} }
} }
...@@ -528,13 +496,13 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -528,13 +496,13 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
{ {
if(K % BBlockTransferSrcScalarPerVector != 0) if(K % BBlockTransferSrcScalarPerVector != 0)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "Arg K (" << K std::cout << "Arg K (" << K
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector (" << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl; << __LINE__ << ", in function: " << __func__ << std::endl;
}
#endif // DEBUG_LOG
return false; return false;
} }
} }
...@@ -543,14 +511,15 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -543,14 +511,15 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
{ {
if(N % CDEShuffleBlockTransferScalarPerVector_NPerBlock != 0) if(N % CDEShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "Arg N (" << N std::cout << "Arg N (" << N
<< ") value is not a multiple of " << ") value is not a multiple of "
"CDEShuffleBlockTransferScalarPerVector_NPerBlock (" "CDEShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CDEShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ << CDEShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl; << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
#endif // DEBUG_LOG }
return false; return false;
} }
} }
...@@ -558,31 +527,33 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -558,31 +527,33 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
{ {
if(M % CDEShuffleBlockTransferScalarPerVector_NPerBlock != 0) if(M % CDEShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "Arg M (" << M std::cout << "Arg M (" << M
<< ") value is not a multiple of " << ") value is not a multiple of "
"CDEShuffleBlockTransferScalarPerVector_NPerBlock (" "CDEShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CDEShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ << CDEShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl; << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
#endif // DEBUG_LOG }
return false; return false;
} }
} }
// check gridwise gemm pipeline // check gridwise gemm pipeline
const auto num_k_loop = (a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * const auto num_k_loop =
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)) / (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock; (KPerBlock * KBatch);
if(!GridwiseGemmPipe::IsSupported(num_k_loop)) if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "The number of k loops (" << num_k_loop std::cout << "The number of k loops (" << num_k_loop
<< ") value is not supported by GridwiseGemm Pipeline." << ") value is not supported by GridwiseGemm Pipeline."
<< " K0Padded: " << a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) << __FILE__ << " K0Padded: " << a_grid_desc_ak0_m_ak1.GetLength(I1) << __FILE__ << ":"
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl; << __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG }
return false; return false;
} }
...@@ -590,8 +561,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -590,8 +561,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// check tensor size: cannot be larger than 2GB each // check tensor size: cannot be larger than 2GB each
constexpr long_index_t TwoGB = (long_index_t{1} << 31); constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && if(!(a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{ {
return false; return false;
...@@ -681,16 +652,17 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -681,16 +652,17 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const AGridDesc_KBatch_AK0_M_AK1& a_grid_desc_kbatch_ak0_m_ak1, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_KBatch_BK0_N_BK1& b_grid_desc_kbatch_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const Block2ETileMap& block_2_etile_map, const Block2ETileMap& block_2_etile_map,
CThreadBuf& c_thread_buf) CThreadBuf& c_thread_buf,
const index_t k_tiles)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize()); p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
// divide block work by [M, N, K] // divide block work by [M, N, K]
const auto block_work_idx = block_2_etile_map.GetBottomIndex(); const auto block_work_idx = block_2_etile_map.GetBottomIndex();
...@@ -701,33 +673,27 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -701,33 +673,27 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
const index_t n_block_data_idx_on_grid = const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_kbatch_ak0_m_ak1 = constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_kbatch_bk0_n_bk1 = constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1();
using ABlockwiseCopy = using ABlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<1, AK0PerBlock, MPerBlock, AK1>, Sequence<AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ADataType, ADataType,
ComputeType, ComputeType,
AGridDesc_KBatch_AK0_M_AK1, AGridDesc_AK0_M_AK1,
ABlockDesc_KBatch_AK0PerB_MPerB_AK1, ABlockDesc_AK0PerB_MPerB_AK1,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<2, 0, 1, 3>, Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
3, 2,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1, ABlockTransferDstScalarPerVector_AK1,
1, 1,
...@@ -741,17 +707,17 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -741,17 +707,17 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<1, BK0PerBlock, NPerBlock, BK1>, Sequence<BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BDataType, BDataType,
ComputeType, ComputeType,
BGridDesc_KBatch_BK0_N_BK1, BGridDesc_BK0_N_BK1,
BBlockDesc_KBatch_BK0PerB_NPerB_BK1, BBlockDesc_BK0PerB_NPerB_BK1,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<2, 0, 1, 3>, Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
3, 2,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1, BBlockTransferDstScalarPerVector_BK1,
1, 1,
...@@ -760,30 +726,35 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -760,30 +726,35 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
true, true,
NumGemmKPrefetchStage>; NumGemmKPrefetchStage>;
const index_t ak0_start_idx = kbatch_id * AK0PerBlock;
const index_t bk0_start_idx = kbatch_id * BK0PerBlock;
if(blockIdx.x < 4 && ck::debug::is_thread_local_1d_id_idx<0>())
{
printf("[RunGEMM] bid: %d, ak0_start_idx: %d, bk0_start_idx: %d\n",
static_cast<index_t>(blockIdx.x),
ak0_start_idx,
bk0_start_idx);
}
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
ABlockwiseCopy(a_grid_desc_kbatch_ak0_m_ak1, ABlockwiseCopy(a_grid_desc_ak0_m_ak1,
make_multi_index(kbatch_id, 0, m_block_data_idx_on_grid, 0), make_multi_index(ak0_start_idx, m_block_data_idx_on_grid, 0),
a_element_op, a_element_op,
a_block_desc_kbatch_ak0_m_ak1, a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0, 0), make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BBlockwiseCopy(b_grid_desc_kbatch_bk0_n_bk1, BBlockwiseCopy(b_grid_desc_bk0_n_bk1,
make_multi_index(kbatch_id, 0, n_block_data_idx_on_grid, 0), make_multi_index(bk0_start_idx, n_block_data_idx_on_grid, 0),
b_element_op, b_element_op,
b_block_desc_kbatch_bk0_n_bk1, b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0, 0), make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS // a_mtx[K0PerBlock, MPerBlock] is in LDS
...@@ -792,6 +763,9 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -792,6 +763,9 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// register // register
// auto& c_thread_buf = blockwise_gemm_.GetCThreadBuffer(); // auto& c_thread_buf = blockwise_gemm_.GetCThreadBuffer();
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple( constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
...@@ -803,19 +777,27 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -803,19 +777,27 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
static_cast<ComputeType*>(p_shared) + a_block_space_size_aligned, static_cast<ComputeType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock / BK1, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
// gridwise GEMM pipeline // gridwise GEMM pipeline
const auto gridwise_gemm_pipeline = const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>(); GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>();
const index_t num_k_block_main_loop = // TODO: what if AK1 != BK1 ???
__builtin_amdgcn_readfirstlane((a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(k_tiles);
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)) / // __builtin_amdgcn_readfirstlane((a_grid_desc_ak0_m_ak1.GetLength(I1) *
KPerBlock); // a_grid_desc_ak0_m_ak1.GetLength(I3)) /
// KPerBlock);
bool clear_c_thread_buf = false; if(blockIdx.x < 4 && ck::debug::is_thread_local_1d_id_idx<0>())
{
printf("[RunGEMM] bid: %d, num_k_block_main_loop %d\n",
static_cast<index_t>(blockIdx.x),
num_k_block_main_loop);
}
bool clear_c_thread_buf = true;
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
...@@ -831,14 +813,14 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -831,14 +813,14 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
KPack, KPack,
LoopSched>(); LoopSched>();
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_kbatch_ak0_m_ak1, gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_kbatch_ak0_m_ak1, a_block_desc_ak0_m_ak1,
a_blockwise_copy, a_blockwise_copy,
a_grid_buf, a_grid_buf,
a_block_buf, a_block_buf,
a_block_slice_copy_step, a_block_slice_copy_step,
b_grid_desc_kbatch_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
b_block_desc_kbatch_bk0_n_bk1, b_block_desc_bk0_n_bk1,
b_blockwise_copy, b_blockwise_copy,
b_grid_buf, b_grid_buf,
b_block_buf, b_block_buf,
...@@ -862,27 +844,26 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -862,27 +844,26 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
const index_t StrideB, const index_t StrideB,
const index_t KBatch, const index_t KBatch,
const Block2ETileMap& block_2_etile_map, const Block2ETileMap& block_2_etile_map,
CThreadBuf& c_thread_buf) CThreadBuf& c_thread_buf,
const index_t k_tiles)
{ {
const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_); const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_);
const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_); const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
const auto a_grid_desc_kbatch_ak0_m_ak1 = const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(M, K, StrideA, KBatch);
MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch); const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(K, N, StrideB, KBatch);
const auto b_grid_desc_kbatch_bk0_n_bk1 =
MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch);
RunGEMM<HasMainKBlockLoop>(p_a_grid, RunGEMM<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
a_grid_desc_kbatch_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_kbatch_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
block_2_etile_map, block_2_etile_map,
c_thread_buf); c_thread_buf,
k_tiles);
} }
template <typename CThreadBuf> template <typename CThreadBuf>
...@@ -1247,6 +1228,27 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1247,6 +1228,27 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
acc_load.MoveSrcSliceWindow(workspace_grid_desc_m0m1_n0n1n2, partial_acc_load_step); acc_load.MoveSrcSliceWindow(workspace_grid_desc_m0m1_n0n1n2, partial_acc_load_step);
} }
// if(is_thread_local_1d_id_idx<0, 1, 8, 39>())
// {
// printf("[bid: %d, tid: %d], {Accumulate Partials} AccBuf v[0, 0, 0, 0, 0-3]: [%f,
// %f,"
// "%f, %f]\n",
// static_cast<index_t>(blockIdx.x),
// static_cast<index_t>(threadIdx.x),
// static_cast<float>(acc_buff[Number<0>{}]),
// static_cast<float>(acc_buff[Number<1>{}]),
// static_cast<float>(acc_buff[Number<2>{}]),
// static_cast<float>(acc_buff[Number<3>{}]));
// printf("[bid: %d, tid: %d], {Accumulate Partials} AccBuf v[0, 0, 0, 1, 0-3]: [%f,
// %f,"
// "%f, %f]\n",
// static_cast<index_t>(blockIdx.x),
// static_cast<index_t>(threadIdx.x),
// static_cast<float>(acc_buff[Number<8>{}]),
// static_cast<float>(acc_buff[Number<9>{}]),
// static_cast<float>(acc_buff[Number<10>{}]),
// static_cast<float>(acc_buff[Number<11>{}]));
// }
} }
template <typename Block2ETileMap, typename AccumulationBuffer> template <typename Block2ETileMap, typename AccumulationBuffer>
...@@ -1411,6 +1413,21 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1411,6 +1413,21 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
unpack2(cde_element_op, tie(aux_vgpr_buf(I)), src_data_refs); unpack2(cde_element_op, tie(aux_vgpr_buf(I)), src_data_refs);
}); });
// if(is_thread_local_1d_id_idx<0, 1, 8, 39>())
// {
// printf("[bid: %d, tid: %d, m_iter: %d, n_iter: %d], {RunWrite} AuxBuf v[0-3]:
// "
// " [%f, %f, %f, %f]\n",
// static_cast<index_t>(blockIdx.x),
// static_cast<index_t>(threadIdx.x),
// m_idx.value,
// n_idx.value,
// static_cast<float>(aux_vgpr_buf[Number<0>{}]),
// static_cast<float>(aux_vgpr_buf[Number<1>{}]),
// static_cast<float>(aux_vgpr_buf[Number<2>{}]),
// static_cast<float>(aux_vgpr_buf[Number<3>{}]));
// }
e_grid_store.Run(workspace_thread_desc_m0m1_n0n1n2, e_grid_store.Run(workspace_thread_desc_m0m1_n0n1n2,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
aux_vgpr_buf, aux_vgpr_buf,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment