Commit 925e6ac3 authored by root's avatar root
Browse files

clean

parent 8f5690c4
...@@ -88,7 +88,7 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou ...@@ -88,7 +88,7 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
using BLayout = remove_cvref_t<tuple_element_t<0, BsLayout>>; using BLayout = remove_cvref_t<tuple_element_t<0, BsLayout>>;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< using GridwiseGemm = GridwiseGemmMultiABD_xdl_cshuffle_v3<
ALayout, ALayout,
BLayout, BLayout,
CLayout, CLayout,
...@@ -165,12 +165,6 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou ...@@ -165,12 +165,6 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto Run = [&](const auto& kernel) { const auto Run = [&](const auto& kernel) {
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}; };
...@@ -184,388 +178,181 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou ...@@ -184,388 +178,181 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{ {
#if 0
if(arg.KBatch > 1) const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
// Tail number could be One to Seven
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true, true,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::Set,
minimum_occupancy>; minimum_occupancy,
TailNumber::One>;
Run(kernel); Run(kernel);
} }
else else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full)
#endif
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy>; minimum_occupancy,
TailNumber::Full>;
Run(kernel); Run(kernel);
} }
}
// Tail number could be One to Seven if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
#if 0
if(arg.KBatch > 1)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true, true,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, minimum_occupancy,
TailNumber::Full>; TailNumber::Two>;
Run(kernel); Run(kernel);
} }
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
} }
else
#endif if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three)
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::One>; TailNumber::Three>;
Run(kernel); Run(kernel);
} }
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == }
TailNumber::Full)
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four)
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Full>; TailNumber::Four>;
Run(kernel); Run(kernel);
} }
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
} }
}
// Tail number could be Odd or Even if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
#if 0
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
#endif
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
#if 0
if(arg.KBatch > 1)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true, true,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Even>; TailNumber::Five>;
Run(kernel); Run(kernel);
} }
} }
else
#endif if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Odd>; TailNumber::Six>;
Run(kernel); Run(kernel);
} }
else }
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven)
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Even>; TailNumber::Seven>;
Run(kernel); Run(kernel);
} }
} }
} }
} // Tail number could be Odd or Even
else else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{ {
// Tail number always 1 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) {
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{ {
#if 0 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
if(arg.KBatch > 1)
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false, true,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::Set,
minimum_occupancy>; minimum_occupancy,
TailNumber::Odd>;
Run(kernel); Run(kernel);
} }
else else
#endif
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy>; minimum_occupancy,
TailNumber::Even>;
Run(kernel); Run(kernel);
} }
} }
} }
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
return ave_time; return ave_time;
} }
......
...@@ -132,7 +132,7 @@ template <typename ALayout, ...@@ -132,7 +132,7 @@ template <typename ALayout,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v4, BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v4,
typename ComputeTypeA = CDataType, typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA> typename ComputeTypeB = ComputeTypeA>
struct GridwiseGemm_xdl_cshuffle_v3 struct GridwiseGemmMultiABD_xdl_cshuffle_v3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -670,94 +670,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -670,94 +670,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
const CElementwiseOperation c_element_op; const CElementwiseOperation c_element_op;
}; };
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(Argument& karg)
{
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead * karg.M;
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = blockIdx.z * karg.KRead * karg.N;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_k_split_offset = blockIdx.z * karg.KRead;
}
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
{
karg.K = karg.KRead;
}
else
{
karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
}
}
index_t a_k_split_offset;
index_t b_k_split_offset;
};
#if 0
struct SplitKBatchOffsetMultiABD
{
__device__ SplitKBatchOffsetMultiABD(AsGridPointer& p_as_grid,
BsGridPointer& p_bs_grid,
Argument& karg)
{
static_for<0, NumATensor, 1>{}([&](auto i) {
using ALayout_ = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout_>)
{
as_k_split_offset[i] = blockIdx.z * karg.KRead;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout_>)
{
as_k_split_offset[i] = blockIdx.z * karg.KRead * karg.StrideAs[i];
}
p_as_grid_(i) = p_as_grid[i] + as_k_split_offset[i];
});
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BLayout_ = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout_>)
{
bs_k_split_offset[i] = blockIdx.z * karg.KRead * karg.StrideBs[i];
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout_>)
{
bs_k_split_offset[i] = blockIdx.z * karg.KRead;
}
p_bs_grid_(i) = p_bs_grid[i] + bs_k_split_offset[i];
});
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
{
karg.K = karg.KRead;
}
else
{
karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
}
}
AsGridPointer p_as_grid_;
BsGridPointer p_bs_grid_;
std::array<index_t, NumATensor> as_k_split_offset;
std::array<index_t, NumBTensor> bs_k_split_offset;
};
#endif
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
...@@ -1322,28 +1234,29 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1322,28 +1234,29 @@ struct GridwiseGemm_xdl_cshuffle_v3
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd> TailNumber TailNum = TailNumber::Odd,
typename AsGridDesc_AK0_M_K1,
typename BsGridDesc_BK0_N_K1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2CTileMap>
__device__ static void Run(AsGridPointer& p_as_grid, __device__ static void Run(AsGridPointer& p_as_grid,
BsGridPointer& p_bs_grid, BsGridPointer& p_bs_grid,
DsGridPointer& p_ds_grid, DsGridPointer& p_ds_grid,
CDataType* p_c_grid, CDataType* p_c_grid,
void* p_shared, void* p_shared,
const Problem& problem, const AsGridDesc_AK0_M_K1& as_grid_desc_ak0_m_ak1,
const BsGridDesc_BK0_N_K1& bs_grid_desc_bk0_n_bk1,
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op) const CElementwiseOperation& c_element_op)
{ {
// std::array<index_t, NumATensor> StrideAs = {problem.StrideA}; #if 0
// std::array<index_t, NumBTensor> StrideBs = {problem.StrideB};
// AsGridPointer p_as_grid;
// BsGridPointer p_bs_grid;
// DsGridPointer p_ds_grid;
// const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
// problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
// const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
// problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1( const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0); problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1( const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
...@@ -1358,21 +1271,10 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1358,21 +1271,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
#if 0
static_for<0, NumDTensor, 1>{}([&](auto j) {
ds_grid_desc_m_n(j) = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs[j]);
});
#endif
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n, problem.MBlock, problem.NBlock); ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
#endif
// const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
// const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_bs_grid[I0], b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const auto as_grid_buf = generate_tuple( const auto as_grid_buf = generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -1394,12 +1296,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1394,12 +1296,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
const auto ds_grid_buf = generate_tuple( const auto ds_grid_buf = generate_tuple(
[&](auto i) { [&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>( return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
// divide block work by [M, N] // divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; // const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
const auto block_work_idx = const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
...@@ -1431,38 +1334,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1431,38 +1334,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
#if 0
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
#else
const auto idx_as_block_begin = const auto idx_as_block_begin =
generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); }, generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); },
Number<NumATensor>{}); Number<NumATensor>{});
...@@ -1471,7 +1342,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1471,7 +1342,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
ThisThreadBlock, ThisThreadBlock,
AsDataType, AsDataType,
Tuple<LDSTypeA>, Tuple<LDSTypeA>,
decltype(as_grid_desc_ak0_m_ak1), AsGridDesc_AK0_M_K1,
decltype(tie(a_block_desc_ak0_m_ak1)), decltype(tie(a_block_desc_ak0_m_ak1)),
AElementwiseOperation, AElementwiseOperation,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>, Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
...@@ -1491,40 +1362,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1491,40 +1362,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
tie(a_block_desc_ak0_m_ak1), tie(a_block_desc_ak0_m_ak1),
make_tuple(make_multi_index(0, 0, 0)), make_tuple(make_multi_index(0, 0, 0)),
a_element_op}; a_element_op};
#endif
#if 0
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
#else
const auto idx_bs_block_begin = const auto idx_bs_block_begin =
generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); }, generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); },
Number<NumBTensor>{}); Number<NumBTensor>{});
...@@ -1533,7 +1371,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1533,7 +1371,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
ThisThreadBlock, ThisThreadBlock,
BsDataType, BsDataType,
Tuple<LDSTypeB>, Tuple<LDSTypeB>,
decltype(bs_grid_desc_bk0_n_bk1), BsGridDesc_BK0_N_K1,
decltype(tie(b_block_desc_bk0_n_bk1)), decltype(tie(b_block_desc_bk0_n_bk1)),
BElementwiseOperation, BElementwiseOperation,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>, Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
...@@ -1554,8 +1392,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1554,8 +1392,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple(make_multi_index(0, 0, 0)), make_tuple(make_multi_index(0, 0, 0)),
b_element_op}; b_element_op};
#endif
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple( constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
...@@ -1709,33 +1545,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1709,33 +1545,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
n_thread_data_on_block_idx[I2]), n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}}; ck::tensor_operation::element_wise::PassThrough{}};
#if 0
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
CShuffleDataType, // typename SrcData,
CDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_m_id, 0, block_n_id, 0),
c_element_op};
#else
using EDataType = CDataType; using EDataType = CDataType;
// tuple of reference to C/Ds tensor descriptors // tuple of reference to C/Ds tensor descriptors
...@@ -1804,8 +1613,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1804,8 +1613,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
c_element_op}; c_element_op};
#endif
// space filling curve for threadwise C in VGPR // space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr = constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>, SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
...@@ -1820,20 +1627,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1820,20 +1627,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
1>>{}; 1>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
#if 0
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
#else
// space filling curve for shuffled blockwise C/D/E // space filling curve for shuffled blockwise C/D/E
constexpr auto sfc_cde_block = constexpr auto sfc_cde_block =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>, SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
...@@ -1844,7 +1638,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1844,7 +1638,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
#endif
static_for<0, num_access, 1>{}([&](auto access_id) { static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS // make sure it's safe to write to LDS
...@@ -1860,23 +1653,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1860,23 +1653,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
// make sure it's safe to read from LDS // make sure it's safe to read from LDS
block_sync_lds(); block_sync_lds();
#if 0
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
}
#else
// each block copy its data from LDS to global // each block copy its data from LDS to global
cde_block_copy_lds_and_global.Run( cde_block_copy_lds_and_global.Run(
c_ds_desc_refs, c_ds_desc_refs,
...@@ -1901,30 +1677,85 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1901,30 +1677,85 @@ struct GridwiseGemm_xdl_cshuffle_v3
I0, I0,
cde_lds_and_global_step); cde_lds_and_global_step);
} }
#endif
}); });
} }
} }
#if 1
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd> TailNumber TailNum = TailNumber::Odd>
__device__ static void Run(AsGridPointer& p_as_grid,
BsGridPointer& p_bs_grid,
DsGridPointer& p_ds_grid,
CDataType* p_c_grid,
void* p_shared,
const Problem& problem,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op)
{
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
p_as_grid,
p_bs_grid,
p_ds_grid,
p_c_grid,
p_shared,
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map,
a_element_op,
b_element_op,
c_element_op);
}
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd,
typename AsGridDesc_AK0_M_K1,
typename BsGridDesc_BK0_N_K1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2CTileMap>
__device__ static void Run_2Lds(AsGridPointer& p_as_grid, __device__ static void Run_2Lds(AsGridPointer& p_as_grid,
BsGridPointer& p_bs_grid, BsGridPointer& p_bs_grid,
DsGridPointer& p_ds_grid, DsGridPointer& p_ds_grid,
CDataType* p_c_grid, CDataType* p_c_grid,
void* p_shared_0, void* p_shared_0,
void* p_shared_1, void* p_shared_1,
const Problem& problem, const AsGridDesc_AK0_M_K1& as_grid_desc_ak0_m_ak1,
const BsGridDesc_BK0_N_K1& bs_grid_desc_bk0_n_bk1,
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op) const CElementwiseOperation& c_element_op)
{ {
// const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( #if 0
// problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
// const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
// problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1( const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0); problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1( const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
...@@ -1943,10 +1774,10 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1943,10 +1774,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n, problem.MBlock, problem.NBlock); ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
// const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( // divide block work by [M, N]
// p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
// const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( #endif
// p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const auto as_grid_buf = generate_tuple( const auto as_grid_buf = generate_tuple(
[&](auto i) { [&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>( return make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -1967,13 +1798,11 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1967,13 +1798,11 @@ struct GridwiseGemm_xdl_cshuffle_v3
const auto ds_grid_buf = generate_tuple( const auto ds_grid_buf = generate_tuple(
[&](auto i) { [&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>( return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
const auto block_work_idx = const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
...@@ -2004,38 +1833,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -2004,38 +1833,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
#if 0
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
#else
const auto idx_as_block_begin = const auto idx_as_block_begin =
generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); }, generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); },
Number<NumATensor>{}); Number<NumATensor>{});
...@@ -2044,7 +1841,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -2044,7 +1841,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
ThisThreadBlock, ThisThreadBlock,
AsDataType, AsDataType,
Tuple<LDSTypeA>, Tuple<LDSTypeA>,
decltype(as_grid_desc_ak0_m_ak1), AsGridDesc_AK0_M_K1, // decltype(as_grid_desc_ak0_m_ak1),
decltype(tie(a_block_desc_ak0_m_ak1)), decltype(tie(a_block_desc_ak0_m_ak1)),
AElementwiseOperation, AElementwiseOperation,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>, Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
...@@ -2065,40 +1862,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -2065,40 +1862,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple(make_multi_index(0, 0, 0)), make_tuple(make_multi_index(0, 0, 0)),
a_element_op}; a_element_op};
#endif
#if 0
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
#else
const auto idx_bs_block_begin = const auto idx_bs_block_begin =
generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); }, generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); },
Number<NumBTensor>{}); Number<NumBTensor>{});
...@@ -2107,7 +1870,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -2107,7 +1870,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
ThisThreadBlock, ThisThreadBlock,
BsDataType, BsDataType,
Tuple<LDSTypeB>, Tuple<LDSTypeB>,
decltype(bs_grid_desc_bk0_n_bk1), BsGridDesc_BK0_N_K1, // decltype(bs_grid_desc_bk0_n_bk1),
decltype(tie(b_block_desc_bk0_n_bk1)), decltype(tie(b_block_desc_bk0_n_bk1)),
BElementwiseOperation, BElementwiseOperation,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>, Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
...@@ -2127,7 +1890,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -2127,7 +1890,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
tie(b_block_desc_bk0_n_bk1), tie(b_block_desc_bk0_n_bk1),
make_tuple(make_multi_index(0, 0, 0)), make_tuple(make_multi_index(0, 0, 0)),
b_element_op}; b_element_op};
#endif
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple( constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
...@@ -2292,33 +2054,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -2292,33 +2054,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
n_thread_data_on_block_idx[I2]), n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}}; ck::tensor_operation::element_wise::PassThrough{}};
#if 0
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
CShuffleDataType, // typename SrcData,
CDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_m_id, 0, block_n_id, 0),
c_element_op};
#else
using EDataType = CDataType; using EDataType = CDataType;
// tuple of reference to C/Ds tensor descriptors // tuple of reference to C/Ds tensor descriptors
...@@ -2387,8 +2122,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -2387,8 +2122,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
c_element_op}; c_element_op};
#endif
// space filling curve for threadwise C in VGPR // space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr = constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>, SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
...@@ -2415,7 +2148,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -2415,7 +2148,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
#if 1
// space filling curve for shuffled blockwise C/D/E // space filling curve for shuffled blockwise C/D/E
constexpr auto sfc_cde_block = constexpr auto sfc_cde_block =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>, SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
...@@ -2424,7 +2156,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -2424,7 +2156,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1, 1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
#endif
static_for<0, num_access, 1>{}([&](auto access_id) { static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS // make sure it's safe to write to LDS
...@@ -2440,23 +2171,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -2440,23 +2171,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
// make sure it's safe to read from LDS // make sure it's safe to read from LDS
block_sync_lds(); block_sync_lds();
#if 0
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
}
#else
// each block copy its data from LDS to global // each block copy its data from LDS to global
cde_block_copy_lds_and_global.Run( cde_block_copy_lds_and_global.Run(
c_ds_desc_refs, c_ds_desc_refs,
...@@ -2481,11 +2195,61 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -2481,11 +2195,61 @@ struct GridwiseGemm_xdl_cshuffle_v3
I0, I0,
cde_lds_and_global_step); cde_lds_and_global_step);
} }
#endif
}); });
} }
} }
#endif
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run_2Lds(AsGridPointer& p_as_grid,
BsGridPointer& p_bs_grid,
DsGridPointer& p_ds_grid,
CDataType* p_c_grid,
void* p_shared_0,
void* p_shared_1,
const Problem& problem,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op)
{
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
p_as_grid,
p_bs_grid,
p_ds_grid,
p_c_grid,
p_shared_0,
p_shared_1,
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map,
a_element_op,
b_element_op,
c_element_op);
}
}; };
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment