Unverified Commit 06701e70 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer Committed by GitHub
Browse files

Merge branch 'develop' into lwpck-1815

parents 5800d24e da42a889
......@@ -24,4 +24,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32)
set(target 1)
endif()
endforeach()
\ No newline at end of file
endforeach()
......@@ -83,14 +83,14 @@ using DeviceOpInstanceKKNN =
2,
4,
4,
true,
false,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
4,
4,
true,
false,
1,
1,
S<1, 64, 1, 2>,
......
......@@ -2,6 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common_wmma.hpp"
#include "ck/host_utility/device_prop.hpp"
// kernel data types
using InKernelDataType = FP16;
......@@ -23,4 +24,14 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;
#include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); }
int main(int argc, char* argv[])
{
bool is_supported = ck::is_gfx11_supported();
if(!is_supported)
{
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
<< std::endl;
return 0;
}
return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv);
}
......@@ -2,6 +2,7 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common_wmma.hpp"
#include "ck/host_utility/device_prop.hpp"
// kernel data types
using InKernelDataType = I8;
......@@ -23,4 +24,14 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;
#include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); }
int main(int argc, char* argv[])
{
bool is_supported = ck::is_gfx11_supported();
if(!is_supported)
{
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
<< std::endl;
return 0;
}
return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv);
}
......@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/host_utility/device_prop.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......@@ -163,4 +164,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
#include "run_batched_gemm_scale_softmax_gemm_permute_wmma.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
int main(int argc, char* argv[])
{
bool is_supported = ck::is_gfx11_supported();
if(!is_supported)
{
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
<< std::endl;
return 0;
}
return run(argc, argv);
}
......@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/host_utility/device_prop.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......@@ -285,4 +286,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
#include "run_batched_gemm_scale_softmax_gemm_permute_wmma.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
int main(int argc, char* argv[])
{
bool is_supported = ck::is_gfx11_supported();
if(!is_supported)
{
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
<< std::endl;
return 0;
}
return run(argc, argv);
}
......@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/host_utility/device_prop.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......@@ -71,7 +72,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
#define CK_MHA_USE_WAVE_1
#define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_4
#define CK_MHA_USE_WAVE_8
//#define CK_MHA_USE_WAVE_8
using DeviceMHAFactory =
std::tuple<
#ifdef CK_MHA_USE_WAVE_1
......@@ -277,10 +278,10 @@ using DeviceMHAFactory =
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8,
MaskingSpec>,
MaskingSpec>
#endif
#ifdef CK_MHA_USE_WAVE_8
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
......@@ -351,4 +352,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
#include "run_cross_attention_wmma.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
int main(int argc, char* argv[])
{
bool is_supported = ck::is_gfx11_supported();
if(!is_supported)
{
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
<< std::endl;
return 0;
}
return run(argc, argv);
}
......@@ -28,6 +28,7 @@ Example is GQA-4
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/host_utility/device_prop.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......@@ -299,4 +300,14 @@ using ReferenceGemm1Instance =
#include "run_grouped_query_attention_forward_wmma.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
int main(int argc, char* argv[])
{
bool is_supported = ck::is_gfx11_supported();
if(!is_supported)
{
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
<< std::endl;
return 0;
}
return run(argc, argv);
}
......@@ -26,6 +26,7 @@ Shazeer, Noam. “Fast Transformer Decoding: One Write-Head Is All You Need.”
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/host_utility/device_prop.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......@@ -284,4 +285,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm_
#include "run_multi_query_attention_forward_wmma.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
int main(int argc, char* argv[])
{
bool is_supported = ck::is_gfx11_supported();
if(!is_supported)
{
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
<< std::endl;
return 0;
}
return run(argc, argv);
}
......@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/host_utility/device_prop.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......@@ -71,7 +72,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
#define CK_MHA_USE_WAVE_1
#define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_4
#define CK_MHA_USE_WAVE_8
//#define CK_MHA_USE_WAVE_8
using DeviceMHAFactory =
std::tuple<
#ifdef CK_MHA_USE_WAVE_1
......@@ -277,10 +278,10 @@ using DeviceMHAFactory =
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8,
MaskingSpec>,
MaskingSpec>
#endif
#ifdef CK_MHA_USE_WAVE_8
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
......@@ -329,4 +330,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
#include "run_self_attention_wmma.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
int main(int argc, char* argv[])
{
bool is_supported = ck::is_gfx11_supported();
if(!is_supported)
{
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
<< std::endl;
return 0;
}
return run(argc, argv);
}
......@@ -3,6 +3,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp"
#include "common.hpp"
#include "ck/host_utility/device_prop.hpp"
using OutDataType = FP16;
using WeiDataType = FP16;
......@@ -31,4 +32,14 @@ using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDat
#include "run_grouped_conv_bwd_data_example.inc"
int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_example(argc, argv); }
int main(int argc, char* argv[])
{
bool is_supported = ck::is_gfx11_supported();
if(!is_supported)
{
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
<< std::endl;
return 0;
}
return run_grouped_conv_bwd_data_example(argc, argv);
}
......@@ -67,7 +67,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
endforeach()
#Do not build any WMMA examples if gfx11 targets are not on the list
foreach(source IN LISTS FILE_NAME)
if(NOT EX_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
message("removing wmma example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
......@@ -154,7 +154,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
endforeach()
#Do not build any WMMA examples if gfx11 targets are not on the list
foreach(source IN LISTS FILE_NAME)
if(NOT EX_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
message("removing wmma example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
......@@ -181,7 +181,7 @@ endfunction(add_example_executable_no_testing EXAMPLE_NAME)
# add all example subdir
file(GLOB dir_list LIST_DIRECTORIES true *)
FOREACH(subdir ${dir_list})
IF(IS_DIRECTORY "${subdir}")
if(IS_DIRECTORY "${subdir}" AND EXISTS "${subdir}/CMakeLists.txt")
add_subdirectory(${subdir})
ENDIF()
ENDFOREACH()
......@@ -271,7 +271,9 @@ class FmhaBwdApiPool:
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes)
# GEMM0: Q@K=S^T
......
......@@ -278,6 +278,9 @@ class FmhaFwdApiPool:
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
@dataclass
......
......@@ -331,6 +331,9 @@ class FmhaFwdSplitKVApiPool:
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format(F_dispatch = per_dtypes)
@dataclass
......
......@@ -69,6 +69,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#define __gfx11__
#endif
#if defined(__gfx1200__) || defined(__gfx1201__)
#define __gfx12__
#endif
// buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code
......@@ -77,7 +80,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx103__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx11__)
#elif defined(__gfx11__) || defined(__gfx12__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
......@@ -89,7 +92,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8
#elif defined(__gfx11__)
#elif defined(__gfx11__) || defined(__gfx12__)
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8_GFX11
......@@ -110,13 +113,6 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#define CK_USE_AMD_MFMA_GFX940
#endif
// WMMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_WMMA
#elif defined(__gfx11__) // for GPU code
#define CK_USE_AMD_WMMA
#endif
// buffer load
#define CK_USE_AMD_BUFFER_LOAD 1
......
......@@ -84,4 +84,9 @@ inline bool is_gfx11_supported()
ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103";
}
inline bool is_gfx12_supported()
{
return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201";
}
} // namespace ck
......@@ -13,6 +13,504 @@
namespace ck {
#ifdef __gfx12__
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatAcc,
typename ABlockDesc,
typename BBlockDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWMMA,
index_t NPerWMMA,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool AEnableLds = true,
bool BEnableLds = true,
bool TransposeC = false>
/* Option: Read from LDS, big buffer hold all threads required data
* Source
* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1
* Destination
* C, non-transpose
* thread level: MRepeat x NRepeat x MAccVgprs
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
* KPACK == WMMA_K = 16
*
* Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS)
* Source:
* A(if skip LDS): MRepeat x KPack
* B(if skip LDS): NRepeat x KPack
* Destination
* C, non-transpose
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
*/
struct BlockwiseGemmWMMA
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto WmmaK = Number<16>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
static constexpr index_t WaveSize = 32;
// When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// permutation
static constexpr index_t A_KRow = 2;
static constexpr index_t B_KRow = 2;
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
static constexpr auto wmma_gemm =
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,
wmma_gemm.GetRegSizePerWmma(),
true>
c_thread_buf_;
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
__device__ static auto GetWaveIdx()
{
const index_t thread_id = ThisThreadBlock::GetThreadId();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
// Default, Block buffer in LDS, thread level offset enabled
__device__ static auto CalculateAThreadOriginDataIndex()
{
if constexpr(AEnableLds)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack
return make_tuple(0, 0, waveId_m, wmma_gemm.GetSubGroupId(), WMMA_a_idx, 0);
}
else
{
return make_tuple(0, 0, 0, 0, 0, 0);
}
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
if constexpr(BEnableLds)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
return make_tuple(0, 0, waveId_n, wmma_gemm.GetSubGroupId(), WMMA_b_idx, 0);
}
else
{
return make_tuple(0, 0, 0, 0, 0, 0);
}
}
template <index_t m0, index_t n0>
__device__ static auto CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk();
constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
template <index_t m0, index_t n0>
__device__ static auto CalculateCThreadOriginDataIndex7D(Number<m0>, Number<n0>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D();
return make_tuple(
Number<m0>{}, waveId_m, blk_idx[I0], Number<n0>{}, waveId_n, blk_idx[I1], blk_idx[I2]);
}
using Tuple6 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(),
Tuple6 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!");
}
// transposed WMMA output C' = B' * A'
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
{
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, NAccVgprs));
}
// Thread level, register decriptor. Vector-write
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3];
return make_naive_tensor_descriptor(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, MAccVgprs),
make_tuple(Number<NRepeat>{} * MAccVgprs * AccStride,
Number<NRepeat>{} * MAccVgprs * AccStride,
Number<NRepeat>{} * MAccVgprs * AccStride,
MAccVgprs * AccStride,
MAccVgprs * AccStride,
MAccVgprs * AccStride,
AccStride));
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(
make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
}
// transposed WMMA output C' = B' * A'
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
{
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<MWaves>{},
Number<MPerWMMA>{},
Number<NRepeat>{},
Number<NWaves>{},
Number<NPerWMMA>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
}
// Provide dimension size
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<MWaves>{},
Number<MPerWMMA>{},
Number<NRepeat>{},
Number<NWaves>{},
Number<NPerWMMA>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
}
// Describe how data allocated in thread copy src buffer
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1;
static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1;
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
b_thread_desc_.GetElementSpaceSize());
static_assert(KPack % (A_K1 * A_KRow) == 0, "");
static_assert(KPack % (B_K1 * B_KRow) == 0, "");
// basic intrinsic to determine loopover direction
if constexpr(MRepeat < NRepeat)
{
static_for<0, KPerBlock / KPack, 1>{}(
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf);
vector_type<FloatA, KPack / A_KRow> a_thread_vec;
vector_type<FloatB, KPack / B_KRow> b_thread_vec;
static_for<0, KPack / A_KRow, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1, m0, 0, 0, 0, i % A_K1))>{}];
});
static_for<0, KPack / B_KRow, 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, n0, 0, 0, 0, i % B_K1))>{}];
});
using wmma_input_type_a =
typename vector_type<FloatA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<FloatB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
else
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of
// k=0,kpack*1, ..
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf);
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf);
vector_type<FloatA, KPack / A_KRow> a_thread_vec;
vector_type<FloatB, KPack / B_KRow> b_thread_vec;
static_for<0, KPack / A_KRow, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1, m0, 0, 0, 0, i % A_K1))>{}];
});
static_for<0, KPack / B_KRow, 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, n0, 0, 0, 0, i % B_K1))>{}];
});
using wmma_input_type_a =
typename vector_type<FloatA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<FloatB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
}
protected:
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<KPack / A_K1 / A_KRow>{}, Number<MRepeat>{}, I1, I1, I1, Number<A_K1>{}),
make_tuple(Number<A_K1>{},
Number<KPack / A_KRow>{},
Number<A_K1>{},
Number<A_K1>{},
Number<A_K1>{},
Number<1>{}));
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<KPack / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, I1, I1, Number<B_K1>{}),
make_tuple(Number<B_K1>{},
Number<KPack / B_KRow>{},
Number<B_K1>{},
Number<B_K1>{},
Number<B_K1>{},
Number<1>{}));
// C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
template <bool EnableLds>
struct AThreadCopySelector;
template <>
struct AThreadCopySelector<true>
{
using type =
ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
A_K1>;
};
template <>
struct AThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow<
FloatA,
FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
tensor_operation::element_wise::PassThrough,
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
false>;
};
template <bool EnableLds>
struct BThreadCopySelector;
template <>
struct BThreadCopySelector<true>
{
using type =
ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,
B_K1>;
};
template <>
struct BThreadCopySelector<false>
{
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow<
FloatB,
FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
tensor_operation::element_wise::PassThrough,
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,
false>;
};
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
};
#else
template <index_t BlockSize,
typename FloatA,
typename FloatB,
......@@ -527,5 +1025,6 @@ struct BlockwiseGemmWMMA
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
};
#endif
} // namespace ck
......@@ -487,7 +487,14 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
// sync point.
if constexpr(k.value != 0 || KPerInnerLoop == KPerThread)
{
#ifdef __gfx12__
asm volatile("\
s_barrier_signal -1 \n \
s_barrier_wait -1 \
" ::);
#else
asm volatile("s_barrier" ::);
#endif
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemm_Streamk_V2 : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
ck::index_t Streamk_sel,
ck::index_t Grid_size,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment