Commit d4adc71a authored by aska-0096's avatar aska-0096
Browse files

Mat-A LDS Bypass sanity pass

parent c811a0e9
...@@ -19,15 +19,49 @@ using AElementOp = PassThrough; ...@@ -19,15 +19,49 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| < ALayout,
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| BLayout,
// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| CLayout,
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | ADataType,
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 256, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, 1>; BDataType,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
BElementOp,
CElementOp,
GemmDefault,
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
64, // KPerBlock
8, // K1
16, // MPerWmma
16, // NPerWmma
1, // M Repeat
8, // N-Repeat
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1, // C shuffle (M Repeat) Per store
4, // C shuffle (N Repeat) Per store
S<1, 64, 1, 4>,
8>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
......
...@@ -35,6 +35,18 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -35,6 +35,18 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n); ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break; break;
case 2:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
break;
case 3:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
case 4:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n);
break;
default: default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k); ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n); ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
......
...@@ -35,7 +35,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -35,7 +35,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
// warm up // warm up
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
const int nrepeat = 10; const int nrepeat = 100;
#if DEBUG_LOG #if DEBUG_LOG
printf("Start running %d times...\n", nrepeat); printf("Start running %d times...\n", nrepeat);
#endif #endif
......
...@@ -62,12 +62,33 @@ struct BlockwiseGemmWMMA ...@@ -62,12 +62,33 @@ struct BlockwiseGemmWMMA
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I4); static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I4);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I4); static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I4);
static constexpr auto A_temp0 = Number<ABlockDesc{}.GetLength(I0)>{};
static constexpr auto A_temp1 = Number<ABlockDesc{}.GetLength(I1)>{};
static constexpr auto A_temp2 = Number<ABlockDesc{}.GetLength(I2)>{};
static constexpr auto A_temp3 = Number<ABlockDesc{}.GetLength(I3)>{};
static constexpr auto A_temp4 = Number<ABlockDesc{}.GetLength(I4)>{};
// FIX it, workaround
using ABlockDesc_temp = decltype(
make_naive_tensor_descriptor(make_tuple(A_temp0, A_temp1, A_temp2, A_temp3, A_temp4),
make_tuple(A_temp1* A_temp2* A_temp3* A_temp4,
A_temp2* A_temp3* A_temp4,
A_temp3* A_temp4,
A_temp4,
I1)));
static constexpr auto wmma_gemm = static constexpr auto wmma_gemm =
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{}; WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
static constexpr bool AEnableLds = NWaves == 1 ? false : true;
static constexpr bool BEnableLds = MWaves == 1 ? false : true;
// Read from Lds, duplicate Twice, Read from VGPR, no duplication.
static constexpr index_t A_Data_Duplicated_Rate = AEnableLds ? 2 : 1;
static constexpr index_t B_Data_Duplicated_Rate = BEnableLds ? 2 : 1;
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc, FloatAcc,
MRepeat * NRepeat, MRepeat * NRepeat,
...@@ -92,24 +113,36 @@ struct BlockwiseGemmWMMA ...@@ -92,24 +113,36 @@ struct BlockwiseGemmWMMA
// Default, Block buffer in LDS, thread level offset enabled // Default, Block buffer in LDS, thread level offset enabled
__device__ static auto CalculateAThreadOriginDataIndex() __device__ static auto CalculateAThreadOriginDataIndex()
{ {
const auto wave_idx = GetWaveIdx(); if constexpr(AEnableLds)
{
const auto waveId_m = wave_idx[I0]; const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
// |KRepeat |MRepeat|MWave |MLane |KPack
return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0); // |KRepeat |MRepeat|MWave |MLane |KPack
return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0);
}
else
{
return make_tuple(0, 0, 0, 0, 0);
}
} }
__device__ static auto CalculateBThreadOriginDataIndex() __device__ static auto CalculateBThreadOriginDataIndex()
{ {
const auto wave_idx = GetWaveIdx(); if constexpr(BEnableLds)
{
const auto waveId_n = wave_idx[I1]; const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
// |KRepeat |NRepeat|Nwave |NLane |KPack
return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0); // |KRepeat |NRepeat|Nwave |NLane |KPack
return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0);
}
else
{
return make_tuple(0, 0, 0, 0, 0);
}
} }
template <index_t m0, index_t n0> template <index_t m0, index_t n0>
...@@ -269,7 +302,7 @@ struct BlockwiseGemmWMMA ...@@ -269,7 +302,7 @@ struct BlockwiseGemmWMMA
// Describe how data allocated in thread copy src buffer // Describe how data allocated in thread copy src buffer
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1; static constexpr ABlockDesc_temp a_block_desc_k0_m0_m1_m2_k1;
static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1; static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1;
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
...@@ -285,21 +318,28 @@ struct BlockwiseGemmWMMA ...@@ -285,21 +318,28 @@ struct BlockwiseGemmWMMA
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A // read A
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, a_thread_copy_.Run(
make_tuple(Number<k * WmmaK / A_K1>{}, m0, I0, I0, I0), a_block_desc_k0_m0_m1_m2_k1,
a_block_buf, make_tuple(
a_thread_desc_, Number<k * WmmaK / A_K1 * A_Data_Duplicated_Rate / 2>{}, m0, I0, I0, I0),
make_tuple(I0, m0, I0, I0, I0), a_block_buf,
a_thread_buf); a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B // read B
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, b_thread_copy_.Run(
make_tuple(Number<k * WmmaK / B_K1>{}, n0, I0, I0, I0), b_block_desc_k0_n0_n1_n2_k1,
b_block_buf, make_tuple(Number<k * WmmaK / B_K1 * B_Data_Duplicated_Rate / 2>{},
b_thread_desc_, n0,
make_tuple(I0, n0, I0, I0, I0), I0,
b_thread_buf); I0,
I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0),
b_thread_buf);
vector_type<FloatA, WmmaK> a_thread_vec; vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec; vector_type<FloatB, WmmaK> b_thread_vec;
...@@ -324,6 +364,7 @@ struct BlockwiseGemmWMMA ...@@ -324,6 +364,7 @@ struct BlockwiseGemmWMMA
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
}); });
}); });
} }
...@@ -340,28 +381,78 @@ struct BlockwiseGemmWMMA ...@@ -340,28 +381,78 @@ struct BlockwiseGemmWMMA
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma())); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA, template <bool EnableLds>
FloatA, struct AThreadCopySelector;
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<WmmaK / A_K1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4>,
4,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB, template <>
FloatB, struct AThreadCopySelector<true>
decltype(b_block_desc_k0_n0_n1_n2_k1), {
decltype(b_thread_desc_), using type = ThreadwiseTensorSliceTransfer_v4<FloatA,
Sequence<WmmaK / B_K1, 1, 1, 1, B_K1>, FloatA,
Sequence<0, 1, 2, 3, 4>, decltype(a_block_desc_k0_m0_m1_m2_k1),
4, decltype(a_thread_desc_),
B_K1, Sequence<WmmaK / A_K1, 1, 1, 1, A_K1>,
B_K1>; Sequence<0, 1, 2, 3, 4>,
4,
AThreadCopy a_thread_copy_; A_K1,
BThreadCopy b_thread_copy_; A_K1>;
};
template <>
struct AThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow<
FloatA,
FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
tensor_operation::element_wise::PassThrough,
Sequence<1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4>,
4,
A_K1,
0x76543210,
0xfedcba98,
true>;
};
template <bool EnableLds>
struct BThreadCopySelector;
template <>
struct BThreadCopySelector<true>
{
using type = ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<WmmaK / B_K1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4>,
4,
B_K1,
B_K1>;
};
template <>
struct BThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow<
FloatB,
FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
tensor_operation::element_wise::PassThrough,
Sequence<1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4>,
4,
B_K1,
0x76543210,
0xfedcba98,
false>;
};
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
}; };
// block wise level pipe designed for inline asm // block wise level pipe designed for inline asm
...@@ -376,7 +467,7 @@ template <index_t BlockSize, ...@@ -376,7 +467,7 @@ template <index_t BlockSize,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPack, index_t KPack,
bool TransposeC = false, bool TransposeC = false,
bool AssemblyBackend = true> bool AssemblyBackend = true>
/* A: K0PerBlock x MPerBlock x K1 /* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1 * B: K0PerBlock x NPerBlock x K1
...@@ -407,8 +498,14 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO ...@@ -407,8 +498,14 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto wmma_gemm = static constexpr auto wmma_gemm = WmmaGemm<FloatA,
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC, AssemblyBackend>{}; FloatB,
FloatAcc,
MPerWMMA,
NPerWMMA,
KPack,
TransposeC,
AssemblyBackend>{};
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -35,10 +36,10 @@ template <typename ALayout, ...@@ -35,10 +36,10 @@ template <typename ALayout,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
ck::index_t K0PerBlock, ck::index_t KPerBlock,
ck::index_t K1, ck::index_t K1,
ck::index_t MPerWMMA, ck::index_t MPerWmma,
ck::index_t NPerWMMA, ck::index_t NPerWmma,
ck::index_t MRepeat, ck::index_t MRepeat,
ck::index_t NRepeat, ck::index_t NRepeat,
typename ABlockTransferThreadClusterLengths_K0_M_K1, typename ABlockTransferThreadClusterLengths_K0_M_K1,
...@@ -75,19 +76,31 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -75,19 +76,31 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
// K1 = Max Vector Access Pixels // K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{}; static constexpr auto K1Number = Number<K1>{};
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
{ static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
assert(K % K1 == 0); static constexpr auto WmmaK = 16;
const index_t K0 = K / K1; static constexpr auto AEnableLds = NWaves == 1 ? false : true;
static constexpr auto BEnableLds = MWaves == 1 ? false : true;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
// Describe how data read from Global memory
static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA)
{
const auto a_grid_desc_m_k = [&]() { const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); const auto a_grid_desc_mraw_kraw =
make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(StrideA, I1));
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
} }
#ifdef ENABLE_COLMAJOR #ifdef ENABLE_COLMAJOR
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
...@@ -97,104 +110,88 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -97,104 +110,88 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
#endif #endif
}(); }();
if constexpr(GemmSpec == GemmSpecialization::MNPadding) const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
assert(K % K1 == 0);
if constexpr(AEnableLds)
{ {
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const index_t K0 = K / K1;
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_k, a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)), make_pass_through_transform(a_grid_desc_m_k.GetLength(I0))),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
else else
{ {
constexpr auto A_KRow = WmmaK / K1;
const auto A_KWmma = K / WmmaK;
const auto M0 = M / MPerBlock;
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_k, a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), make_tuple(make_unmerge_transform(make_tuple(A_KWmma, Number<A_KRow>{}, K1Number)),
make_pass_through_transform(M)), make_unmerge_transform(
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
} }
} }
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB)
{ {
assert(K % K1 == 0); const auto b_grid_desc_nraw_kraw = [&]() {
const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
} }
}(); }();
if constexpr(GemmSpec == GemmSpecialization::MNPadding) const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
{
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor( const auto N = b_grid_desc_n_k.GetLength(I0);
b_grid_desc_k_n, const auto K = b_grid_desc_n_k.GetLength(I1);
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), assert(K % K1 == 0);
make_right_pad_transform(N, PadN)), const index_t K0 = K / K1;
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); return transform_tensor_descriptor(
} b_grid_desc_n_k,
else make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
{ make_pass_through_transform(N)),
return transform_tensor_descriptor( make_tuple(Sequence<1>{}, Sequence<0>{}),
b_grid_desc_k_n, make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
} }
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
{ {
const auto c_grid_desc_m_n = [&]() { const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideC, I1));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideC));
} }
}(); }();
if constexpr(GemmSpec == GemmSpecialization::MNPadding) return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
} }
// Gridwise descriptor, mapping to whole given provblem. // Gridwise descriptor, mapping to whole given provblem.
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
...@@ -207,7 +204,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -207,7 +204,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
CShuffleDataType, CShuffleDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
AElementwiseOperation, AElementwiseOperation,
...@@ -215,9 +212,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -215,9 +212,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
CElementwiseOperation, CElementwiseOperation,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
K0PerBlock, KPerBlock,
MPerWMMA, MPerWmma,
NPerWMMA, NPerWmma,
K1, K1,
MRepeat, MRepeat,
NRepeat, NRepeat,
...@@ -228,6 +225,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -228,6 +225,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun, false, // AThreadTransferSrcResetCoordinateAfterRun,
AEnableLds,
ABlockLdsAddExtraM, ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
...@@ -236,6 +234,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -236,6 +234,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun, false, // BThreadTransferSrcResetCoordinateAfterRun,
BEnableLds,
BBlockLdsAddExtraN, BBlockLdsAddExtraN,
CShuffleMRepeatPerShuffle, CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle, CShuffleNRepeatPerShuffle,
...@@ -265,7 +264,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -265,7 +264,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
a_grid_desc_k0_m_k1_{}, a_grid_desc_{},
b_grid_desc_k0_n_k1_{}, b_grid_desc_k0_n_k1_{},
c_grid_desc_m_n_{}, c_grid_desc_m_n_{},
c_grid_desc_mblock_mperblock_nblock_nperblock{}, c_grid_desc_mblock_mperblock_nblock_nperblock{},
...@@ -276,8 +275,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -276,8 +275,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op}
{ {
a_grid_desc_k0_m_k1_ = a_grid_desc_ = DeviceGemmWmma_CShuffle::MakeAGridDescriptor(M, K, StrideA);
DeviceGemmWmma_CShuffle::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ = b_grid_desc_k0_n_k1_ =
DeviceGemmWmma_CShuffle::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); DeviceGemmWmma_CShuffle::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmWmma_CShuffle::MakeCGridDescriptor_M_N(M, N, StrideC); c_grid_desc_m_n_ = DeviceGemmWmma_CShuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
...@@ -285,10 +283,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -285,10 +283,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
block_2_ctile_map_ = block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, if(GridwiseGemm::CheckValidity(
b_grid_desc_k0_n_k1_, a_grid_desc_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, block_2_ctile_map_))
c_grid_desc_m_n_,
block_2_ctile_map_))
{ {
c_grid_desc_mblock_mperblock_nblock_nperblock = c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
...@@ -300,7 +296,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -300,7 +296,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; AGridDesc a_grid_desc_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -322,9 +318,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -322,9 +318,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
{ {
#if 0 #if 0
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_{" << arg.a_grid_desc_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << ", " << arg.a_grid_desc_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; << arg.a_grid_desc_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
...@@ -336,7 +332,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -336,7 +332,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
} }
#endif #endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_)) arg.block_2_ctile_map_))
...@@ -348,8 +344,18 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -348,8 +344,18 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
const index_t grid_size = const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K = const auto GetK = [&]() {
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); if constexpr(AEnableLds)
{
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2);
}
else
{
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
arg.a_grid_desc_.GetLength(I5);
}
};
const auto K = GetK();
float ave_time = 0; float ave_time = 0;
...@@ -360,7 +366,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -360,7 +366,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc>,
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>,
remove_reference_t< remove_reference_t<
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
...@@ -378,7 +384,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -378,7 +384,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock, arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_, arg.a_element_op_,
...@@ -393,7 +399,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -393,7 +399,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc>,
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>,
remove_reference_t< remove_reference_t<
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
...@@ -411,7 +417,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -411,7 +417,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock, arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_, arg.a_element_op_,
...@@ -443,15 +449,17 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -443,15 +449,17 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{ {
printf("DeviceOp err: AccDataType");
return false; return false;
} }
} }
else else
{ {
printf("DeviceOp err: Arch");
return false; return false;
} }
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
...@@ -547,10 +555,10 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -547,10 +555,10 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << ", " << KPerBlock << ", "
<< K1 << ", " << K1 << ", "
<< MPerWMMA << ", " << MPerWmma << ", "
<< NPerWMMA << ", " << NPerWmma << ", "
<< MRepeat << ", " << MRepeat << ", "
<< NRepeat << NRepeat
<< ">" << ">"
......
...@@ -15,6 +15,8 @@ enum struct PipelineVersion ...@@ -15,6 +15,8 @@ enum struct PipelineVersion
}; };
template <PipelineVersion PipelineVer, template <PipelineVersion PipelineVer,
bool AEnableLds = true,
bool BEnableLds = true,
index_t NumPrefetch = 1, index_t NumPrefetch = 1,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
constexpr auto GridwiseGemmPipeline_Selector() constexpr auto GridwiseGemmPipeline_Selector()
...@@ -23,7 +25,7 @@ constexpr auto GridwiseGemmPipeline_Selector() ...@@ -23,7 +25,7 @@ constexpr auto GridwiseGemmPipeline_Selector()
{ {
if constexpr(LoopSched == LoopScheduler::Default) if constexpr(LoopSched == LoopScheduler::Default)
{ {
return GridwiseGemmPipeline_v1<NumPrefetch>{}; return GridwiseGemmPipeline_v1<NumPrefetch, AEnableLds, BEnableLds>{};
} }
else if constexpr(LoopSched == LoopScheduler::Interwave) else if constexpr(LoopSched == LoopScheduler::Interwave)
{ {
......
...@@ -8,12 +8,12 @@ ...@@ -8,12 +8,12 @@
namespace ck { namespace ck {
template <index_t NumPrefetch> template <index_t NumPrefetch, bool AEnableLds, bool BEnableLds>
struct GridwiseGemmPipeline_v1; struct GridwiseGemmPipeline_v1;
// 1-stage prefetch // 1-stage prefetch
template <> template <>
struct GridwiseGemmPipeline_v1<1> struct GridwiseGemmPipeline_v1<1, true, true>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -107,7 +107,7 @@ struct GridwiseGemmPipeline_v1<1> ...@@ -107,7 +107,7 @@ struct GridwiseGemmPipeline_v1<1>
// 2-stage prefetch // 2-stage prefetch
template <> template <>
struct GridwiseGemmPipeline_v1<2> struct GridwiseGemmPipeline_v1<2, true, true>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -253,6 +253,123 @@ struct GridwiseGemmPipeline_v1<2> ...@@ -253,6 +253,123 @@ struct GridwiseGemmPipeline_v1<2>
} }
}; };
template <>
struct GridwiseGemmPipeline_v1<1, false, true>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
#if 0
constexpr auto a_block_origin_idx = generate_sequence_v2(
[]() constexpr {
return Number<0>{};
},
Number<a_block_desc.GetLengths().GetSize()>{});
#endif
constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0);
auto a_block_buf_switch = a_block_buf;
// preload data into LDS
a_blockwise_copy.Run(
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
a_blockwise_copy.Run(
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf_switch);
block_sync_lds();
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
a_block_buf = a_block_buf_switch;
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
}
}
};
// placeholder
template <>
struct GridwiseGemmPipeline_v1<1, true, false>
{
};
template <>
struct GridwiseGemmPipeline_v1<1, false, false>
{
};
template <index_t NumPrefetch> template <index_t NumPrefetch>
struct GridwiseGemmPipelineInterwave_v1; struct GridwiseGemmPipelineInterwave_v1;
...@@ -348,7 +465,7 @@ struct GridwiseGemmPipelineInterwave_v1<1> ...@@ -348,7 +465,7 @@ struct GridwiseGemmPipelineInterwave_v1<1>
// Note: 2 stage prefetch not optimized for inter-wave loop scheduler // Note: 2 stage prefetch not optimized for inter-wave loop scheduler
template <> template <>
struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2> struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2, true, true>
{ {
}; };
...@@ -358,7 +475,7 @@ constexpr auto GridwiseGemmPipeline_v1_Selector() ...@@ -358,7 +475,7 @@ constexpr auto GridwiseGemmPipeline_v1_Selector()
{ {
if constexpr(LoopSched == LoopScheduler::Default) if constexpr(LoopSched == LoopScheduler::Default)
{ {
return GridwiseGemmPipeline_v1<NumPrefetch>{}; return GridwiseGemmPipeline_v1<NumPrefetch, true, true>{};
} }
else if constexpr(LoopSched == LoopScheduler::Interwave) else if constexpr(LoopSched == LoopScheduler::Interwave)
{ {
......
...@@ -1324,15 +1324,14 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow ...@@ -1324,15 +1324,14 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow( __device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow(const Index& src_idx)
const ElementwiseOperation& element_op)
: element_op_{element_op}
{ {
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc need to known at compile-time"); "wrong! Desc need to known at compile-time");
static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0, static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
"wrong! Not divisible"); "wrong! Not divisible");
ignore = src_idx;
} }
template <typename SrcSliceOriginIdx, template <typename SrcSliceOriginIdx,
...@@ -1344,7 +1343,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow ...@@ -1344,7 +1343,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
const SrcBuffer& src_buf, const SrcBuffer& src_buf,
const DstDesc&, const DstDesc&,
const DstSliceOriginIdx&, const DstSliceOriginIdx&,
DstBuffer& dst_buf) DstBuffer& dst_buf) const
{ {
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc need to known at compile-time"); "wrong! Desc need to known at compile-time");
...@@ -1383,7 +1382,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow ...@@ -1383,7 +1382,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
// copy data from src_buf into dst_vector // copy data from src_buf into dst_vector
static_for<0, DstScalarPerVector, 1>{}([&](auto i) { static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
// idx_md err. as dst access 2 strided elements while src visit 1 per loop // src_desc error, non constexpr?
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
...@@ -1396,16 +1395,22 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow ...@@ -1396,16 +1395,22 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
// apply element-wise operation // apply element-wise operation
element_op_(v_this_row, src_buf[Number<src_offset>{}]); element_op_(v_this_row, src_buf[Number<src_offset>{}]);
// if (get_thread_local_1d_id() < 16)
// apply intra-row swizzle permute // printf("tid: %03d, RawData: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<uint16_t*>(&v_this_row)) ); apply intra-row swizzle permute
if constexpr(IntraRowSwizzlePerm) if constexpr(IntraRowSwizzlePerm)
{ {
// origin: temp = __builtin_amdgcn_permlane16( // 0x76543210, 0xfedcba98
// 0xfedcba98, temp,
// 0x76543210 type_convert<int>(v_this_row),
temp = __builtin_amdgcn_permlane16( 0xb3a29180,
temp, type_convert<int>(v_this_row), 0xeca86420, 0xfdb97531, 1, 0); 0xf7e6d5c4,
v_this_row = type_convert<float>(temp); 1,
0);
v_this_row = type_convert<SrcData>(temp);
// if (get_thread_local_1d_id() < 16)
// printf("tid: %03d, SwiData: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<uint16_t*>(&v_this_row)) );
} }
// apply inter-row permute. // apply inter-row permute.
...@@ -1415,8 +1420,9 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow ...@@ -1415,8 +1420,9 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
HighEightRowLaneIdx, HighEightRowLaneIdx,
1, 1,
0); 0);
v_theother_row = type_convert<float>(temp); v_theother_row = type_convert<SrcData>(temp);
// printf("tid: %03d, PermData: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<uint16_t*>(&v_theother_row)) );
if(get_thread_local_1d_id() % 32 < 16) if(get_thread_local_1d_id() % 32 < 16)
{ {
// apply type convert // apply type convert
...@@ -1434,8 +1440,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow ...@@ -1434,8 +1440,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
}); });
}); });
} }
ElementwiseOperation element_op_{};
ElementwiseOperation element_op_;
}; };
} // namespace ck } // namespace ck
...@@ -103,7 +103,12 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16, ...@@ -103,7 +103,12 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4; m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, bool AssemblyBackend, class FloatA, class FloatB, class FloatC> template <index_t MPerWmma,
index_t NPerWmma,
bool AssemblyBackend,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
if constexpr(wave_size == 32) if constexpr(wave_size == 32)
...@@ -358,7 +363,7 @@ template <typename src_type_a, ...@@ -358,7 +363,7 @@ template <typename src_type_a,
index_t MPerWmma, index_t MPerWmma,
index_t NPerWmma, index_t NPerWmma,
index_t KPack, index_t KPack,
bool TransposeC = false, bool TransposeC = false,
bool AssemblyBackend = false> bool AssemblyBackend = false>
struct WmmaGemm struct WmmaGemm
{ {
...@@ -492,11 +497,13 @@ struct WmmaGemm ...@@ -492,11 +497,13 @@ struct WmmaGemm
"(int8, int32) or (int4, int32)!"); "(int8, int32) or (int4, int32)!");
if constexpr(!TransposeC) if constexpr(!TransposeC)
{ {
wmma_instr.template run<MPerWmma, NPerWmma, AssemblyBackend>(p_a_wave, p_b_wave, p_c_thread); wmma_instr.template run<MPerWmma, NPerWmma, AssemblyBackend>(
p_a_wave, p_b_wave, p_c_thread);
} }
else else
{ {
wmma_instr.template run<MPerWmma, NPerWmma, AssemblyBackend>(p_b_wave, p_a_wave, p_c_thread); wmma_instr.template run<MPerWmma, NPerWmma, AssemblyBackend>(
p_b_wave, p_a_wave, p_c_thread);
} }
} }
......
...@@ -21,13 +21,16 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16, AssemblyBackend> ...@@ -21,13 +21,16 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16, AssemblyBackend>
template <class FloatC> template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{ {
if constexpr(AssemblyBackend){ if constexpr(AssemblyBackend)
{
amd_assembly_wmma_f32_16x16x16_f16_w32( amd_assembly_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{})); reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
} }
else{ else
reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( {
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]); reg_c.template AsType<float8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
} }
} }
}; };
......
...@@ -988,6 +988,30 @@ inline __host__ __device__ constexpr float type_convert<float, int>(int x) ...@@ -988,6 +988,30 @@ inline __host__ __device__ constexpr float type_convert<float, int>(int x)
return u.fp32; return u.fp32;
} }
template <>
inline __host__ __device__ constexpr int type_convert<int, half_t>(half_t x)
{
union
{
half_t fp16;
int int32;
} u = {x};
return u.int32;
}
template <>
inline __host__ __device__ constexpr half_t type_convert<half_t, int>(int x)
{
union
{
int int32;
half_t fp16;
} u = {x};
return u.fp16;
}
// convert fp32 to bfp16 // convert fp32 to bfp16
template <> template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x) inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment