Commit 5cfe67b1 authored by aska-0096's avatar aska-0096
Browse files

NewBlkGEMM

parent cc0ffeb7
......@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
......@@ -83,30 +84,30 @@ using DeviceOpInstance =
1,
128,
64,
128,
64,
64,
4,
8,
16,
16,
1,
2,
4,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
4,
4,
true,
8,
8,
false,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
4,
4,
true,
8,
8,
false,
1,
1,
S<1, 64, 1, 2>,
S<1, 32, 1, 4>,
8>;
int main(int argc, char* argv[])
......@@ -208,12 +209,29 @@ int main(int argc, char* argv[])
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
break;
case 2:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n);
ck::utils::FillUniformDistributionIntegerValue<DDataType>{1.f, 1.f}(d_m_n);
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
}
#if 0
for(int im = 0; im<M; im++)
{
for(int ik = 0; ik<K; ik++)
{
if(ik%8==0) printf("|");
printf("%4x ", *(reinterpret_cast<uint16_t*>(&(a_m_k(im,ik)))));
}
printf("\n");
}
#endif
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize());
......
......@@ -62,12 +62,13 @@ struct BlockwiseGemmWMMA
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
static constexpr index_t WaveSize = 32;
static constexpr index_t RowSize = 16;
// When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// permutation
static constexpr index_t A_KRow = AEnableLds ? 1 : 2;
static constexpr index_t B_KRow = BEnableLds ? 1 : 2;
static constexpr index_t A_KRow = AEnableLds ? 2 : 2;
static constexpr index_t B_KRow = BEnableLds ? 2 : 2;
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
......@@ -105,10 +106,11 @@ struct BlockwiseGemmWMMA
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto rowId = (get_thread_local_1d_id() % WaveSize) / RowSize;
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack
return make_tuple(0, 0, waveId_m, 0, WMMA_a_idx, 0);
return make_tuple(0, 0, waveId_m, rowId, WMMA_a_idx, 0);
}
else
{
......@@ -122,10 +124,11 @@ struct BlockwiseGemmWMMA
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto rowId = (get_thread_local_1d_id() % WaveSize) / RowSize;
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
return make_tuple(0, 0, waveId_n, 0, WMMA_b_idx, 0);
return make_tuple(0, 0, waveId_n, rowId, WMMA_b_idx, 0);
}
else
{
......@@ -286,8 +289,8 @@ struct BlockwiseGemmWMMA
// Describe how data allocated in thread copy src buffer
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1;
static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1;
static constexpr ABlockDesc a_block_desc_;
static constexpr BBlockDesc b_block_desc_;
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
......@@ -299,6 +302,10 @@ struct BlockwiseGemmWMMA
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
b_thread_desc_.GetElementSpaceSize());
using RowwisePermuter =
tensor_operation::element_wise::InterRowPermuter<FloatA, WmmaK / A_KRow>;
RowwisePermuter interrow_permuter;
// basic intrinsic to determine loopover direction
if constexpr(MRepeat < NRepeat)
{
......@@ -307,7 +314,7 @@ struct BlockwiseGemmWMMA
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
a_block_desc_,
make_tuple(Number<k * WmmaK / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
......@@ -317,35 +324,44 @@ struct BlockwiseGemmWMMA
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
b_block_desc_,
make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf);
vector_type<FloatA, WmmaK / A_KRow> a_this_thread_vec;
vector_type<FloatB, WmmaK / B_KRow> b_this_thread_vec;
vector_type<FloatA, WmmaK / A_KRow> a_other_thread_vec;
vector_type<FloatB, WmmaK / B_KRow> b_other_thread_vec;
vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) =
static_for<0, WmmaK / A_KRow, 1>{}([&](auto i) {
a_this_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1 / A_KRow,
m0,
0,
(i / A_K1) % A_KRow,
0,
i % A_K1))>{}];
b_thread_vec.template AsType<FloatB>()(i) =
make_tuple(i / B_K1 / B_KRow, m0, 0, 0, 0, i % A_K1))>{}];
b_this_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1 / B_KRow,
n0,
0,
(i / B_K1) % B_KRow,
0,
i % B_K1))>{}];
make_tuple(i / A_K1 / A_KRow, n0, 0, 0, 0, i % B_K1))>{}];
});
a_other_thread_vec = interrow_permuter(a_this_thread_vec);
b_other_thread_vec = interrow_permuter(b_this_thread_vec);
using vec_row_type_a = typename vector_type<FloatA, WmmaK / 2>::type;
using vec_row_type_b = typename vector_type<FloatB, WmmaK / 2>::type;
a_thread_vec.template AsType<vec_row_type_a>()(Number<0>{}) =
a_this_thread_vec.template AsType<vec_row_type_a>()(Number<0>{});
a_thread_vec.template AsType<vec_row_type_a>()(Number<1>{}) =
a_other_thread_vec.template AsType<vec_row_type_a>()(Number<0>{});
b_thread_vec.template AsType<vec_row_type_b>()(Number<0>{}) =
b_this_thread_vec.template AsType<vec_row_type_b>()(Number<0>{});
b_thread_vec.template AsType<vec_row_type_b>()(Number<1>{}) =
b_other_thread_vec.template AsType<vec_row_type_b>()(Number<0>{});
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
......@@ -362,49 +378,101 @@ struct BlockwiseGemmWMMA
}
else
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of
// k=0,kpack*1, ..
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf);
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of
// k=0,kpack*1, ..
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(
b_block_desc_,
make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
a_block_desc_,
make_tuple(Number<k * WmmaK / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf);
vector_type<FloatA, WmmaK / A_KRow> a_this_thread_vec;
vector_type<FloatB, WmmaK / B_KRow> b_this_thread_vec;
vector_type<FloatA, WmmaK / A_KRow> a_other_thread_vec;
vector_type<FloatB, WmmaK / B_KRow> b_other_thread_vec;
vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatB>()(i) =
static_for<0, WmmaK / A_KRow, 1>{}([&](auto i) {
b_this_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1 / B_KRow,
n0,
0,
(i / B_K1) % B_KRow,
0,
i % B_K1))>{}];
a_thread_vec.template AsType<FloatA>()(i) =
make_tuple(i / A_K1 / A_KRow, n0, 0, 0, 0, i % B_K1))>{}];
a_this_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1 / A_KRow,
m0,
0,
(i / A_K1) % A_KRow,
0,
i % A_K1))>{}];
make_tuple(i / B_K1 / B_KRow, m0, 0, 0, 0, i % A_K1))>{}];
});
a_other_thread_vec = interrow_permuter(a_this_thread_vec);
b_other_thread_vec = interrow_permuter(b_this_thread_vec);
using vec_row_type_a = typename vector_type<FloatA, WmmaK / 2>::type;
using vec_row_type_b = typename vector_type<FloatB, WmmaK / 2>::type;
b_thread_vec.template AsType<vec_row_type_b>()(Number<0>{}) =
b_this_thread_vec.template AsType<vec_row_type_b>()(Number<0>{});
b_thread_vec.template AsType<vec_row_type_b>()(Number<1>{}) =
b_other_thread_vec.template AsType<vec_row_type_b>()(Number<0>{});
a_thread_vec.template AsType<vec_row_type_a>()(Number<0>{}) =
a_this_thread_vec.template AsType<vec_row_type_a>()(Number<0>{});
a_thread_vec.template AsType<vec_row_type_a>()(Number<1>{}) =
a_other_thread_vec.template AsType<vec_row_type_a>()(Number<0>{});
#if 0
if(get_thread_local_1d_id() == 0 || get_thread_local_1d_id() == 16)
{
// printf("a_this_thread_vec: %04x %04x %04x %04x %04x %04x %04x %04x\n",
// *(reinterpret_cast<uint16_t*>(&(a_this_thread_vec.template AsType<FloatA>()(Number<0>{})))),
// *(reinterpret_cast<uint16_t*>(&(a_this_thread_vec.template AsType<FloatA>()(Number<1>{})))),
// *(reinterpret_cast<uint16_t*>(&(a_this_thread_vec.template AsType<FloatA>()(Number<2>{})))),
// *(reinterpret_cast<uint16_t*>(&(a_this_thread_vec.template AsType<FloatA>()(Number<3>{})))),
// *(reinterpret_cast<uint16_t*>(&(a_this_thread_vec.template AsType<FloatA>()(Number<4>{})))),
// *(reinterpret_cast<uint16_t*>(&(a_this_thread_vec.template AsType<FloatA>()(Number<5>{})))),
// *(reinterpret_cast<uint16_t*>(&(a_this_thread_vec.template AsType<FloatA>()(Number<6>{})))),
// *(reinterpret_cast<uint16_t*>(&(a_this_thread_vec.template AsType<FloatA>()(Number<7>{})))));
// printf("a_other_thread_vec: %04x %04x %04x %04x %04x %04x %04x %04x\n",
// *(reinterpret_cast<uint16_t*>(&(a_other_thread_vec.template AsType<FloatA>()(Number<0>{})))),
// *(reinterpret_cast<uint16_t*>(&(a_other_thread_vec.template AsType<FloatA>()(Number<1>{})))),
// *(reinterpret_cast<uint16_t*>(&(a_other_thread_vec.template AsType<FloatA>()(Number<2>{})))),
// *(reinterpret_cast<uint16_t*>(&(a_other_thread_vec.template AsType<FloatA>()(Number<3>{})))),
// *(reinterpret_cast<uint16_t*>(&(a_other_thread_vec.template AsType<FloatA>()(Number<4>{})))),
// *(reinterpret_cast<uint16_t*>(&(a_other_thread_vec.template AsType<FloatA>()(Number<5>{})))),
// *(reinterpret_cast<uint16_t*>(&(a_other_thread_vec.template AsType<FloatA>()(Number<6>{})))),
// *(reinterpret_cast<uint16_t*>(&(a_other_thread_vec.template AsType<FloatA>()(Number<7>{})))));
printf("a_thread_vec: %04x %04x %04x %04x %04x %04x %04x %04x| %04x %04x %04x %04x %04x %04x %04x %04x\n",
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<0>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<1>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<2>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<3>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<4>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<5>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<6>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<7>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<8>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<9>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<10>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<11>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<12>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<13>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<14>{})))),
*(reinterpret_cast<uint16_t*>(&(a_thread_vec.template AsType<FloatA>()(Number<15>{})))));
}
#endif
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
......@@ -422,33 +490,23 @@ struct BlockwiseGemmWMMA
}
protected:
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<WmmaK / A_K1 / A_KRow>{},
Number<MRepeat>{},
I1,
Number<A_KRow>{},
I1,
Number<A_K1>{}),
make_tuple(Number<A_K1 * A_KRow>{},
Number<WmmaK>{},
Number<A_K1 * A_KRow>{},
Number<A_K1>{},
Number<A_K1>{},
Number<1>{}));
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<WmmaK / B_K1 / B_KRow>{},
Number<NRepeat>{},
I1,
Number<B_KRow>{},
I1,
Number<B_K1>{}),
make_tuple(Number<B_K1 * B_KRow>{},
Number<WmmaK>{},
Number<B_K1 * B_KRow>{},
Number<B_K1>{},
Number<B_K1>{},
Number<1>{}));
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<WmmaK / A_K1 / A_KRow>{}, Number<MRepeat>{}, I1, I1, I1, Number<A_K1>{}),
make_tuple(Number<A_K1>{},
Number<WmmaK / A_KRow>{},
Number<A_K1>{},
Number<A_K1>{},
Number<A_K1>{},
I1));
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<WmmaK / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, I1, I1, Number<B_K1>{}),
make_tuple(Number<B_K1>{},
Number<WmmaK / B_KRow>{},
Number<B_K1>{},
Number<B_K1>{},
Number<B_K1>{},
I1));
// C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
......@@ -463,9 +521,9 @@ struct BlockwiseGemmWMMA
using type =
ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_block_desc_),
decltype(a_thread_desc_),
Sequence<WmmaK / A_K1 / A_KRow, 1, 1, A_KRow, 1, A_K1>,
Sequence<WmmaK / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
......@@ -478,7 +536,7 @@ struct BlockwiseGemmWMMA
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow<
FloatA,
FloatA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_block_desc_),
decltype(a_thread_desc_),
tensor_operation::element_wise::PassThrough,
Sequence<WmmaK / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
......@@ -499,9 +557,9 @@ struct BlockwiseGemmWMMA
using type =
ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_block_desc_),
decltype(b_thread_desc_),
Sequence<WmmaK / B_K1 / B_KRow, 1, 1, B_KRow, 1, B_K1>,
Sequence<WmmaK / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,
......@@ -514,7 +572,7 @@ struct BlockwiseGemmWMMA
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow<
FloatB,
FloatB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_block_desc_),
decltype(b_thread_desc_),
tensor_operation::element_wise::PassThrough,
Sequence<WmmaK / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
......
......@@ -448,6 +448,68 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, N>
__device__ OutputArray operator()(InputArray const& Input) { return convert(Input); }
};
template <typename InputDataType, index_t RegPackNumber>
struct InterRowPermuter
{
};
template <>
struct InterRowPermuter<ck::half_t, 2>
{
using InputArray = vector_type<ck::half_t, 2>;
using OutputArray = vector_type<ck::half_t, 2>;
__device__ static OutputArray convert(InputArray const& Input)
{
OutputArray Output;
uint32_t* output_half_2 = reinterpret_cast<uint32_t*>(&Output);
uint32_t const input_half_2 = reinterpret_cast<uint32_t const&>(Input);
output_half_2[0] = __builtin_amdgcn_permlanex16(
output_half_2[0], input_half_2, 0x76543210, 0xfedcba98, 1, 0);
#if 0
if(get_thread_local_1d_id() == 0)
{
printf("After permlanex, input: %04x, output: %04x\n", input_half_2, output_half_2);
}
#endif
return Output;
}
__device__ OutputArray operator()(InputArray const& Input) { return convert(Input); }
};
template <index_t N>
struct InterRowPermuter<ck::half_t, N>
{
static constexpr int VEC_WIDTH = 2;
static_assert(!(N % VEC_WIDTH), "N must be multiple of 2.");
using InputArray = vector_type<ck::half_t, N>;
using OutputArray = vector_type<ck::half_t, N>;
__device__ static OutputArray convert(InputArray const& Input)
{
InterRowPermuter<ck::half_t, 2> converter;
OutputArray Output;
using Vec_InputArray = vector_type<ck::half_t, 2>;
using Vec_OutputArray = vector_type<ck::half_t, 2>;
Vec_OutputArray* output_half_2_ptr = reinterpret_cast<Vec_OutputArray*>(&Output);
Vec_InputArray const* input_half_2_ptr = reinterpret_cast<Vec_InputArray const*>(&Input);
static_for<0, N / VEC_WIDTH, 1>{}(
[&](auto i) { output_half_2_ptr[i] = converter(input_half_2_ptr[i]); });
return Output;
}
__device__ OutputArray operator()(InputArray const& Input) { return convert(Input); }
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
......@@ -498,15 +498,23 @@ struct GridwiseGemmMultipleD_Wmma
if constexpr(AEnableLds)
{
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
constexpr auto A_KRow = I1;
// Debug this part
constexpr auto A_KRow = 2;
constexpr auto A_K0PerRow = ABlockDesc_{}.GetLength(I0) / A_KRow;
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
// return make_naive_tensor_descriptor_packed(make_tuple(Number<A_K0PerRow>{},
// Number<MRepeat>{},
// Number<MWaves>{},
// Number<A_KRow>{},
// Number<MPerWmma>{},
// Number<A_K1>{}));
return transform_tensor_descriptor(
ABlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(
make_unmerge_transform(make_tuple(Number<A_K0PerRow>{}, Number<A_KRow>{})),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
}
......@@ -537,17 +545,31 @@ struct GridwiseGemmMultipleD_Wmma
if constexpr(BEnableLds)
{
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
constexpr auto B_KRow = I1;
#if 1
constexpr auto B_KRow = 2;
constexpr auto B_K0PerRow = BBlockDesc_{}.GetLength(I0) / B_KRow;
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
return transform_tensor_descriptor(
BBlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(
make_unmerge_transform(make_tuple(Number<B_K0PerRow>{}, Number<B_KRow>{})),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
#endif
#if 0
constexpr auto B_KRow = 2;
constexpr auto B_K0PerRow = BBlockDesc_{}.GetLength(I0) / B_KRow;
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
return make_naive_tensor_descriptor_packed(make_tuple(Number<B_K0PerRow>{},
Number<NRepeat>{},
Number<NWaves>{},
Number<B_KRow>{},
Number<NPerWmma>{},
Number<B_K1>{}));
#endif
}
else
{
......
......@@ -1136,7 +1136,17 @@ struct ThreadwiseTensorSliceTransfer_v4
auto src_data_coord = src_ref_coord_;
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
#if 0
printf("Tid: %03d, Inele_Offset: %d, Coord: (%d, %d, %d, %d, %d, %d)\n",
get_thread_local_1d_id(),
src_data_coord.GetOffset(),
src_data_coord.GetIndex().At(Number<0>{}),
src_data_coord.GetIndex().At(Number<1>{}),
src_data_coord.GetIndex().At(Number<2>{}),
src_data_coord.GetIndex().At(Number<3>{}),
src_data_coord.GetIndex().At(Number<4>{}),
src_data_coord.GetIndex().At(Number<5>{}));
#endif
vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector;
using src_vector_t = typename decltype(src_tmp_vector)::type;
......@@ -1178,6 +1188,12 @@ struct ThreadwiseTensorSliceTransfer_v4
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
});
#if 0
printf("Tid: %03d, Inele_Offset: %d\n",
get_thread_local_1d_id(),
dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx));
#endif
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment