Commit a6b2f1c1 authored by aska-0096's avatar aska-0096
Browse files

Add Inter-Row thread transfer

parent a0a469e4
...@@ -533,7 +533,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -533,7 +533,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
constexpr auto b1_block_desc_l0perblock_nperblock_l1 = GetB1BlockDescriptor_L0PerBlock_NPerBlock_L1(); constexpr auto b1_block_desc_l0perblock_nperblock_l1 = GetB1BlockDescriptor_L0PerBlock_NPerBlock_L1();
// A1 matrix blockwise copy // A1 matrix blockwise copy
auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow<
FloatAcc, FloatAcc,
FloatA, FloatA,
decltype(acc_thread_desc_k0_m_k1), decltype(acc_thread_desc_k0_m_k1),
...@@ -542,7 +542,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -542,7 +542,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>, Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>,
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
2, 2,
n4>{tensor_operation::element_wise::PassThrough{}}; n4,
// dst Rowlane
// 0x76543210 0xfedcba98
// src Rowlane
0x76543210, 0xfedcba98>{tensor_operation::element_wise::PassThrough{}};
// B1 matrix blockwise copy // B1 matrix blockwise copy
auto b1_blockwise_copy = auto b1_blockwise_copy =
...@@ -700,12 +704,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -700,12 +704,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
running_sum_new = mathext::exp(running_max - running_max_new) * running_sum + running_sum_new = mathext::exp(running_max - running_max_new) * running_sum +
mathext::exp(max - running_max_new) * sum; mathext::exp(max - running_max_new) * sum;
// Intra-Row data permutation, make swizzled A input for WMMA
__builtin_amdgcn_permlane16(0xeca86420, 0xfdb97531);
// Low/high row move data to low/high half of thread buffer
/* thread copy*/
// Inter-Row data permutation, fullfill data duplication requirement
__builtin_amdgcn_permlanex16(0x76543210, 0xfedcba98);
// gemm1 // gemm1
{ {
// TODO: explore using dynamic buffer for a1 thread buffer // TODO: explore using dynamic buffer for a1 thread buffer
......
...@@ -1298,4 +1298,120 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic ...@@ -1298,4 +1298,120 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
ElementwiseOperation element_op_; ElementwiseOperation element_op_;
}; };
// Specilized for WMMA
// A single Wave32 is composed by double row
// Data exchange allowed between these two rows
// This RowLane Dst buf will be filled from two Src buf
// SrcA: From specific thread buffer hold by This RowLane on This Row
// SrcB: From specific thread buffer hold by This RowLane on The other Row
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename ElementwiseOperation,
typename SliceLengths,
typename DimAccessOrder,
index_t DstVectorDim,
index_t DstScalarPerVector,
index_t LowEightRowlaneIdx,
index_t HighEightRowLaneIdx,
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
{
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
__device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow(
const ElementwiseOperation& element_op)
: element_op_{element_op}
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc need to known at compile-time");
static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
"wrong! Not divisible");
}
template <typename SrcSliceOriginIdx,
typename DstSliceOriginIdx,
typename SrcBuffer,
typename DstBuffer>
__device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&,
const SrcBuffer& src_buf,
const DstDesc&,
const DstSliceOriginIdx&,
DstBuffer& dst_buf)
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc need to known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<DstSliceOriginIdx>>::value,
"wrong! SliceOrigin need to known at compile-time");
static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(),
"wrong! Buffer need to be StaticBuffer");
// SrcDesc and src_slice_origin_idx are known at compile-time
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{});
// scalar per access on each dim
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_scalar_step_in_vector =
generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(dst_scalar_per_access)>>;
static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
static_for<0, num_access, 1>{}([&](auto idx_1d) {
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
// copy data from src_buf into dst_vector
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
SrcData v;
// apply element-wise operation
element_op_(v, src_buf[Number<src_offset>{}]);
if(get_thread_local_1d_id() % 32 > 16){
// apply type convert
dst_buf(Number<dst_offset>{}) = type_convert<DstData>(v);
dst_buf(Number<dst_offset + dst_buf.size()/2>{}) = __builtin_amdgcn_permlanex16(type_convert<DstData>(dst_buf(Number<dst_offset + dst_buf.size()/2>{})),
type_convert<DstData>(v),
LowEightRowlaneIdx, HighEightRowLaneIdx, 1, 0);
}
else{
// apply type convert
dst_buf(Number<dst_offset + dst_buf.size()/2>{}) = type_convert<DstData>(v);
dst_buf(Number<dst_offset>{}) = __builtin_amdgcn_permlanex16(type_convert<DstData>(dst_buf(Number<dst_offset>{})),
type_convert<DstData>(v),
LowEightRowlaneIdx, HighEightRowLaneIdx, 1, 0);
}
});
});
}
ElementwiseOperation element_op_;
};
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment