Commit 0374f8de authored by Chao Liu's avatar Chao Liu
Browse files

blockwise gemm does 3d*3d=4d

parent 4a661578
...@@ -378,5 +378,180 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -378,5 +378,180 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
AThreadCopy a_thread_copy_; AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_; BThreadCopy b_thread_copy_;
}; };
// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
// A and B are visable to the whole block, C is distributed among each thread
// Assume:
// 1. A:
// 1. ABlockDesc is known at compile-time
// 2. ABlockBuffer is DynamicBuffer
// 2. B:
// 1. ABlockDesc is known at compile-time
// 2. BBlockBuffer is DynamicBuffer
// 3. C:
// 1. CThreadDesc is known at compile-time
// 2. CThreadBuffer is StaticBuffer
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename ABlockDesc,
typename BBlockDesc,
typename CThreadDesc,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThreadLoop,
index_t MLevel0ThreadCluster,
index_t NLevel0ThreadCluster,
index_t MLevel1ThreadCluster,
index_t NLevel1ThreadCluster,
index_t AThreadCopyScalarPerVector_M1,
index_t BThreadCopyScalarPerVector_N1,
typename std::enable_if<ABlockDesc::IsKnownAtCompileTime() &&
BBlockDesc::IsKnownAtCompileTime() &&
CThreadDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
{
using AIndex = MultiIndex<3>;
using BIndex = MultiIndex<3>;
using CIndex = MultiIndex<4>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
public:
__device__ BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1()
: c_thread_origin_data_idx_{CalculateCThreadOriginDataIndex(get_thread_local_1d_id())},
a_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])},
b_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])}
{
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime() &&
CThreadDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(BlockSize == c_thread_cluster_desc_.GetElementSize(),
"wrong! wrong blocksize");
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
"wrong! K dimension not consistent");
}
__device__ static CIndex CalculateCThreadOriginDataIndex(index_t thread_id)
{
const auto thread_cluster_idx =
c_thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id));
constexpr index_t MPerLevel0Cluster = M1PerThread * MLevel0ThreadCluster;
constexpr index_t NPerLevel0Cluster = N1PerThread * NLevel0ThreadCluster;
return make_multi_index(
0,
thread_cluster_idx[I0] * MPerLevel0Cluster + thread_cluster_idx[I2] * M1PerThread,
0,
thread_cluster_idx[I1] * NPerLevel0Cluster + thread_cluster_idx[I3] * N1PerThread);
}
__host__ __device__ static constexpr auto GetCThreadClusterDescriptor()
{
return make_cluster_descriptor_v2(Sequence<MLevel1ThreadCluster,
NLevel1ThreadCluster,
MLevel0ThreadCluster,
NLevel0ThreadCluster>{},
Sequence<0, 1, 2, 3>{});
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<FloatA>(a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<FloatB>(b_thread_desc_.GetElementSpaceSize());
constexpr auto threadwise_gemm =
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA,
FloatB,
FloatC,
decltype(a_thread_desc_),
decltype(b_thread_desc_),
CThreadDesc>{};
constexpr index_t K = ABlockDesc{}.GetLength(I0);
static_for<0, K, KPerThreadLoop>{}([&](auto k) {
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0),
a_thread_buf);
b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0),
b_thread_buf);
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
});
}
private:
static constexpr auto c_thread_cluster_desc_ = GetCThreadClusterDescriptor();
static constexpr index_t M0_ = ABlockDesc{}.GetLength(I1);
static constexpr index_t N0_ = BBlockDesc{}.GetLength(I1);
// A[K, M0, M1]
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<M0_>{}, Number<M1PerThread>{}));
// B[K, N0, N1]
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<N0_>{}, Number<N1PerThread>{}));
using AThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
FloatA,
ABlockDesc,
decltype(a_thread_desc_),
Sequence<KPerThreadLoop, M0_, M1PerThread>,
Sequence<0, 1, 2>,
2,
AThreadCopyScalarPerVector_M1,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>;
using BThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
FloatB,
BBlockDesc,
decltype(b_thread_desc_),
Sequence<KPerThreadLoop, N0_, N1PerThread>,
Sequence<0, 1, 2>,
2,
BThreadCopyScalarPerVector_N1,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>;
CIndex c_thread_origin_data_idx_;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
} // namespace ck } // namespace ck
#endif #endif
...@@ -59,6 +59,7 @@ __global__ void run_gridwise_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_gl ...@@ -59,6 +59,7 @@ __global__ void run_gridwise_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_gl
} }
#endif #endif
#if 0
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
...@@ -504,6 +505,454 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -504,6 +505,454 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
}; };
#else
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N,
typename BBlockTransferThreadClusterLengths_K_N,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGlobalIteratorHacks,
typename BGlobalIteratorHacks,
typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
Number<BBlockTransferDstScalarPerVector_N>{},
Number<MPerThread>{},
Number<NPerThread>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
}
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_k_n_global_desc,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global,
const CBlockClusterDesc& c_block_cluster_desc,
FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto K = a_k_m_global_desc.GetLength(I0);
const auto M = a_k_m_global_desc.GetLength(I1);
const auto N = b_k_n_global_desc.GetLength(I1);
// divide block work by [M, N]
const auto block_work_idx =
c_block_cluster_desc.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force m/n_block_data_idx_on_global into SGPR
const index_t m_block_data_idx_on_global =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_global =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
Number<BBlockTransferDstScalarPerVector_N>{},
Number<MPerThread>{},
Number<NPerThread>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, MPerBlock>,
ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_k_m_global_desc),
decltype(a_k_m_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 1>,
ABlockTransferSrcVectorDim,
1,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M,
AddressSpace::Global,
AddressSpace::Lds,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_k_m_global_desc,
make_multi_index(0, m_block_data_idx_on_global),
a_k_m_block_desc,
make_multi_index(0, 0));
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, NPerBlock>,
BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 1>,
BBlockTransferSrcVectorDim,
1,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N,
AddressSpace::Global,
AddressSpace::Lds,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_k_n_global_desc,
make_multi_index(0, n_block_data_idx_on_global),
b_k_n_block_desc,
make_multi_index(0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 &&
NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0,
"wrong!");
constexpr index_t MRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
constexpr index_t NRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
constexpr auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor(
a_k_m_block_desc,
make_tuple(
make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
constexpr auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor(
b_k_n_block_desc,
make_tuple(
make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
constexpr auto c_m0_m1_n0_n1_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
const auto blockwise_gemm =
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k_m0_m1_block_desc),
decltype(b_k_n0_n1_block_desc),
decltype(c_m0_m1_n0_n1_thread_desc),
MPerThread,
NPerThread,
KPerThread,
MLevel0Cluster,
NLevel0Cluster,
MLevel1Cluster,
NLevel1Cluster,
MPerThread,
NPerThread>{};
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output
auto c_thread_buf =
make_static_buffer<FloatAcc>(c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
decltype(c_m0_m1_n0_n1_thread_desc),
Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{}
.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{};
constexpr auto b_k_n_global_iterator_hacks = BGlobalIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k_m_global_move_slice_window_iterator_hack =
AGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_k_n_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{};
FloatAB* p_a_block_even = p_a_block_double;
FloatAB* p_b_block_even = p_b_block_double;
FloatAB* p_a_block_odd = p_a_block_double + a_block_space_size;
FloatAB* p_b_block_odd = p_b_block_double + b_block_space_size;
auto a_block_even_buf = make_dynamic_buffer(p_a_block_even);
auto b_block_even_buf = make_dynamic_buffer(p_b_block_even);
auto a_block_odd_buf = make_dynamic_buffer(p_a_block_odd);
auto b_block_odd_buf = make_dynamic_buffer(p_b_block_odd);
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double);
}
if constexpr(HasMainKBlockLoop)
{
index_t k_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_odd);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_odd);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_even);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_even);
k_block_data_begin += 2 * KPerBlock;
} while(k_block_data_begin < K - 2 * KPerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double + b_block_space_size);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr auto M1 = Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{};
constexpr auto N1 = Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
const auto c_thread_data_idx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(get_thread_local_1d_id());
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc),
Sequence<MRepeat, MPerThread, NRepeat, NPerThread>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation,
1,
true>{
c_m0_m1_n0_n1_global_desc,
make_multi_index(m_block_data_idx_on_global / M1 + c_thread_data_idx_on_block[I0],
c_thread_data_idx_on_block[I1],
n_block_data_idx_on_global / N1 + c_thread_data_idx_on_block[I2],
c_thread_data_idx_on_block[I3])}
.Run(c_m0_m1_n0_n1_thread_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
c_m0_m1_n0_n1_global_desc,
p_c_global,
c_m0_m1_n0_n1_global_tensor_iterator_hacks);
}
}
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_k_n_global_desc,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global,
const CBlockClusterDesc& c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
Run(a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
c_block_cluster_desc,
p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
};
#endif
} // namespace ck } // namespace ck
#endif #endif
...@@ -1376,6 +1376,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1376,6 +1376,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
{ {
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time"); "wrong! SrcDesc and DstDesc need to known at compile-time");
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0, "wrong!");
} }
template <typename SrcRefToOriginDisplacement, template <typename SrcRefToOriginDisplacement,
......
...@@ -140,5 +140,91 @@ struct ThreadwiseGemm_km_kn_mn_v1r1 ...@@ -140,5 +140,91 @@ struct ThreadwiseGemm_km_kn_mn_v1r1
} }
}; };
// C[M0, M1, N0, N1] += A[K, M0, M1] * B[K, N0, N1]
// Tensor element can be vectorized data
// Assume:
// 1. ADesc, BDesc, CDesc are known at compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
template <typename FloatA,
typename FloatB,
typename FloatC,
typename ADesc,
typename BDesc,
typename CDesc,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
{
template <typename ABuffer,
typename AOriginIdx,
typename BBuffer,
typename BOriginIdx,
typename CBuffer,
typename COriginIdx>
__device__ static void Run(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
{
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
remove_cv_t<remove_reference_t<FloatA>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
remove_cv_t<remove_reference_t<FloatB>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
remove_cv_t<remove_reference_t<FloatC>>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto K = ADesc{}.GetLength(I0);
constexpr auto M0 = CDesc{}.GetLength(I0);
constexpr auto M1 = CDesc{}.GetLength(I1);
constexpr auto N0 = CDesc{}.GetLength(I2);
constexpr auto N1 = CDesc{}.GetLength(I3);
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, K, 1>{}([&](auto k) {
static_for<0, M0, 1>{}([&](auto m0) {
static_for<0, M1, 1>{}([&](auto m1) {
static_for<0, N0, 1>{}([&](auto n0) {
static_for<0, N1, 1>{}([&](auto n1) {
constexpr index_t a_offset =
ADesc{}.CalculateOffset(a_origin_idx + make_multi_index(k, m0, m1));
constexpr index_t b_offset =
BDesc{}.CalculateOffset(b_origin_idx + make_multi_index(k, n0, n1));
constexpr index_t c_offset = CDesc{}.CalculateOffset(
c_origin_idx + make_multi_index(m0, m1, n0, n1));
amd_assembly_inner_product(a_buf[Number<a_offset>{}],
b_buf[Number<b_offset>{}],
c_buf(Number<c_offset>{}));
});
});
});
});
});
}
};
} // namespace ck } // namespace ck
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment