Commit 5d5891b0 authored by aska-0096's avatar aska-0096
Browse files

clang format

parent cfb397b1
...@@ -30,7 +30,6 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -30,7 +30,6 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>; ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
......
...@@ -52,7 +52,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -52,7 +52,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto wmma_gemm = WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack>{}; static constexpr auto wmma_gemm =
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack>{};
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
...@@ -141,7 +142,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -141,7 +142,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
NPerBlock % (NPerWMMA * NRepeat) == 0, NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!"); "wrong!");
} }
// Provide dimension size // Provide dimension size
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
...@@ -279,7 +280,6 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -279,7 +280,6 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
}; };
// block wise level pipe designed for inline asm // block wise level pipe designed for inline asm
template <index_t BlockSize, template <index_t BlockSize,
typename FloatA, typename FloatA,
...@@ -321,7 +321,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO ...@@ -321,7 +321,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto wmma_gemm = WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack>{}; static constexpr auto wmma_gemm =
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack>{};
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
...@@ -512,7 +513,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO ...@@ -512,7 +513,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
constexpr auto RepeatDiff = MRepeat - NRepeat; constexpr auto RepeatDiff = MRepeat - NRepeat;
// Read all Mrepeat, Nrepeat // Read all Mrepeat, Nrepeat
static_for<0, NRepeat, 1>{}([&](auto iN){ static_for<0, NRepeat, 1>{}([&](auto iN) {
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, Number<iN>{}, I0, I0, I0), make_tuple(I0, Number<iN>{}, I0, I0, I0),
b_block_buf, b_block_buf,
...@@ -521,7 +522,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO ...@@ -521,7 +522,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
b_thread_buf); b_thread_buf);
}); });
static_for<0, MRepeat, 1>{}([&](auto iM){ static_for<0, MRepeat, 1>{}([&](auto iM) {
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, Number<iM>{}, I0, I0, I0), make_tuple(I0, Number<iM>{}, I0, I0, I0),
a_block_buf, a_block_buf,
...@@ -531,35 +532,36 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO ...@@ -531,35 +532,36 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
}); });
// Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat // Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat
static_for<0, RepeatDiff, 1>{}([&](auto iCut){ static_for<0, RepeatDiff, 1>{}([&](auto iCut) {
static_for<0, NRepeat, 1>{}([&](auto iN){ static_for<0, NRepeat, 1>{}([&](auto iN) {
vector_type<FloatA, WmmaK> a_thread_vec; vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec; vector_type<FloatB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto iK) { static_for<0, WmmaK, 1>{}([&](auto iK) {
a_thread_vec.template AsType<FloatA>()(iK) = a_thread_vec.template AsType<FloatA>()(iK) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(iK/A_K1, iCut, 0, 0, iK%A_K1))>{}]; make_tuple(iK / A_K1, iCut, 0, 0, iK % A_K1))>{}];
b_thread_vec.template AsType<FloatB>()(iK) = b_thread_vec.template AsType<FloatB>()(iK) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset( b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(iK/B_K1, iN, 0, 0, iK%B_K1))>{}]; make_tuple(iK / B_K1, iN, 0, 0, iK % B_K1))>{}];
}); });
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type; using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type; using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0));
s_nop(); s_nop();
wmma_gemm.template Run( wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}), a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}), b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
s_nop(); s_nop();
}); });
if constexpr( KPerBlock > WmmaK ){ if constexpr(KPerBlock > WmmaK)
{
// Read Consumed Next inner loop A // Read Consumed Next inner loop A
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<WmmaK/A_K1>{}, Number<iCut>{}, I0, I0, I0), make_tuple(Number<WmmaK / A_K1>{}, Number<iCut>{}, I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, Number<iCut>{}, I0, I0, I0), make_tuple(I0, Number<iCut>{}, I0, I0, I0),
...@@ -567,55 +569,57 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO ...@@ -567,55 +569,57 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
} }
}); });
static_for<WmmaK, KPerBlock, WmmaK>{}([&](auto iWmmaK){ static_for<WmmaK, KPerBlock, WmmaK>{}([&](auto iWmmaK) {
// Stage 2: Run FIFO fashion loopover in Square // Stage 2: Run FIFO fashion loopover in Square
static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop){ static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop) {
// Row Repeatation // Row Repeatation
static_for<WmmaInnerloop, NRepeat, 1>{}([&](auto iN){ static_for<WmmaInnerloop, NRepeat, 1>{}([&](auto iN) {
vector_type<FloatA, WmmaK> a_thread_vec; vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec; vector_type<FloatB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto iK) { static_for<0, WmmaK, 1>{}([&](auto iK) {
a_thread_vec.template AsType<FloatA>()(iK) = a_thread_vec.template AsType<FloatA>()(iK) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
make_tuple(iK/A_K1, WmmaInnerloop+RepeatDiff, 0, 0, iK%A_K1))>{}]; iK / A_K1, WmmaInnerloop + RepeatDiff, 0, 0, iK % A_K1))>{}];
b_thread_vec.template AsType<FloatB>()(iK) = b_thread_vec.template AsType<FloatB>()(iK) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset( b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(iK/B_K1, iN, 0, 0, iK%B_K1))>{}]; make_tuple(iK / B_K1, iN, 0, 0, iK % B_K1))>{}];
}); });
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type; using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type; using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
constexpr index_t c_offset = constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop+RepeatDiff, iN, 0)); make_tuple(WmmaInnerloop + RepeatDiff, iN, 0));
s_nop(); s_nop();
wmma_gemm.template Run( wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}), a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}), b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
s_nop(); s_nop();
}); });
// Read Consumed Next inner loop A // Read Consumed Next inner loop A
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, a_thread_copy_.Run(
make_tuple(Number<iWmmaK/A_K1>{}, Number<WmmaInnerloop+RepeatDiff>{}, I0, I0, I0), a_block_desc_k0_m0_m1_m2_k1,
a_block_buf, make_tuple(
a_thread_desc_, Number<iWmmaK / A_K1>{}, Number<WmmaInnerloop + RepeatDiff>{}, I0, I0, I0),
make_tuple(I0, Number<WmmaInnerloop+RepeatDiff>{}, I0, I0, I0), a_block_buf,
a_thread_buf); a_thread_desc_,
make_tuple(I0, Number<WmmaInnerloop + RepeatDiff>{}, I0, I0, I0),
a_thread_buf);
// Col Repeatation // Col Repeatation
static_for<WmmaInnerloop+1+RepeatDiff, MRepeat, 1>{}([&](auto iM){ static_for<WmmaInnerloop + 1 + RepeatDiff, MRepeat, 1>{}([&](auto iM) {
vector_type<FloatA, WmmaK> a_thread_vec; vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec; vector_type<FloatB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto iK) { static_for<0, WmmaK, 1>{}([&](auto iK) {
a_thread_vec.template AsType<FloatA>()(iK) = a_thread_vec.template AsType<FloatA>()(iK) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(iK/A_K1, iM, 0, 0, iK%A_K1))>{}]; make_tuple(iK / A_K1, iM, 0, 0, iK % A_K1))>{}];
b_thread_vec.template AsType<FloatB>()(iK) = b_thread_vec.template AsType<FloatB>()(iK) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset( b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(iK/B_K1, WmmaInnerloop, 0, 0, iK%B_K1))>{}]; make_tuple(iK / B_K1, WmmaInnerloop, 0, 0, iK % B_K1))>{}];
}); });
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type; using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type; using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
...@@ -624,96 +628,100 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO ...@@ -624,96 +628,100 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0)); c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0));
s_nop(); s_nop();
wmma_gemm.template Run( wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}), a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}), b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
s_nop(); s_nop();
}); });
// Read Consumed Next inner loop B // Read Consumed Next inner loop B
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, b_thread_copy_.Run(
make_tuple(Number<iWmmaK/B_K1>{}, Number<WmmaInnerloop>{}, I0, I0, I0), b_block_desc_k0_n0_n1_n2_k1,
b_block_buf, make_tuple(Number<iWmmaK / B_K1>{}, Number<WmmaInnerloop>{}, I0, I0, I0),
b_thread_desc_, b_block_buf,
make_tuple(I0, Number<WmmaInnerloop>{}, I0, I0, I0), b_thread_desc_,
b_thread_buf); make_tuple(I0, Number<WmmaInnerloop>{}, I0, I0, I0),
b_thread_buf);
}); });
// Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat // Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat
static_for<0, RepeatDiff, 1>{}([&](auto iCut){ static_for<0, RepeatDiff, 1>{}([&](auto iCut) {
static_for<0, NRepeat, 1>{}([&](auto iN){ static_for<0, NRepeat, 1>{}([&](auto iN) {
vector_type<FloatA, WmmaK> a_thread_vec; vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec; vector_type<FloatB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto iK) { static_for<0, WmmaK, 1>{}([&](auto iK) {
a_thread_vec.template AsType<FloatA>()(iK) = a_thread_vec.template AsType<FloatA>()(iK) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(iK/A_K1, iCut, 0, 0, iK%A_K1))>{}]; make_tuple(iK / A_K1, iCut, 0, 0, iK % A_K1))>{}];
b_thread_vec.template AsType<FloatB>()(iK) = b_thread_vec.template AsType<FloatB>()(iK) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset( b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(iK/B_K1, iN, 0, 0, iK%B_K1))>{}]; make_tuple(iK / B_K1, iN, 0, 0, iK % B_K1))>{}];
}); });
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type; using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type; using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0));
s_nop(); s_nop();
wmma_gemm.template Run( wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}), a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}), b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
s_nop(); s_nop();
}); });
if constexpr( KPerBlock > WmmaK ){ if constexpr(KPerBlock > WmmaK)
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, {
make_tuple(Number<(iWmmaK+WmmaK)/A_K1>{}, Number<iCut>{}, I0, I0, I0), a_thread_copy_.Run(
a_block_buf, a_block_desc_k0_m0_m1_m2_k1,
a_thread_desc_, make_tuple(Number<(iWmmaK + WmmaK) / A_K1>{}, Number<iCut>{}, I0, I0, I0),
make_tuple(I0, Number<iCut>{}, I0, I0, I0), a_block_buf,
a_thread_buf); a_thread_desc_,
make_tuple(I0, Number<iCut>{}, I0, I0, I0),
a_thread_buf);
} }
}); });
}); });
// Stage 2: Run FIFO fashion loopover in Square // Stage 2: Run FIFO fashion loopover in Square
static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop){ static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop) {
// Row Repeatation // Row Repeatation
static_for<WmmaInnerloop, NRepeat, 1>{}([&](auto iN){ static_for<WmmaInnerloop, NRepeat, 1>{}([&](auto iN) {
vector_type<FloatA, WmmaK> a_thread_vec; vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec; vector_type<FloatB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto iK) { static_for<0, WmmaK, 1>{}([&](auto iK) {
a_thread_vec.template AsType<FloatA>()(iK) = a_thread_vec.template AsType<FloatA>()(iK) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(iK/A_K1, WmmaInnerloop+RepeatDiff, 0, 0, iK%A_K1))>{}]; make_tuple(iK / A_K1, WmmaInnerloop + RepeatDiff, 0, 0, iK % A_K1))>{}];
b_thread_vec.template AsType<FloatB>()(iK) = b_thread_vec.template AsType<FloatB>()(iK) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset( b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(iK/B_K1, iN, 0, 0, iK%B_K1))>{}]; make_tuple(iK / B_K1, iN, 0, 0, iK % B_K1))>{}];
}); });
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type; using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type; using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop+RepeatDiff, iN, 0)); c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop + RepeatDiff, iN, 0));
s_nop(); s_nop();
wmma_gemm.template Run( wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}), a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}), b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
s_nop(); s_nop();
}); });
// Col Repeatation // Col Repeatation
static_for<WmmaInnerloop+1+RepeatDiff, MRepeat, 1>{}([&](auto iM){ static_for<WmmaInnerloop + 1 + RepeatDiff, MRepeat, 1>{}([&](auto iM) {
vector_type<FloatA, WmmaK> a_thread_vec; vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec; vector_type<FloatB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto iK) { static_for<0, WmmaK, 1>{}([&](auto iK) {
a_thread_vec.template AsType<FloatA>()(iK) = a_thread_vec.template AsType<FloatA>()(iK) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(iK/A_K1, iM, 0, 0, iK%A_K1))>{}]; make_tuple(iK / A_K1, iM, 0, 0, iK % A_K1))>{}];
b_thread_vec.template AsType<FloatB>()(iK) = b_thread_vec.template AsType<FloatB>()(iK) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset( b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(iK/B_K1, WmmaInnerloop, 0, 0, iK%B_K1))>{}]; make_tuple(iK / B_K1, WmmaInnerloop, 0, 0, iK % B_K1))>{}];
}); });
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type; using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type; using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
...@@ -722,9 +730,9 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO ...@@ -722,9 +730,9 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0)); c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0));
s_nop(); s_nop();
wmma_gemm.template Run( wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}), a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}), b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
s_nop(); s_nop();
}); });
}); });
......
...@@ -196,7 +196,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -196,7 +196,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
// Gridwise descriptor, mapping to whole given provblem. // Gridwise descriptor, mapping to whole given provblem.
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_wmma< using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_wmma<
...@@ -276,8 +276,10 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -276,8 +276,10 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op}
{ {
a_grid_desc_k0_m_k1_ = DeviceGemmWmma_CShuffle::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); a_grid_desc_k0_m_k1_ =
b_grid_desc_k0_n_k1_ = DeviceGemmWmma_CShuffle::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); DeviceGemmWmma_CShuffle::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ =
DeviceGemmWmma_CShuffle::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmWmma_CShuffle::MakeCGridDescriptor_M_N(M, N, StrideC); c_grid_desc_m_n_ = DeviceGemmWmma_CShuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
block_2_ctile_map_ = block_2_ctile_map_ =
...@@ -289,7 +291,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -289,7 +291,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
block_2_ctile_map_)) block_2_ctile_map_))
{ {
c_grid_desc_mblock_mperblock_nblock_nperblock = c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_);
} }
} }
...@@ -359,13 +362,14 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -359,13 +362,14 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
CDataType, CDataType,
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, remove_reference_t<
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>; // Last Option is W/O true>; // Last Option is W/O
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(grid_size), dim3(grid_size),
...@@ -391,7 +395,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -391,7 +395,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
CDataType, CDataType,
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, remove_reference_t<
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
......
...@@ -218,7 +218,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -218,7 +218,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto b_block_space_size_aligned = math::integer_least_multiple( constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align); b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size_aligned * sizeof(FloatA) + b_block_space_size_aligned * sizeof(FloatB)); return (a_block_space_size_aligned * sizeof(FloatA) +
b_block_space_size_aligned * sizeof(FloatB));
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...@@ -305,19 +306,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -305,19 +306,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap> template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void __device__ static void Run(const FloatA* __restrict__ p_a_grid,
Run(const FloatA* __restrict__ p_a_grid, const FloatB* __restrict__ p_b_grid,
const FloatB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid,
FloatC* __restrict__ p_c_grid, void* __restrict__ p_shared,
void* __restrict__ p_shared, const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation& a_element_op,
const AElementwiseOperation& a_element_op, const BElementwiseOperation& b_element_op,
const BElementwiseOperation& b_element_op, const CElementwiseOperation& c_element_op,
const CElementwiseOperation& c_element_op, const Block2CTileMap& block_2_ctile_map)
const Block2CTileMap& block_2_ctile_map)
{ {
// clang-format off // clang-format off
/*******************************************************************************/ /*******************************************************************************/
......
...@@ -283,10 +283,18 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8, ...@@ -283,10 +283,18 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
} }
}; };
template <typename src_type_a, typename src_type_b, typename dst_type, index_t MPerWmma, index_t NPerWmma> template <typename src_type_a,
typename src_type_b,
typename dst_type,
index_t MPerWmma,
index_t NPerWmma>
struct WmmaSelector struct WmmaSelector
{ {
template <typename src_type_a_, typename src_type_b_, typename dst_type_, index_t MPerWmma_, index_t NPerWmma_> template <typename src_type_a_,
typename src_type_b_,
typename dst_type_,
index_t MPerWmma_,
index_t NPerWmma_>
static constexpr auto GetWmma(); static constexpr auto GetWmma();
template <> template <>
...@@ -424,13 +432,19 @@ struct WmmaGemm ...@@ -424,13 +432,19 @@ struct WmmaGemm
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{ {
static_assert( static_assert(
(is_same<src_type_a, half_t>::value && is_same<src_type_b, half_t>::value && is_same<dst_type, float>::value) || (is_same<src_type_a, half_t>::value && is_same<src_type_b, half_t>::value &&
(is_same<src_type_a, bhalf_t>::value && is_same<src_type_b, bhalf_t>::value && is_same<dst_type, float>::value) || is_same<dst_type, float>::value) ||
(is_same<src_type_a, half_t>::value && is_same<src_type_b, half_t>::value && is_same<dst_type, half_t>::value) || (is_same<src_type_a, bhalf_t>::value && is_same<src_type_b, bhalf_t>::value &&
(is_same<src_type_a, bhalf_t>::value && is_same<src_type_b, bhalf_t>::value && is_same<dst_type, bhalf_t>::value) || is_same<dst_type, float>::value) ||
(is_same<src_type_a, int8_t>::value && is_same<src_type_b, int8_t>::value && is_same<dst_type, int32_t>::value) (is_same<src_type_a, half_t>::value && is_same<src_type_b, half_t>::value &&
is_same<dst_type, half_t>::value) ||
(is_same<src_type_a, bhalf_t>::value && is_same<src_type_b, bhalf_t>::value &&
is_same<dst_type, bhalf_t>::value) ||
(is_same<src_type_a, int8_t>::value && is_same<src_type_b, int8_t>::value &&
is_same<dst_type, int32_t>::value)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| (is_same<src_type_a, int4_t>::value && is_same<src_type_b, int4_t>::value && is_same<dst_type, int32_t>::value) || (is_same<src_type_a, int4_t>::value && is_same<src_type_b, int4_t>::value &&
is_same<dst_type, int32_t>::value)
#endif #endif
, ,
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), " "base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
...@@ -479,7 +493,8 @@ struct WmmaGemm ...@@ -479,7 +493,8 @@ struct WmmaGemm
return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset}; return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
} }
static constexpr auto wmma = WmmaSelector<src_type_a, src_type_b, dst_type, MPerWmma, NPerWmma>{}; static constexpr auto wmma =
WmmaSelector<src_type_a, src_type_b, dst_type, MPerWmma, NPerWmma>{};
static constexpr auto wmma_instr = wmma.selected_wmma; static constexpr auto wmma_instr = wmma.selected_wmma;
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
......
...@@ -356,13 +356,9 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a, ...@@ -356,13 +356,9 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
} }
// Ranged input operand // Ranged input operand
__device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, __device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, float8_t& c)
half16_t b,
float8_t& c)
{ {
asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" : "=v"(c) : "v"(a), "v"(b), "0"(c));
: "=v"(c)
: "v"(a), "v"(b), "0"(c));
} }
} // namespace ck } // namespace ck
......
...@@ -21,10 +21,13 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16> ...@@ -21,10 +21,13 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{ {
// * Inline assembly need to elimate the duplicated data load, compiler won't help you delete them. // * Inline assembly need to elimate the duplicated data load, compiler won't help you
amd_assembly_wmma_f32_16x16x16_f16_w32(reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{})); // delete them.
// reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( amd_assembly_wmma_f32_16x16x16_f16_w32(
// reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]); reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
// reg_c.template AsType<float8_t>()(Number<0>{}) =
// __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( reg_a, reg_b, reg_c.template
// AsType<float8_t>()[Number<0>{}]);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment