Commit 8713ade3 authored by mtgu0705's avatar mtgu0705
Browse files

Enalbe splitK

parent 7a17ead7
......@@ -45,6 +45,17 @@ using DeviceGemmV2Instance =
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0,
1, 1, S<1, 32, 1, 8>, 8,
// 128, Scale_Block_N, Scale_Block_K,
// 16, 128,
// KPerBlock, 8, 32,
// 16, 16,
// 1, 4,
// S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
// 2, 8, 8, 0,
// S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
// 2, 32, 32, 0,
// 1, 1, S<1, 16, 1, 8>, 4,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, CDataType, CDataType, false, PermuteB>;
// clang-format on
......@@ -273,6 +284,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
StrideA,
StrideB,
StrideC,
Scale_Stride_BN,
static_cast<BScaleDataType*>(b1_scale_device_buf.GetDeviceBuffer()),
KBatch,
a_element_op,
......
......@@ -100,6 +100,7 @@ struct DeviceGemmV2BScale : public BaseOperator
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
ck::index_t StrideScaleB,
const void* p_b_scale,
ck::index_t KSplit,
AElementwiseOperation a_element_op,
......
......@@ -663,6 +663,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t StrideScaleB,
const BScaleDataType* p_b_scale,
index_t KBatch,
AElementwiseOperation a_element_op,
......@@ -678,6 +679,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
StrideA,
StrideB,
StrideC,
StrideScaleB,
p_b_scale,
KBatch,
a_element_op,
......@@ -697,6 +699,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t StrideScaleB,
const void* p_b_scale,
index_t KBatch,
AElementwiseOperation a_element_op,
......@@ -712,6 +715,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
StrideA,
StrideB,
StrideC,
StrideScaleB,
static_cast<const BScaleDataType*>(p_b_scale),
KBatch,
a_element_op,
......
......@@ -37,18 +37,16 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
// GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
// karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
// karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
// karg.p_c_grid + splitk_batch_offset.scale_k_split_offset,
// karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
// p_shared,
// karg);
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, karg.p_b_scale_grid, p_shared, karg);
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset,
p_shared,
karg);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
......@@ -72,24 +70,17 @@ __global__ void
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
// GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
// karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
// karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
// karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
// p_shared_0,
// p_shared_1,
// karg);
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
karg.p_b_scale_grid,
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset,
p_shared_0,
p_shared_1,
karg);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
......@@ -533,6 +524,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
index_t StrideScaleB_,
index_t KBatch_)
: M{M_},
N{N_},
......@@ -540,6 +532,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
StrideScaleB{StrideScaleB_},
KBatch{KBatch_},
MPadded{CalculateMPadded(M_)},
NPadded{CalculateNPadded(N_)},
......@@ -561,6 +554,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << ", "
<< "SScaleB:" << StrideScaleB << ", "
<< "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", "
<< "KRead:" << KRead << ", "
......@@ -577,6 +571,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
index_t StrideA;
index_t StrideB;
index_t StrideC;
index_t StrideScaleB;
index_t KBatch;
index_t MPadded;
index_t NPadded;
......@@ -600,13 +595,14 @@ struct GridwiseGemm_xdl_cshuffle_v3
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
index_t StrideScaleB_,
const BScaleType* p_b_scale_grid_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_,
bool is_reduce_ = false)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_},
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, StrideScaleB_, k_batch_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
......@@ -670,15 +666,15 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
}
// // Calculate B scale offset
// if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
// {
// scale_k_split_offset = blockIdx.z * (karg.K / 64) * karg.StrideB;
// }
// else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
// {
// scale_k_split_offset = blockIdx.z * (karg.K / 64) * karg.N;
// }
// Calculate B scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK);
}
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
{
......@@ -701,7 +697,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
index_t a_k_split_offset;
index_t b_k_split_offset;
// index_t scale_k_split_offset; // New member for scale matrix offset
index_t scale_k_split_offset; // New member for scale matrix offset
index_t c_reduce_offset;
};
......@@ -1273,6 +1269,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
template <typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename BScaleGridDesc_BN_AK,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
......@@ -1285,6 +1282,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
const Problem& problem,
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock)
{
......@@ -1295,12 +1293,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// B Scale grid and buffer
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
// B Scale buffer
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
......@@ -1703,8 +1696,15 @@ struct GridwiseGemm_xdl_cshuffle_v3
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
// B Scale grid
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(problem.StrideScaleB, 1));
Run<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_scale_grid_desc_bn_ak),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
......@@ -1716,11 +1716,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
problem,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
b_scale_grid_desc_bn_ak,
c_grid_desc_mblock_mperblock_nblock_nperblock);
}
template <typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename BScaleGridDesc_BN_AK,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
......@@ -1734,6 +1736,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
const Problem& problem,
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock)
{
......@@ -1744,12 +1747,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// B Scale grid and buffer
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
// B Scale buffer
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
......@@ -2164,8 +2162,14 @@ struct GridwiseGemm_xdl_cshuffle_v3
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(problem.StrideScaleB, 1));
Run_2Lds<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_scale_grid_desc_bn_ak),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
......@@ -2178,6 +2182,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
problem,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
b_scale_grid_desc_bn_ak,
c_grid_desc_mblock_mperblock_nblock_nperblock);
}
};
......
......@@ -301,6 +301,7 @@ bool profile_gemm_b_scale_impl(int do_verification,
StrideA,
StrideB,
StrideC,
Scale_Stride_BN,
static_cast<BScaleDataType*>(b1_device_buf.GetDeviceBuffer()),
kbatch_curr,
a_element_op,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment