Commit db843529 authored by coderfeli's avatar coderfeli
Browse files

fix warnings and revert cmake and fix clang format

parent 5765ba51
...@@ -516,6 +516,10 @@ include_directories(BEFORE ...@@ -516,6 +516,10 @@ include_directories(BEFORE
) )
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV)
add_compile_options(-Werror)
add_compile_options(-Weverything)
endif()
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
......
...@@ -66,6 +66,7 @@ else() ...@@ -66,6 +66,7 @@ else()
-Wunreachable-code -Wunreachable-code
-Wunused -Wunused
-Wno-reserved-identifier -Wno-reserved-identifier
-Werror
-Wno-option-ignored -Wno-option-ignored
-Wsign-compare -Wsign-compare
-Wno-extra-semi-stmt -Wno-extra-semi-stmt
......
...@@ -59,25 +59,25 @@ template <index_t BlockSize, ...@@ -59,25 +59,25 @@ template <index_t BlockSize,
// ,bool TransposeC //disable transposec right now... // ,bool TransposeC //disable transposec right now...
> >
struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intrawave, struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intrawave,
BlockSize, BlockSize,
ADataType, ADataType,
BDataType, BDataType,
ComputeDataType, ComputeDataType,
AccDataType, AccDataType,
ATileDesc, ATileDesc,
BTileDesc, BTileDesc,
AMmaTileDesc, AMmaTileDesc,
BMmaTileDesc, BMmaTileDesc,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
MPerXDL, MPerXDL,
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack> KPack>
: BlockwiseGemmXdlops_pipeline_base<BlockSize, : BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType, ADataType,
BDataType, BDataType,
...@@ -137,7 +137,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -137,7 +137,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::a_block_desc_m0_m1_m2_k; using Base::a_block_desc_m0_m1_m2_k;
using Base::b_block_desc_n0_n1_n2_k;
using Base::AMmaKStride; using Base::AMmaKStride;
using Base::BMmaKStride; using Base::BMmaKStride;
...@@ -271,10 +270,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -271,10 +270,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
typename ABlockBuffer, typename ABlockBuffer,
typename ABlockTransferStep, typename ABlockTransferStep,
typename BGridDesc, typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer, typename BBlockTransfer,
typename BGridBuffer, typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep, typename BBlockTransferStep,
typename CThreadBuffer> typename CThreadBuffer>
__device__ void Run(const AGridDesc& a_grid_desc, __device__ void Run(const AGridDesc& a_grid_desc,
...@@ -285,10 +282,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -285,10 +282,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
ABlockBuffer& a_block_buf1, ABlockBuffer& a_block_buf1,
const ABlockTransferStep& a_block_copy_step, const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc, const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy, BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf, const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop) const index_t num_loop) const
...@@ -296,8 +291,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -296,8 +291,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
a_thread_desc_.GetElementSpaceSize()); a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_thread_desc_.GetElementSpaceSize());
// Global prefetch 1 // Global prefetch 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
...@@ -348,14 +341,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -348,14 +341,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec = vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>, Number<0>{}>(); b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>,
Number<0>{}>();
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}]; make_tuple(m0, I0, k0, ik))>{}];
}); });
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
...@@ -399,8 +393,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -399,8 +393,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec = vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>, Number<1>{}>(); b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>,
Number<1>{}>();
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
...@@ -449,25 +444,24 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -449,25 +444,24 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec = vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>, Number<0>{}>(); b_blockwise_copy
.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>, Number<0>{}>();
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}]; make_tuple(m0, I0, k0, ik))>{}];
}); });
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run( xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
a_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
}); });
}); });
...@@ -477,11 +471,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -477,11 +471,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}), make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
a_block_buf1, a_block_buf1,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k0, I0), make_tuple(m0, I0, k0, I0),
a_thread_buf); a_thread_buf);
}); });
}); });
...@@ -491,8 +485,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -491,8 +485,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec = vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>, Number<1>{}>(); b_blockwise_copy
.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>, Number<1>{}>();
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
......
...@@ -112,7 +112,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1 ...@@ -112,7 +112,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1
template <typename SeqIdx, index_t ThreadScratchId = 0> template <typename SeqIdx, index_t ThreadScratchId = 0>
__device__ constexpr auto GetSrcThreadScratchIdx() __device__ constexpr auto GetSrcThreadScratchIdx()
{ {
return threadwise_transfer_.template GetSrcThreadScratchIdx<SeqIdx, ThreadScratchId>(); return threadwise_transfer_.template GetSrcThreadScratchIdx<SeqIdx, ThreadScratchId>();
} }
template <typename SrcBuffer, index_t ThreadScratchId = 0> template <typename SrcBuffer, index_t ThreadScratchId = 0>
......
...@@ -67,55 +67,57 @@ template <typename ALayout, ...@@ -67,55 +67,57 @@ template <typename ALayout,
typename ComputeTypeA = CDataType, typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA, typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ComputeTypeA, typename LDSTypeA = ComputeTypeA,
typename LDSTypeB = ComputeTypeB> typename LDSTypeB = ComputeTypeB>
struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xdl_CShuffle_V3<ALayout, struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
BLayout, : public DeviceGemmMultiD_Xdl_CShuffle_V3<
DsLayout, ALayout,
CLayout, BLayout,
ADataType, DsLayout,
BDataType, CLayout,
DsDataType, ADataType,
CDataType, BDataType,
GemmAccDataType, DsDataType,
CShuffleDataType, CDataType,
AElementwiseOperation, GemmAccDataType,
BElementwiseOperation, CShuffleDataType,
CElementwiseOperation, AElementwiseOperation,
GemmSpec, BElementwiseOperation,
BlockSize, CElementwiseOperation,
MPerBlock, GemmSpec,
NPerBlock, BlockSize,
KPerBlock, MPerBlock,
AK1, NPerBlock,
BK1, KPerBlock,
MPerXDL, AK1,
NPerXDL, BK1,
MXdlPerWave, MPerXDL,
NXdlPerWave, NPerXDL,
ABlockTransferThreadClusterLengths_AK0_M_AK1, MXdlPerWave,
ABlockTransferThreadClusterArrangeOrder, NXdlPerWave,
ABlockTransferSrcAccessOrder, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferSrcVectorDim, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcAccessOrder,
ABlockTransferDstScalarPerVector_AK1, ABlockTransferSrcVectorDim,
ABlockLdsExtraM, ABlockTransferSrcScalarPerVector,
BBlockTransferThreadClusterLengths_BK0_N_BK1, ABlockTransferDstScalarPerVector_AK1,
BBlockTransferThreadClusterArrangeOrder, ABlockLdsExtraM,
BBlockTransferSrcAccessOrder, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferSrcVectorDim, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcAccessOrder,
BBlockTransferDstScalarPerVector_BK1, BBlockTransferSrcVectorDim,
BBlockLdsExtraN, BBlockTransferSrcScalarPerVector,
CShuffleMXdlPerWavePerShuffle, BBlockTransferDstScalarPerVector_BK1,
CShuffleNXdlPerWavePerShuffle, BBlockLdsExtraN,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleMXdlPerWavePerShuffle,
CDEShuffleBlockTransferScalarPerVectors, CShuffleNXdlPerWavePerShuffle,
BlkGemmPipeSched, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
BlkGemmPipelineVer, CDEShuffleBlockTransferScalarPerVectors,
ComputeTypeA, BlkGemmPipeSched,
ComputeTypeB, BlkGemmPipelineVer,
LDSTypeA, ComputeTypeA,
LDSTypeB> ComputeTypeB,
LDSTypeA,
LDSTypeB>
{ {
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
...@@ -172,7 +174,6 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd ...@@ -172,7 +174,6 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd
LDSTypeA, LDSTypeA,
LDSTypeB>; LDSTypeB>;
using Argument = typename GridwiseGemm::Argument; using Argument = typename GridwiseGemm::Argument;
// Invoker // Invoker
...@@ -267,7 +268,9 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd ...@@ -267,7 +268,9 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd
constexpr index_t minimum_occupancy = constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
// static_assert(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 && has_main_k_block_loop, "only impl BlockGemmPipelineVersion::v3 and has mainloop right now"); // static_assert(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
// has_main_k_block_loop, "only impl BlockGemmPipelineVersion::v3 and has mainloop right
// now");
if(has_main_k_block_loop) if(has_main_k_block_loop)
{ {
// Tail number always full // Tail number always full
...@@ -284,11 +287,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd ...@@ -284,11 +287,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd
} }
else else
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<GridwiseGemm, GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy>; minimum_occupancy>;
Run(kernel); Run(kernel);
} }
} }
...@@ -298,7 +301,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd ...@@ -298,7 +301,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd
} }
} }
else else
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
...@@ -310,11 +313,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd ...@@ -310,11 +313,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd
} }
else else
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<GridwiseGemm, GridwiseGemm,
false, false,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy>; minimum_occupancy>;
Run(kernel); Run(kernel);
} }
} }
...@@ -437,4 +440,4 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd ...@@ -437,4 +440,4 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
\ No newline at end of file
...@@ -125,23 +125,25 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -125,23 +125,25 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock =
CDEShuffleBlockTransferScalarPerVectors{}[I0]; CDEShuffleBlockTransferScalarPerVectors{}[I0];
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{}; static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{}; static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1Number = Number<AK1Value>{}; static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{}; static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr auto BlockSizeNumber = Number<BlockSize>{}; static constexpr auto BlockSizeNumber = Number<BlockSize>{};
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
using mfma_selector = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>; using mfma_selector = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>;
static constexpr index_t KPack = math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); static constexpr index_t KPack =
static constexpr index_t KLane = mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops(); math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk);
static constexpr index_t KLane =
mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops();
static constexpr index_t KRepeat = KPerBlock / KLane / KPack; static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
static constexpr index_t NLane = NPerXdl; static constexpr index_t NLane = NPerXdl;
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
static_assert(NLane * NWave * KLane == BlockSize); static_assert(NLane * NWave * KLane == BlockSize);
static_assert(NXdlPerWave == 1, "only 1 validated now, tbd next week"); static_assert(NXdlPerWave == 1, "only 1 validated now, tbd next week");
static constexpr auto MakeDsGridPointer() static constexpr auto MakeDsGridPointer()
{ {
return generate_tuple( return generate_tuple(
...@@ -320,12 +322,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -320,12 +322,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
{ {
constexpr index_t NkSwizzle = BlockSize * KPack; constexpr index_t NkSwizzle = BlockSize * KPack;
constexpr index_t NkSwizzleNumber = Number<NkSwizzle>{}; constexpr index_t NkSwizzleNumber = Number<NkSwizzle>{};
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(make_tuple(N0, K0, NkSwizzleNumber),
make_tuple(N0, K0, NkSwizzleNumber), make_tuple(K0 * NkSwizzle, NkSwizzleNumber, I1));
make_tuple(K0 * NkSwizzle, NkSwizzleNumber, I1)
);
} }
__host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
...@@ -423,9 +423,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -423,9 +423,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{ {
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
} }
template <typename ELayout> template <typename ELayout>
...@@ -943,7 +941,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -943,7 +941,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{ {
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed( make_naive_tensor_descriptor_packed(
...@@ -955,44 +952,40 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -955,44 +952,40 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
} }
using BlockwiseGemmPipe = using BlockwiseGemmPipe =
remove_cvref_t<decltype(BlockwiseGemmXdlops_pipeline_bpreshuffle<BlkGemmPipeSched, remove_cvref_t<decltype(BlockwiseGemmXdlops_pipeline_bpreshuffle<
BlockSize, BlkGemmPipeSched,
LDSTypeA, BlockSize,
LDSTypeB, LDSTypeA,
ComputeTypeA, LDSTypeB,
AccDataType, ComputeTypeA,
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), AccDataType,
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()),
decltype(MakeAMmaTileDescriptor_M0_M1_M2_K( decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()),
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(
decltype(MakeBMmaTileDescriptor_N0_N1_N2_K( GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())),
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(
ABlockTransferSrcScalarPerVector, GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())),
BBlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
MPerBlock, BBlockTransferSrcScalarPerVector,
NPerBlock, MPerBlock,
KPerBlock, NPerBlock,
MPerXdl, KPerBlock,
NPerXdl, MPerXdl,
MXdlPerWave, NPerXdl,
NXdlPerWave, MXdlPerWave,
KPack>{})>; NXdlPerWave,
KPack>{})>;
__device__ static constexpr index_t GetSharedMemoryNumberOfByte() __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment // lds max alignment
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple( constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
// constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
// b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for C shuffle in LDS // LDS allocation for C shuffle in LDS
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
...@@ -1259,8 +1252,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1259,8 +1252,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
{ {
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bpreshuffled = MakeBGridDescriptor_Preshuffled( const auto b_grid_desc_bpreshuffled =
problem.BN0Shuffled, problem.BK0Shuffled); MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>( const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
...@@ -1294,10 +1287,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1294,10 +1287,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
const index_t n_block_data_idx_on_grid = const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * (NPerBlock / NLane / NWave)) ; __builtin_amdgcn_readfirstlane(block_n_id * (NPerBlock / NLane / NWave));
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
...@@ -1339,51 +1329,42 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1339,51 +1329,42 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
// using BThreadClusterLengths = Sequence<1, 1, BlockSize>; // using BThreadClusterLengths = Sequence<1, 1, BlockSize>;
// using BBlockTransferClusterArrangeOrder = Sequence<0, 1, 2>; // using BBlockTransferClusterArrangeOrder = Sequence<0, 1, 2>;
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1<
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThisThreadBlock,
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<1, KRepeat, KPack * BlockSize>, Sequence<1, KRepeat, KPack * BlockSize>,
Sequence<1, 1, BlockSize>, //BThreadClusterLengths, Sequence<1, 1, BlockSize>, // BThreadClusterLengths,
Sequence<0, 1, 2>, //BBlockTransferClusterArrangeOrder, Sequence<0, 1, 2>, // BBlockTransferClusterArrangeOrder,
BDataType, BDataType,
LDSTypeB, LDSTypeB,
decltype(b_grid_desc_bpreshuffled), decltype(b_grid_desc_bpreshuffled),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
Sequence<0, 1, 2>,//BBlockTransferSrcAccessOrder, Sequence<0, 1, 2>, // BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
2, 2,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1, BBlockTransferDstScalarPerVector_BK1,
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true, true,
2>( 2>(b_grid_desc_bpreshuffled,
b_grid_desc_bpreshuffled, make_multi_index(n_block_data_idx_on_grid, 0, 0),
make_multi_index(n_block_data_idx_on_grid, 0, 0), b_element_op,
b_element_op, b_block_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1, make_multi_index(0, 0, 0),
make_multi_index(0, 0, 0), ck::tensor_operation::element_wise::PassThrough{});
ck::tensor_operation::element_wise::PassThrough{});
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
// Cast after lds // Cast after lds
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto a_block_buf1 = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf1 = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeA*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<LDSTypeA*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeB*>(p_shared) +
a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB),
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, KRepeat, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, KRepeat, 0);
...@@ -1404,10 +1385,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1404,10 +1385,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
a_block_buf1, a_block_buf1,
a_block_slice_copy_step, a_block_slice_copy_step,
b_grid_desc_bpreshuffled, b_grid_desc_bpreshuffled,
b_block_desc_bk0_n_bk1,
b_blockwise_copy, b_blockwise_copy,
b_grid_buf, b_grid_buf,
b_block_buf,
b_block_slice_copy_step, b_block_slice_copy_step,
c_thread_buf, c_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
...@@ -1419,7 +1398,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1419,7 +1398,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
"wrong!"); "wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// TODO: hacky, fix it! // TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
...@@ -1672,7 +1650,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1672,7 +1650,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
}); });
} }
} }
}; };
} // namespace ck } // namespace ck
...@@ -268,12 +268,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -268,12 +268,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
} }
template <typename SeqIdx, index_t ThreadScratchId = 0> template <typename SeqIdx, index_t ThreadScratchId = 0>
__device__ constexpr auto GetSrcThreadScratchIdx(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{}) __device__ constexpr auto
GetSrcThreadScratchIdx(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{ {
using vector_t = typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type; using vector_t = typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
return src_thread_scratch_tuple_(thread_scratch_id).template GetAsType<vector_t>(SeqIdx{}); return src_thread_scratch_tuple_(thread_scratch_id).template GetAsType<vector_t>(SeqIdx{});
} }
template <index_t ThreadScratchId> template <index_t ThreadScratchId>
__device__ void __device__ void
TransferDataFromSrcThreadScratchToDstThreadScratch(Number<ThreadScratchId> thread_scratch_id) TransferDataFromSrcThreadScratchToDstThreadScratch(Number<ThreadScratchId> thread_scratch_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment