Commit 333176c5 authored by Adam Osewski's avatar Adam Osewski
Browse files

Draft changes to run gridwise gemm through multiple SplitK tiles

parent be48abdb
......@@ -37,10 +37,11 @@ using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>;
using EDataType = F32;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using ALayout = Row;
using BLayout = Col;
// using BLayout = Row;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
......@@ -56,7 +57,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmMultip
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, ck::PipelineVersion::v1>;
// < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
// < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>;
// clang-format on
struct ProblemSize final
......@@ -76,8 +79,9 @@ struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
int k_batch = 128;
bool time_kernel = false;
// int k_batch = 128;
int k_batch = 1;
bool time_kernel = false;
};
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
......@@ -158,9 +162,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
case 3:
ck::utils::FillConstant<ADataType>{1}(a_tensors[i]);
ck::utils::FillConstant<BDataType>{1}(b_tensors[i]);
break;
default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
// a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
// b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
ck::utils::FillMonotonicSeq<ADataType>{0, 1}(a_tensors[i]);
ck::utils::FillMonotonicSeq<BDataType>{1, 1}(b_tensors[i]);
}
}
......@@ -309,17 +319,20 @@ int main(int argc, char* argv[])
if(argc < 11)
{
std::vector<ck::index_t> Ms{64, 127, 255, 129, 260, 190, 77};
// std::vector<ck::index_t> Ms{64, 127, 255, 129, 260, 190, 77};
std::vector<ck::index_t> Ms{64};
problem_size.group_count = Ms.size();
for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ms.push_back(Ms[i]);
problem_size.Ns.push_back(252);
// problem_size.Ns.push_back(252);
problem_size.Ns.push_back(256);
problem_size.Ks.push_back(4608);
problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
// problem_size.stride_Bs.push_back(problem_size.Ns[i]);
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
}
......
......@@ -131,30 +131,40 @@ __global__ void
const auto StrideA = gemm_desc_ptr[group_id].StrideA;
const auto StrideB = gemm_desc_ptr[group_id].StrideB;
results_buffer.Clear();
// results_buffer.Clear();
b2c_tile_map.CalculateBottomIndex(work_scheduler.tile_id_ - offset);
// Iterate over K dimension for this [M,N] tile
// still in the same GEMM && the same [M,N] tile
// TODO: change desc so that few K-tiles will be done in single GEMM.
do
{
// just accumulate results in registers!
GridwiseGemm::template RunGEMM<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
static_cast<void*>(p_shared),
a_element_op,
b_element_op,
M,
N,
K,
StrideA,
StrideB,
k_batch,
b2c_tile_map,
results_buffer);
} while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx());
// do
// {
auto k_tiles = work_scheduler.GetNextKTiles(k_batch, b2c_tile_map.GetTileKIdx());
// if (blockIdx.x < 4 && ck::debug::is_thread_local_1d_id_idx<0>())
// {
// printf("bid: %d, k_tiles: %d\n",
// static_cast<index_t>(blockIdx.x),
// k_tiles);
// }
// just accumulate results in registers!
GridwiseGemm::template RunGEMM<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
static_cast<void*>(p_shared),
a_element_op,
b_element_op,
M,
N,
K,
StrideA,
StrideB,
k_batch,
b2c_tile_map,
results_buffer,
k_tiles);
// Move to the last processed k-tile
b2c_tile_map.AdvanceTileKIdx(k_tiles - 1);
// } while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx());
// if (changed group_id || next [M,N] tile)
// With cshuffle at store partials all workgroups have to store
......@@ -164,7 +174,7 @@ __global__ void
// do CShuffle in flight with loading partials products of other peer workgroups.
GridwiseGemm::StorePartials(p_workspace, static_cast<void*>(p_shared), results_buffer);
#if 0
#if 1
// make sure all writes to gmem has finished.
__builtin_amdgcn_s_waitcnt(0x0f70); // s_waitcnt vmcnt(0)
// __builtin_amdgcn_s_waitcnt(0x0070); // s_waitcnt vmcnt(0) lgkmcnt(0)
......@@ -212,6 +222,11 @@ __global__ void
p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
});
// if (threadIdx.x == 0)
// {
// p_e_grid[blockIdx.x] = 0;
// }
GridwiseGemm::template RunWrite(p_ds_grid,
p_e_grid,
acc_buff,
......@@ -497,29 +512,29 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
{
bool all_have_main_k_block_loop;
{
const auto a_grid_desc_kbatch_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_KBatch_AK0_M_AK1(gemm_kernel_args_[0].M,
gemm_kernel_args_[0].K,
gemm_kernel_args_[0].StrideA,
K_BATCH);
const auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(gemm_kernel_args_[0].M,
gemm_kernel_args_[0].K,
gemm_kernel_args_[0].StrideA,
K_BATCH);
all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3));
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2) /
K_BATCH);
}
for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i)
{
const auto& gemm_arg = gemm_kernel_args_[i];
auto kbatch = K_BATCH;
const auto a_grid_desc_kbatch_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_KBatch_AK0_M_AK1(
gemm_arg.M, gemm_arg.K, gemm_arg.StrideA, kbatch);
const auto& gemm_arg = gemm_kernel_args_[i];
auto kbatch = K_BATCH;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
gemm_arg.M, gemm_arg.K, gemm_arg.StrideA, kbatch);
bool not_all_have_main_k_block_loop_same =
all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3));
all_have_main_k_block_loop xor
GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(I0) *
a_grid_desc_ak0_m_ak1.GetLength(I2) /
K_BATCH);
if(not_all_have_main_k_block_loop_same)
{
......@@ -616,7 +631,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
void* dev_gemm_workspace,
const StreamConfig& stream_config = StreamConfig{})
{
auto [all_have_kbatch_gt_one, all_have_main_k_block_loop] =
[[maybe_unused]] auto [all_have_kbatch_gt_one, all_have_main_k_block_loop] =
CheckArgument(arg, stream_config);
if(dev_gemm_args == nullptr)
......@@ -698,17 +713,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
bool all_have_kbatch_gt_one, all_have_main_k_block_loop;
{
const auto a_grid_desc_kbatch_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_KBatch_AK0_M_AK1(
arg.gemm_kernel_args_[0].M,
arg.gemm_kernel_args_[0].K,
arg.gemm_kernel_args_[0].StrideA,
arg.K_BATCH);
const auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(arg.gemm_kernel_args_[0].M,
arg.gemm_kernel_args_[0].K,
arg.gemm_kernel_args_[0].StrideA,
arg.K_BATCH);
all_have_kbatch_gt_one = arg.K_BATCH > 1;
all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3));
a_grid_desc_ak0_m_ak1.GetLength(I0) *
a_grid_desc_ak0_m_ak1.GetLength(I2 / kbatch);
}
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
......@@ -737,14 +751,14 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
throw std::runtime_error(err.str());
}
const auto a_grid_desc_kbatch_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_KBatch_AK0_M_AK1(
gemm_arg.M, gemm_arg.K, gemm_arg.StrideA, kbatch);
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
gemm_arg.M, gemm_arg.K, gemm_arg.StrideA, kbatch);
bool not_all_have_main_k_block_loop_same =
all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3));
all_have_main_k_block_loop xor
GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(I0) *
a_grid_desc_ak0_m_ak1.GetLength(I2) /
kbatch);
bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1);
if(not_all_have_main_k_block_loop_same)
......@@ -853,8 +867,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
}
auto preprocess = [&]() {
// std::cout << "[preprocess] p_flags: " << p_flags
// << ", flag count: " << flag_count
// << ", bytes: " << flag_count * sizeof(uint32_t)
// << ", stream id: " << stream_config.stream_id_
// << std::endl;
hip_check_error(hipMemsetAsync(
p_flags, 0, flag_count * sizeof(uint32_t), stream_config.stream_id_));
// TODO: For debug only!
hip_check_error(hipMemsetAsync(
dev_gemm_workspace, 2, acc_workspace_size_bytes, stream_config.stream_id_));
};
return launch_and_time_kernel_with_preprocess(
......@@ -890,11 +912,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
arg.skipped_group_count_) != arg.group_count_)
{
#if DEBUG_LOG
std::cout << "The group count is not equal to sum of skipped groups "
"and kernel args size!"
<< std::endl;
#endif // DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "The group count is not equal to sum of skipped groups "
"and kernel args size!"
<< std::endl;
}
return false;
}
......@@ -913,11 +936,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
arg.K_BATCH);
if(not group_arg_valid)
{
#if DEBUG_LOG
std::cout << "[" << __func__ << "] group id: " << i
<< " has invalid GridwiseGemm settings!" << std::endl;
gemm_arg.Print();
#endif // DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "[" << __func__ << "] group id: " << i
<< " has invalid GridwiseGemm settings!" << std::endl;
gemm_arg.Print();
}
}
supported = supported && group_arg_valid;
}
......@@ -1043,6 +1067,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
size_t size_bytes =
Block2ETileMapKSplit::GetAccWorkspaceSize(sizeof(CShuffleDataType), grid_size) +
flag_count * sizeof(uint32_t);
std::cout << "[GetWorkspaceSize]: "
<< "occ_grid_size: " << occ_grid_size << ", grid_size: " << grid_size
<< ", tiles_per_block: " << tiles_per_block << ", flag_count: " << flag_count
<< ", size_bytes: " << size_bytes << std::endl;
return size_bytes;
}
......
......@@ -1531,6 +1531,8 @@ struct BlockToCTileMap_LinearKSplit
return false;
}
__host__ __device__ void AdvanceTileKIdx(index_t k_tiles) { K0_idx_ += k_tiles; }
///
/// @brief Determines whether the current workgroup processed first tile in K dimension
///
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment