Commit 6e3b47c3 authored by ltqin's avatar ltqin
Browse files

init version, 4 buffer

parent 83e6a4b9
......@@ -117,12 +117,12 @@ int main(int argc, char* argv[])
ck::index_t StrideC = N;
#else
ck::index_t M = 16;
ck::index_t N = 16;
ck::index_t K = 32;
ck::index_t N = 32;
ck::index_t K = 16;
ck::index_t StrideA = 8;
ck::index_t StrideB = 8;
ck::index_t StrideC = 16;
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideC = N;
#endif
if(argc == 4)
......@@ -189,7 +189,7 @@ int main(int argc, char* argv[])
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
default:
// a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
//a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
}
......
......@@ -32,9 +32,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
static constexpr index_t KPerBlock = K0PerBlock * KPack;
static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
......@@ -219,46 +216,13 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
c_grid_desc_g_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K()
{
return transform_tensor_descriptor(
AK0MK1BlockDesc{},
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(Number<A_K0>{}, Number<A_K1>{})),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerXDL>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
__device__ void MoveABlockSliceWindow()
{
a_thread_copy_.MoveSrcSliceWindow(a_block_desc_m0_m1_m2_k,
make_multi_index(0, 0, 0, K0PerBlock * KPack));
}
__device__ void ResetABlockStartWindow()
{
a_thread_copy_.SetSrcCoord(CalculateAThreadOriginDataIndex());
}
static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K();
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
__device__ void Run(const ABlockBuffer& a_thread_buf,
const BBlockBuffer& b_thread_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize());
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
......@@ -268,7 +232,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
constexpr index_t k0 = k / KPack;
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
[Number<a_thread_desc_.CalculateOffset(make_tuple(k0, m0, i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(k0, n0, i))>{}];
});
......@@ -291,7 +255,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
private:
// A[M0, M1, M2, KPerThread]
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
make_naive_tensor_descriptor_packed(make_tuple(Number<K0PerThread>{}, // KPerThread
Number<MRepeat>{}, // repeat
Number<KPack>{}));
// B[N0, N1, N2, KPerThread]
static constexpr auto b_thread_desc_ =
......@@ -302,18 +268,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
// C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerThread>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
A_K1>;
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
};
} // namespace ck
......
......@@ -113,7 +113,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
static constexpr auto I7 = Number<7>{};
static constexpr auto BaseMultK0 = 4;
static constexpr auto MultiK0 = BaseMultK0 * 2;
static constexpr auto MultiK0 = BaseMultK0 * 1;
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
......@@ -153,14 +153,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
constexpr auto max_lds_align = K1;
constexpr auto a_block_space_size_aligned =
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size_aligned) * sizeof(FloatAB);
return 1;
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
......@@ -238,16 +231,34 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
const auto b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1 = transform_tensor_descriptor(
b_grid_desc_k0_n_k1,
make_tuple(make_unmerge_transform(
make_tuple(K0 / K0PerBlock, xdlops_gemm.K0PerXdlops, K0PerThread)),
make_unmerge_transform(make_tuple(
N / (NXdlPerWave * NWaves * NPerXDL), NXdlPerWave, NWaves, NPerXDL)),
make_pass_through_transform(K1)),
make_tuple(
make_unmerge_transform(
make_tuple(K0 / K0PerBlock, xdlops_gemm.K0PerXdlops, K0PerThread)),
make_unmerge_transform(make_tuple(N / NPerBlock, NXdlPerWave, NWaves, NPerXDL)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}, Sequence<7>{}));
return b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1;
}
__host__ __device__ static constexpr auto
MakeAGridDescriptor_K0_K1_K2_M0_M1_M2_M3_K3(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1)
{
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto a_griddesc_k0_mblockid_mrepeat_waves_mperxdlops_k1 = transform_tensor_descriptor(
a_grid_desc_k0_m_k1,
make_tuple(
make_unmerge_transform(
make_tuple(K0 / K0PerBlock, xdlops_gemm.K0PerXdlops, K0PerThread)),
make_unmerge_transform(make_tuple(M / MPerBlock, MXdlPerWave, MWaves, MPerXDL)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}, Sequence<7>{}));
return a_griddesc_k0_mblockid_mrepeat_waves_mperxdlops_k1;
}
__device__ static auto GetWaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
......@@ -262,12 +273,22 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
__device__ static auto GetWaveKNIdx(const index_t thread_id)
{
constexpr auto wave_threadid_to_nk_idx_adaptor = make_single_stage_tensor_adaptor(
constexpr auto wave_threadid_to_kn_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(xdlops_gemm.K0PerXdlops, NPerXDL))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return wave_threadid_to_nk_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
return wave_threadid_to_kn_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto GetWaveKMIdx(const index_t thread_id)
{
constexpr auto wave_threadid_to_km_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(xdlops_gemm.K0PerXdlops, MPerXDL))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return wave_threadid_to_km_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__host__ __device__ static constexpr auto
......@@ -375,6 +396,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map)
{
ignore = b_element_op;
ignore = a_element_op;
ignore = p_shared;
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -395,41 +419,56 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
const auto wave_id = GetWaveIdx();
const auto a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3 =
MakeAGridDescriptor_K0_K1_K2_M0_M1_M2_M3_K3(a_grid_desc_k0_m_k1);
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock * MultiK0, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_grid_desc_k0_m_k1),
decltype(a_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
1>(
a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_k0_m_k1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
ignore = b_element_op;
constexpr auto a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
I1,
Number<K0PerThread>{}, // K0PerThread
I1, // NBlockId
Number<MXdlPerWave>{}, // repeat
I1, // waves
I1, // MPerXdlops
Number<K1>{}));
auto a_thread_buf = generate_tuple(
[&](auto i) {
ignore = i;
return StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAB,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3.GetElementSpaceSize(),
true>{};
},
Number<4>{});
const auto wave_k_m_id = GetWaveKMIdx(wave_id[I2]);
auto a_threadwise_copy =
ThreadwiseTensorSliceTransfer_v2<FloatAB,
FloatAB,
decltype(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3),
decltype(a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3),
Sequence<I1,
I1,
Number<K0PerThread>{},
I1,
Number<MXdlPerWave>{},
I1,
I1,
Number<K1>{}>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
ABlockTransferSrcScalarPerVector,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_multi_index(
0, wave_k_m_id[I0], 0, block_work_idx[I0], 0, wave_id[I0], wave_k_m_id[I1], 0));
ignore = a_threadwise_copy;
ignore = a_thread_buf;
// B matrix threadwise copy
constexpr auto b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
......@@ -451,7 +490,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
},
Number<4>{});
const auto wave_id = GetWaveIdx();
const auto wave_k_n_id = GetWaveKNIdx(wave_id[I2]);
#if 0
......@@ -505,7 +543,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
BlockSize,
FloatAB,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
decltype(a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3),
decltype(b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3),
MPerBlock,
NPerBlock,
......@@ -518,42 +556,83 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
// gridwise GEMM pipeline
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock * MultiK0, 0, 0);
constexpr auto a_thread_slice_copy_step = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
constexpr auto b_thread_slice_copy_step = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
// preload data to regiester and LDS
{
// Read
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<0>{}));
// Move
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
auto read_first_half_data = [&]() {
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf(Number<0>{}));
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<0>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf(Number<1>{}));
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<1>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
};
auto read_last_half_data = [&]() {
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf(Number<2>{}));
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<2>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf(Number<3>{}));
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<3>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
};
read_first_half_data();
// Initialize C
c_thread_buf.Clear();
// a data write to lds
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
// load 2nd a matrix data
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<1>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
// main body
if constexpr(HasMainK0BlockLoop)
{
......@@ -562,65 +641,24 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
index_t i = 0;
do
{
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
blockwise_gemm.ResetABlockStartWindow();
block_sync_lds();
static_for<0, MultiK0, BaseMultK0>{}([&](auto) {
// 1st
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<2>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<3>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
s_nop();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<0>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
// 2nd
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<1>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
// 3rd
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<0>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<1>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
s_nop();
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<2>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
// 4th
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<3>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
});
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
// move a and b window
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
a_block_slice_copy_step);
// 1st
read_last_half_data();
s_nop();
blockwise_gemm.Run(
a_thread_buf(Number<0>{}), b_thread_buf(Number<0>{}), c_thread_buf);
blockwise_gemm.Run(
a_thread_buf(Number<1>{}), b_thread_buf(Number<1>{}), c_thread_buf);
read_first_half_data();
s_nop();
blockwise_gemm.Run(
a_thread_buf(Number<2>{}), b_thread_buf(Number<2>{}), c_thread_buf);
blockwise_gemm.Run(
a_thread_buf(Number<3>{}), b_thread_buf(Number<3>{}), c_thread_buf);
i += 1;
} while(i < (K0BlockMainLoop - 1));
......@@ -628,65 +666,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
// tail
{
block_sync_lds();
blockwise_gemm.ResetABlockStartWindow();
static_for<0, MultiK0, BaseMultK0>{}([&](auto i) {
// 1st
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<2>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<0>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
// 2nd
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<3>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<1>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
// 3rd
if constexpr(i < MultiK0 - BaseMultK0)
{
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<0>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
}
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<2>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
// 4th
if constexpr(i < MultiK0 - BaseMultK0)
{
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<1>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
}
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<3>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
});
// 1st
read_last_half_data();
s_nop();
blockwise_gemm.Run(
a_thread_buf(Number<0>{}), b_thread_buf(Number<0>{}), c_thread_buf);
blockwise_gemm.Run(
a_thread_buf(Number<1>{}), b_thread_buf(Number<1>{}), c_thread_buf);
blockwise_gemm.Run(
a_thread_buf(Number<2>{}), b_thread_buf(Number<2>{}), c_thread_buf);
blockwise_gemm.Run(
a_thread_buf(Number<3>{}), b_thread_buf(Number<3>{}), c_thread_buf);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment