Commit d4adc71a authored by aska-0096's avatar aska-0096
Browse files

Mat-A LDS Bypass sanity pass

parent c811a0e9
......@@ -19,15 +19,49 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 256, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, 1>;
< ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
BElementOp,
CElementOp,
GemmDefault,
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
64, // KPerBlock
8, // K1
16, // MPerWmma
16, // NPerWmma
1, // M Repeat
8, // N-Repeat
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1, // C shuffle (M Repeat) Per store
4, // C shuffle (N Repeat) Per store
S<1, 64, 1, 4>,
8>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
......
......@@ -35,6 +35,18 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
case 2:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
break;
case 3:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
case 4:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
......
......@@ -35,7 +35,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
// warm up
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
const int nrepeat = 10;
const int nrepeat = 100;
#if DEBUG_LOG
printf("Start running %d times...\n", nrepeat);
#endif
......
......@@ -62,12 +62,33 @@ struct BlockwiseGemmWMMA
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I4);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I4);
static constexpr auto A_temp0 = Number<ABlockDesc{}.GetLength(I0)>{};
static constexpr auto A_temp1 = Number<ABlockDesc{}.GetLength(I1)>{};
static constexpr auto A_temp2 = Number<ABlockDesc{}.GetLength(I2)>{};
static constexpr auto A_temp3 = Number<ABlockDesc{}.GetLength(I3)>{};
static constexpr auto A_temp4 = Number<ABlockDesc{}.GetLength(I4)>{};
// FIX it, workaround
using ABlockDesc_temp = decltype(
make_naive_tensor_descriptor(make_tuple(A_temp0, A_temp1, A_temp2, A_temp3, A_temp4),
make_tuple(A_temp1* A_temp2* A_temp3* A_temp4,
A_temp2* A_temp3* A_temp4,
A_temp3* A_temp4,
A_temp4,
I1)));
static constexpr auto wmma_gemm =
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
static constexpr bool AEnableLds = NWaves == 1 ? false : true;
static constexpr bool BEnableLds = MWaves == 1 ? false : true;
// Read from Lds, duplicate Twice, Read from VGPR, no duplication.
static constexpr index_t A_Data_Duplicated_Rate = AEnableLds ? 2 : 1;
static constexpr index_t B_Data_Duplicated_Rate = BEnableLds ? 2 : 1;
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,
......@@ -91,26 +112,38 @@ struct BlockwiseGemmWMMA
// Default, Block buffer in LDS, thread level offset enabled
__device__ static auto CalculateAThreadOriginDataIndex()
{
if constexpr(AEnableLds)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
// |KRepeat |MRepeat|MWave |MLane |KPack
return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0);
}
else
{
return make_tuple(0, 0, 0, 0, 0);
}
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
if constexpr(BEnableLds)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
// |KRepeat |NRepeat|Nwave |NLane |KPack
return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0);
}
else
{
return make_tuple(0, 0, 0, 0, 0);
}
}
template <index_t m0, index_t n0>
__device__ static auto CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>)
......@@ -269,7 +302,7 @@ struct BlockwiseGemmWMMA
// Describe how data allocated in thread copy src buffer
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1;
static constexpr ABlockDesc_temp a_block_desc_k0_m0_m1_m2_k1;
static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1;
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
......@@ -285,8 +318,10 @@ struct BlockwiseGemmWMMA
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * WmmaK / A_K1>{}, m0, I0, I0, I0),
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(
Number<k * WmmaK / A_K1 * A_Data_Duplicated_Rate / 2>{}, m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0),
......@@ -294,8 +329,13 @@ struct BlockwiseGemmWMMA
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * WmmaK / B_K1>{}, n0, I0, I0, I0),
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * WmmaK / B_K1 * B_Data_Duplicated_Rate / 2>{},
n0,
I0,
I0,
I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0),
......@@ -324,6 +364,7 @@ struct BlockwiseGemmWMMA
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
......@@ -340,7 +381,13 @@ struct BlockwiseGemmWMMA
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
template <bool EnableLds>
struct AThreadCopySelector;
template <>
struct AThreadCopySelector<true>
{
using type = ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
......@@ -349,8 +396,33 @@ struct BlockwiseGemmWMMA
4,
A_K1,
A_K1>;
};
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
template <>
struct AThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow<
FloatA,
FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
tensor_operation::element_wise::PassThrough,
Sequence<1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4>,
4,
A_K1,
0x76543210,
0xfedcba98,
true>;
};
template <bool EnableLds>
struct BThreadCopySelector;
template <>
struct BThreadCopySelector<true>
{
using type = ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
......@@ -359,9 +431,28 @@ struct BlockwiseGemmWMMA
4,
B_K1,
B_K1>;
};
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
template <>
struct BThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow<
FloatB,
FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
tensor_operation::element_wise::PassThrough,
Sequence<1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4>,
4,
B_K1,
0x76543210,
0xfedcba98,
false>;
};
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
};
// block wise level pipe designed for inline asm
......@@ -407,8 +498,14 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto wmma_gemm =
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC, AssemblyBackend>{};
static constexpr auto wmma_gemm = WmmaGemm<FloatA,
FloatB,
FloatAcc,
MPerWMMA,
NPerWMMA,
KPack,
TransposeC,
AssemblyBackend>{};
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
......
......@@ -15,6 +15,7 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
namespace ck {
namespace tensor_operation {
......@@ -35,10 +36,10 @@ template <typename ALayout,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t KPerBlock,
ck::index_t K1,
ck::index_t MPerWMMA,
ck::index_t NPerWMMA,
ck::index_t MPerWmma,
ck::index_t NPerWmma,
ck::index_t MRepeat,
ck::index_t NRepeat,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
......@@ -75,19 +76,31 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
// K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{};
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
{
assert(K % K1 == 0);
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16;
const index_t K0 = K / K1;
static constexpr auto AEnableLds = NWaves == 1 ? false : true;
static constexpr auto BEnableLds = MWaves == 1 ? false : true;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
// Describe how data read from Global memory
static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA)
{
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
const auto a_grid_desc_mraw_kraw =
make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(StrideA, I1));
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
#ifdef ENABLE_COLMAJOR
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
......@@ -97,104 +110,88 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
#endif
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
assert(K % K1 == 0);
if constexpr(AEnableLds)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const index_t K0 = K / K1;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_pass_through_transform(a_grid_desc_m_k.GetLength(I0))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
constexpr auto A_KRow = WmmaK / K1;
const auto A_KWmma = K / WmmaK;
const auto M0 = M / MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(make_unmerge_transform(make_tuple(A_KWmma, Number<A_KRow>{}, K1Number)),
make_unmerge_transform(
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
}
}
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB)
{
assert(K % K1 == 0);
const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() {
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
assert(K % K1 == 0);
const index_t K0 = K / K1;
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
{
const auto c_grid_desc_m_n = [&]() {
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideC));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
}
// Gridwise descriptor, mapping to whole given provblem.
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
......@@ -207,7 +204,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
CShuffleDataType,
CDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
AGridDesc,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
AElementwiseOperation,
......@@ -215,9 +212,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
CElementwiseOperation,
MPerBlock,
NPerBlock,
K0PerBlock,
MPerWMMA,
NPerWMMA,
KPerBlock,
MPerWmma,
NPerWmma,
K1,
MRepeat,
NRepeat,
......@@ -228,6 +225,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
AEnableLds,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
......@@ -236,6 +234,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BEnableLds,
BBlockLdsAddExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
......@@ -265,7 +264,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
a_grid_desc_k0_m_k1_{},
a_grid_desc_{},
b_grid_desc_k0_n_k1_{},
c_grid_desc_m_n_{},
c_grid_desc_mblock_mperblock_nblock_nperblock{},
......@@ -276,8 +275,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
a_grid_desc_k0_m_k1_ =
DeviceGemmWmma_CShuffle::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
a_grid_desc_ = DeviceGemmWmma_CShuffle::MakeAGridDescriptor(M, K, StrideA);
b_grid_desc_k0_n_k1_ =
DeviceGemmWmma_CShuffle::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmWmma_CShuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
......@@ -285,10 +283,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
if(GridwiseGemm::CheckValidity(
a_grid_desc_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
......@@ -300,7 +296,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
AGridDesc a_grid_desc_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
......@@ -322,9 +318,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
{
#if 0
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.a_grid_desc_{" << arg.a_grid_desc_.GetLength(I0)
<< ", " << arg.a_grid_desc_.GetLength(I1) << ", "
<< arg.a_grid_desc_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
......@@ -336,7 +332,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
......@@ -348,8 +344,18 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
const auto GetK = [&]() {
if constexpr(AEnableLds)
{
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2);
}
else
{
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
arg.a_grid_desc_.GetLength(I5);
}
};
const auto K = GetK();
float ave_time = 0;
......@@ -360,7 +366,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
ADataType,
BDataType,
CDataType,
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc>,
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>,
remove_reference_t<
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
......@@ -378,7 +384,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.a_grid_desc_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_,
......@@ -393,7 +399,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
ADataType,
BDataType,
CDataType,
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc>,
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>,
remove_reference_t<
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
......@@ -411,7 +417,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.a_grid_desc_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_,
......@@ -443,15 +449,17 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
printf("DeviceOp err: AccDataType");
return false;
}
}
else
{
printf("DeviceOp err: Arch");
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
return GridwiseGemm::CheckValidity(arg.a_grid_desc_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
......@@ -547,10 +555,10 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< KPerBlock << ", "
<< K1 << ", "
<< MPerWMMA << ", "
<< NPerWMMA << ", "
<< MPerWmma << ", "
<< NPerWmma << ", "
<< MRepeat << ", "
<< NRepeat
<< ">"
......
......@@ -15,6 +15,8 @@ enum struct PipelineVersion
};
template <PipelineVersion PipelineVer,
bool AEnableLds = true,
bool BEnableLds = true,
index_t NumPrefetch = 1,
LoopScheduler LoopSched = LoopScheduler::Default>
constexpr auto GridwiseGemmPipeline_Selector()
......@@ -23,7 +25,7 @@ constexpr auto GridwiseGemmPipeline_Selector()
{
if constexpr(LoopSched == LoopScheduler::Default)
{
return GridwiseGemmPipeline_v1<NumPrefetch>{};
return GridwiseGemmPipeline_v1<NumPrefetch, AEnableLds, BEnableLds>{};
}
else if constexpr(LoopSched == LoopScheduler::Interwave)
{
......
......@@ -8,12 +8,12 @@
namespace ck {
template <index_t NumPrefetch>
template <index_t NumPrefetch, bool AEnableLds, bool BEnableLds>
struct GridwiseGemmPipeline_v1;
// 1-stage prefetch
template <>
struct GridwiseGemmPipeline_v1<1>
struct GridwiseGemmPipeline_v1<1, true, true>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -107,7 +107,7 @@ struct GridwiseGemmPipeline_v1<1>
// 2-stage prefetch
template <>
struct GridwiseGemmPipeline_v1<2>
struct GridwiseGemmPipeline_v1<2, true, true>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -253,6 +253,123 @@ struct GridwiseGemmPipeline_v1<2>
}
};
template <>
struct GridwiseGemmPipeline_v1<1, false, true>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
#if 0
constexpr auto a_block_origin_idx = generate_sequence_v2(
[]() constexpr {
return Number<0>{};
},
Number<a_block_desc.GetLengths().GetSize()>{});
#endif
constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0);
auto a_block_buf_switch = a_block_buf;
// preload data into LDS
a_blockwise_copy.Run(
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
a_blockwise_copy.Run(
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf_switch);
block_sync_lds();
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
a_block_buf = a_block_buf_switch;
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
}
}
};
// placeholder
template <>
struct GridwiseGemmPipeline_v1<1, true, false>
{
};
template <>
struct GridwiseGemmPipeline_v1<1, false, false>
{
};
template <index_t NumPrefetch>
struct GridwiseGemmPipelineInterwave_v1;
......@@ -348,7 +465,7 @@ struct GridwiseGemmPipelineInterwave_v1<1>
// Note: 2 stage prefetch not optimized for inter-wave loop scheduler
template <>
struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2>
struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2, true, true>
{
};
......@@ -358,7 +475,7 @@ constexpr auto GridwiseGemmPipeline_v1_Selector()
{
if constexpr(LoopSched == LoopScheduler::Default)
{
return GridwiseGemmPipeline_v1<NumPrefetch>{};
return GridwiseGemmPipeline_v1<NumPrefetch, true, true>{};
}
else if constexpr(LoopSched == LoopScheduler::Interwave)
{
......
......@@ -21,7 +21,7 @@ template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
typename AGridDesc_K0_M_K1,
typename AGridDesc,
typename BGridDesc_K0_N_K1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation,
......@@ -33,30 +33,26 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_wmma(
const FloatA* __restrict__ p_a_grid,
kernel_gemm_wmma(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const AGridDesc a_grid_desc,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
// const
// CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
// c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
a_grid_desc_k0_m_k1,
a_grid_desc,
b_grid_desc_k0_n_k1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
......@@ -67,7 +63,7 @@ __global__ void
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_grid_desc_k0_m_k1;
ignore = a_grid_desc;
ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op;
......@@ -84,7 +80,7 @@ template <index_t BlockSize,
typename FloatCShuffle,
typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_K0_M_K1,
typename AGridDesc,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N,
typename AElementwiseOperation,
......@@ -92,7 +88,7 @@ template <index_t BlockSize,
typename CElementwiseOperation,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t KPerBlock,
index_t MPerWmma,
index_t NPerWmma,
index_t K1Value,
......@@ -105,6 +101,7 @@ template <index_t BlockSize,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun,
bool AEnableLds,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
......@@ -113,6 +110,7 @@ template <index_t BlockSize,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun,
bool BEnableLds,
bool BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
......@@ -132,39 +130,149 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto B_K0 = BGridDesc_K0_N_K1{}.GetLength(I0);
static constexpr auto B_K1 = BGridDesc_K0_N_K1{}.GetLength(I2);
// FIX ME: To be deprecated
static constexpr auto K1 = Number<K1Value>{};
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
using GridwiseGemmPipe =
remove_cvref_t<decltype(GridwiseGemmPipeline_Selector<PipelineVer,
AEnableLds,
BEnableLds,
NumGemmKPrefetchStage,
LoopSched>())>;
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeABlockDescriptor_K0_M0_M1_M2_K1(const ABlockDesc_AK0_M_AK1&)
// Describe how data store to (LDS/VGPR) buffer from Global memory
__host__ __device__ static constexpr auto MakeABlockDescriptor()
{
constexpr auto a_block_desc = [&]() {
if constexpr(AEnableLds)
{
// K0->M->K1 Per Block
constexpr auto K0PerBlock = KPerBlock / K1;
constexpr auto max_lds_align = K1;
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}
else
{
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
// KWmma->MRepeat->MWave->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{}, Number<MRepeat>{}, I1, I1, I1, K1),
make_tuple(Number<MRepeat>{} * K1, K1, K1, K1, K1, I1));
}
}();
return a_block_desc;
}
__host__ __device__ static constexpr auto MakeABlockSliceCopyStep()
{
constexpr auto a_block_copy_step = [&]() {
if constexpr(AEnableLds)
{
constexpr auto K0PerBlock = KPerBlock / K1;
return make_multi_index(K0PerBlock, 0, 0);
}
else
{
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0);
}
}();
return a_block_copy_step;
}
__host__ __device__ static constexpr auto MakeBBlockSliceCopyStep()
{
constexpr auto b_block_copy_step = [&]() {
if constexpr(BEnableLds)
{
constexpr auto K0PerBlock = KPerBlock / K1;
return make_multi_index(K0PerBlock, 0, 0);
}
else
{
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0);
}
}();
return b_block_copy_step;
}
// Describe how data read from (LDS/VGPR) buffer
template <typename ABlockDesc_>
__host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&)
{
constexpr index_t A_K0 = ABlockDesc_AK0_M_AK1{}.GetLength(I0);
constexpr index_t A_K1 = ABlockDesc_AK0_M_AK1{}.GetLength(I2);
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
constexpr auto a_wave_desc = [&]() {
if constexpr(AEnableLds)
{
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_MPerWmma_AK1
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
return transform_tensor_descriptor(
ABlockDesc_AK0_M_AK1{},
ABlockDesc_{},
make_tuple(make_pass_through_transform(Number<A_K0>{}),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
else
{
// KWmma_MRepeat_MWave_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I5);
return transform_tensor_descriptor(
ABlockDesc_{},
make_tuple(make_merge_transform(make_tuple(Number<KWmma>{}, I1)),
make_pass_through_transform(Number<MRepeat>{}),
make_pass_through_transform(I1),
make_pass_through_transform(I1),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0, 3>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
}
}();
return a_wave_desc;
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeBBlockDescriptor_K0_N0_N1_N2_K1(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t B_K0 = BBlockDesc_BK0_N_BK1{}.GetLength(I0);
constexpr index_t B_K1 = BBlockDesc_BK0_N_BK1{}.GetLength(I2);
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
return transform_tensor_descriptor(
BBlockDesc_BK0_N_BK1{},
make_tuple(make_pass_through_transform(Number<B_K0>{}),
......@@ -175,32 +283,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();
return a_block_desc_k0perblock_mperblock_k1;
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
{
constexpr auto max_lds_align = K1;
constexpr auto K0PerBlock = KPerBlock / K1;
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() {
if constexpr(BBlockLdsExtraN)
......@@ -223,44 +309,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
// *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
{
constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma);
constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma);
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMRepeatPerShuffle * MWave * MPerWmma>{},
Number<CShuffleMRepeatPerShuffle * MWaves * MPerWmma>{},
I1,
Number<CShuffleNRepeatPerShuffle * NWave * NPerWmma>{}));
Number<CShuffleNRepeatPerShuffle * NWaves * NPerWmma>{}));
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_k0perblock_mperblock_k1 =
GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
constexpr auto b_block_desc_k0perblock_nperblock_k1 =
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
constexpr auto max_lds_align = K1;
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size_aligned * sizeof(FloatA) +
b_block_space_size_aligned * sizeof(FloatB));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
CheckValidity(const AGridDesc& a_grid_desc,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
......@@ -272,23 +334,68 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
(NPerBlock % (NRepeat * NPerWmma)) == 0,
"Invalid tuning param!");
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
const auto GetAProblemsizeMK = [&]() {
if constexpr(AEnableLds)
{
return make_tuple(a_grid_desc.GetLength(I1),
a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2));
}
else
{
return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) *
a_grid_desc.GetLength(I4),
a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) *
a_grid_desc.GetLength(I5));
}
};
const auto GetBProblemsizeNK = [&]() {
if constexpr(BEnableLds)
{
return make_tuple(b_grid_desc_k0_n_k1.GetLength(I1),
b_grid_desc_k0_n_k1.GetLength(I0) *
b_grid_desc_k0_n_k1.GetLength(I2));
}
else
{
return make_tuple(
b_grid_desc_k0_n_k1.GetLength(I1) * b_grid_desc_k0_n_k1.GetLength(I2) *
b_grid_desc_k0_n_k1.GetLength(I4),
b_grid_desc_k0_n_k1.GetLength(I0) * b_grid_desc_k0_n_k1.GetLength(I3) *
b_grid_desc_k0_n_k1.GetLength(I5));
}
};
const auto M = GetAProblemsizeMK()[I0];
const auto N = GetBProblemsizeNK()[I0];
const auto K = GetAProblemsizeMK()[I1];
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
K == GetBProblemsizeNK()[I1]))
{
printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n",
GetAProblemsizeMK()[I0],
GetAProblemsizeMK()[I1],
GetBProblemsizeNK()[I0],
GetBProblemsizeNK()[I1],
c_grid_desc_m_n.GetLength(I0),
c_grid_desc_m_n.GetLength(I1));
printf("GridwiseOp err: ProblemSize check");
return false;
}
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
{
printf("GridwiseOp err: ProblemSize division");
return false;
}
// check gridwise gemm pipeline
const auto num_k_loop = K0 / K0PerBlock;
const auto num_k_loop = K / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
printf("GridwiseOp err: Pipeline not support this k_loop");
return false;
}
......@@ -303,7 +410,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / (K0PerBlock * K1);
const index_t num_loop = K / KPerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
......@@ -340,12 +447,45 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
struct SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
static constexpr auto max_lds_align = K1;
static constexpr auto a_block_space_size_aligned =
AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(),
max_lds_align) *
sizeof(FloatA)
: 0;
static constexpr auto b_block_space_size_aligned =
BEnableLds ? math::integer_least_multiple(
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1().GetElementSpaceSize(),
max_lds_align) *
sizeof(FloatB)
: 0;
static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_space_size =
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
.GetElementSpaceSize() *
sizeof(FloatCShuffle);
static constexpr auto c_shuffle_block_space_offset = 0;
static constexpr auto lds_size = math::max(
c_shuffle_block_space_size, (a_block_space_size_aligned + b_block_space_size_aligned));
};
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void Run(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const AGridDesc& a_grid_desc,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
......@@ -358,7 +498,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/*******************************************************************************/
// Memory buffer zone.
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
p_a_grid, a_grid_desc.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -378,14 +518,32 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
/*******************************************************************************/
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
constexpr auto max_lds_align = K1;
constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
// BlockLevel, A/B Matrix ThreadMapping in WMMA Source buffer, As Destinaion of BlockWise_Copy
const auto K = [&](){
if constexpr(AEnableLds){
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2);
}
else{
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * a_grid_desc.GetLength(I5);
}
}();
// printf("---------------K = %d\n", K);
constexpr auto a_block_desc = MakeABlockDescriptor();
constexpr auto b_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
auto a_block_trait = [&](){
// A matrix blockwise copy
if constexpr(AEnableLds)
{
constexpr auto K0PerBlock = KPerBlock/ K1;
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatA*>(p_shared),
a_block_desc.GetElementSpaceSize());
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock,
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
/* typename SrcElementwiseOperation, */ AElementwiseOperation,
/* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough,
/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set,
......@@ -394,8 +552,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder,
/* typename SrcData, */ FloatA,
/* typename DstData, */ FloatA,
/* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1),
/* typename DstDesc, */ decltype(a_block_desc_k0perblock_mperblock_k1),
/* typename SrcDesc, */ decltype(a_grid_desc),
/* typename DstDesc, */ decltype(a_block_desc),
/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>,
/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim,
......@@ -406,14 +564,60 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/* index_t DstScalarStrideInVector, */ 1,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun,
/* bool ThreadTransferDstResetCoordinateAfterRun, */ true>(
a_grid_desc_k0_m_k1,
a_grid_desc,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_k0perblock_mperblock_k1,
a_block_desc,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
return make_tuple(a_block_buf, a_blockwise_copy);
}
else
{
// Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
a_block_desc.GetElementSpaceSize());
// Limitation: NumDim of Src and Dst descriptor should be identical
auto a_blockwise_copy =
ThreadwiseTensorSliceTransfer_v2<FloatA,
FloatA,
decltype(a_grid_desc),
decltype(a_block_desc),
Sequence<Number<KWmmaPerBlock>{},
Number<MRepeat>{},
I1,
I1,
I1,
Number<K1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
ABlockTransferSrcScalarPerVector,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_grid_desc,
make_multi_index(0,
m_block_data_idx_on_grid/(MWaves * MPerWmma),
get_thread_local_1d_id() / 32,
(get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16,
0));
return make_tuple(a_block_buf, a_blockwise_copy);
}
};
auto b_block_trait = [&](){
if constexpr(BEnableLds)
{
constexpr auto K0PerBlock = KPerBlock/ K1;
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatB*>(p_shared) + SharedMemTrait::a_block_space_size_aligned,
b_block_desc.GetElementSpaceSize());
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
......@@ -425,7 +629,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
FloatB,
FloatB,
decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc_k0perblock_nperblock_k1),
decltype(b_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
......@@ -439,13 +643,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_k0perblock_nperblock_k1,
b_block_desc,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
return make_tuple(b_block_buf, b_blockwise_copy);
}
else
{
constexpr auto K0PerBlock = KPerBlock/ K1;
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
b_block_desc.GetElementSpaceSize());
auto b_blockwise_copy =
ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB,
decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc),
Sequence<Number<K0PerBlock>{},
Number<NRepeat>{},
Number<K1Value>{}>,
Sequence<0, 1, 2>,
2,
BBlockTransferSrcScalarPerVector,
1>(
make_multi_index(0, get_thread_local_1d_id()/32 * 16 + get_thread_local_1d_id() % 16, 0));
return make_tuple(b_block_buf, b_blockwise_copy);
}
};
auto a_block_buf = a_block_trait()[I0];
auto a_blockwise_copy = a_block_trait()[I1];
auto b_block_buf = b_block_trait()[I0];
auto b_blockwise_copy = b_block_trait()[I1];
/*******************************************************************************/
// GEMM
constexpr auto WmmaK = 16;
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
auto blockwise_gemm =
......@@ -453,11 +686,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
FloatA,
FloatB,
FloatAcc,
decltype(MakeABlockDescriptor_K0_M0_M1_M2_K1(a_block_desc_k0perblock_mperblock_k1)),
decltype(MakeBBlockDescriptor_K0_N0_N1_N2_K1(b_block_desc_k0perblock_nperblock_k1)),
decltype(MakeAWaveDescriptor(a_block_desc)),
decltype(MakeBBlockDescriptor_K0_N0_N1_N2_K1(b_block_desc)),
MPerBlock,
NPerBlock,
K0PerBlock * K1,
KPerBlock,
MPerWmma,
NPerWmma,
MRepeat,
......@@ -468,25 +701,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
/*******************************************************************************/
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for A and B: be careful of alignment
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatA*>(p_shared), a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatB*>(p_shared) + a_block_space_size_aligned, b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize());
// Shift Per SUB_K
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep();
// printf("a_block_slice_copy_step FirstKdim = %d\n", a_block_slice_copy_step[I0]);
constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep();
// gridwise GEMM pipeline
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
a_block_desc_k0perblock_mperblock_k1,
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock);
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc,
a_block_desc,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_k0_n_k1,
b_block_desc_k0perblock_nperblock_k1,
b_block_desc,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
......@@ -497,6 +726,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/*******************************************************************************/
// write out to C, implement shuffle
{
#if 0
static_for<0, c_thread_buf.Size(), 1>{}([&](auto i) {
printf("tid: %03d, c_thread_buf[%02d] val: %08x\n", get_thread_local_1d_id(), i.value,
*(reinterpret_cast<const uint32_t*>(&(c_thread_buf[i]))));
// c_thread_buf(i) = 32;
});
#endif
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
......@@ -515,8 +751,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared),
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize());
static_cast<FloatCShuffle*>(p_shared), SharedMemTrait::c_shuffle_block_space_size);
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
......@@ -666,6 +901,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
......
......@@ -1324,15 +1324,14 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
using Index = MultiIndex<nDim>;
__device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow(
const ElementwiseOperation& element_op)
: element_op_{element_op}
__device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow(const Index& src_idx)
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc need to known at compile-time");
static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
"wrong! Not divisible");
ignore = src_idx;
}
template <typename SrcSliceOriginIdx,
......@@ -1344,7 +1343,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
const SrcBuffer& src_buf,
const DstDesc&,
const DstSliceOriginIdx&,
DstBuffer& dst_buf)
DstBuffer& dst_buf) const
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc need to known at compile-time");
......@@ -1383,7 +1382,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
// copy data from src_buf into dst_vector
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
// idx_md err. as dst access 2 strided elements while src visit 1 per loop
// src_desc error, non constexpr?
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
......@@ -1396,16 +1395,22 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
// apply element-wise operation
element_op_(v_this_row, src_buf[Number<src_offset>{}]);
// apply intra-row swizzle permute
// if (get_thread_local_1d_id() < 16)
// printf("tid: %03d, RawData: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<uint16_t*>(&v_this_row)) ); apply intra-row swizzle permute
if constexpr(IntraRowSwizzlePerm)
{
// origin:
// 0xfedcba98,
// 0x76543210
temp = __builtin_amdgcn_permlane16(
temp, type_convert<int>(v_this_row), 0xeca86420, 0xfdb97531, 1, 0);
v_this_row = type_convert<float>(temp);
temp = __builtin_amdgcn_permlane16( // 0x76543210, 0xfedcba98
temp,
type_convert<int>(v_this_row),
0xb3a29180,
0xf7e6d5c4,
1,
0);
v_this_row = type_convert<SrcData>(temp);
// if (get_thread_local_1d_id() < 16)
// printf("tid: %03d, SwiData: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<uint16_t*>(&v_this_row)) );
}
// apply inter-row permute.
......@@ -1415,8 +1420,9 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
HighEightRowLaneIdx,
1,
0);
v_theother_row = type_convert<float>(temp);
v_theother_row = type_convert<SrcData>(temp);
// printf("tid: %03d, PermData: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<uint16_t*>(&v_theother_row)) );
if(get_thread_local_1d_id() % 32 < 16)
{
// apply type convert
......@@ -1434,8 +1440,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
});
});
}
ElementwiseOperation element_op_;
ElementwiseOperation element_op_{};
};
} // namespace ck
......@@ -103,7 +103,12 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, bool AssemblyBackend, class FloatA, class FloatB, class FloatC>
template <index_t MPerWmma,
index_t NPerWmma,
bool AssemblyBackend,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
......@@ -492,11 +497,13 @@ struct WmmaGemm
"(int8, int32) or (int4, int32)!");
if constexpr(!TransposeC)
{
wmma_instr.template run<MPerWmma, NPerWmma, AssemblyBackend>(p_a_wave, p_b_wave, p_c_thread);
wmma_instr.template run<MPerWmma, NPerWmma, AssemblyBackend>(
p_a_wave, p_b_wave, p_c_thread);
}
else
{
wmma_instr.template run<MPerWmma, NPerWmma, AssemblyBackend>(p_b_wave, p_a_wave, p_c_thread);
wmma_instr.template run<MPerWmma, NPerWmma, AssemblyBackend>(
p_b_wave, p_a_wave, p_c_thread);
}
}
......
......@@ -21,12 +21,15 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16, AssemblyBackend>
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
if constexpr(AssemblyBackend){
if constexpr(AssemblyBackend)
{
amd_assembly_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
}
else{
reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
else
{
reg_c.template AsType<float8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
}
}
......
......@@ -988,6 +988,30 @@ inline __host__ __device__ constexpr float type_convert<float, int>(int x)
return u.fp32;
}
template <>
inline __host__ __device__ constexpr int type_convert<int, half_t>(half_t x)
{
union
{
half_t fp16;
int int32;
} u = {x};
return u.int32;
}
template <>
inline __host__ __device__ constexpr half_t type_convert<half_t, int>(int x)
{
union
{
int int32;
half_t fp16;
} u = {x};
return u.fp16;
}
// convert fp32 to bfp16
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment