Commit 9bd44685 authored by aska-0096's avatar aska-0096
Browse files

Correctness OK, waiting for optimization

parent 73959956
...@@ -22,14 +22,21 @@ using CElementOp = PassThrough; ...@@ -22,14 +22,21 @@ using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma // using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmWmma
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MWMMA|NMMMA| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| // ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MWMMA|NMMMA| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| // ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| // ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 6, 1>; // < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 6, 1>;
// clang-format on // clang-format on
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MWmma|NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 32, 1, 8>, 8>;
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>; ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
......
...@@ -32,8 +32,11 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -32,8 +32,11 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{ {
case 0: break; case 0: break;
case 1: case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k.begin(), a_m_k.end()); // CONFIRMED
ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n.begin(), b_k_n.end()); // ck::utils::FillMNID<ADataType>{}(a_m_k.begin(), a_m_k.end());
// ck::utils::FillMNID<BDataType>{}(b_k_n.begin(), b_k_n.end());
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k.begin(), a_m_k.end());
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n.begin(), b_k_n.end());
break; break;
default: default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k.begin(), a_m_k.end()); ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k.begin(), a_m_k.end());
......
...@@ -137,7 +137,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -137,7 +137,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && NPerBlock % (NPerWMMA * NRepeat) == 0, static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!"); "wrong!");
} }
// Thread level, register decriptor. // Thread level, register decriptor. Vector-write
__host__ __device__ static constexpr auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() __host__ __device__ static constexpr auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{ {
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
...@@ -168,6 +168,51 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -168,6 +168,51 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
return wmma_gemm.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); return wmma_gemm.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
} }
// Thread level, register decriptor. Per-pixel write
__host__ __device__ static constexpr auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_MAccVgprs_NRepeat_NWave_NThreadPerSubGroup()
{
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0];
constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |MAccVgprs |NRepeat |NWave |NThreadPerSubGroup
make_tuple(Number<MRepeat>{}, I1, MSubGroup, MAccVgprs, Number<NRepeat>{}, I1, NThreadPerSubGroup));
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return wmma_gemm.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
}
// Provide dimension size
__host__ __device__ static constexpr auto GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<MWaves>{},
Number<MPerWMMA>{},
Number<NRepeat>{},
Number<NWaves>{},
Number<NPerWMMA>{}));
return wmma_gemm.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
}
__host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1() __host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1()
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -205,8 +250,28 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -205,8 +250,28 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
constexpr auto RepeatDiff = MRepeat - NRepeat; // constexpr auto RepeatDiff = MRepeat - NRepeat;
// debug_hexprinter(0xffffffff, a_thread_buf[Number<a_thread_desc_.CalculateOffset( make_tuple(0, 0, 0, 0,0))>{}], "Avalue ");
/* First local prefetch, move out of blockwise operation.
static_for<0, NRepeat, 1>{}([&](auto iN){
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, Number<iN>{}, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, Number<iN>{}, I0, I0, I0),
b_thread_buf);
});
static_for<0, MRepeat, 1>{}([&](auto iN){
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, Number<iN>{}, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, Number<iN>{}, I0, I0, I0),
b_thread_buf);
});
*/
/*
static_for<0, KPerBlock, WmmaK>{}([&](auto iWmmaK){ static_for<0, KPerBlock, WmmaK>{}([&](auto iWmmaK){
// Cut to Repeat Retangle to Square, assume MRepeat > NRepeat // Cut to Repeat Retangle to Square, assume MRepeat > NRepeat
static_for<0, RepeatDiff, 1>{}([&](auto iCut){ static_for<0, RepeatDiff, 1>{}([&](auto iCut){
...@@ -297,16 +362,77 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -297,16 +362,77 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
b_thread_buf); b_thread_buf);
}); });
}); });
*/
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k*WmmaK/A_K1>{}, m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k*WmmaK/B_K1>{}, n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
b_thread_buf);
vector_type<FloatAB, WmmaK> a_thread_vec;
vector_type<FloatAB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(i/A_K1, 0, 0, 0, i%A_K1))>{}];
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(i/B_K1, 0, 0, 0, i%B_K1))>{}];
});
using wmma_input_type =
typename vector_type<FloatAB, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
// static_for<0, 16, 1>{}([&](auto i){
// char info[4];
// info[0] = 'A';
// info[1] = i/10 + '0';
// info[2] = i%10 + '0';
// info[3] = '\0';
// debug_hexprinter(0xffffffff, a_thread_buf[Number<i>{}], info);
// });
// static_for<0, 16, 1>{}([&](auto i){
// char info[4];
// info[0] = 'B';
// info[1] = i/10 + '0';
// info[2] = i%10 + '0';
// info[3] = '\0';
// debug_hexprinter(0xffffffff, b_thread_buf[Number<i>{}], info);
// });
} }
protected: protected:
// A[M0, M1, M2, K0 = WmmaK] // A[M0, M1, M2, K0 = WmmaK]
static constexpr auto a_thread_desc_ = static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<WmmaK/A_K1>{}, Number<MRepeat>{}, I1, I1, Number<A_K1>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<WmmaK/A_K1>{}, I1, I1, I1, Number<A_K1>{}));
// B[N0, N1, N2, K0 = WmmaK] // B[N0, N1, N2, K0 = WmmaK]
static constexpr auto b_thread_desc_ = static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<WmmaK/B_K1>{}, Number<MRepeat>{}, I1, I1, Number<B_K1>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<WmmaK/B_K1>{}, I1, I1, I1, Number<B_K1>{}));
// C[M, N, NumRegWMMA] // C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
......
...@@ -20,13 +20,14 @@ namespace ck { ...@@ -20,13 +20,14 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename ADataType, template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename AccDataType, typename AccDataType,
typename ALayout, typename CShuffleDataType,
typename BLayout,
typename CLayout,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
...@@ -54,12 +55,14 @@ template <typename ADataType, ...@@ -54,12 +55,14 @@ template <typename ADataType,
ck::index_t BBlockTransferSrcScalarPerVector, ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1, ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BBlockLdsAddExtraN, bool BBlockLdsAddExtraN,
ck::index_t CThreadTransferSrcDstVectorDim, index_t CShuffleMRepeatPerShuffle,
ck::index_t CThreadTransferDstScalarPerVector, index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
ck::index_t NumPrefetch = 1, ck::index_t NumPrefetch = 1,
ck::LoopScheduler LoopSched = make_default_loop_scheduler(), ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1> ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct DeviceGemmWmma : public DeviceGemm<ALayout, struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
BLayout, BLayout,
CLayout, CLayout,
ADataType, ADataType,
...@@ -200,6 +203,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -200,6 +203,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
BlockSize, BlockSize,
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CShuffleDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
...@@ -232,9 +236,10 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -232,9 +236,10 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun, false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN, BBlockLdsAddExtraN,
Sequence<0, 1, 2, 3, 4, 5, 6>, // CThreadTransferSrcDstAccessOrder, CShuffleMRepeatPerShuffle,
CThreadTransferSrcDstVectorDim, CShuffleNRepeatPerShuffle,
CThreadTransferDstScalarPerVector, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
NumPrefetch, NumPrefetch,
LoopSched, LoopSched,
PipelineVer>; PipelineVer>;
...@@ -262,7 +267,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -262,7 +267,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
a_grid_desc_k0_m_k1_{}, a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{}, b_grid_desc_k0_n_k1_{},
c_grid_desc_m_n_{}, c_grid_desc_m_n_{},
c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs_{}, c_grid_desc_mblock_mperblock_nblock_nperblock{},
block_2_ctile_map_{}, block_2_ctile_map_{},
M01_{M01}, M01_{M01},
N01_{N01}, N01_{N01},
...@@ -270,9 +275,9 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -270,9 +275,9 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op}
{ {
a_grid_desc_k0_m_k1_ = DeviceGemmWmma::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); a_grid_desc_k0_m_k1_ = DeviceGemmWmma_CShuffle::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceGemmWmma::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); b_grid_desc_k0_n_k1_ = DeviceGemmWmma_CShuffle::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmWmma::MakeCGridDescriptor_M_N(M, N, StrideC); c_grid_desc_m_n_ = DeviceGemmWmma_CShuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
block_2_ctile_map_ = block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
...@@ -282,8 +287,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -282,8 +287,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
c_grid_desc_m_n_, c_grid_desc_m_n_,
block_2_ctile_map_)) block_2_ctile_map_))
{ {
c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs_ = c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_);
} }
} }
...@@ -294,8 +299,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -294,8 +299,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs_; c_grid_desc_mblock_mperblock_nblock_nperblock;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
...@@ -307,7 +312,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -307,7 +312,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
using Argument = DeviceGemmWmma::Argument; using Argument = DeviceGemmWmma_CShuffle::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
...@@ -350,9 +355,9 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -350,9 +355,9 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceGemmWmma::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmWmma::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs>, remove_reference_t<typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
...@@ -369,7 +374,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -369,7 +374,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
...@@ -381,9 +386,9 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -381,9 +386,9 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceGemmWmma::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmWmma::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs>, remove_reference_t<typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
...@@ -400,7 +405,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -400,7 +405,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
...@@ -428,8 +433,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -428,8 +433,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
{ {
if(ck::get_device_name() == "gfx1100") if(ck::get_device_name() == "gfx1100")
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> || if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
is_same_v<AccDataType, int32_t>))
{ {
return false; return false;
} }
...@@ -530,7 +534,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -530,7 +534,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
{PipelineVersion::v2, "v2"}}; {PipelineVersion::v2, "v2"}};
// clang-format off // clang-format off
str << "DeviceGemmWmma" str << "DeviceGemmWmma_CShuffle"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...@@ -22,7 +22,7 @@ template <typename GridwiseGemm, ...@@ -22,7 +22,7 @@ template <typename GridwiseGemm,
typename FloatC, typename FloatC,
typename AGridDesc_K0_M_K1, typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
typename CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
...@@ -38,8 +38,9 @@ __global__ void ...@@ -38,8 +38,9 @@ __global__ void
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs, // const CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
// c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
...@@ -55,7 +56,7 @@ __global__ void ...@@ -55,7 +56,7 @@ __global__ void
p_shared, p_shared,
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs, c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
...@@ -66,7 +67,7 @@ __global__ void ...@@ -66,7 +67,7 @@ __global__ void
ignore = p_c_grid; ignore = p_c_grid;
ignore = a_grid_desc_k0_m_k1; ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1; ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = c_element_op;
...@@ -78,6 +79,7 @@ template < ...@@ -78,6 +79,7 @@ template <
index_t BlockSize, index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatCShuffle,
typename FloatC, typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_K0_M_K1, typename AGridDesc_K0_M_K1,
...@@ -110,9 +112,10 @@ template < ...@@ -110,9 +112,10 @@ template <
index_t BBlockTransferDstScalarPerVector_K1, index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun, bool BThreadTransferSrcResetCoordinateAfterRun,
bool BBlockLdsExtraN, bool BBlockLdsExtraN,
typename CThreadTransferSrcDstAccessOrder, index_t CShuffleMRepeatPerShuffle,
index_t CThreadTransferSrcDstVectorDim, index_t CShuffleNRepeatPerShuffle,
index_t CThreadTransferDstScalarPerVector, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
index_t NumGemmKPrefetchStage = 1, index_t NumGemmKPrefetchStage = 1,
LoopScheduler LoopSched = make_default_loop_scheduler(), LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
...@@ -179,6 +182,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -179,6 +182,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
return b_block_desc_k0perblock_nperblock_k1; return b_block_desc_k0perblock_nperblock_k1;
} }
__host__ __device__ static constexpr auto
// *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
{
constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma);
constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma);
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMRepeatPerShuffle * MWave * MPerWmma>{},
I1,
Number<CShuffleNRepeatPerShuffle * NWave * NPerWmma>{}));
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
...@@ -248,6 +268,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -248,6 +268,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
} }
// Vector write
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
const CGridDesc_M_N& c_grid_desc_m_n) const CGridDesc_M_N& c_grid_desc_m_n)
...@@ -301,6 +322,79 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -301,6 +322,79 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
return BlockwiseGemm::MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_grid_desc_m_n); return BlockwiseGemm::MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_grid_desc_m_n);
} }
// Per pixel
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(
const CGridDesc_M_N& c_grid_desc_m_n)
{
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() {
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();
constexpr auto WmmaK = 16;
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3<BlockSize,
FloatAB,
FloatAcc,
decltype(a_block_desc_k0perblock_mperblock_k1),
decltype(b_block_desc_k0perblock_nperblock_k1),
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
KPack>;
return BlockwiseGemm::MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(c_grid_desc_m_n);
}
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return c_grid_desc_mblock_mperblock_nblock_nperblock;
}
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
...@@ -308,10 +402,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -308,10 +402,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>( return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n); c_grid_desc_m_n);
} }
using CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs = // using CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup = remove_cvref_t<decltype(
remove_cvref_t<decltype( // MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( // CGridDesc_M_N{}))>;
CGridDesc_M_N{}))>; using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
...@@ -323,8 +418,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -323,8 +418,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs& const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs, c_grid_desc_mblock_mperblock_nblock_nperblock,
// const CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup&
// c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
...@@ -338,15 +435,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -338,15 +435,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
/*******************************************************************************/ /*******************************************************************************/
// BlockIdx.x -> [BlockId.m, BlockId.n] // BlockIdx.x -> [BlockId.m, BlockId.n]
const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex( if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx, block_work_idx,
make_tuple(c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I0), make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I4)))) c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{ return; } { return; }
// Store BlockId into SGPR // Store BlockId into SGPR
...@@ -360,8 +457,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -360,8 +457,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
printf("blockdesc: K0 = %d, M = %d, K1 = %d\n", (a_block_desc_k0perblock_mperblock_k1.GetLength(I0))(), // printf("blockdesc: K0 = %d, M = %d, K1 = %d\n", (a_block_desc_k0perblock_mperblock_k1.GetLength(I0))(),
(a_block_desc_k0perblock_mperblock_k1.GetLength(I1))(), (a_block_desc_k0perblock_mperblock_k1.GetLength(I2))()); // (a_block_desc_k0perblock_mperblock_k1.GetLength(I1))(), (a_block_desc_k0perblock_mperblock_k1.GetLength(I2))());
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock,
...@@ -391,7 +488,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -391,7 +488,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
a_block_desc_k0perblock_mperblock_k1, a_block_desc_k0perblock_mperblock_k1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
printf("BlockSliceLengths K0 = %d, M = %d, K1 = %d\n", K0PerBlock, MPerBlock, K1()); // printf("BlockSliceLengths K0 = %d, M = %d, K1 = %d\n", K0PerBlock, MPerBlock, K1());
// printf("a_block_wise_copy: %s\n", std::string(type_name<decltype(a_blockwise_copy)>()).c_str()); // printf("a_block_wise_copy: %s\n", std::string(type_name<decltype(a_blockwise_copy)>()).c_str());
// B matrix blockwise copy // B matrix blockwise copy
...@@ -477,21 +574,38 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -477,21 +574,38 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
c_thread_buf, c_thread_buf,
K0BlockMainLoop); K0BlockMainLoop);
/*******************************************************************************/ /*******************************************************************************/
#ifdef CK_EXPERIMENTAL_ARBITRARY_WRITEOUT
// write out C matrix, c shuffle not implemented // write out C matrix, c shuffle not implemented
{ {
static_for<0, 16, 1>{}([&](auto i){
char info[4];
info[0] = 'C';
info[1] = i/10 + '0';
info[2] = i%10 + '0';
info[3] = '\0';
debug_hexprinter(0xffffffff, c_thread_buf[Number<i>{}], info);
});
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
constexpr auto MWave = c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I1); // This API Provide All dimension (size) you need
constexpr auto MSubGroup = c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I2); constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
constexpr auto Nwave = c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I4); blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
constexpr auto NThreadPerSubGroup = c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I5);
constexpr auto MAccVgprs = c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I6);
constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I1);
constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I2);
constexpr auto NWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I4);
constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I5);
constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I6);
// printf("MWave = %d, MSubGroup = %d, NWave = %d, NThreadPerSubGroup = %d, MAccVgprs = %d\n", MWave, MSubGroup, NWave, NThreadPerSubGroup, MAccVgprs);
// Mapping // Mapping
const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0); const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0);
const index_t m_thread_data_on_grid = m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; const index_t m_thread_data_on_grid = m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_grid = n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; const index_t n_thread_data_on_grid = n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
// Checked
// debug_hexprinter(0xffffffff, m_thread_data_on_grid, "c_m");
// debug_hexprinter(0xffffffff, n_thread_data_on_grid, "c_n");
const auto m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = const auto m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
...@@ -501,25 +615,31 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -501,25 +615,31 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
const auto n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor = const auto n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(NRepeat, Nwave, NThreadPerSubGroup))), make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))),
make_tuple(Sequence<0, 1, 2>{}), make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto m_thread_data_on_grid_idx = m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex( const auto m_thread_data_on_grid_idx = m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_grid)); make_multi_index(m_thread_data_on_grid));
debug_hexprinter(0x4, MRepeat, "mblockxrepeat");
debug_hexprinter(0x2, MWave, "mwave");
debug_hexprinter(0x2, MSubGroup, "msubgroup");
debug_hexprinter(0x8, MAccVgprs, "maccvgprs");
debug_hexprinter(0x4, NWave, "nwave");
const auto n_thread_data_on_grid_idx = n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex( const auto n_thread_data_on_grid_idx = n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_grid)); make_multi_index(n_thread_data_on_grid));
// printf("write out dimension access order = (%d, %d, %d, %d, %d, %d, %d)\n", CThreadTransferSrcDstAccessOrder{}[Number<0>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<1>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<2>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<3>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<4>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<5>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<6>{}].value);
auto c_thread_copy = auto c_thread_copy =
ThreadwiseTensorSliceTransfer_v1r3< ThreadwiseTensorSliceTransfer_v1r3<
/* typename SrcData */ FloatAcc, /* typename SrcData */ FloatAcc,
/* typename DstData */ FloatC, /* typename DstData */ FloatC,
/* typename SrcDesc */ decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), /* typename SrcDesc */ decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
/* typename DstDesc */ decltype(c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs), /* typename DstDesc */ decltype(c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup),
/* typename ElementwiseOperation */ CElementwiseOperation, /* typename ElementwiseOperation */ CElementwiseOperation,
// Thread register Mapping // Thread register Mapping 0 1 2 4 5 6 3
/* typename SliceLengths */ Sequence<MRepeat, I1, I1, NRepeat, I1, I1, MAccVgprs>, /* typename SliceLengths */ Sequence<MRepeat, I1, I1, NRepeat, I1, I1, MAccVgprs>,
/* typename DimAccessOrder */ CThreadTransferSrcDstAccessOrder, /* typename DimAccessOrder */ CThreadTransferSrcDstAccessOrder,
/* index_t DstVectorDim */ CThreadTransferSrcDstVectorDim, /* index_t DstVectorDim */ CThreadTransferSrcDstVectorDim,
...@@ -528,14 +648,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -528,14 +648,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/* index_t DstScalarStrideInVector */ 1, /* index_t DstScalarStrideInVector */ 1,
/* bool DstResetCoordinateAfterRun */ true> /* bool DstResetCoordinateAfterRun */ true>
{ {
/* dst_desc */ c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs, /* dst_desc */ c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
/* dst_slice_origin_idx */ make_multi_index(m_thread_data_on_grid_idx[I0], /* dst_slice_origin_idx */ make_multi_index(m_thread_data_on_grid_idx[I0],
m_thread_data_on_grid_idx[I1], m_thread_data_on_grid_idx[I1],
m_thread_data_on_grid_idx[I2], m_thread_data_on_grid_idx[I2],
m_thread_data_on_grid_idx[I3],
n_thread_data_on_grid_idx[I0], n_thread_data_on_grid_idx[I0],
n_thread_data_on_grid_idx[I1], n_thread_data_on_grid_idx[I1],
n_thread_data_on_grid_idx[I2], n_thread_data_on_grid_idx[I2]),
m_thread_data_on_grid_idx[I3]),
/* element_op */ c_element_op /* element_op */ c_element_op
}; };
...@@ -543,9 +663,193 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -543,9 +663,193 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/* c_thread_desc */ c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, /* c_thread_desc */ c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
/* c_register_beginning*/ make_tuple(I0, I0, I0, I0, I0, I0, I0), /* c_register_beginning*/ make_tuple(I0, I0, I0, I0, I0, I0, I0),
/* c_local(register) */ c_thread_buf, /* c_local(register) */ c_thread_buf,
/* c_grid_desc */ c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs, /* c_grid_desc */ c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
/* c_grid_buf */ c_grid_buf); /* c_grid_buf */ c_grid_buf);
} }
#endif
{
// write out to C, implement shuffle
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
// This API Provide All dimension (size) you need
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I1);
constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I2);
constexpr auto NWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I4);
constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I5);
constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I6);
// LDS descriptor, shuffle and write out in MRepeat x NRepeat times
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared),
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize());
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMRepeatPerShuffle>{}, // MRepeat per shuffle repeat
MWave, // MWave
MSubGroup, // MSubGroup * MAccVgprs = MPerWmma
MAccVgprs)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNRepeatPerShuffle>{}, // NRepeat per shuffle repeat
NWave, // NWave
NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0, 1, 2, 6>{}, Sequence<>{}, Sequence<3, 4, 5>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
FloatCShuffle,
decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMRepeatPerShuffle,
I1,
I1,
CShuffleNRepeatPerShuffle,
I1,
I1,
MAccVgprs>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
1, // vector write pixel
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
make_multi_index(0,
m_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
0,
n_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3]),
ck::tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
1,
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
FloatC, // typename DstData,
decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
c_element_op};
// space filling curve for local reg & global memory
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MRepeat, 1, 1, NRepeat, 1, 1, MAccVgprs>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
Sequence<CShuffleMRepeatPerShuffle,
1,
1,
CShuffleNRepeatPerShuffle,
1,
1,
MAccVgprs>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
1,
CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// CONFIRMED
// printf("c_global_step = (%d, %d, %d, %d)\n",
// c_global_step[Number<0>{}],
// c_global_step[Number<1>{}],
// c_global_step[Number<2>{}],
// c_global_step[Number<3>{}]);
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
}
});
}
// clang-format on // clang-format on
} }
}; };
......
...@@ -119,7 +119,29 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -119,7 +119,29 @@ struct ThreadwiseTensorSliceTransfer_v1r3
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths, using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder, DimAccessOrder,
remove_cv_t<decltype(dst_scalar_per_access)>>; remove_cv_t<decltype(dst_scalar_per_access)>>;
// printf("SpaceFillingCurve access_lengths = (%d, %d, %d, %d, %d, %d, %d)\n", SpaceFillingCurve::access_lengths[Number<0>{}].value,
// SpaceFillingCurve::access_lengths[Number<1>{}].value,
// SpaceFillingCurve::access_lengths[Number<2>{}].value,
// SpaceFillingCurve::access_lengths[Number<3>{}].value,
// SpaceFillingCurve::access_lengths[Number<4>{}].value,
// SpaceFillingCurve::access_lengths[Number<5>{}].value,
// SpaceFillingCurve::access_lengths[Number<6>{}].value);
//
// // printf("SpaceFillingCurve dim_access_order = (%d, %d, %d, %d, %d, %d, %d)\n", SpaceFillingCurve::dim_access_order[Number<0>{}].value,
// SpaceFillingCurve::dim_access_order[Number<1>{}].value,
// SpaceFillingCurve::dim_access_order[Number<2>{}].value,
// SpaceFillingCurve::dim_access_order[Number<3>{}].value,
// SpaceFillingCurve::dim_access_order[Number<4>{}].value,
// SpaceFillingCurve::dim_access_order[Number<5>{}].value,
// SpaceFillingCurve::dim_access_order[Number<6>{}].value);
//
// // // printf("SpaceFillingCurve ordered_access_lengths = (%d, %d, %d, %d, %d, %d, %d)\n", SpaceFillingCurve::ordered_access_lengths[Number<0>{}].value,
// SpaceFillingCurve::ordered_access_lengths[Number<1>{}].value,
// SpaceFillingCurve::ordered_access_lengths[Number<2>{}].value,
// SpaceFillingCurve::ordered_access_lengths[Number<3>{}].value,
// SpaceFillingCurve::ordered_access_lengths[Number<4>{}].value,
// SpaceFillingCurve::ordered_access_lengths[Number<5>{}].value,
// SpaceFillingCurve::ordered_access_lengths[Number<6>{}].value);
// TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector? // TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector?
static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
...@@ -136,7 +158,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -136,7 +158,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
static_for<0, DstScalarPerVector, 1>{}([&](auto i) { static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
// debug_hexprinter(0xffffffff, src_offset, "src_coord_iteration");
SrcData v; SrcData v;
// apply element-wise operation // apply element-wise operation
...@@ -154,11 +176,11 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -154,11 +176,11 @@ struct ThreadwiseTensorSliceTransfer_v1r3
dst_coord_.GetOffset(), dst_coord_.GetOffset(),
is_dst_valid, is_dst_valid,
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]); dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
// debug_hexprinter(0xffffffff, dst_coord_.GetOffset(), "dst_coord_iteration");
if constexpr(idx_1d.value != num_access - 1) if constexpr(idx_1d.value != num_access - 1)
{ {
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
// printf("move forward = (%d, %d, %d, %d, %d, %d, %d)\n", forward_step[Number<0>{}], forward_step[Number<1>{}], forward_step[Number<2>{}], forward_step[Number<3>{}], forward_step[Number<4>{}], forward_step[Number<5>{}], forward_step[Number<6>{}]);
move_tensor_coordinate( move_tensor_coordinate(
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
} }
......
...@@ -96,7 +96,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -96,7 +96,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_element_op_(src_element_op), src_element_op_(src_element_op),
dst_element_op_(dst_element_op) dst_element_op_(dst_element_op)
{ {
printf("global desc: %s\n", __PRETTY_FUNCTION__); // printf("global desc: %s\n", __PRETTY_FUNCTION__);
} }
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
...@@ -128,12 +128,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -128,12 +128,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
printf("src_access_lengths: %d, %d, %d\n", (src_access_lengths[Number<0>{}])(), src_access_lengths[Number<1>{}](), src_access_lengths[Number<2>{}]()); // printf("src_access_lengths: %d, %d, %d\n", (src_access_lengths[Number<0>{}])(), src_access_lengths[Number<1>{}](), src_access_lengths[Number<2>{}]());
constexpr auto src_dim_access_order = SrcDimAccessOrder{}; constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths = constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order); container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
printf("ordered_src_access_lengths: %d, %d, %d\n", (ordered_src_access_lengths[Number<0>{}])(), ordered_src_access_lengths[Number<1>{}](), ordered_src_access_lengths[Number<2>{}]()); // printf("ordered_src_access_lengths: %d, %d, %d\n", (ordered_src_access_lengths[Number<0>{}])(), ordered_src_access_lengths[Number<1>{}](), ordered_src_access_lengths[Number<2>{}]());
// make forward steps // make forward steps
const auto src_forward_steps = generate_tuple( const auto src_forward_steps = generate_tuple(
...@@ -147,9 +147,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -147,9 +147,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
return make_tensor_coordinate_step(src_desc, forward_step_idx); return make_tensor_coordinate_step(src_desc, forward_step_idx);
}, },
Number<nDim>{}); Number<nDim>{});
printf("src_forward_steps: %d, %d, %d\n", (src_forward_steps.GetIndexDiff()[Number<0>{}])(),
(src_forward_steps.GetIndexDiff()[Number<1>{}])(),
(src_forward_steps.GetIndexDiff()[Number<2>{}])() );
// make backward steps // make backward steps
const auto src_backward_steps = generate_tuple( const auto src_backward_steps = generate_tuple(
...@@ -213,7 +210,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -213,7 +210,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)}; src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
// apply SrcElementwiseOperation on src_vector_container // apply SrcElementwiseOperation on src_vector_container
debug_hexprinter(0xffffffff, src_coord_.GetOffset()); // debug_hexprinter(0xffffffff, src_coord_.GetOffset());
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
SrcData src_v; SrcData src_v;
......
...@@ -205,6 +205,7 @@ struct WmmaGemm ...@@ -205,6 +205,7 @@ struct WmmaGemm
} }
// WMMA output supporting C = A * B // WMMA output supporting C = A * B
// Vector Write
// MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave // MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave
template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA> template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -239,6 +240,40 @@ struct WmmaGemm ...@@ -239,6 +240,40 @@ struct WmmaGemm
Sequence<5>{})); Sequence<5>{}));
} }
// Per-Pixel write
template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
__host__ __device__ static constexpr auto
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
(const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA& c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
{
const auto MBlockxRepeat = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0);
const auto NBlockxRepeat = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3);
const auto MWave = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1);
const auto NWave = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
return transform_tensor_descriptor(
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
make_tuple(make_pass_through_transform(MBlockxRepeat),
make_pass_through_transform(MWave),
make_unmerge_transform(make_tuple(Number<wmma_instr.num_subgroups>{},
Number<wmma_instr.num_acc_vgprs_per_wave>{})),
make_pass_through_transform(NBlockxRepeat),
make_pass_through_transform(NWave),
make_pass_through_transform(Number<wmma_instr.num_thread_per_subgroups>{})),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2, 3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{}));
}
__device__ static constexpr index_t GetRegSizePerWmma() __device__ static constexpr index_t GetRegSizePerWmma()
{ {
return wmma_instr.num_acc_vgprs_per_wave; return wmma_instr.num_acc_vgprs_per_wave;
......
...@@ -73,10 +73,26 @@ constexpr auto type_name() { ...@@ -73,10 +73,26 @@ constexpr auto type_name() {
return name; return name;
} }
// Accepet int, float, and Number<> as input
template <typename T> template <typename T>
__device__ __host__ __device__
void debug_hexprinter(const uint32_t v_target, T v_val){ void debug_hexprinter(const uint32_t v_target, const T v_val, const char* info){
const uint32_t v_dbg = *(reinterpret_cast<uint32_t*>(&v_val)); if constexpr(std::is_same_v<T, int> || std::is_same_v<T, float> )
{
const uint32_t v_dbg = *(reinterpret_cast<const uint32_t*>(&v_val));
if(v_dbg != v_target) if(v_dbg != v_target)
printf("@Thread: %d, Val: %08x != Target: %08x\n", ck::get_thread_local_1d_id(), v_dbg, v_target); printf("%s@Thread: %d, Val: %08x != Target: %08x\n", info, ck::get_thread_local_1d_id(), v_dbg, v_target);
}
else if constexpr(std::is_same_v<T, _Float16>)
{
const uint16_t v_dbg = *(reinterpret_cast<const uint16_t*>(&v_val));
if(v_dbg != v_target)
printf("%s@Thread: %d, Val: %04x != Target: %08x\n", info, ck::get_thread_local_1d_id(), v_dbg, v_target);
}
else
{
const uint32_t v_dbg = *(reinterpret_cast<const uint32_t*>(&(v_val.value)));
if(v_dbg != v_target)
printf("%s@Thread: %d, Val: %08x != Target: %08x\n", info, ck::get_thread_local_1d_id(), v_dbg, v_target);
}
} }
...@@ -49,7 +49,7 @@ check_err(const std::vector<T>& out, ...@@ -49,7 +49,7 @@ check_err(const std::vector<T>& out,
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 5) if(err_count < 16384)
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << out[i] << " != " << ref[i] << std::endl; << "] != ref[" << i << "]: " << out[i] << " != " << ref[i] << std::endl;
...@@ -59,6 +59,7 @@ check_err(const std::vector<T>& out, ...@@ -59,6 +59,7 @@ check_err(const std::vector<T>& out,
} }
if(!res) if(!res)
{ {
std::cerr << "err count: " << err_count << std::endl;
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
} }
return res; return res;
...@@ -93,7 +94,7 @@ check_err(const std::vector<T>& out, ...@@ -93,7 +94,7 @@ check_err(const std::vector<T>& out,
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 5) if(err_count < 16384)
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl; << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
...@@ -103,6 +104,7 @@ check_err(const std::vector<T>& out, ...@@ -103,6 +104,7 @@ check_err(const std::vector<T>& out,
} }
if(!res) if(!res)
{ {
std::cerr << "err count: " << err_count << std::endl;
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
} }
return res; return res;
...@@ -136,7 +138,7 @@ check_err(span<const T> out, ...@@ -136,7 +138,7 @@ check_err(span<const T> out,
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 5) if(err_count < 16384)
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl; << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
...@@ -146,6 +148,7 @@ check_err(span<const T> out, ...@@ -146,6 +148,7 @@ check_err(span<const T> out,
} }
if(!res) if(!res)
{ {
std::cerr << "err count: " << err_count << std::endl;
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
} }
return res; return res;
...@@ -196,7 +199,7 @@ check_err(const std::vector<T>& out, ...@@ -196,7 +199,7 @@ check_err(const std::vector<T>& out,
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 5) if(err_count < 16384)
{ {
std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r
<< std::endl; << std::endl;
...@@ -206,6 +209,7 @@ check_err(const std::vector<T>& out, ...@@ -206,6 +209,7 @@ check_err(const std::vector<T>& out,
} }
if(!res) if(!res)
{ {
std::cerr << "err count: " << err_count << std::endl;
std::cerr << "max err: " << max_err << std::endl; std::cerr << "max err: " << max_err << std::endl;
} }
return res; return res;
......
...@@ -103,5 +103,23 @@ struct FillConstant ...@@ -103,5 +103,23 @@ struct FillConstant
} }
}; };
template <typename T>
struct FillMNID
{
T step_{0.1};
int k_num_{32};
int mn_num_{128};
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
std::generate(first, last, [=, iter = 0]() mutable {
auto tmp = ((iter/k_num_) % mn_num_ ) * step_;
iter ++;
return tmp;
});
}
};
} // namespace utils } // namespace utils
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment