Commit 5cfe67b1 authored by aska-0096's avatar aska-0096
Browse files

NewBlkGEMM

parent cc0ffeb7
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
...@@ -83,30 +84,30 @@ using DeviceOpInstance = ...@@ -83,30 +84,30 @@ using DeviceOpInstance =
1, 1,
128, 128,
64, 64,
128,
64, 64,
64, 8,
4,
16, 16,
16, 16,
1, 2,
4, 4,
S<4, 32, 1>, S<4, 32, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
4, 8,
4, 8,
true, false,
S<4, 32, 1>, S<4, 32, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
4, 8,
4, 8,
true, false,
1, 1,
1, 1,
S<1, 64, 1, 2>, S<1, 32, 1, 4>,
8>; 8>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
...@@ -208,12 +209,29 @@ int main(int argc, char* argv[]) ...@@ -208,12 +209,29 @@ int main(int argc, char* argv[])
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5}); d_m_n.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
break; break;
case 2:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n);
ck::utils::FillUniformDistributionIntegerValue<DDataType>{1.f, 1.f}(d_m_n);
break;
default: default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5}); d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
} }
#if 0
for(int im = 0; im<M; im++)
{
for(int ik = 0; ik<K; ik++)
{
if(ik%8==0) printf("|");
printf("%4x ", *(reinterpret_cast<uint16_t*>(&(a_m_k(im,ik)))));
}
printf("\n");
}
#endif
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize());
......
...@@ -62,12 +62,13 @@ struct BlockwiseGemmWMMA ...@@ -62,12 +62,13 @@ struct BlockwiseGemmWMMA
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one. // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
static constexpr index_t WaveSize = 32; static constexpr index_t WaveSize = 32;
static constexpr index_t RowSize = 16;
// When use LDS, each Row(16 consecutive lanes) read whole data from source buffer // When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via // When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// permutation // permutation
static constexpr index_t A_KRow = AEnableLds ? 1 : 2; static constexpr index_t A_KRow = AEnableLds ? 2 : 2;
static constexpr index_t B_KRow = BEnableLds ? 1 : 2; static constexpr index_t B_KRow = BEnableLds ? 2 : 2;
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5); static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5); static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
...@@ -105,10 +106,11 @@ struct BlockwiseGemmWMMA ...@@ -105,10 +106,11 @@ struct BlockwiseGemmWMMA
{ {
const auto wave_idx = GetWaveIdx(); const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0]; const auto waveId_m = wave_idx[I0];
const auto rowId = (get_thread_local_1d_id() % WaveSize) / RowSize;
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack // |KRepeat |MRepeat|MWave |KRow |MLane |KPack
return make_tuple(0, 0, waveId_m, 0, WMMA_a_idx, 0); return make_tuple(0, 0, waveId_m, rowId, WMMA_a_idx, 0);
} }
else else
{ {
...@@ -122,10 +124,11 @@ struct BlockwiseGemmWMMA ...@@ -122,10 +124,11 @@ struct BlockwiseGemmWMMA
{ {
const auto wave_idx = GetWaveIdx(); const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1]; const auto waveId_n = wave_idx[I1];
const auto rowId = (get_thread_local_1d_id() % WaveSize) / RowSize;
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
return make_tuple(0, 0, waveId_n, 0, WMMA_b_idx, 0); return make_tuple(0, 0, waveId_n, rowId, WMMA_b_idx, 0);
} }
else else
{ {
...@@ -286,8 +289,8 @@ struct BlockwiseGemmWMMA ...@@ -286,8 +289,8 @@ struct BlockwiseGemmWMMA
// Describe how data allocated in thread copy src buffer // Describe how data allocated in thread copy src buffer
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1; static constexpr ABlockDesc a_block_desc_;
static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1; static constexpr BBlockDesc b_block_desc_;
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf, __device__ void Run(const ABlockBuffer& a_block_buf,
...@@ -299,6 +302,10 @@ struct BlockwiseGemmWMMA ...@@ -299,6 +302,10 @@ struct BlockwiseGemmWMMA
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
using RowwisePermuter =
tensor_operation::element_wise::InterRowPermuter<FloatA, WmmaK / A_KRow>;
RowwisePermuter interrow_permuter;
// basic intrinsic to determine loopover direction // basic intrinsic to determine loopover direction
if constexpr(MRepeat < NRepeat) if constexpr(MRepeat < NRepeat)
{ {
...@@ -307,7 +314,7 @@ struct BlockwiseGemmWMMA ...@@ -307,7 +314,7 @@ struct BlockwiseGemmWMMA
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A // read A
a_thread_copy_.Run( a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1, a_block_desc_,
make_tuple(Number<k * WmmaK / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0), make_tuple(Number<k * WmmaK / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
...@@ -317,35 +324,44 @@ struct BlockwiseGemmWMMA ...@@ -317,35 +324,44 @@ struct BlockwiseGemmWMMA
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B // read B
b_thread_copy_.Run( b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1, b_block_desc_,
make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0), make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0), make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf); b_thread_buf);
vector_type<FloatA, WmmaK / A_KRow> a_this_thread_vec;
vector_type<FloatB, WmmaK / B_KRow> b_this_thread_vec;
vector_type<FloatA, WmmaK / A_KRow> a_other_thread_vec;
vector_type<FloatB, WmmaK / B_KRow> b_other_thread_vec;
vector_type<FloatA, WmmaK> a_thread_vec; vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec; vector_type<FloatB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) { static_for<0, WmmaK / A_KRow, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) = a_this_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1 / A_KRow, make_tuple(i / B_K1 / B_KRow, m0, 0, 0, 0, i % A_K1))>{}];
m0, b_this_thread_vec.template AsType<FloatB>()(i) =
0,
(i / A_K1) % A_KRow,
0,
i % A_K1))>{}];
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset( b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1 / B_KRow, make_tuple(i / A_K1 / A_KRow, n0, 0, 0, 0, i % B_K1))>{}];
n0,
0,
(i / B_K1) % B_KRow,
0,
i % B_K1))>{}];
}); });
a_other_thread_vec = interrow_permuter(a_this_thread_vec);
b_other_thread_vec = interrow_permuter(b_this_thread_vec);
using vec_row_type_a = typename vector_type<FloatA, WmmaK / 2>::type;
using vec_row_type_b = typename vector_type<FloatB, WmmaK / 2>::type;
a_thread_vec.template AsType<vec_row_type_a>()(Number<0>{}) =
a_this_thread_vec.template AsType<vec_row_type_a>()(Number<0>{});
a_thread_vec.template AsType<vec_row_type_a>()(Number<1>{}) =
a_other_thread_vec.template AsType<vec_row_type_a>()(Number<0>{});
b_thread_vec.template AsType<vec_row_type_b>()(Number<0>{}) =
b_this_thread_vec.template AsType<vec_row_type_b>()(Number<0>{});
b_thread_vec.template AsType<vec_row_type_b>()(Number<1>{}) =
b_other_thread_vec.template AsType<vec_row_type_b>()(Number<0>{});
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type; using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type; using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
...@@ -362,49 +378,101 @@ struct BlockwiseGemmWMMA ...@@ -362,49 +378,101 @@ struct BlockwiseGemmWMMA
} }
else else
{ {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of
static_for<0, MRepeat, 1>{}([&](auto m0) { // k=0,kpack*1, ..
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of static_for<0, NRepeat, 1>{}([&](auto n0) {
// k=0,kpack*1, .. // read B
// read B b_thread_copy_.Run(
b_thread_copy_.Run( b_block_desc_,
b_block_desc_k0_n0_n1_n2_k1, make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0), b_block_buf,
b_block_buf, b_thread_desc_,
b_thread_desc_, make_tuple(I0, n0, I0, I0, I0, I0),
make_tuple(I0, n0, I0, I0, I0, I0), b_thread_buf);
b_thread_buf); static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A // read A
a_thread_copy_.Run( a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1, a_block_desc_,
make_tuple(Number<k * WmmaK / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0), make_tuple(Number<k * WmmaK / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0), make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf); a_thread_buf);
vector_type<FloatA, WmmaK / A_KRow> a_this_thread_vec;
vector_type<FloatB, WmmaK / B_KRow> b_this_thread_vec;
vector_type<FloatA, WmmaK / A_KRow> a_other_thread_vec;
vector_type<FloatB, WmmaK / B_KRow> b_other_thread_vec;
vector_type<FloatA, WmmaK> a_thread_vec; vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec; vector_type<FloatB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) { static_for<0, WmmaK / A_KRow, 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatB>()(i) = b_this_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset( b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1 / B_KRow, make_tuple(i / A_K1 / A_KRow, n0, 0, 0, 0, i % B_K1))>{}];
n0, a_this_thread_vec.template AsType<FloatA>()(i) =
0,
(i / B_K1) % B_KRow,
0,
i % B_K1))>{}];
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1 / A_KRow, make_tuple(i / B_K1 / B_KRow, m0, 0, 0, 0, i % A_K1))>{}];
m0,
0,
(i / A_K1) % A_KRow,
0,
i % A_K1))>{}];
}); });
a_other_thread_vec = interrow_permuter(a_this_thread_vec);
b_other_thread_vec = interrow_permuter(b_this_thread_vec);
using vec_row_type_a = typename vector_type<FloatA, WmmaK / 2>::type;
using vec_row_type_b = typename vector_type<FloatB, WmmaK / 2>::type;
b_thread_vec.template AsType<vec_row_type_b>()(Number<0>{}) =
b_this_thread_vec.template AsType<vec_row_type_b>()(Number<0>{});
b_thread_vec.template AsType<vec_row_type_b>()(Number<1>{}) =
b_other_thread_vec.template AsType<vec_row_type_b>()(Number<0>{});
a_thread_vec.template AsType<vec_row_type_a>()(Number<0>{}) =
a_this_thread_vec.template AsType<vec_row_type_a>()(Number<0>{});
a_thread_vec.template AsType<vec_row_type_a>()(Number<1>{}) =
a_other_thread_vec.template AsType<vec_row_type_a>()(Number<0>{});
#if 0
if(get_thread_local_1d_id() == 0 || get_thread_local_1d_id() == 16)
{
// printf("a_this_thread_vec: %04x %04x %04x %04x %04x %04x %04x %04x\n",
// *(reinterpret_cast<uint16_t*>(&(a_this_thread_vec.template AsType<FloatA>()(Number<0>{})))),
// *(reinterpret_cast<uint16_t*>(&(a_this_thread_vec.template AsType<FloatA>()(Number<1>{})))),
// *(reinterpret_cast<uint16_t*>(&(a_this_thread_vec.template AsType<FloatA>()(Number<2>{})))),
// *(reinterpret_cast<uint16_t*>(&(a_this_thread_vec.template AsType<FloatA>()(Number<3>{})))),
// *(reinterpret_cast<uint16_t*>(&(a_this_thread_vec.template AsType<FloatA>()(Number<4>{})))),
// *(reinterpret_cast<uint16_t*>(&(a_this_thread_vec.template AsType<FloatA>()(Number<5>{})))),
// *(reinterpret_cast<uint16_t*>(&(a_this_thread_vec.template AsType<FloatA>()(Number<6>{})))),
// *(reinterpret_cast<uint16_t*>(&(a_this_thread_vec.template AsType<FloatA>()(Number<7>{})))));
// printf("a_other_thread_vec: %04x %04x %04x %04x %04x %04x %04x %04x\n",
// *(reinterpret_cast<uint16_t*>(&(a_other_thread_vec.template AsType<FloatA>()(Number<0>{})))),
// *(reinterpret_cast<uint16_t*>(&(a_other_thread_vec.template AsType<FloatA>()(Number<1>{})))),
// *(reinterpret_cast<uint16_t*>(&(a_other_thread_vec.template AsType<FloatA>()(Number<2>{})))),
// *(reinterpret_cast<uint16_t*>(&(a_other_thread_vec.template AsType<FloatA>()(Number<3>{})))),
// *(reinterpret_cast<uint16_t*>(&(a_other_thread_vec.template AsType<FloatA>()(Number<4>{})))),
// *(reinterpret_cast<uint16_t*>(&(a_other_thread_vec.template AsType<FloatA>()(Number<5>{})))),
// *(reinterpret_cast<uint16_t*>(&(a_other_thread_vec.template AsType<FloatA>()(Number<6>{})))),
// *(reinterpret_cast<uint16_t*>(&(a_other_thread_vec.template AsType<FloatA>()(Number<7>{})))));
printf("a_thread_vec: %04x %04x %04x %04x %04x %04x %04x %04x| %04x %04x %04x %04x %04x %04x %04x %04x\n",
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<0>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<1>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<2>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<3>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<4>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<5>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<6>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<7>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<8>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<9>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<10>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<11>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<12>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<13>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<14>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<15>{})))));
}
#endif
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type; using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type; using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
...@@ -422,33 +490,23 @@ struct BlockwiseGemmWMMA ...@@ -422,33 +490,23 @@ struct BlockwiseGemmWMMA
} }
protected: protected:
static constexpr auto a_thread_desc_ = static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
make_naive_tensor_descriptor(make_tuple(Number<WmmaK / A_K1 / A_KRow>{}, make_tuple(Number<WmmaK / A_K1 / A_KRow>{}, Number<MRepeat>{}, I1, I1, I1, Number<A_K1>{}),
Number<MRepeat>{}, make_tuple(Number<A_K1>{},
I1, Number<WmmaK / A_KRow>{},
Number<A_KRow>{}, Number<A_K1>{},
I1, Number<A_K1>{},
Number<A_K1>{}), Number<A_K1>{},
make_tuple(Number<A_K1 * A_KRow>{}, I1));
Number<WmmaK>{},
Number<A_K1 * A_KRow>{}, static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
Number<A_K1>{}, make_tuple(Number<WmmaK / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, I1, I1, Number<B_K1>{}),
Number<A_K1>{}, make_tuple(Number<B_K1>{},
Number<1>{})); Number<WmmaK / B_KRow>{},
Number<B_K1>{},
static constexpr auto b_thread_desc_ = Number<B_K1>{},
make_naive_tensor_descriptor(make_tuple(Number<WmmaK / B_K1 / B_KRow>{}, Number<B_K1>{},
Number<NRepeat>{}, I1));
I1,
Number<B_KRow>{},
I1,
Number<B_K1>{}),
make_tuple(Number<B_K1 * B_KRow>{},
Number<WmmaK>{},
Number<B_K1 * B_KRow>{},
Number<B_K1>{},
Number<B_K1>{},
Number<1>{}));
// C[M, N, NumRegWMMA] // C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
...@@ -463,9 +521,9 @@ struct BlockwiseGemmWMMA ...@@ -463,9 +521,9 @@ struct BlockwiseGemmWMMA
using type = using type =
ThreadwiseTensorSliceTransfer_v4<FloatA, ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA, FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_block_desc_),
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<WmmaK / A_K1 / A_KRow, 1, 1, A_KRow, 1, A_K1>, Sequence<WmmaK / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5>,
5, 5,
A_K1, A_K1,
...@@ -478,7 +536,7 @@ struct BlockwiseGemmWMMA ...@@ -478,7 +536,7 @@ struct BlockwiseGemmWMMA
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow< using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow<
FloatA, FloatA,
FloatA, FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_block_desc_),
decltype(a_thread_desc_), decltype(a_thread_desc_),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
Sequence<WmmaK / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>, Sequence<WmmaK / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
...@@ -499,9 +557,9 @@ struct BlockwiseGemmWMMA ...@@ -499,9 +557,9 @@ struct BlockwiseGemmWMMA
using type = using type =
ThreadwiseTensorSliceTransfer_v4<FloatB, ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB, FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_block_desc_),
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<WmmaK / B_K1 / B_KRow, 1, 1, B_KRow, 1, B_K1>, Sequence<WmmaK / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5>,
5, 5,
B_K1, B_K1,
...@@ -514,7 +572,7 @@ struct BlockwiseGemmWMMA ...@@ -514,7 +572,7 @@ struct BlockwiseGemmWMMA
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow< using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow<
FloatB, FloatB,
FloatB, FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_block_desc_),
decltype(b_thread_desc_), decltype(b_thread_desc_),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
Sequence<WmmaK / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>, Sequence<WmmaK / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
......
...@@ -448,6 +448,68 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, N> ...@@ -448,6 +448,68 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, N>
__device__ OutputArray operator()(InputArray const& Input) { return convert(Input); } __device__ OutputArray operator()(InputArray const& Input) { return convert(Input); }
}; };
template <typename InputDataType, index_t RegPackNumber>
struct InterRowPermuter
{
};
template <>
struct InterRowPermuter<ck::half_t, 2>
{
using InputArray = vector_type<ck::half_t, 2>;
using OutputArray = vector_type<ck::half_t, 2>;
__device__ static OutputArray convert(InputArray const& Input)
{
OutputArray Output;
uint32_t* output_half_2 = reinterpret_cast<uint32_t*>(&Output);
uint32_t const input_half_2 = reinterpret_cast<uint32_t const&>(Input);
output_half_2[0] = __builtin_amdgcn_permlanex16(
output_half_2[0], input_half_2, 0x76543210, 0xfedcba98, 1, 0);
#if 0
if(get_thread_local_1d_id() == 0)
{
printf("After permlanex, input: %04x, output: %04x\n", input_half_2, output_half_2);
}
#endif
return Output;
}
__device__ OutputArray operator()(InputArray const& Input) { return convert(Input); }
};
template <index_t N>
struct InterRowPermuter<ck::half_t, N>
{
static constexpr int VEC_WIDTH = 2;
static_assert(!(N % VEC_WIDTH), "N must be multiple of 2.");
using InputArray = vector_type<ck::half_t, N>;
using OutputArray = vector_type<ck::half_t, N>;
__device__ static OutputArray convert(InputArray const& Input)
{
InterRowPermuter<ck::half_t, 2> converter;
OutputArray Output;
using Vec_InputArray = vector_type<ck::half_t, 2>;
using Vec_OutputArray = vector_type<ck::half_t, 2>;
Vec_OutputArray* output_half_2_ptr = reinterpret_cast<Vec_OutputArray*>(&Output);
Vec_InputArray const* input_half_2_ptr = reinterpret_cast<Vec_InputArray const*>(&Input);
static_for<0, N / VEC_WIDTH, 1>{}(
[&](auto i) { output_half_2_ptr[i] = converter(input_half_2_ptr[i]); });
return Output;
}
__device__ OutputArray operator()(InputArray const& Input) { return convert(Input); }
};
} // namespace element_wise } // namespace element_wise
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -498,15 +498,23 @@ struct GridwiseGemmMultipleD_Wmma ...@@ -498,15 +498,23 @@ struct GridwiseGemmMultipleD_Wmma
if constexpr(AEnableLds) if constexpr(AEnableLds)
{ {
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); // Debug this part
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); constexpr auto A_KRow = 2;
constexpr auto A_KRow = I1; constexpr auto A_K0PerRow = ABlockDesc_{}.GetLength(I0) / A_KRow;
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
// return make_naive_tensor_descriptor_packed(make_tuple(Number<A_K0PerRow>{},
// Number<MRepeat>{},
// Number<MWaves>{},
// Number<A_KRow>{},
// Number<MPerWmma>{},
// Number<A_K1>{}));
return transform_tensor_descriptor( return transform_tensor_descriptor(
ABlockDesc_{}, ABlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)), make_tuple(
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(Number<A_K0PerRow>{}, Number<A_KRow>{})),
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})), make_unmerge_transform(
make_pass_through_transform(Number<A_K1>{})), make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
} }
...@@ -537,17 +545,31 @@ struct GridwiseGemmMultipleD_Wmma ...@@ -537,17 +545,31 @@ struct GridwiseGemmMultipleD_Wmma
if constexpr(BEnableLds) if constexpr(BEnableLds)
{ {
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); #if 1
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); constexpr auto B_KRow = 2;
constexpr auto B_KRow = I1; constexpr auto B_K0PerRow = BBlockDesc_{}.GetLength(I0) / B_KRow;
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
return transform_tensor_descriptor( return transform_tensor_descriptor(
BBlockDesc_{}, BBlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)), make_tuple(
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(Number<B_K0PerRow>{}, Number<B_KRow>{})),
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})), make_unmerge_transform(
make_pass_through_transform(Number<B_K1>{})), make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
#endif
#if 0
constexpr auto B_KRow = 2;
constexpr auto B_K0PerRow = BBlockDesc_{}.GetLength(I0) / B_KRow;
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
return make_naive_tensor_descriptor_packed(make_tuple(Number<B_K0PerRow>{},
Number<NRepeat>{},
Number<NWaves>{},
Number<B_KRow>{},
Number<NPerWmma>{},
Number<B_K1>{}));
#endif
} }
else else
{ {
......
...@@ -1136,7 +1136,17 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1136,7 +1136,17 @@ struct ThreadwiseTensorSliceTransfer_v4
auto src_data_coord = src_ref_coord_; auto src_data_coord = src_ref_coord_;
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step); move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
#if 0
printf("Tid: %03d, Inele_Offset: %d, Coord: (%d, %d, %d, %d, %d, %d)\n",
get_thread_local_1d_id(),
src_data_coord.GetOffset(),
src_data_coord.GetIndex().At(Number<0>{}),
src_data_coord.GetIndex().At(Number<1>{}),
src_data_coord.GetIndex().At(Number<2>{}),
src_data_coord.GetIndex().At(Number<3>{}),
src_data_coord.GetIndex().At(Number<4>{}),
src_data_coord.GetIndex().At(Number<5>{}));
#endif
vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector; vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector;
using src_vector_t = typename decltype(src_tmp_vector)::type; using src_vector_t = typename decltype(src_tmp_vector)::type;
...@@ -1178,6 +1188,12 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1178,6 +1188,12 @@ struct ThreadwiseTensorSliceTransfer_v4
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i]; dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
}); });
#if 0
printf("Tid: %03d, Inele_Offset: %d\n",
get_thread_local_1d_id(),
dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx));
#endif
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment