Commit 6e3b47c3 authored by ltqin's avatar ltqin
Browse files

init version, 4 buffer

parent 83e6a4b9
...@@ -117,12 +117,12 @@ int main(int argc, char* argv[]) ...@@ -117,12 +117,12 @@ int main(int argc, char* argv[])
ck::index_t StrideC = N; ck::index_t StrideC = N;
#else #else
ck::index_t M = 16; ck::index_t M = 16;
ck::index_t N = 16; ck::index_t N = 32;
ck::index_t K = 32; ck::index_t K = 16;
ck::index_t StrideA = 8; ck::index_t StrideA = K;
ck::index_t StrideB = 8; ck::index_t StrideB = K;
ck::index_t StrideC = 16; ck::index_t StrideC = N;
#endif #endif
if(argc == 4) if(argc == 4)
...@@ -189,7 +189,7 @@ int main(int argc, char* argv[]) ...@@ -189,7 +189,7 @@ int main(int argc, char* argv[])
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break; break;
default: default:
// a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); //a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); b_k_n.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
} }
......
...@@ -32,9 +32,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 ...@@ -32,9 +32,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
static constexpr index_t KPerBlock = K0PerBlock * KPack; static constexpr index_t KPerBlock = K0PerBlock * KPack;
static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack>{}; static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
...@@ -219,46 +216,13 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 ...@@ -219,46 +216,13 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
c_grid_desc_g_m0_n0_m1_n1_m2_n2); c_grid_desc_g_m0_n0_m1_n1_m2_n2);
} }
__host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K()
{
return transform_tensor_descriptor(
AK0MK1BlockDesc{},
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(Number<A_K0>{}, Number<A_K1>{})),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerXDL>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
__device__ void MoveABlockSliceWindow()
{
a_thread_copy_.MoveSrcSliceWindow(a_block_desc_m0_m1_m2_k,
make_multi_index(0, 0, 0, K0PerBlock * KPack));
}
__device__ void ResetABlockStartWindow()
{
a_thread_copy_.SetSrcCoord(CalculateAThreadOriginDataIndex());
}
static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K();
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf, __device__ void Run(const ABlockBuffer& a_thread_buf,
const BBlockBuffer& b_thread_buf, const BBlockBuffer& b_thread_buf,
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
{ {
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize());
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A // read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B // read B
...@@ -268,7 +232,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 ...@@ -268,7 +232,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
constexpr index_t k0 = k / KPack; constexpr index_t k0 = k / KPack;
static_for<0, KPack, 1>{}([&](auto i) { static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}]; [Number<a_thread_desc_.CalculateOffset(make_tuple(k0, m0, i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(k0, n0, i))>{}]; [Number<b_thread_desc_.CalculateOffset(make_tuple(k0, n0, i))>{}];
}); });
...@@ -291,7 +255,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 ...@@ -291,7 +255,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
private: private:
// A[M0, M1, M2, KPerThread] // A[M0, M1, M2, KPerThread]
static constexpr auto a_thread_desc_ = static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<K0PerThread>{}, // KPerThread
Number<MRepeat>{}, // repeat
Number<KPack>{}));
// B[N0, N1, N2, KPerThread] // B[N0, N1, N2, KPerThread]
static constexpr auto b_thread_desc_ = static constexpr auto b_thread_desc_ =
...@@ -302,18 +268,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 ...@@ -302,18 +268,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
// C[M, N, NumRegXdlops] // C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops())); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerThread>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
A_K1>;
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
}; };
} // namespace ck } // namespace ck
......
...@@ -113,7 +113,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -113,7 +113,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
static constexpr auto I7 = Number<7>{}; static constexpr auto I7 = Number<7>{};
static constexpr auto BaseMultK0 = 4; static constexpr auto BaseMultK0 = 4;
static constexpr auto MultiK0 = BaseMultK0 * 2; static constexpr auto MultiK0 = BaseMultK0 * 1;
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{}; static constexpr auto K1 = Number<K1Value>{};
...@@ -153,14 +153,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -153,14 +153,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); return 1;
constexpr auto max_lds_align = K1;
constexpr auto a_block_space_size_aligned =
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size_aligned) * sizeof(FloatAB);
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...@@ -238,16 +231,34 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -238,16 +231,34 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
const auto b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1 = transform_tensor_descriptor( const auto b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1 = transform_tensor_descriptor(
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
make_tuple(make_unmerge_transform( make_tuple(
make_tuple(K0 / K0PerBlock, xdlops_gemm.K0PerXdlops, K0PerThread)), make_unmerge_transform(
make_unmerge_transform(make_tuple( make_tuple(K0 / K0PerBlock, xdlops_gemm.K0PerXdlops, K0PerThread)),
N / (NXdlPerWave * NWaves * NPerXDL), NXdlPerWave, NWaves, NPerXDL)), make_unmerge_transform(make_tuple(N / NPerBlock, NXdlPerWave, NWaves, NPerXDL)),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}, Sequence<7>{})); make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}, Sequence<7>{}));
return b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1; return b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1;
} }
__host__ __device__ static constexpr auto
MakeAGridDescriptor_K0_K1_K2_M0_M1_M2_M3_K3(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1)
{
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto a_griddesc_k0_mblockid_mrepeat_waves_mperxdlops_k1 = transform_tensor_descriptor(
a_grid_desc_k0_m_k1,
make_tuple(
make_unmerge_transform(
make_tuple(K0 / K0PerBlock, xdlops_gemm.K0PerXdlops, K0PerThread)),
make_unmerge_transform(make_tuple(M / MPerBlock, MXdlPerWave, MWaves, MPerXDL)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}, Sequence<7>{}));
return a_griddesc_k0_mblockid_mrepeat_waves_mperxdlops_k1;
}
__device__ static auto GetWaveIdx() __device__ static auto GetWaveIdx()
{ {
const index_t thread_id = get_thread_local_1d_id(); const index_t thread_id = get_thread_local_1d_id();
...@@ -262,12 +273,22 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -262,12 +273,22 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
__device__ static auto GetWaveKNIdx(const index_t thread_id) __device__ static auto GetWaveKNIdx(const index_t thread_id)
{ {
constexpr auto wave_threadid_to_nk_idx_adaptor = make_single_stage_tensor_adaptor( constexpr auto wave_threadid_to_kn_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(xdlops_gemm.K0PerXdlops, NPerXDL))), make_tuple(make_merge_transform(make_tuple(xdlops_gemm.K0PerXdlops, NPerXDL))),
make_tuple(Sequence<0, 1>{}), make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
return wave_threadid_to_nk_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); return wave_threadid_to_kn_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto GetWaveKMIdx(const index_t thread_id)
{
constexpr auto wave_threadid_to_km_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(xdlops_gemm.K0PerXdlops, MPerXDL))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return wave_threadid_to_km_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -375,6 +396,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -375,6 +396,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
ignore = b_element_op;
ignore = a_element_op;
ignore = p_shared;
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -395,41 +419,56 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -395,41 +419,56 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
const index_t n_block_data_idx_on_grid = const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// A matrix in LDS memory, dst of blockwise copy const auto wave_id = GetWaveIdx();
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
const auto a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3 =
MakeAGridDescriptor_K0_K1_K2_M0_M1_M2_M3_K3(a_grid_desc_k0_m_k1);
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = constexpr auto a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3 =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, make_naive_tensor_descriptor_packed(make_tuple(I1,
AElementwiseOperation, I1,
ck::tensor_operation::element_wise::PassThrough, Number<K0PerThread>{}, // K0PerThread
InMemoryDataOperationEnum::Set, I1, // NBlockId
Sequence<K0PerBlock * MultiK0, MPerBlock, K1>, Number<MXdlPerWave>{}, // repeat
ABlockTransferThreadClusterLengths_K0_M_K1, I1, // waves
ABlockTransferThreadClusterArrangeOrder, I1, // MPerXdlops
FloatAB, Number<K1>{}));
FloatAB,
decltype(a_grid_desc_k0_m_k1), auto a_thread_buf = generate_tuple(
decltype(a_block_desc_k0_m_k1), [&](auto i) {
ABlockTransferSrcAccessOrder, ignore = i;
Sequence<1, 0, 2>, return StaticBuffer<AddressSpaceEnum::Vgpr,
ABlockTransferSrcVectorDim, FloatAB,
2, a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3.GetElementSpaceSize(),
ABlockTransferSrcScalarPerVector, true>{};
ABlockTransferDstScalarPerVector_K1, },
1, Number<4>{});
1,
AThreadTransferSrcResetCoordinateAfterRun, const auto wave_k_m_id = GetWaveKMIdx(wave_id[I2]);
true, auto a_threadwise_copy =
1>( ThreadwiseTensorSliceTransfer_v2<FloatAB,
a_grid_desc_k0_m_k1, FloatAB,
make_multi_index(0, m_block_data_idx_on_grid, 0), decltype(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3),
a_element_op, decltype(a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3),
a_block_desc_k0_m_k1, Sequence<I1,
make_multi_index(0, 0, 0), I1,
ck::tensor_operation::element_wise::PassThrough{}); Number<K0PerThread>{},
I1,
ignore = b_element_op; Number<MXdlPerWave>{},
I1,
I1,
Number<K1>{}>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
ABlockTransferSrcScalarPerVector,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_multi_index(
0, wave_k_m_id[I0], 0, block_work_idx[I0], 0, wave_id[I0], wave_k_m_id[I1], 0));
ignore = a_threadwise_copy;
ignore = a_thread_buf;
// B matrix threadwise copy // B matrix threadwise copy
constexpr auto b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3 = constexpr auto b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
make_naive_tensor_descriptor_packed(make_tuple(I1, make_naive_tensor_descriptor_packed(make_tuple(I1,
...@@ -451,7 +490,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -451,7 +490,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
}, },
Number<4>{}); Number<4>{});
const auto wave_id = GetWaveIdx();
const auto wave_k_n_id = GetWaveKNIdx(wave_id[I2]); const auto wave_k_n_id = GetWaveKNIdx(wave_id[I2]);
#if 0 #if 0
...@@ -505,7 +543,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -505,7 +543,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0_m_k1), decltype(a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3),
decltype(b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3), decltype(b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3),
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
...@@ -518,42 +556,83 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -518,42 +556,83 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
// gridwise GEMM pipeline // gridwise GEMM pipeline
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock * MultiK0, 0, 0); constexpr auto a_thread_slice_copy_step = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
constexpr auto b_thread_slice_copy_step = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0); constexpr auto b_thread_slice_copy_step = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
// preload data to regiester and LDS // preload data to regiester and LDS
{ {
// Read // Read
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); auto read_first_half_data = [&]() {
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
b_grid_buf, a_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<0>{})); a_thread_buf(Number<0>{}));
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
// Move a_thread_slice_copy_step);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step); b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<0>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf(Number<1>{}));
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<1>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
};
auto read_last_half_data = [&]() {
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf(Number<2>{}));
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<2>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf(Number<3>{}));
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<3>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
};
read_first_half_data();
// Initialize C // Initialize C
c_thread_buf.Clear(); c_thread_buf.Clear();
// a data write to lds
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
// load 2nd a matrix data
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<1>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
// main body // main body
if constexpr(HasMainK0BlockLoop) if constexpr(HasMainK0BlockLoop)
{ {
...@@ -562,65 +641,24 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -562,65 +641,24 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
index_t i = 0; index_t i = 0;
do do
{ {
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
blockwise_gemm.ResetABlockStartWindow(); // 1st
block_sync_lds(); read_last_half_data();
s_nop();
static_for<0, MultiK0, BaseMultK0>{}([&](auto) {
// 1st blockwise_gemm.Run(
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, a_thread_buf(Number<0>{}), b_thread_buf(Number<0>{}), c_thread_buf);
b_grid_buf, blockwise_gemm.Run(
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, a_thread_buf(Number<1>{}), b_thread_buf(Number<1>{}), c_thread_buf);
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<2>{})); read_first_half_data();
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step); s_nop();
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf, blockwise_gemm.Run(
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, a_thread_buf(Number<2>{}), b_thread_buf(Number<2>{}), c_thread_buf);
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), blockwise_gemm.Run(
b_thread_buf(Number<3>{})); a_thread_buf(Number<3>{}), b_thread_buf(Number<3>{}), c_thread_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
s_nop();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<0>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
// 2nd
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<1>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
// 3rd
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<0>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<1>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
s_nop();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<2>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
// 4th
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<3>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
});
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
// move a and b window
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
a_block_slice_copy_step);
i += 1; i += 1;
} while(i < (K0BlockMainLoop - 1)); } while(i < (K0BlockMainLoop - 1));
...@@ -628,65 +666,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -628,65 +666,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
// tail // tail
{ {
block_sync_lds();
blockwise_gemm.ResetABlockStartWindow(); // 1st
static_for<0, MultiK0, BaseMultK0>{}([&](auto i) { read_last_half_data();
// 1st s_nop();
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, blockwise_gemm.Run(
b_grid_buf, a_thread_buf(Number<0>{}), b_thread_buf(Number<0>{}), c_thread_buf);
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, blockwise_gemm.Run(
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), a_thread_buf(Number<1>{}), b_thread_buf(Number<1>{}), c_thread_buf);
b_thread_buf(Number<2>{})); blockwise_gemm.Run(
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, a_thread_buf(Number<2>{}), b_thread_buf(Number<2>{}), c_thread_buf);
b_thread_slice_copy_step); blockwise_gemm.Run(
a_thread_buf(Number<3>{}), b_thread_buf(Number<3>{}), c_thread_buf);
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<0>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
// 2nd
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<3>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<1>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
// 3rd
if constexpr(i < MultiK0 - BaseMultK0)
{
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<0>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
}
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<2>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
// 4th
if constexpr(i < MultiK0 - BaseMultK0)
{
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<1>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
}
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<3>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
});
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment