Commit 5d5891b0 authored by aska-0096's avatar aska-0096
Browse files

clang format

parent cfb397b1
......@@ -30,7 +30,6 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
......
......@@ -52,7 +52,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto wmma_gemm = WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack>{};
static constexpr auto wmma_gemm =
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack>{};
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
......@@ -141,7 +142,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!");
}
// Provide dimension size
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
......@@ -279,7 +280,6 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
};
// block wise level pipe designed for inline asm
template <index_t BlockSize,
typename FloatA,
......@@ -321,7 +321,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto wmma_gemm = WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack>{};
static constexpr auto wmma_gemm =
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack>{};
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
......@@ -512,7 +513,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
constexpr auto RepeatDiff = MRepeat - NRepeat;
// Read all Mrepeat, Nrepeat
static_for<0, NRepeat, 1>{}([&](auto iN){
static_for<0, NRepeat, 1>{}([&](auto iN) {
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, Number<iN>{}, I0, I0, I0),
b_block_buf,
......@@ -521,7 +522,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
b_thread_buf);
});
static_for<0, MRepeat, 1>{}([&](auto iM){
static_for<0, MRepeat, 1>{}([&](auto iM) {
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, Number<iM>{}, I0, I0, I0),
a_block_buf,
......@@ -531,35 +532,36 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
});
// Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat
static_for<0, RepeatDiff, 1>{}([&](auto iCut){
static_for<0, NRepeat, 1>{}([&](auto iN){
static_for<0, RepeatDiff, 1>{}([&](auto iCut) {
static_for<0, NRepeat, 1>{}([&](auto iN) {
vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto iK) {
a_thread_vec.template AsType<FloatA>()(iK) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(iK/A_K1, iCut, 0, 0, iK%A_K1))>{}];
make_tuple(iK / A_K1, iCut, 0, 0, iK % A_K1))>{}];
b_thread_vec.template AsType<FloatB>()(iK) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(iK/B_K1, iN, 0, 0, iK%B_K1))>{}];
make_tuple(iK / B_K1, iN, 0, 0, iK % B_K1))>{}];
});
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0));
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0));
s_nop();
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
s_nop();
});
if constexpr( KPerBlock > WmmaK ){
if constexpr(KPerBlock > WmmaK)
{
// Read Consumed Next inner loop A
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<WmmaK/A_K1>{}, Number<iCut>{}, I0, I0, I0),
make_tuple(Number<WmmaK / A_K1>{}, Number<iCut>{}, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, Number<iCut>{}, I0, I0, I0),
......@@ -567,55 +569,57 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
}
});
static_for<WmmaK, KPerBlock, WmmaK>{}([&](auto iWmmaK){
static_for<WmmaK, KPerBlock, WmmaK>{}([&](auto iWmmaK) {
// Stage 2: Run FIFO fashion loopover in Square
static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop){
static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop) {
// Row Repeatation
static_for<WmmaInnerloop, NRepeat, 1>{}([&](auto iN){
static_for<WmmaInnerloop, NRepeat, 1>{}([&](auto iN) {
vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto iK) {
a_thread_vec.template AsType<FloatA>()(iK) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(iK/A_K1, WmmaInnerloop+RepeatDiff, 0, 0, iK%A_K1))>{}];
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
iK / A_K1, WmmaInnerloop + RepeatDiff, 0, 0, iK % A_K1))>{}];
b_thread_vec.template AsType<FloatB>()(iK) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(iK/B_K1, iN, 0, 0, iK%B_K1))>{}];
make_tuple(iK / B_K1, iN, 0, 0, iK % B_K1))>{}];
});
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop+RepeatDiff, iN, 0));
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(WmmaInnerloop + RepeatDiff, iN, 0));
s_nop();
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
s_nop();
});
// Read Consumed Next inner loop A
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<iWmmaK/A_K1>{}, Number<WmmaInnerloop+RepeatDiff>{}, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, Number<WmmaInnerloop+RepeatDiff>{}, I0, I0, I0),
a_thread_buf);
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(
Number<iWmmaK / A_K1>{}, Number<WmmaInnerloop + RepeatDiff>{}, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, Number<WmmaInnerloop + RepeatDiff>{}, I0, I0, I0),
a_thread_buf);
// Col Repeatation
static_for<WmmaInnerloop+1+RepeatDiff, MRepeat, 1>{}([&](auto iM){
static_for<WmmaInnerloop + 1 + RepeatDiff, MRepeat, 1>{}([&](auto iM) {
vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto iK) {
a_thread_vec.template AsType<FloatA>()(iK) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(iK/A_K1, iM, 0, 0, iK%A_K1))>{}];
make_tuple(iK / A_K1, iM, 0, 0, iK % A_K1))>{}];
b_thread_vec.template AsType<FloatB>()(iK) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(iK/B_K1, WmmaInnerloop, 0, 0, iK%B_K1))>{}];
make_tuple(iK / B_K1, WmmaInnerloop, 0, 0, iK % B_K1))>{}];
});
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
......@@ -624,96 +628,100 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0));
s_nop();
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
s_nop();
});
// Read Consumed Next inner loop B
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<iWmmaK/B_K1>{}, Number<WmmaInnerloop>{}, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, Number<WmmaInnerloop>{}, I0, I0, I0),
b_thread_buf);
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<iWmmaK / B_K1>{}, Number<WmmaInnerloop>{}, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, Number<WmmaInnerloop>{}, I0, I0, I0),
b_thread_buf);
});
// Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat
static_for<0, RepeatDiff, 1>{}([&](auto iCut){
static_for<0, NRepeat, 1>{}([&](auto iN){
static_for<0, RepeatDiff, 1>{}([&](auto iCut) {
static_for<0, NRepeat, 1>{}([&](auto iN) {
vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto iK) {
a_thread_vec.template AsType<FloatA>()(iK) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(iK/A_K1, iCut, 0, 0, iK%A_K1))>{}];
make_tuple(iK / A_K1, iCut, 0, 0, iK % A_K1))>{}];
b_thread_vec.template AsType<FloatB>()(iK) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(iK/B_K1, iN, 0, 0, iK%B_K1))>{}];
make_tuple(iK / B_K1, iN, 0, 0, iK % B_K1))>{}];
});
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0));
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0));
s_nop();
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
s_nop();
});
if constexpr( KPerBlock > WmmaK ){
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<(iWmmaK+WmmaK)/A_K1>{}, Number<iCut>{}, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, Number<iCut>{}, I0, I0, I0),
a_thread_buf);
if constexpr(KPerBlock > WmmaK)
{
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<(iWmmaK + WmmaK) / A_K1>{}, Number<iCut>{}, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, Number<iCut>{}, I0, I0, I0),
a_thread_buf);
}
});
});
// Stage 2: Run FIFO fashion loopover in Square
static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop){
static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop) {
// Row Repeatation
static_for<WmmaInnerloop, NRepeat, 1>{}([&](auto iN){
static_for<WmmaInnerloop, NRepeat, 1>{}([&](auto iN) {
vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto iK) {
a_thread_vec.template AsType<FloatA>()(iK) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(iK/A_K1, WmmaInnerloop+RepeatDiff, 0, 0, iK%A_K1))>{}];
make_tuple(iK / A_K1, WmmaInnerloop + RepeatDiff, 0, 0, iK % A_K1))>{}];
b_thread_vec.template AsType<FloatB>()(iK) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(iK/B_K1, iN, 0, 0, iK%B_K1))>{}];
make_tuple(iK / B_K1, iN, 0, 0, iK % B_K1))>{}];
});
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop+RepeatDiff, iN, 0));
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop + RepeatDiff, iN, 0));
s_nop();
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
s_nop();
});
// Col Repeatation
static_for<WmmaInnerloop+1+RepeatDiff, MRepeat, 1>{}([&](auto iM){
static_for<WmmaInnerloop + 1 + RepeatDiff, MRepeat, 1>{}([&](auto iM) {
vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto iK) {
a_thread_vec.template AsType<FloatA>()(iK) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(iK/A_K1, iM, 0, 0, iK%A_K1))>{}];
make_tuple(iK / A_K1, iM, 0, 0, iK % A_K1))>{}];
b_thread_vec.template AsType<FloatB>()(iK) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(iK/B_K1, WmmaInnerloop, 0, 0, iK%B_K1))>{}];
make_tuple(iK / B_K1, WmmaInnerloop, 0, 0, iK % B_K1))>{}];
});
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
......@@ -722,9 +730,9 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0));
s_nop();
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
s_nop();
});
});
......
......@@ -196,7 +196,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
// Gridwise descriptor, mapping to whole given provblem.
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_wmma<
......@@ -276,8 +276,10 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
a_grid_desc_k0_m_k1_ = DeviceGemmWmma_CShuffle::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceGemmWmma_CShuffle::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
a_grid_desc_k0_m_k1_ =
DeviceGemmWmma_CShuffle::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ =
DeviceGemmWmma_CShuffle::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmWmma_CShuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
block_2_ctile_map_ =
......@@ -289,7 +291,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_);
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_);
}
}
......@@ -359,13 +362,14 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
CDataType,
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
remove_reference_t<
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>; // Last Option is W/O
true>; // Last Option is W/O
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
......@@ -391,7 +395,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
CDataType,
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
remove_reference_t<
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
......
......@@ -218,7 +218,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size_aligned * sizeof(FloatA) + b_block_space_size_aligned * sizeof(FloatB));
return (a_block_space_size_aligned * sizeof(FloatA) +
b_block_space_size_aligned * sizeof(FloatB));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
......@@ -305,19 +306,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void
Run(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map)
__device__ static void Run(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map)
{
// clang-format off
/*******************************************************************************/
......
......@@ -283,10 +283,18 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
}
};
template <typename src_type_a, typename src_type_b, typename dst_type, index_t MPerWmma, index_t NPerWmma>
template <typename src_type_a,
typename src_type_b,
typename dst_type,
index_t MPerWmma,
index_t NPerWmma>
struct WmmaSelector
{
template <typename src_type_a_, typename src_type_b_, typename dst_type_, index_t MPerWmma_, index_t NPerWmma_>
template <typename src_type_a_,
typename src_type_b_,
typename dst_type_,
index_t MPerWmma_,
index_t NPerWmma_>
static constexpr auto GetWmma();
template <>
......@@ -424,13 +432,19 @@ struct WmmaGemm
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{
static_assert(
(is_same<src_type_a, half_t>::value && is_same<src_type_b, half_t>::value && is_same<dst_type, float>::value) ||
(is_same<src_type_a, bhalf_t>::value && is_same<src_type_b, bhalf_t>::value && is_same<dst_type, float>::value) ||
(is_same<src_type_a, half_t>::value && is_same<src_type_b, half_t>::value && is_same<dst_type, half_t>::value) ||
(is_same<src_type_a, bhalf_t>::value && is_same<src_type_b, bhalf_t>::value && is_same<dst_type, bhalf_t>::value) ||
(is_same<src_type_a, int8_t>::value && is_same<src_type_b, int8_t>::value && is_same<dst_type, int32_t>::value)
(is_same<src_type_a, half_t>::value && is_same<src_type_b, half_t>::value &&
is_same<dst_type, float>::value) ||
(is_same<src_type_a, bhalf_t>::value && is_same<src_type_b, bhalf_t>::value &&
is_same<dst_type, float>::value) ||
(is_same<src_type_a, half_t>::value && is_same<src_type_b, half_t>::value &&
is_same<dst_type, half_t>::value) ||
(is_same<src_type_a, bhalf_t>::value && is_same<src_type_b, bhalf_t>::value &&
is_same<dst_type, bhalf_t>::value) ||
(is_same<src_type_a, int8_t>::value && is_same<src_type_b, int8_t>::value &&
is_same<dst_type, int32_t>::value)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| (is_same<src_type_a, int4_t>::value && is_same<src_type_b, int4_t>::value && is_same<dst_type, int32_t>::value)
|| (is_same<src_type_a, int4_t>::value && is_same<src_type_b, int4_t>::value &&
is_same<dst_type, int32_t>::value)
#endif
,
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
......@@ -479,7 +493,8 @@ struct WmmaGemm
return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
}
static constexpr auto wmma = WmmaSelector<src_type_a, src_type_b, dst_type, MPerWmma, NPerWmma>{};
static constexpr auto wmma =
WmmaSelector<src_type_a, src_type_b, dst_type, MPerWmma, NPerWmma>{};
static constexpr auto wmma_instr = wmma.selected_wmma;
__host__ __device__ static constexpr auto
......
......@@ -356,13 +356,9 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
}
// Ranged input operand
__device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a,
half16_t b,
float8_t& c)
__device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, float8_t& c)
{
asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0"
: "=v"(c)
: "v"(a), "v"(b), "0"(c));
asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" : "=v"(c) : "v"(a), "v"(b), "0"(c));
}
} // namespace ck
......
......@@ -21,10 +21,13 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
// * Inline assembly need to elimate the duplicated data load, compiler won't help you delete them.
amd_assembly_wmma_f32_16x16x16_f16_w32(reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
// reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
// reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
// * Inline assembly need to elimate the duplicated data load, compiler won't help you
// delete them.
amd_assembly_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
// reg_c.template AsType<float8_t>()(Number<0>{}) =
// __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( reg_a, reg_b, reg_c.template
// AsType<float8_t>()[Number<0>{}]);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment