Commit c713d224 authored by aska-0096's avatar aska-0096
Browse files

Update low level abstration of blockwise gemm wmma

parent 2ec3f4c3
......@@ -37,8 +37,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
GemmDefault,
2, // Prefetch stage
128, // BlockSize
128, // MPerBlock
64, // NPerBlock
128, // MPerBlock
64, // NPerBlock
64, // KPerBlock
8, // K1
16, // MPerWmma
......
......@@ -34,7 +34,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
#endif
const int nrepeat = 50;
for(int i = 0; i < nrepeat; ++i)
for(int i = 0; i < nrepeat; ++i)
{
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
}
......
......@@ -80,6 +80,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
// K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{};
......@@ -136,18 +137,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
}
else
{
constexpr auto A_KRow = WmmaK / K1;
const auto A_KWmma = K / WmmaK;
constexpr auto A_KRow = 2;
constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
const auto A_KWmma = K / WmmaK;
const auto M0 = M / MPerBlock;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(A_KWmma, Number<A_KRow>{}, K1Number)),
make_tuple(make_unmerge_transform(make_tuple(
A_KWmma, Number<A_K0PerWmma>{}, Number<A_KRow>{}, K1Number)),
make_unmerge_transform(
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
}
}
......@@ -187,18 +191,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
}
else
{
constexpr auto B_KRow = WmmaK / K1;
const auto B_KWmma = K / WmmaK;
constexpr auto B_KRow = 2;
constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
const auto B_KWmma = K / WmmaK;
const auto N0 = N / NPerBlock;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B_KWmma, Number<B_KRow>{}, K1Number)),
make_tuple(make_unmerge_transform(make_tuple(
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
make_unmerge_transform(
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
}
}
......@@ -372,7 +379,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
else
{
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
arg.a_grid_desc_.GetLength(I5);
arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6);
}
}();
......
......@@ -297,15 +297,7 @@ struct GridwiseGemmPipeline_v1<1, false, true>
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
#if 0
constexpr auto a_block_origin_idx = generate_sequence_v2(
[]() constexpr {
return Number<0>{};
},
Number<a_block_desc.GetLengths().GetSize()>{});
#endif
constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0);
constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
auto a_block_buf_switch = a_block_buf;
// preload data into LDS
......@@ -404,7 +396,7 @@ struct GridwiseGemmPipeline_v1<1, true, false>
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0);
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
auto b_block_buf_switch = b_block_buf;
// preload data into LDS
......
......@@ -172,10 +172,23 @@ struct GridwiseGemm_Wmma
else
{
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
// KWmma->MRepeat->MWave->KRow->MPerWmma->K1 Per Thread
constexpr auto K0PerWmma = WmmaK / 2 / K1;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{}, Number<MRepeat>{}, I1, I1, I1, K1),
make_tuple(Number<MRepeat>{} * K1, K1, K1, K1, K1, I1));
make_tuple(Number<KWmmaPerblock>{},
Number<MRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
K1),
make_tuple(Number<MRepeat>{} * Number<K0PerWmma>{} * K1,
Number<K0PerWmma>{} * K1,
Number<K0PerWmma>{} * K1,
K1,
K1,
K1,
I1));
}
}();
......@@ -206,10 +219,23 @@ struct GridwiseGemm_Wmma
else
{
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
// KWmma->NRepeat->NWave->NRow->NPerWmma->K1 Per Thread
constexpr auto K0PerWmma = WmmaK / 2 / K1;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{}, Number<NRepeat>{}, I1, I1, I1, K1),
make_tuple(Number<NRepeat>{} * K1, K1, K1, K1, K1, I1));
make_tuple(Number<KWmmaPerblock>{},
Number<NRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
K1),
make_tuple(Number<NRepeat>{} * Number<K0PerWmma>{} * K1,
Number<K0PerWmma>{} * K1,
Number<K0PerWmma>{} * K1,
K1,
K1,
K1,
I1));
}
}();
......@@ -229,7 +255,7 @@ struct GridwiseGemm_Wmma
{
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0);
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
}
}();
......@@ -249,7 +275,7 @@ struct GridwiseGemm_Wmma
{
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0);
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
}
}();
......@@ -264,23 +290,26 @@ struct GridwiseGemm_Wmma
constexpr auto a_wave_desc = [&]() {
if constexpr(AEnableLds)
{
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_MPerWmma_AK1
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
constexpr auto A_KRow = I1;
return transform_tensor_descriptor(
ABlockDesc_{},
make_tuple(make_pass_through_transform(Number<A_K0>{}),
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
}
else
{
// KWmma_MRepeat_MWave_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I5);
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3);
constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6);
// Err: merge transform cause non-constexpr issue
......@@ -301,26 +330,12 @@ struct GridwiseGemm_Wmma
// Sequence<4>{}));
// Workaround, Freeze transform
return transform_tensor_descriptor(
ABlockDesc_{},
make_tuple(make_freeze_transform(I0),
make_pass_through_transform(Number<KWmma>{}),
make_pass_through_transform(Number<MRepeat>{}),
make_pass_through_transform(I1),
make_pass_through_transform(I1),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<3>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}));
return make_naive_tensor_descriptor_packed(make_tuple(Number<KWmma * K0PerWmma>{},
Number<MRepeat>{},
I1,
Number<A_KRow>{},
I1,
Number<A_K1>{}));
}
}();
......@@ -334,44 +349,33 @@ struct GridwiseGemm_Wmma
if constexpr(BEnableLds)
{
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
constexpr auto B_KRow = I1;
return transform_tensor_descriptor(
BBlockDesc_{},
make_tuple(make_pass_through_transform(Number<B_K0>{}),
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
}
else
{
// KWmma_NRepeat_NWave_KRow_NPerWmma_K1 -> K0_NRepeat_Nwaves_NPerWmma_K1
constexpr auto KWmma = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I5);
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr auto KWmma = BBlockDesc_{}.GetLength(I0);
constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3);
constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6);
// Workaround, Freeze transform
return transform_tensor_descriptor(
BBlockDesc_{},
make_tuple(make_freeze_transform(I0),
make_pass_through_transform(Number<KWmma>{}),
make_pass_through_transform(Number<NRepeat>{}),
make_pass_through_transform(I1),
make_pass_through_transform(I1),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<3>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}));
return make_naive_tensor_descriptor_packed(make_tuple(Number<KWmma * K0PerWmma>{},
Number<NRepeat>{},
I1,
Number<B_KRow>{},
I1,
Number<B_K1>{}));
}
}();
......@@ -415,9 +419,9 @@ struct GridwiseGemm_Wmma
else
{
return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) *
a_grid_desc.GetLength(I4),
a_grid_desc.GetLength(I5),
a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) *
a_grid_desc.GetLength(I5));
a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6));
}
};
......@@ -430,9 +434,9 @@ struct GridwiseGemm_Wmma
else
{
return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) *
b_grid_desc.GetLength(I4),
b_grid_desc.GetLength(I5),
b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) *
b_grid_desc.GetLength(I5));
b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6));
}
};
......@@ -599,7 +603,8 @@ struct GridwiseGemm_Wmma
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2);
}
else{
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * a_grid_desc.GetLength(I5);
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3)
* a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6);
}
}();
......@@ -652,6 +657,7 @@ struct GridwiseGemm_Wmma
// Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK/2/K1Value;
auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
a_block_desc.GetElementSpaceSize());
......@@ -664,11 +670,12 @@ struct GridwiseGemm_Wmma
Sequence<Number<KWmmaPerBlock>{},
Number<MRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
Number<K1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
ABlockTransferSrcScalarPerVector,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
......@@ -676,6 +683,7 @@ struct GridwiseGemm_Wmma
make_multi_index(0,
m_block_data_idx_on_grid/(MWaves * MPerWmma),
get_thread_local_1d_id() / 32,
0,
(get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16,
0));
......@@ -729,6 +737,7 @@ struct GridwiseGemm_Wmma
// Thread-wise copy
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK/2/K1Value;
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_block_desc.GetElementSpaceSize());
......@@ -741,11 +750,12 @@ struct GridwiseGemm_Wmma
Sequence<Number<KWmmaPerBlock>{},
Number<NRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
Number<K1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
......@@ -753,6 +763,7 @@ struct GridwiseGemm_Wmma
make_multi_index(0,
n_block_data_idx_on_grid/(NWaves * NPerWmma),
get_thread_local_1d_id() / 32,
0,
(get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16,
0));
......
......@@ -1387,7 +1387,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
// copy data from src_buf into dst_vector
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
// src_desc error, non constexpr?
// src_desc error, non constexpr, caused by merge transform
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
......@@ -1396,8 +1396,6 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
SrcData v_this_row, v_theother_row;
// int type temp value due to intrinsic requirement
// TODO: This temp value will generate the scratch memory if
// IntraRowSwizzlePerm is flase
int temp = 0;
// apply element-wise operation
......@@ -1419,7 +1417,10 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
1,
0);
v_theother_row = type_convert_sp<SrcData>(temp);
// if (get_thread_local_1d_id() == 0){
// printf("src_offset:%d, dst_offset for this row: %d, dst_offset
// for the other row: %d \n",
// src_offset, dst_offset, dst_offset+DstScalarPerVector);}
if(get_thread_local_1d_id() % 32 < 16)
{
// apply type convert
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment