Commit e43df26a authored by aska-0096's avatar aska-0096
Browse files

temp save, reproduce the v_bfi_b32 issue

parent 9739ede0
...@@ -280,24 +280,24 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -280,24 +280,24 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
{ {
// auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
// a_thread_desc_.GetElementSpaceSize()); a_thread_desc_.GetElementSpaceSize());
// auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
// b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, // StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAB, // FloatAB,
MRepeat, // MRepeat,
WmmaK, // WmmaK,
true> // true>
a_thread_buf; // a_thread_buf;
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, // StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAB, // FloatAB,
NRepeat, // NRepeat,
WmmaK, // WmmaK,
true> // true>
b_thread_buf; // b_thread_buf;
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
...@@ -306,8 +306,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -306,8 +306,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
make_tuple(Number<k * WmmaK / A_K1>{}, m0, I0, I0, I0), make_tuple(Number<k * WmmaK / A_K1>{}, m0, I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, m0, I0, I0, I0),
a_thread_buf.GetVectorTypeReference(Number<m0*WmmaK>{}).template AsType<FloatAB>()); a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B // read B
...@@ -315,28 +315,28 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -315,28 +315,28 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
make_tuple(Number<k * WmmaK / B_K1>{}, n0, I0, I0, I0), make_tuple(Number<k * WmmaK / B_K1>{}, n0, I0, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, n0, I0, I0, I0),
b_thread_buf.GetVectorTypeReference(Number<n0*WmmaK>{}).template AsType<FloatAB>()); b_thread_buf);
// vector_type<FloatAB, WmmaK> a_thread_vec; vector_type<FloatAB, WmmaK> a_thread_vec;
// vector_type<FloatAB, WmmaK> b_thread_vec; vector_type<FloatAB, WmmaK> b_thread_vec;
// static_for<0, WmmaK, 1>{}([&](auto i) { static_for<0, WmmaK, 1>{}([&](auto i) {
// a_thread_vec.template AsType<FloatAB>()(i) = a_thread_vec.template AsType<FloatAB>()(i) =
// a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
// make_tuple(i / A_K1, m0, 0, 0, i % A_K1))>{}]; make_tuple(i / A_K1, m0, 0, 0, i % A_K1))>{}];
// b_thread_vec.template AsType<FloatAB>()(i) = b_thread_vec.template AsType<FloatAB>()(i) =
// b_thread_buf[Number<b_thread_desc_.CalculateOffset( b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}]; make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}];
// }); });
// using wmma_input_type = typename vector_type<FloatAB, WmmaK>::type; using wmma_input_type = typename vector_type<FloatAB, WmmaK>::type;
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run( wmma_gemm.template Run(
a_thread_buf.GetVectorTypeReference(Number<m0*WmmaK>{}), a_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
b_thread_buf.GetVectorTypeReference(Number<n0*WmmaK>{}), b_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
}); });
...@@ -346,11 +346,11 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -346,11 +346,11 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
protected: protected:
// A[M0, M1, M2, K0 = WmmaK] // A[M0, M1, M2, K0 = WmmaK]
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<WmmaK / A_K1>{}, I1, I1, I1, Number<A_K1>{})); make_tuple(Number<WmmaK / A_K1>{}, Number<MRepeat>{}, I1, I1, Number<A_K1>{}));
// B[N0, N1, N2, K0 = WmmaK] // B[N0, N1, N2, K0 = WmmaK]
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<WmmaK / B_K1>{}, I1, I1, I1, Number<B_K1>{})); make_tuple(Number<WmmaK / B_K1>{}, Number<NRepeat>{}, I1, I1, Number<B_K1>{}));
// C[M, N, NumRegWMMA] // C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
...@@ -659,7 +659,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop ...@@ -659,7 +659,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop
make_tuple(I0, m0, I0, I0, I0), make_tuple(I0, m0, I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, Number<m0>{}, I0, I0, I0),
a_thread_buf); a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
...@@ -668,7 +668,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop ...@@ -668,7 +668,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop
make_tuple(I0, n0, I0, I0, I0), make_tuple(I0, n0, I0, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, Number<n0>{}, I0, I0, I0),
b_thread_buf); b_thread_buf);
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
...@@ -678,10 +678,10 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop ...@@ -678,10 +678,10 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop
static_for<0, WmmaK, 1>{}([&](auto i) { static_for<0, WmmaK, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_vec.template AsType<FloatAB>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple((k*WmmaK + i) / A_K1, 0, 0, 0, (k*WmmaK + i) % A_K1))>{}]; make_tuple((k*WmmaK + i) / A_K1, m0, 0, 0, (k*WmmaK + i) % A_K1))>{}];
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_vec.template AsType<FloatAB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset( b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple((k*WmmaK + i) / B_K1, 0, 0, 0, (k*WmmaK + i) % B_K1))>{}]; make_tuple((k*WmmaK + i) / B_K1, n0, 0, 0, (k*WmmaK + i) % B_K1))>{}];
}); });
using wmma_input_type = typename vector_type<FloatAB, WmmaK>::type; using wmma_input_type = typename vector_type<FloatAB, WmmaK>::type;
...@@ -701,11 +701,11 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop ...@@ -701,11 +701,11 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop
protected: protected:
// A[M0, M1, M2, K0 = WmmaK] // A[M0, M1, M2, K0 = WmmaK]
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<KPerBlock / A_K1>{}, I1, I1, I1, Number<A_K1>{})); make_tuple(Number<KPerBlock / A_K1>{}, Number<MRepeat>{}, I1, I1, Number<A_K1>{}));
// B[N0, N1, N2, K0 = WmmaK] // B[N0, N1, N2, K0 = WmmaK]
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<KPerBlock / B_K1>{}, I1, I1, I1, Number<B_K1>{})); make_tuple(Number<KPerBlock / B_K1>{}, Number<NRepeat>{}, I1, I1, Number<B_K1>{}));
// C[M, N, NumRegWMMA] // C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
...@@ -716,7 +716,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop ...@@ -716,7 +716,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop
decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<KPerBlock / A_K1, 1, 1, 1, A_K1>, Sequence<KPerBlock / A_K1, 1, 1, 1, A_K1>,
Sequence<3, 0, 1, 2, 4>, Sequence<0, 1, 2, 3, 4>,
4, 4,
A_K1, A_K1,
A_K1>; A_K1>;
...@@ -726,7 +726,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop ...@@ -726,7 +726,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop
decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<KPerBlock / B_K1, 1, 1, 1, B_K1>, Sequence<KPerBlock / B_K1, 1, 1, 1, B_K1>,
Sequence<3, 0, 1, 2, 4>, Sequence<0, 1, 2, 3, 4>,
4, 4,
B_K1, B_K1,
B_K1>; B_K1>;
...@@ -1009,9 +1009,17 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO ...@@ -1009,9 +1009,17 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
constexpr auto RepeatDiff = MRepeat - NRepeat; constexpr auto RepeatDiff = MRepeat - NRepeat;
static_for<0, KPerBlock, WmmaK>{}([&](auto iWmmaK){ static_for<0, KPerBlock, WmmaK>{}([&](auto iWmmaK){
static_for<0, NRepeat, 1>{}([&](auto iN){
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<iWmmaK/B_K1>{}, Number<iN>{}, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, Number<iN>{}, I0, I0, I0),
b_thread_buf);
});
// Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat // Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat
static_for<0, RepeatDiff, 1>{}([&](auto iCut){ static_for<0, RepeatDiff, 1>{}([&](auto iCut){
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
...@@ -1021,12 +1029,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO ...@@ -1021,12 +1029,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
make_tuple(I0, Number<iCut>{}, I0, I0, I0), make_tuple(I0, Number<iCut>{}, I0, I0, I0),
a_thread_buf); a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto iN){ static_for<0, NRepeat, 1>{}([&](auto iN){
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, // b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<iWmmaK/B_K1>{}, Number<iN>{}, I0, I0, I0), // make_tuple(Number<iWmmaK/B_K1>{}, Number<iN>{}, I0, I0, I0),
b_block_buf, // b_block_buf,
b_thread_desc_, // b_thread_desc_,
make_tuple(I0, Number<iN>{}, I0, I0, I0), // make_tuple(I0, Number<iN>{}, I0, I0, I0),
b_thread_buf); // b_thread_buf);
vector_type<FloatAB, WmmaK> a_thread_vec; vector_type<FloatAB, WmmaK> a_thread_vec;
vector_type<FloatAB, WmmaK> b_thread_vec; vector_type<FloatAB, WmmaK> b_thread_vec;
...@@ -1042,30 +1050,34 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO ...@@ -1042,30 +1050,34 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
using wmma_input_type = typename vector_type<FloatAB, WmmaK>::type; using wmma_input_type = typename vector_type<FloatAB, WmmaK>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0));
s_nop();
wmma_gemm.template Run( wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type>()(Number<0>{}), a_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type>()(Number<0>{}), b_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
s_nop();
}); });
}); });
static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop){
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<iWmmaK/A_K1>{}, Number<WmmaInnerloop+RepeatDiff>{}, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, Number<WmmaInnerloop+RepeatDiff>{}, I0, I0, I0),
a_thread_buf);
});
// Stage 2: Run FIFO fashion loopover in Square // Stage 2: Run FIFO fashion loopover in Square
static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop){ static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop){
// Row Repeatation // Row Repeatation
static_for<WmmaInnerloop, NRepeat, 1>{}([&](auto iN){ static_for<WmmaInnerloop, NRepeat, 1>{}([&](auto iN){
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<iWmmaK/A_K1>{}, Number<WmmaInnerloop+RepeatDiff>{}, I0, I0, I0), // b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
a_block_buf, // make_tuple(Number<iWmmaK/B_K1>{}, Number<iN>{}, I0, I0, I0),
a_thread_desc_, // b_block_buf,
make_tuple(I0, Number<WmmaInnerloop+RepeatDiff>{}, I0, I0, I0), // b_thread_desc_,
a_thread_buf); // make_tuple(I0, Number<iN>{}, I0, I0, I0),
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, // b_thread_buf);
make_tuple(Number<iWmmaK/B_K1>{}, Number<iN>{}, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, Number<iN>{}, I0, I0, I0),
b_thread_buf);
vector_type<FloatAB, WmmaK> a_thread_vec; vector_type<FloatAB, WmmaK> a_thread_vec;
vector_type<FloatAB, WmmaK> b_thread_vec; vector_type<FloatAB, WmmaK> b_thread_vec;
...@@ -1081,27 +1093,29 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO ...@@ -1081,27 +1093,29 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop+RepeatDiff, iN, 0)); c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop+RepeatDiff, iN, 0));
s_nop();
wmma_gemm.template Run( wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type>()(Number<0>{}), a_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type>()(Number<0>{}), b_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
s_nop();
}); });
// WmmaInnerloop++ // WmmaInnerloop++
// Col Repeatation // Col Repeatation
static_for<WmmaInnerloop+1+RepeatDiff, MRepeat, 1>{}([&](auto iM){ static_for<WmmaInnerloop+1+RepeatDiff, MRepeat, 1>{}([&](auto iM){
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, // a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<iWmmaK/A_K1>{}, Number<iM>{}, I0, I0, I0), // make_tuple(Number<iWmmaK/A_K1>{}, Number<iM>{}, I0, I0, I0),
a_block_buf, // a_block_buf,
a_thread_desc_, // a_thread_desc_,
make_tuple(I0, Number<iM>{}, I0, I0, I0), // make_tuple(I0, Number<iM>{}, I0, I0, I0),
a_thread_buf); // a_thread_buf);
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, // b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<iWmmaK/B_K1>{}, Number<WmmaInnerloop>{}, I0, I0, I0), // make_tuple(Number<iWmmaK/B_K1>{}, Number<WmmaInnerloop>{}, I0, I0, I0),
b_block_buf, // b_block_buf,
b_thread_desc_, // b_thread_desc_,
make_tuple(I0, Number<WmmaInnerloop>{}, I0, I0, I0), // make_tuple(I0, Number<WmmaInnerloop>{}, I0, I0, I0),
b_thread_buf); // b_thread_buf);
vector_type<FloatAB, WmmaK> a_thread_vec; vector_type<FloatAB, WmmaK> a_thread_vec;
vector_type<FloatAB, WmmaK> b_thread_vec; vector_type<FloatAB, WmmaK> b_thread_vec;
...@@ -1117,10 +1131,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO ...@@ -1117,10 +1131,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0)); c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0));
s_nop();
wmma_gemm.template Run( wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type>()(Number<0>{}), a_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type>()(Number<0>{}), b_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
s_nop();
}); });
}); });
}); });
...@@ -1144,7 +1160,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO ...@@ -1144,7 +1160,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<WmmaK / A_K1, 1, 1, 1, A_K1>, Sequence<WmmaK / A_K1, 1, 1, 1, A_K1>,
Sequence<3, 0, 1, 2, 4>, Sequence<0, 1, 2, 3, 4>,
4, 4,
A_K1, A_K1,
A_K1>; A_K1>;
...@@ -1154,7 +1170,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO ...@@ -1154,7 +1170,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<WmmaK / B_K1, 1, 1, 1, B_K1>, Sequence<WmmaK / B_K1, 1, 1, 1, B_K1>,
Sequence<3, 0, 1, 2, 4>, Sequence<0, 1, 2, 3, 4>,
4, 4,
B_K1, B_K1,
B_K1>; B_K1>;
......
...@@ -310,7 +310,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -310,7 +310,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto WmmaK = 16; constexpr auto WmmaK = 16;
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle< using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
...@@ -367,7 +367,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -367,7 +367,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto WmmaK = 16; constexpr auto WmmaK = 16;
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle< using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
...@@ -540,7 +540,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -540,7 +540,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<BlockSize, BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0perblock_mperblock_k1), decltype(a_block_desc_k0perblock_mperblock_k1),
......
...@@ -360,9 +360,7 @@ __device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, ...@@ -360,9 +360,7 @@ __device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a,
half16_t b, half16_t b,
float8_t& c) float8_t& c)
{ {
asm volatile("\n \ asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0"
v_wmma_f32_16x16x16_f16_w32 %0, %1, %2, %0\n \
"
: "=v"(c) : "=v"(c)
: "v"(a), "v"(b), "0"(c)); : "v"(a), "v"(b), "0"(c));
} }
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#ifndef CK_AMD_WMMA_HPP #ifndef CK_AMD_WMMA_HPP
#define CK_AMD_WMMA_HPP #define CK_AMD_WMMA_HPP
#include "ck/utility/amd_inline_asm.hpp"
#include "data_type.hpp" #include "data_type.hpp"
// TODO: Add arch limitation // TODO: Add arch limitation
namespace ck { namespace ck {
...@@ -20,8 +21,10 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16> ...@@ -20,8 +21,10 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( // * Inline assembly need to elimate the duplicated data load, compiler won't help you delete them.
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]); amd_assembly_wmma_f32_16x16x16_f16_w32(reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
// reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
// reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
} }
}; };
......
...@@ -97,6 +97,7 @@ builtin_wmma_naive_selector<int4x16_t, ...@@ -97,6 +97,7 @@ builtin_wmma_naive_selector<int4x16_t,
template <typename src_t, typename dst_t, typename acc_t, index_t acc_num> template <typename src_t, typename dst_t, typename acc_t, index_t acc_num>
__global__ void matmul(const src_t* a, const src_t* b, dst_t* c) __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
{ {
__shared__ src_t p_shared[16*16*2];
const int lIdx = threadIdx.x; const int lIdx = threadIdx.x;
// a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and // a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and
// b a_frag will store one column of the 16x16 matrix tile b_frag will store one row of the // b a_frag will store one column of the 16x16 matrix tile b_frag will store one row of the
...@@ -104,6 +105,9 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) ...@@ -104,6 +105,9 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
using src_vec = typename vector_type<src_t, 16>::type; using src_vec = typename vector_type<src_t, 16>::type;
src_vec a_frag = {}; src_vec a_frag = {};
src_vec b_frag = {}; src_vec b_frag = {};
src_vec a_temp = {};
src_vec b_temp = {};
// initialize c fragment to 0 // initialize c fragment to 0
using acc_vec = StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, acc_t, 1, acc_num, true>; using acc_vec = StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, acc_t, 1, acc_num, true>;
acc_vec c_thread_buf_; acc_vec c_thread_buf_;
...@@ -112,19 +116,52 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) ...@@ -112,19 +116,52 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
// see https://atlvsp3.amd.com/sp3_gfx11_5_instructions.pdf page 482 // see https://atlvsp3.amd.com/sp3_gfx11_5_instructions.pdf page 482
// TODO: remove this dependency in gfx12 https://ontrack-internal.amd.com/browse/DEGFXSP3-101 // TODO: remove this dependency in gfx12 https://ontrack-internal.amd.com/browse/DEGFXSP3-101
const int lane = lIdx % 16; const int lane = lIdx % 16;
const int lane_lo = lIdx / 2;
const int lane_hi = lIdx % 2;
for(int ele = 0; ele < 8; ++ele)
{
a_temp[ele] = a[8 * lane_hi + 16 * lane_lo + ele];
}
for(int ele = 0; ele < 8; ++ele)
{
b_temp[ele] = b[8 * lane_hi + 16 * lane_lo + ele];
}
__syncthreads();
for(int ele = 0; ele < 8; ++ele)
{
p_shared[8*16*lane_hi + 8 * lane_lo + ele] = a_temp[ele];
}
for(int ele = 0; ele < 8; ++ele)
{
p_shared[8*16*lane_hi + 8 * lane_lo + ele + 16*16] = b_temp[ele];
}
asm volatile("\
s_waitcnt lgkmcnt(0) \n \
s_barrier \
" ::);
for(int ele = 0; ele < 16; ++ele) for(int ele = 0; ele < 16; ++ele)
{ {
b_frag[ele] = b[16 * lane + ele]; b_frag[ele] = p_shared[(ele/8) * 16*8 + 8 * lane + ele%8 + 16*16];
} }
// follow origin design // follow origin design
for(int ele = 0; ele < 16; ++ele) for(int ele = 0; ele < 16; ++ele)
{ {
a_frag[ele] = a[16 * lane + ele]; a_frag[ele] = p_shared[(ele/8) * 16*8 + 8 * lane + ele%8];
} }
asm volatile("\
s_waitcnt lgkmcnt(0) \n \
s_barrier \
" ::);
// sync threads, similar to mma_sync // sync threads, similar to mma_sync
__syncthreads(); // __syncthreads();
builtin_wmma_naive_selector<src_vec, acc_vec>(a_frag, b_frag, c_thread_buf_); builtin_wmma_naive_selector<src_vec, acc_vec>(a_frag, b_frag, c_thread_buf_);
__syncthreads(); __syncthreads();
// wait for results, similar to mma_sync // wait for results, similar to mma_sync
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment