Commit 925e6ac3 authored by root's avatar root
Browse files

clean

parent 8f5690c4
......@@ -88,7 +88,7 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
using BLayout = remove_cvref_t<tuple_element_t<0, BsLayout>>;
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<
using GridwiseGemm = GridwiseGemmMultiABD_xdl_cshuffle_v3<
ALayout,
BLayout,
CLayout,
......@@ -165,12 +165,6 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto Run = [&](const auto& kernel) {
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
};
......@@ -184,146 +178,16 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
#if 0
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
#endif
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
// Tail number could be One to Seven
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
#if 0
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
else
#endif
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel =
......@@ -334,8 +198,7 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
......@@ -362,8 +225,7 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
......@@ -377,8 +239,7 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
......@@ -392,8 +253,7 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
......@@ -421,8 +281,7 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
......@@ -434,37 +293,9 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
}
}
}
}
// Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
#if 0
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
#endif
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
......@@ -486,35 +317,7 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
Run(kernel);
}
}
}
else
{
#if 0
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
#endif
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
......@@ -538,34 +341,18 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
}
}
}
}
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
#if 0
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
#endif
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
}
return ave_time;
}
......
......@@ -132,7 +132,7 @@ template <typename ALayout,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v4,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
struct GridwiseGemm_xdl_cshuffle_v3
struct GridwiseGemmMultiABD_xdl_cshuffle_v3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -670,94 +670,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
const CElementwiseOperation c_element_op;
};
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(Argument& karg)
{
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead * karg.M;
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = blockIdx.z * karg.KRead * karg.N;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_k_split_offset = blockIdx.z * karg.KRead;
}
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
{
karg.K = karg.KRead;
}
else
{
karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
}
}
index_t a_k_split_offset;
index_t b_k_split_offset;
};
#if 0
struct SplitKBatchOffsetMultiABD
{
__device__ SplitKBatchOffsetMultiABD(AsGridPointer& p_as_grid,
BsGridPointer& p_bs_grid,
Argument& karg)
{
static_for<0, NumATensor, 1>{}([&](auto i) {
using ALayout_ = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout_>)
{
as_k_split_offset[i] = blockIdx.z * karg.KRead;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout_>)
{
as_k_split_offset[i] = blockIdx.z * karg.KRead * karg.StrideAs[i];
}
p_as_grid_(i) = p_as_grid[i] + as_k_split_offset[i];
});
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BLayout_ = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout_>)
{
bs_k_split_offset[i] = blockIdx.z * karg.KRead * karg.StrideBs[i];
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout_>)
{
bs_k_split_offset[i] = blockIdx.z * karg.KRead;
}
p_bs_grid_(i) = p_bs_grid[i] + bs_k_split_offset[i];
});
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
{
karg.K = karg.KRead;
}
else
{
karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
}
}
AsGridPointer p_as_grid_;
BsGridPointer p_bs_grid_;
std::array<index_t, NumATensor> as_k_split_offset;
std::array<index_t, NumBTensor> bs_k_split_offset;
};
#endif
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
......@@ -1322,28 +1234,29 @@ struct GridwiseGemm_xdl_cshuffle_v3
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
TailNumber TailNum = TailNumber::Odd,
typename AsGridDesc_AK0_M_K1,
typename BsGridDesc_BK0_N_K1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2CTileMap>
__device__ static void Run(AsGridPointer& p_as_grid,
BsGridPointer& p_bs_grid,
DsGridPointer& p_ds_grid,
CDataType* p_c_grid,
void* p_shared,
const Problem& problem,
const AsGridDesc_AK0_M_K1& as_grid_desc_ak0_m_ak1,
const BsGridDesc_BK0_N_K1& bs_grid_desc_bk0_n_bk1,
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op)
{
// std::array<index_t, NumATensor> StrideAs = {problem.StrideA};
// std::array<index_t, NumBTensor> StrideBs = {problem.StrideB};
// AsGridPointer p_as_grid;
// BsGridPointer p_bs_grid;
// DsGridPointer p_ds_grid;
// const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
// problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
// const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
// problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
#if 0
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
......@@ -1358,21 +1271,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
#if 0
static_for<0, NumDTensor, 1>{}([&](auto j) {
ds_grid_desc_m_n(j) = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs[j]);
});
#endif
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
// const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
// const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_bs_grid[I0], b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
#endif
const auto as_grid_buf = generate_tuple(
[&](auto i) {
......@@ -1394,12 +1296,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
// const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
......@@ -1431,38 +1334,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
#if 0
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
#else
const auto idx_as_block_begin =
generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); },
Number<NumATensor>{});
......@@ -1471,7 +1342,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
ThisThreadBlock,
AsDataType,
Tuple<LDSTypeA>,
decltype(as_grid_desc_ak0_m_ak1),
AsGridDesc_AK0_M_K1,
decltype(tie(a_block_desc_ak0_m_ak1)),
AElementwiseOperation,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
......@@ -1491,40 +1362,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
tie(a_block_desc_ak0_m_ak1),
make_tuple(make_multi_index(0, 0, 0)),
a_element_op};
#endif
#if 0
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
#else
const auto idx_bs_block_begin =
generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); },
Number<NumBTensor>{});
......@@ -1533,7 +1371,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
ThisThreadBlock,
BsDataType,
Tuple<LDSTypeB>,
decltype(bs_grid_desc_bk0_n_bk1),
BsGridDesc_BK0_N_K1,
decltype(tie(b_block_desc_bk0_n_bk1)),
BElementwiseOperation,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
......@@ -1554,8 +1392,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple(make_multi_index(0, 0, 0)),
b_element_op};
#endif
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
......@@ -1709,33 +1545,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
#if 0
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
CShuffleDataType, // typename SrcData,
CDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_m_id, 0, block_n_id, 0),
c_element_op};
#else
using EDataType = CDataType;
// tuple of reference to C/Ds tensor descriptors
......@@ -1804,8 +1613,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
c_element_op};
#endif
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
......@@ -1820,20 +1627,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
1>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
#if 0
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
#else
// space filling curve for shuffled blockwise C/D/E
constexpr auto sfc_cde_block =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
......@@ -1844,7 +1638,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
#endif
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
......@@ -1860,23 +1653,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
// make sure it's safe to read from LDS
block_sync_lds();
#if 0
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
}
#else
// each block copy its data from LDS to global
cde_block_copy_lds_and_global.Run(
c_ds_desc_refs,
......@@ -1901,30 +1677,85 @@ struct GridwiseGemm_xdl_cshuffle_v3
I0,
cde_lds_and_global_step);
}
#endif
});
}
}
#if 1
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run(AsGridPointer& p_as_grid,
BsGridPointer& p_bs_grid,
DsGridPointer& p_ds_grid,
CDataType* p_c_grid,
void* p_shared,
const Problem& problem,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op)
{
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
p_as_grid,
p_bs_grid,
p_ds_grid,
p_c_grid,
p_shared,
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map,
a_element_op,
b_element_op,
c_element_op);
}
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd,
typename AsGridDesc_AK0_M_K1,
typename BsGridDesc_BK0_N_K1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2CTileMap>
__device__ static void Run_2Lds(AsGridPointer& p_as_grid,
BsGridPointer& p_bs_grid,
DsGridPointer& p_ds_grid,
CDataType* p_c_grid,
void* p_shared_0,
void* p_shared_1,
const Problem& problem,
const AsGridDesc_AK0_M_K1& as_grid_desc_ak0_m_ak1,
const BsGridDesc_BK0_N_K1& bs_grid_desc_bk0_n_bk1,
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op)
{
// const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
// problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
// const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
// problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
#if 0
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
......@@ -1943,10 +1774,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
// const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
// const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
#endif
const auto as_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -1967,13 +1798,11 @@ struct GridwiseGemm_xdl_cshuffle_v3
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
......@@ -2004,38 +1833,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
#if 0
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
#else
const auto idx_as_block_begin =
generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); },
Number<NumATensor>{});
......@@ -2044,7 +1841,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
ThisThreadBlock,
AsDataType,
Tuple<LDSTypeA>,
decltype(as_grid_desc_ak0_m_ak1),
AsGridDesc_AK0_M_K1, // decltype(as_grid_desc_ak0_m_ak1),
decltype(tie(a_block_desc_ak0_m_ak1)),
AElementwiseOperation,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
......@@ -2065,40 +1862,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple(make_multi_index(0, 0, 0)),
a_element_op};
#endif
#if 0
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
#else
const auto idx_bs_block_begin =
generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); },
Number<NumBTensor>{});
......@@ -2107,7 +1870,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
ThisThreadBlock,
BsDataType,
Tuple<LDSTypeB>,
decltype(bs_grid_desc_bk0_n_bk1),
BsGridDesc_BK0_N_K1, // decltype(bs_grid_desc_bk0_n_bk1),
decltype(tie(b_block_desc_bk0_n_bk1)),
BElementwiseOperation,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
......@@ -2127,7 +1890,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
tie(b_block_desc_bk0_n_bk1),
make_tuple(make_multi_index(0, 0, 0)),
b_element_op};
#endif
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
......@@ -2292,33 +2054,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
#if 0
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
CShuffleDataType, // typename SrcData,
CDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_m_id, 0, block_n_id, 0),
c_element_op};
#else
using EDataType = CDataType;
// tuple of reference to C/Ds tensor descriptors
......@@ -2387,8 +2122,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
c_element_op};
#endif
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
......@@ -2415,7 +2148,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
#if 1
// space filling curve for shuffled blockwise C/D/E
constexpr auto sfc_cde_block =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
......@@ -2424,7 +2156,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
#endif
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
......@@ -2440,23 +2171,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
// make sure it's safe to read from LDS
block_sync_lds();
#if 0
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
}
#else
// each block copy its data from LDS to global
cde_block_copy_lds_and_global.Run(
c_ds_desc_refs,
......@@ -2481,11 +2195,61 @@ struct GridwiseGemm_xdl_cshuffle_v3
I0,
cde_lds_and_global_step);
}
#endif
});
}
}
#endif
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run_2Lds(AsGridPointer& p_as_grid,
BsGridPointer& p_bs_grid,
DsGridPointer& p_ds_grid,
CDataType* p_c_grid,
void* p_shared_0,
void* p_shared_1,
const Problem& problem,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op)
{
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
p_as_grid,
p_bs_grid,
p_ds_grid,
p_c_grid,
p_shared_0,
p_shared_1,
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map,
a_element_op,
b_element_op,
c_element_op);
}
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment