Commit d4adc71a authored by aska-0096's avatar aska-0096
Browse files

Mat-A LDS Bypass sanity pass

parent c811a0e9
......@@ -19,15 +19,49 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 256, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, 1>;
< ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
BElementOp,
CElementOp,
GemmDefault,
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
64, // KPerBlock
8, // K1
16, // MPerWmma
16, // NPerWmma
1, // M Repeat
8, // N-Repeat
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1, // C shuffle (M Repeat) Per store
4, // C shuffle (N Repeat) Per store
S<1, 64, 1, 4>,
8>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
......
......@@ -35,6 +35,18 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
case 2:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
break;
case 3:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
case 4:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
......
......@@ -35,7 +35,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
// warm up
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
const int nrepeat = 10;
const int nrepeat = 100;
#if DEBUG_LOG
printf("Start running %d times...\n", nrepeat);
#endif
......
......@@ -62,12 +62,33 @@ struct BlockwiseGemmWMMA
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I4);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I4);
static constexpr auto A_temp0 = Number<ABlockDesc{}.GetLength(I0)>{};
static constexpr auto A_temp1 = Number<ABlockDesc{}.GetLength(I1)>{};
static constexpr auto A_temp2 = Number<ABlockDesc{}.GetLength(I2)>{};
static constexpr auto A_temp3 = Number<ABlockDesc{}.GetLength(I3)>{};
static constexpr auto A_temp4 = Number<ABlockDesc{}.GetLength(I4)>{};
// FIX it, workaround
using ABlockDesc_temp = decltype(
make_naive_tensor_descriptor(make_tuple(A_temp0, A_temp1, A_temp2, A_temp3, A_temp4),
make_tuple(A_temp1* A_temp2* A_temp3* A_temp4,
A_temp2* A_temp3* A_temp4,
A_temp3* A_temp4,
A_temp4,
I1)));
static constexpr auto wmma_gemm =
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
static constexpr bool AEnableLds = NWaves == 1 ? false : true;
static constexpr bool BEnableLds = MWaves == 1 ? false : true;
// Read from Lds, duplicate Twice, Read from VGPR, no duplication.
static constexpr index_t A_Data_Duplicated_Rate = AEnableLds ? 2 : 1;
static constexpr index_t B_Data_Duplicated_Rate = BEnableLds ? 2 : 1;
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,
......@@ -92,24 +113,36 @@ struct BlockwiseGemmWMMA
// Default, Block buffer in LDS, thread level offset enabled
__device__ static auto CalculateAThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
// |KRepeat |MRepeat|MWave |MLane |KPack
return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0);
if constexpr(AEnableLds)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
// |KRepeat |MRepeat|MWave |MLane |KPack
return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0);
}
else
{
return make_tuple(0, 0, 0, 0, 0);
}
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
// |KRepeat |NRepeat|Nwave |NLane |KPack
return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0);
if constexpr(BEnableLds)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
// |KRepeat |NRepeat|Nwave |NLane |KPack
return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0);
}
else
{
return make_tuple(0, 0, 0, 0, 0);
}
}
template <index_t m0, index_t n0>
......@@ -269,7 +302,7 @@ struct BlockwiseGemmWMMA
// Describe how data allocated in thread copy src buffer
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1;
static constexpr ABlockDesc_temp a_block_desc_k0_m0_m1_m2_k1;
static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1;
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
......@@ -285,21 +318,28 @@ struct BlockwiseGemmWMMA
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * WmmaK / A_K1>{}, m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0),
a_thread_buf);
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(
Number<k * WmmaK / A_K1 * A_Data_Duplicated_Rate / 2>{}, m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * WmmaK / B_K1>{}, n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0),
b_thread_buf);
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * WmmaK / B_K1 * B_Data_Duplicated_Rate / 2>{},
n0,
I0,
I0,
I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0),
b_thread_buf);
vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec;
......@@ -324,6 +364,7 @@ struct BlockwiseGemmWMMA
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
......@@ -340,28 +381,78 @@ struct BlockwiseGemmWMMA
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<WmmaK / A_K1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4>,
4,
A_K1,
A_K1>;
template <bool EnableLds>
struct AThreadCopySelector;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<WmmaK / B_K1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4>,
4,
B_K1,
B_K1>;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
template <>
struct AThreadCopySelector<true>
{
using type = ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<WmmaK / A_K1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4>,
4,
A_K1,
A_K1>;
};
template <>
struct AThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow<
FloatA,
FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
tensor_operation::element_wise::PassThrough,
Sequence<1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4>,
4,
A_K1,
0x76543210,
0xfedcba98,
true>;
};
template <bool EnableLds>
struct BThreadCopySelector;
template <>
struct BThreadCopySelector<true>
{
using type = ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<WmmaK / B_K1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4>,
4,
B_K1,
B_K1>;
};
template <>
struct BThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow<
FloatB,
FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
tensor_operation::element_wise::PassThrough,
Sequence<1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4>,
4,
B_K1,
0x76543210,
0xfedcba98,
false>;
};
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
};
// block wise level pipe designed for inline asm
......@@ -376,7 +467,7 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool TransposeC = false,
bool TransposeC = false,
bool AssemblyBackend = true>
/* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1
......@@ -407,8 +498,14 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto wmma_gemm =
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC, AssemblyBackend>{};
static constexpr auto wmma_gemm = WmmaGemm<FloatA,
FloatB,
FloatAcc,
MPerWMMA,
NPerWMMA,
KPack,
TransposeC,
AssemblyBackend>{};
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
......
......@@ -15,6 +15,7 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
namespace ck {
namespace tensor_operation {
......@@ -35,10 +36,10 @@ template <typename ALayout,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t KPerBlock,
ck::index_t K1,
ck::index_t MPerWMMA,
ck::index_t NPerWMMA,
ck::index_t MPerWmma,
ck::index_t NPerWmma,
ck::index_t MRepeat,
ck::index_t NRepeat,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
......@@ -75,19 +76,31 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
// K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{};
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
{
assert(K % K1 == 0);
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16;
const index_t K0 = K / K1;
static constexpr auto AEnableLds = NWaves == 1 ? false : true;
static constexpr auto BEnableLds = MWaves == 1 ? false : true;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
// Describe how data read from Global memory
static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA)
{
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
const auto a_grid_desc_mraw_kraw =
make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(StrideA, I1));
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
#ifdef ENABLE_COLMAJOR
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
......@@ -97,104 +110,88 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
#endif
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
assert(K % K1 == 0);
if constexpr(AEnableLds)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const index_t K0 = K / K1;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_pass_through_transform(a_grid_desc_m_k.GetLength(I0))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
constexpr auto A_KRow = WmmaK / K1;
const auto A_KWmma = K / WmmaK;
const auto M0 = M / MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(make_unmerge_transform(make_tuple(A_KWmma, Number<A_KRow>{}, K1Number)),
make_unmerge_transform(
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
}
}
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB)
{
assert(K % K1 == 0);
const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() {
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
assert(K % K1 == 0);
const index_t K0 = K / K1;
return transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
{
const auto c_grid_desc_m_n = [&]() {
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideC));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
}
// Gridwise descriptor, mapping to whole given provblem.
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
......@@ -207,7 +204,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
CShuffleDataType,
CDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
AGridDesc,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
AElementwiseOperation,
......@@ -215,9 +212,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
CElementwiseOperation,
MPerBlock,
NPerBlock,
K0PerBlock,
MPerWMMA,
NPerWMMA,
KPerBlock,
MPerWmma,
NPerWmma,
K1,
MRepeat,
NRepeat,
......@@ -228,6 +225,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
AEnableLds,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
......@@ -236,6 +234,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BEnableLds,
BBlockLdsAddExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
......@@ -265,7 +264,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
a_grid_desc_k0_m_k1_{},
a_grid_desc_{},
b_grid_desc_k0_n_k1_{},
c_grid_desc_m_n_{},
c_grid_desc_mblock_mperblock_nblock_nperblock{},
......@@ -276,8 +275,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
a_grid_desc_k0_m_k1_ =
DeviceGemmWmma_CShuffle::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
a_grid_desc_ = DeviceGemmWmma_CShuffle::MakeAGridDescriptor(M, K, StrideA);
b_grid_desc_k0_n_k1_ =
DeviceGemmWmma_CShuffle::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmWmma_CShuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
......@@ -285,10 +283,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
if(GridwiseGemm::CheckValidity(
a_grid_desc_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
......@@ -300,7 +296,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
AGridDesc a_grid_desc_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
......@@ -322,9 +318,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
{
#if 0
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.a_grid_desc_{" << arg.a_grid_desc_.GetLength(I0)
<< ", " << arg.a_grid_desc_.GetLength(I1) << ", "
<< arg.a_grid_desc_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
......@@ -336,7 +332,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
......@@ -348,8 +344,18 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
const auto GetK = [&]() {
if constexpr(AEnableLds)
{
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2);
}
else
{
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
arg.a_grid_desc_.GetLength(I5);
}
};
const auto K = GetK();
float ave_time = 0;
......@@ -360,7 +366,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
ADataType,
BDataType,
CDataType,
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc>,
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>,
remove_reference_t<
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
......@@ -378,7 +384,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.a_grid_desc_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_,
......@@ -393,7 +399,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
ADataType,
BDataType,
CDataType,
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc>,
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>,
remove_reference_t<
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
......@@ -411,7 +417,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.a_grid_desc_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_,
......@@ -443,15 +449,17 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
printf("DeviceOp err: AccDataType");
return false;
}
}
else
{
printf("DeviceOp err: Arch");
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
return GridwiseGemm::CheckValidity(arg.a_grid_desc_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
......@@ -547,10 +555,10 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< KPerBlock << ", "
<< K1 << ", "
<< MPerWMMA << ", "
<< NPerWMMA << ", "
<< MPerWmma << ", "
<< NPerWmma << ", "
<< MRepeat << ", "
<< NRepeat
<< ">"
......
......@@ -15,6 +15,8 @@ enum struct PipelineVersion
};
template <PipelineVersion PipelineVer,
bool AEnableLds = true,
bool BEnableLds = true,
index_t NumPrefetch = 1,
LoopScheduler LoopSched = LoopScheduler::Default>
constexpr auto GridwiseGemmPipeline_Selector()
......@@ -23,7 +25,7 @@ constexpr auto GridwiseGemmPipeline_Selector()
{
if constexpr(LoopSched == LoopScheduler::Default)
{
return GridwiseGemmPipeline_v1<NumPrefetch>{};
return GridwiseGemmPipeline_v1<NumPrefetch, AEnableLds, BEnableLds>{};
}
else if constexpr(LoopSched == LoopScheduler::Interwave)
{
......
......@@ -8,12 +8,12 @@
namespace ck {
template <index_t NumPrefetch>
template <index_t NumPrefetch, bool AEnableLds, bool BEnableLds>
struct GridwiseGemmPipeline_v1;
// 1-stage prefetch
template <>
struct GridwiseGemmPipeline_v1<1>
struct GridwiseGemmPipeline_v1<1, true, true>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -107,7 +107,7 @@ struct GridwiseGemmPipeline_v1<1>
// 2-stage prefetch
template <>
struct GridwiseGemmPipeline_v1<2>
struct GridwiseGemmPipeline_v1<2, true, true>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -253,6 +253,123 @@ struct GridwiseGemmPipeline_v1<2>
}
};
template <>
struct GridwiseGemmPipeline_v1<1, false, true>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
#if 0
constexpr auto a_block_origin_idx = generate_sequence_v2(
[]() constexpr {
return Number<0>{};
},
Number<a_block_desc.GetLengths().GetSize()>{});
#endif
constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0);
auto a_block_buf_switch = a_block_buf;
// preload data into LDS
a_blockwise_copy.Run(
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
a_blockwise_copy.Run(
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf_switch);
block_sync_lds();
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
a_block_buf = a_block_buf_switch;
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
}
}
};
// placeholder
template <>
struct GridwiseGemmPipeline_v1<1, true, false>
{
};
template <>
struct GridwiseGemmPipeline_v1<1, false, false>
{
};
template <index_t NumPrefetch>
struct GridwiseGemmPipelineInterwave_v1;
......@@ -348,7 +465,7 @@ struct GridwiseGemmPipelineInterwave_v1<1>
// Note: 2 stage prefetch not optimized for inter-wave loop scheduler
template <>
struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2>
struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2, true, true>
{
};
......@@ -358,7 +475,7 @@ constexpr auto GridwiseGemmPipeline_v1_Selector()
{
if constexpr(LoopSched == LoopScheduler::Default)
{
return GridwiseGemmPipeline_v1<NumPrefetch>{};
return GridwiseGemmPipeline_v1<NumPrefetch, true, true>{};
}
else if constexpr(LoopSched == LoopScheduler::Interwave)
{
......
......@@ -1324,15 +1324,14 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
using Index = MultiIndex<nDim>;
__device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow(
const ElementwiseOperation& element_op)
: element_op_{element_op}
__device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow(const Index& src_idx)
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc need to known at compile-time");
static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
"wrong! Not divisible");
ignore = src_idx;
}
template <typename SrcSliceOriginIdx,
......@@ -1344,7 +1343,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
const SrcBuffer& src_buf,
const DstDesc&,
const DstSliceOriginIdx&,
DstBuffer& dst_buf)
DstBuffer& dst_buf) const
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc need to known at compile-time");
......@@ -1383,7 +1382,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
// copy data from src_buf into dst_vector
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
// idx_md err. as dst access 2 strided elements while src visit 1 per loop
// src_desc error, non constexpr?
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
......@@ -1396,16 +1395,22 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
// apply element-wise operation
element_op_(v_this_row, src_buf[Number<src_offset>{}]);
// apply intra-row swizzle permute
// if (get_thread_local_1d_id() < 16)
// printf("tid: %03d, RawData: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<uint16_t*>(&v_this_row)) ); apply intra-row swizzle permute
if constexpr(IntraRowSwizzlePerm)
{
// origin:
// 0xfedcba98,
// 0x76543210
temp = __builtin_amdgcn_permlane16(
temp, type_convert<int>(v_this_row), 0xeca86420, 0xfdb97531, 1, 0);
v_this_row = type_convert<float>(temp);
temp = __builtin_amdgcn_permlane16( // 0x76543210, 0xfedcba98
temp,
type_convert<int>(v_this_row),
0xb3a29180,
0xf7e6d5c4,
1,
0);
v_this_row = type_convert<SrcData>(temp);
// if (get_thread_local_1d_id() < 16)
// printf("tid: %03d, SwiData: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<uint16_t*>(&v_this_row)) );
}
// apply inter-row permute.
......@@ -1415,8 +1420,9 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
HighEightRowLaneIdx,
1,
0);
v_theother_row = type_convert<float>(temp);
v_theother_row = type_convert<SrcData>(temp);
// printf("tid: %03d, PermData: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<uint16_t*>(&v_theother_row)) );
if(get_thread_local_1d_id() % 32 < 16)
{
// apply type convert
......@@ -1434,8 +1440,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
});
});
}
ElementwiseOperation element_op_;
ElementwiseOperation element_op_{};
};
} // namespace ck
......@@ -103,7 +103,12 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, bool AssemblyBackend, class FloatA, class FloatB, class FloatC>
template <index_t MPerWmma,
index_t NPerWmma,
bool AssemblyBackend,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
......@@ -358,7 +363,7 @@ template <typename src_type_a,
index_t MPerWmma,
index_t NPerWmma,
index_t KPack,
bool TransposeC = false,
bool TransposeC = false,
bool AssemblyBackend = false>
struct WmmaGemm
{
......@@ -492,11 +497,13 @@ struct WmmaGemm
"(int8, int32) or (int4, int32)!");
if constexpr(!TransposeC)
{
wmma_instr.template run<MPerWmma, NPerWmma, AssemblyBackend>(p_a_wave, p_b_wave, p_c_thread);
wmma_instr.template run<MPerWmma, NPerWmma, AssemblyBackend>(
p_a_wave, p_b_wave, p_c_thread);
}
else
{
wmma_instr.template run<MPerWmma, NPerWmma, AssemblyBackend>(p_b_wave, p_a_wave, p_c_thread);
wmma_instr.template run<MPerWmma, NPerWmma, AssemblyBackend>(
p_b_wave, p_a_wave, p_c_thread);
}
}
......
......@@ -21,13 +21,16 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16, AssemblyBackend>
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
if constexpr(AssemblyBackend){
if constexpr(AssemblyBackend)
{
amd_assembly_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
}
else{
reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
else
{
reg_c.template AsType<float8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
}
}
};
......
......@@ -988,6 +988,30 @@ inline __host__ __device__ constexpr float type_convert<float, int>(int x)
return u.fp32;
}
template <>
inline __host__ __device__ constexpr int type_convert<int, half_t>(half_t x)
{
union
{
half_t fp16;
int int32;
} u = {x};
return u.int32;
}
template <>
inline __host__ __device__ constexpr half_t type_convert<half_t, int>(int x)
{
union
{
int int32;
half_t fp16;
} u = {x};
return u.fp16;
}
// convert fp32 to bfp16
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment