Commit c713d224 authored by aska-0096's avatar aska-0096
Browse files

Update low level abstration of blockwise gemm wmma

parent 2ec3f4c3
...@@ -80,6 +80,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -80,6 +80,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{}; static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
// K1 = Max Vector Access Pixels // K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{}; static constexpr auto K1Number = Number<K1>{};
...@@ -136,18 +137,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -136,18 +137,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
} }
else else
{ {
constexpr auto A_KRow = WmmaK / K1; constexpr auto A_KRow = 2;
constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
const auto A_KWmma = K / WmmaK; const auto A_KWmma = K / WmmaK;
const auto M0 = M / MPerBlock; const auto M0 = M / MPerBlock;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_k, a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(A_KWmma, Number<A_KRow>{}, K1Number)), make_tuple(make_unmerge_transform(make_tuple(
A_KWmma, Number<A_K0PerWmma>{}, Number<A_KRow>{}, K1Number)),
make_unmerge_transform( make_unmerge_transform(
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))), make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{})); make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
} }
} }
...@@ -187,18 +191,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -187,18 +191,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
} }
else else
{ {
constexpr auto B_KRow = WmmaK / K1; constexpr auto B_KRow = 2;
constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
const auto B_KWmma = K / WmmaK; const auto B_KWmma = K / WmmaK;
const auto N0 = N / NPerBlock; const auto N0 = N / NPerBlock;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return transform_tensor_descriptor( return transform_tensor_descriptor(
b_grid_desc_n_k, b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B_KWmma, Number<B_KRow>{}, K1Number)), make_tuple(make_unmerge_transform(make_tuple(
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
make_unmerge_transform( make_unmerge_transform(
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))), make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{})); make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
} }
} }
...@@ -372,7 +379,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -372,7 +379,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
else else
{ {
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) * return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
arg.a_grid_desc_.GetLength(I5); arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6);
} }
}(); }();
......
...@@ -297,15 +297,7 @@ struct GridwiseGemmPipeline_v1<1, false, true> ...@@ -297,15 +297,7 @@ struct GridwiseGemmPipeline_v1<1, false, true>
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop) index_t num_loop)
{ {
#if 0 constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
constexpr auto a_block_origin_idx = generate_sequence_v2(
[]() constexpr {
return Number<0>{};
},
Number<a_block_desc.GetLengths().GetSize()>{});
#endif
constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0);
auto a_block_buf_switch = a_block_buf; auto a_block_buf_switch = a_block_buf;
// preload data into LDS // preload data into LDS
...@@ -404,7 +396,7 @@ struct GridwiseGemmPipeline_v1<1, true, false> ...@@ -404,7 +396,7 @@ struct GridwiseGemmPipeline_v1<1, true, false>
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop) index_t num_loop)
{ {
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0); constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
auto b_block_buf_switch = b_block_buf; auto b_block_buf_switch = b_block_buf;
// preload data into LDS // preload data into LDS
......
...@@ -172,10 +172,23 @@ struct GridwiseGemm_Wmma ...@@ -172,10 +172,23 @@ struct GridwiseGemm_Wmma
else else
{ {
constexpr auto KWmmaPerblock = KPerBlock / WmmaK; constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
// KWmma->MRepeat->MWave->KRow->MPerWmma->K1 Per Thread constexpr auto K0PerWmma = WmmaK / 2 / K1;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{}, Number<MRepeat>{}, I1, I1, I1, K1), make_tuple(Number<KWmmaPerblock>{},
make_tuple(Number<MRepeat>{} * K1, K1, K1, K1, K1, I1)); Number<MRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
K1),
make_tuple(Number<MRepeat>{} * Number<K0PerWmma>{} * K1,
Number<K0PerWmma>{} * K1,
Number<K0PerWmma>{} * K1,
K1,
K1,
K1,
I1));
} }
}(); }();
...@@ -206,10 +219,23 @@ struct GridwiseGemm_Wmma ...@@ -206,10 +219,23 @@ struct GridwiseGemm_Wmma
else else
{ {
constexpr auto KWmmaPerblock = KPerBlock / WmmaK; constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
// KWmma->NRepeat->NWave->NRow->NPerWmma->K1 Per Thread constexpr auto K0PerWmma = WmmaK / 2 / K1;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{}, Number<NRepeat>{}, I1, I1, I1, K1), make_tuple(Number<KWmmaPerblock>{},
make_tuple(Number<NRepeat>{} * K1, K1, K1, K1, K1, I1)); Number<NRepeat>{},
I1,
Number<K0PerWmma>{},
I1,
I1,
K1),
make_tuple(Number<NRepeat>{} * Number<K0PerWmma>{} * K1,
Number<K0PerWmma>{} * K1,
Number<K0PerWmma>{} * K1,
K1,
K1,
K1,
I1));
} }
}(); }();
...@@ -229,7 +255,7 @@ struct GridwiseGemm_Wmma ...@@ -229,7 +255,7 @@ struct GridwiseGemm_Wmma
{ {
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0); return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
} }
}(); }();
...@@ -249,7 +275,7 @@ struct GridwiseGemm_Wmma ...@@ -249,7 +275,7 @@ struct GridwiseGemm_Wmma
{ {
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0); return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
} }
}(); }();
...@@ -264,23 +290,26 @@ struct GridwiseGemm_Wmma ...@@ -264,23 +290,26 @@ struct GridwiseGemm_Wmma
constexpr auto a_wave_desc = [&]() { constexpr auto a_wave_desc = [&]() {
if constexpr(AEnableLds) if constexpr(AEnableLds)
{ {
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_MPerWmma_AK1 // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
constexpr auto A_KRow = I1;
return transform_tensor_descriptor( return transform_tensor_descriptor(
ABlockDesc_{}, ABlockDesc_{},
make_tuple(make_pass_through_transform(Number<A_K0>{}), make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})), Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})), make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
} }
else else
{ {
// KWmma_MRepeat_MWave_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr auto KWmma = ABlockDesc_{}.GetLength(I0); constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I5); constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3);
constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6);
// Err: merge transform cause non-constexpr issue // Err: merge transform cause non-constexpr issue
...@@ -301,26 +330,12 @@ struct GridwiseGemm_Wmma ...@@ -301,26 +330,12 @@ struct GridwiseGemm_Wmma
// Sequence<4>{})); // Sequence<4>{}));
// Workaround, Freeze transform // Workaround, Freeze transform
return transform_tensor_descriptor( return make_naive_tensor_descriptor_packed(make_tuple(Number<KWmma * K0PerWmma>{},
ABlockDesc_{}, Number<MRepeat>{},
make_tuple(make_freeze_transform(I0), I1,
make_pass_through_transform(Number<KWmma>{}), Number<A_KRow>{},
make_pass_through_transform(Number<MRepeat>{}), I1,
make_pass_through_transform(I1), Number<A_K1>{}));
make_pass_through_transform(I1),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<3>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}));
} }
}(); }();
...@@ -336,42 +351,31 @@ struct GridwiseGemm_Wmma ...@@ -336,42 +351,31 @@ struct GridwiseGemm_Wmma
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
constexpr auto B_KRow = I1;
return transform_tensor_descriptor( return transform_tensor_descriptor(
BBlockDesc_{}, BBlockDesc_{},
make_tuple(make_pass_through_transform(Number<B_K0>{}), make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})), Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})), make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
} }
else else
{ {
// KWmma_NRepeat_NWave_KRow_NPerWmma_K1 -> K0_NRepeat_Nwaves_NPerWmma_K1 // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr auto KWmma = BBlockDesc_{}.GetLength(I0); constexpr auto KWmma = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I5); constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3);
constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6);
// Workaround, Freeze transform // Workaround, Freeze transform
return transform_tensor_descriptor( return make_naive_tensor_descriptor_packed(make_tuple(Number<KWmma * K0PerWmma>{},
BBlockDesc_{}, Number<NRepeat>{},
make_tuple(make_freeze_transform(I0), I1,
make_pass_through_transform(Number<KWmma>{}), Number<B_KRow>{},
make_pass_through_transform(Number<NRepeat>{}), I1,
make_pass_through_transform(I1), Number<B_K1>{}));
make_pass_through_transform(I1),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<3>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}));
} }
}(); }();
...@@ -415,9 +419,9 @@ struct GridwiseGemm_Wmma ...@@ -415,9 +419,9 @@ struct GridwiseGemm_Wmma
else else
{ {
return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) * return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) *
a_grid_desc.GetLength(I4), a_grid_desc.GetLength(I5),
a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) *
a_grid_desc.GetLength(I5)); a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6));
} }
}; };
...@@ -430,9 +434,9 @@ struct GridwiseGemm_Wmma ...@@ -430,9 +434,9 @@ struct GridwiseGemm_Wmma
else else
{ {
return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) * return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) *
b_grid_desc.GetLength(I4), b_grid_desc.GetLength(I5),
b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) * b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) *
b_grid_desc.GetLength(I5)); b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6));
} }
}; };
...@@ -599,7 +603,8 @@ struct GridwiseGemm_Wmma ...@@ -599,7 +603,8 @@ struct GridwiseGemm_Wmma
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2); return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2);
} }
else{ else{
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * a_grid_desc.GetLength(I5); return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3)
* a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6);
} }
}(); }();
...@@ -652,6 +657,7 @@ struct GridwiseGemm_Wmma ...@@ -652,6 +657,7 @@ struct GridwiseGemm_Wmma
// Thread-wise copy // Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1 // KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK/2/K1Value;
auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>( auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
a_block_desc.GetElementSpaceSize()); a_block_desc.GetElementSpaceSize());
...@@ -664,11 +670,12 @@ struct GridwiseGemm_Wmma ...@@ -664,11 +670,12 @@ struct GridwiseGemm_Wmma
Sequence<Number<KWmmaPerBlock>{}, Sequence<Number<KWmmaPerBlock>{},
Number<MRepeat>{}, Number<MRepeat>{},
I1, I1,
Number<K0PerWmma>{},
I1, I1,
I1, I1,
Number<K1Value>{}>, Number<K1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5, 6>,
5, 6,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
...@@ -676,6 +683,7 @@ struct GridwiseGemm_Wmma ...@@ -676,6 +683,7 @@ struct GridwiseGemm_Wmma
make_multi_index(0, make_multi_index(0,
m_block_data_idx_on_grid/(MWaves * MPerWmma), m_block_data_idx_on_grid/(MWaves * MPerWmma),
get_thread_local_1d_id() / 32, get_thread_local_1d_id() / 32,
0,
(get_thread_local_1d_id() % 32 )/ 16, (get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16, get_thread_local_1d_id() % 16,
0)); 0));
...@@ -729,6 +737,7 @@ struct GridwiseGemm_Wmma ...@@ -729,6 +737,7 @@ struct GridwiseGemm_Wmma
// Thread-wise copy // Thread-wise copy
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1 // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK/2/K1Value;
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>( auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_block_desc.GetElementSpaceSize()); b_block_desc.GetElementSpaceSize());
...@@ -741,11 +750,12 @@ struct GridwiseGemm_Wmma ...@@ -741,11 +750,12 @@ struct GridwiseGemm_Wmma
Sequence<Number<KWmmaPerBlock>{}, Sequence<Number<KWmmaPerBlock>{},
Number<NRepeat>{}, Number<NRepeat>{},
I1, I1,
Number<K0PerWmma>{},
I1, I1,
I1, I1,
Number<K1Value>{}>, Number<K1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5, 6>,
5, 6,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
...@@ -753,6 +763,7 @@ struct GridwiseGemm_Wmma ...@@ -753,6 +763,7 @@ struct GridwiseGemm_Wmma
make_multi_index(0, make_multi_index(0,
n_block_data_idx_on_grid/(NWaves * NPerWmma), n_block_data_idx_on_grid/(NWaves * NPerWmma),
get_thread_local_1d_id() / 32, get_thread_local_1d_id() / 32,
0,
(get_thread_local_1d_id() % 32 )/ 16, (get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16, get_thread_local_1d_id() % 16,
0)); 0));
......
...@@ -1387,7 +1387,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow ...@@ -1387,7 +1387,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
// copy data from src_buf into dst_vector // copy data from src_buf into dst_vector
static_for<0, DstScalarPerVector, 1>{}([&](auto i) { static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
// src_desc error, non constexpr? // src_desc error, non constexpr, caused by merge transform
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
...@@ -1396,8 +1396,6 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow ...@@ -1396,8 +1396,6 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
SrcData v_this_row, v_theother_row; SrcData v_this_row, v_theother_row;
// int type temp value due to intrinsic requirement // int type temp value due to intrinsic requirement
// TODO: This temp value will generate the scratch memory if
// IntraRowSwizzlePerm is flase
int temp = 0; int temp = 0;
// apply element-wise operation // apply element-wise operation
...@@ -1419,7 +1417,10 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow ...@@ -1419,7 +1417,10 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
1, 1,
0); 0);
v_theother_row = type_convert_sp<SrcData>(temp); v_theother_row = type_convert_sp<SrcData>(temp);
// if (get_thread_local_1d_id() == 0){
// printf("src_offset:%d, dst_offset for this row: %d, dst_offset
// for the other row: %d \n",
// src_offset, dst_offset, dst_offset+DstScalarPerVector);}
if(get_thread_local_1d_id() % 32 < 16) if(get_thread_local_1d_id() % 32 < 16)
{ {
// apply type convert // apply type convert
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment