Unverified Commit fcbb9788 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Dynamic tensor descriptor (#24)



* support dynamic tensor descriptor

* use buffer load OOB feature for padding case

* add navi support

* add int8x4 inference kernel
Co-authored-by: default avatarChao Liu <chao@ixt-rack-81.local.lan>
Co-authored-by: default avatarJing Zhang <jizhan@amd.com>
parent bbcb67d0
......@@ -311,13 +311,13 @@ struct TransformedTensorDescriptor
static_for<0, nTransform, 1>{}([&](auto itran) {
constexpr auto tran = Transforms{}.At(itran);
const auto idx_up_part = pick_array_element(idx_up, UpDimensionIds{}.At(itran));
auto idx_low_part = pick_array_element(idx_low, LowDimensionIds{}.At(itran));
const auto idx_up_part = pick_container_element(idx_up, UpDimensionIds{}.At(itran));
auto idx_low_part = pick_container_element(idx_low, LowDimensionIds{}.At(itran));
// this assume each lower (single) index is only assocaited with one transformation,
// which is required for index transformation, and has been checked during constructor
// of TransformedTensorDescriptor
idx_low_part = tran.CalculateLowerIndex(to_array(idx_up_part));
idx_low_part = tran.CalculateLowerIndex(to_multi_index(idx_up_part));
});
return idx_low;
......@@ -333,20 +333,23 @@ struct TransformedTensorDescriptor
constexpr auto tran = Transforms{}.At(itran);
const auto idx_up_diff_part =
pick_array_element(idx_up_diff, UpDimensionIds{}.At(itran));
pick_container_element(idx_up_diff, UpDimensionIds{}.At(itran));
const auto idx_up_old_part = pick_array_element(idx_up_old, UpDimensionIds{}.At(itran));
const auto idx_up_old_part =
pick_container_element(idx_up_old, UpDimensionIds{}.At(itran));
const auto idx_low_old_part =
pick_array_element(idx_low_old, LowDimensionIds{}.At(itran));
pick_container_element(idx_low_old, LowDimensionIds{}.At(itran));
auto idx_low_diff_part = pick_array_element(idx_low_diff, LowDimensionIds{}.At(itran));
auto idx_low_diff_part =
pick_container_element(idx_low_diff, LowDimensionIds{}.At(itran));
// this assume each lower (single) index is associated with only one transformation,
// which is required for index transformation, and has been checked during constructor
// of TransformedTensorDescriptor
idx_low_diff_part = tran.CalculateLowerIndexDiff(
to_array(idx_up_diff_part), to_array(idx_up_old_part), to_array(idx_low_old_part));
idx_low_diff_part = tran.CalculateLowerIndexDiff(to_multi_index(idx_up_diff_part),
to_multi_index(idx_up_old_part),
to_multi_index(idx_low_old_part));
});
return idx_low_diff;
......@@ -506,12 +509,12 @@ struct TransformedTensorDescriptor
constexpr auto low_dims_part = LowDimensionIds{}.At(itran);
constexpr auto low_lengths_part =
GetLowerTensorDescriptor().GetLengths(low_dims_part);
const auto idx_low_part = to_array(pick_array_element(idx_low, low_dims_part));
const auto idx_low_part =
to_multi_index(pick_container_element(idx_low, low_dims_part));
for(index_t i = 0; i < low_dims_part.Size(); ++i)
{
static_for<0, decltype(low_dims_part)::Size(), 1>{}([&](auto i) {
flag = flag && idx_low_part[i] >= 0 && idx_low_part[i] < low_lengths_part[i];
}
});
}
});
......
......@@ -64,10 +64,10 @@ template <typename LowerTensorDescriptor,
index_t... LowerDimensionIds,
index_t... UpperDimensionIds>
__host__ __device__ constexpr auto
reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor,
Sequence<LowerLengths...>,
Sequence<LowerDimensionIds...>,
Sequence<UpperDimensionIds...>)
reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor,
Sequence<LowerLengths...>,
Sequence<LowerDimensionIds...>,
Sequence<UpperDimensionIds...>)
{
return TransformedTensorDescriptor<LowerTensorDescriptor,
Tuple<PassThrough<LowerLengths>...>,
......@@ -78,7 +78,7 @@ __host__ __device__ constexpr auto
// reorder a NativeTensorDescriptor
template <typename... Ts, typename MapLower2Upper>
__host__ __device__ constexpr auto
reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor<Ts...>, MapLower2Upper)
reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor<Ts...>, MapLower2Upper)
{
static_assert(is_valid_sequence_map<MapLower2Upper>{},
"wrong! MapLower2Upper is not a valid map");
......@@ -96,7 +96,7 @@ __host__ __device__ constexpr auto
// reorder a TransformedTensorDescriptor
template <typename... Ts, typename MapLower2Upper>
__host__ __device__ constexpr auto
reorder_tensor_descriptor_given_lower2upper(TransformedTensorDescriptor<Ts...>, MapLower2Upper)
reorder_tensor_descriptor_given_lower2upper(TransformedTensorDescriptor<Ts...>, MapLower2Upper)
{
static_assert(is_valid_sequence_map<MapLower2Upper>{},
"wrong! MapLower2Upper is not a valid map");
......@@ -172,41 +172,5 @@ __host__ __device__ constexpr auto unfold_tensor_descriptor(NativeTensorDescript
return make_native_tensor_descriptor(new_lengths, new_strides);
}
// a cluster map 1d index to N-d index
template <typename Lengths, typename ArrangeOrder>
struct ClusterDescriptor
{
static constexpr index_t nDim = Lengths::Size();
static constexpr auto mDesc = transform_tensor_descriptor(
make_native_tensor_descriptor_packed(Lengths{}),
make_tuple(Merge<decltype(Lengths::ReorderGivenNew2Old(ArrangeOrder{}))>{}),
make_tuple(ArrangeOrder{}),
make_tuple(Sequence<0>{}));
__host__ __device__ constexpr ClusterDescriptor()
{
static_assert(Lengths::Size() == nDim && ArrangeOrder::Size() == nDim,
"wrong! size not the same");
static_assert(is_valid_sequence_map<ArrangeOrder>{}, "wrong! ArrangeOrder is wrong");
}
__host__ __device__ static constexpr index_t GetElementSize() { return mDesc.GetElementSize(); }
__host__ __device__ static constexpr auto CalculateClusterIndex(index_t idx_1d)
{
return mDesc.CalculateLowerIndex(MultiIndex<1>{idx_1d});
}
};
template <typename Lengths,
typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type>
__host__ __device__ constexpr auto make_cluster_descriptor(
Lengths, ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type{})
{
return ClusterDescriptor<Lengths, decltype(order)>{};
}
} // namespace ck
#endif
......@@ -210,17 +210,16 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
#pragma unroll
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
threadwise_matrix_copy(
a_block_mtx,
p_a_block +
a_block_mtx.GetOffsetFromMultiIndex(k_begin,
m_repeat * MPerLevel1Cluster) +
ib * BlockMatrixStrideA + mMyThreadOffsetA,
a_thread_mtx,
p_a_thread +
a_thread_mtx.GetOffsetFromMultiIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths(),
Number<DataPerReadA>{});
threadwise_matrix_copy(a_block_mtx,
p_a_block +
a_block_mtx.GetOffsetFromMultiIndex(
k_begin, m_repeat * MPerLevel1Cluster) +
ib * BlockMatrixStrideA + mMyThreadOffsetA,
a_thread_mtx,
p_a_thread + a_thread_mtx.GetOffsetFromMultiIndex(
0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths(),
Number<DataPerReadA>{});
}
}
......@@ -229,17 +228,16 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
#pragma unroll
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
b_block_mtx,
p_b_block +
b_block_mtx.GetOffsetFromMultiIndex(k_begin,
n_repeat * NPerLevel1Cluster) +
ib * BlockMatrixStrideB + mMyThreadOffsetB,
b_thread_mtx,
p_b_thread +
b_thread_mtx.GetOffsetFromMultiIndex(0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{});
threadwise_matrix_copy(b_block_mtx,
p_b_block +
b_block_mtx.GetOffsetFromMultiIndex(
k_begin, n_repeat * NPerLevel1Cluster) +
ib * BlockMatrixStrideB + mMyThreadOffsetB,
b_thread_mtx,
p_b_thread + b_thread_mtx.GetOffsetFromMultiIndex(
0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{});
}
}
......@@ -307,7 +305,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
"Run_amd_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == "
"1 for now\n");
using Float4 = vector_type<float, 4>::MemoryType;
using Float4 = vector_type<float, 4>::type;
Float4* reg_a = (Float4*)(p_a_thread);
Float4* reg_b = (Float4*)(p_b_thread);
......@@ -391,9 +389,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
{
threadwise_matrix_copy(
c_thread_sub_mtx,
p_c_thread +
c_thread_sub_mtx.GetOffsetFromMultiIndex(m_repeat * MPerLevel1Cluster,
n_repeat * NPerLevel1Cluster),
p_c_thread + c_thread_sub_mtx.GetOffsetFromMultiIndex(
m_repeat * MPerLevel1Cluster, n_repeat * NPerLevel1Cluster),
c_block_mtx,
p_c_block +
c_block_mtx.GetOffsetFromMultiIndex(m_repeat * MPerLevel1Cluster,
......@@ -405,5 +402,5 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
}
};
} // namespace
} // namespace ck
#endif
#ifndef CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP
#define CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
namespace ck {
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize,
InMemoryDataOperation DstInMemOp,
typename BlockSliceLengths,
typename ThreadSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t DstScalarPerVector,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector,
index_t ThreadTransferSrcResetCoordinateAfterRun,
index_t ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseDynamicTensorSliceTransfer_v4
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseDynamicTensorSliceTransfer_v4(const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin)
: threadwise_transfer_(
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>())
{
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_id =
thread_cluster_desc_.CalculateClusterIndex(get_thread_local_1d_id());
const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{};
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_id_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_id_begin);
}
}
__device__ static constexpr auto CalculateThreadDataBegin()
{
const auto thread_cluster_id =
thread_cluster_desc_.CalculateClusterIndex(get_thread_local_1d_id());
return thread_cluster_id * ThreadSliceLengths{};
}
template <typename SrcIteratorHacks>
__device__ void RunRead(const SrcDesc& src_desc,
const SrcData* p_src,
const SrcIteratorHacks& src_iterator_hacks)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_desc, p_src, src_iterator_hacks);
}
}
__device__ void RunWrite(const DstDesc& dst_desc, DstData* p_dst)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunWrite(dst_desc, p_dst);
}
}
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
}
}
// SrcMoveSliceWindowIteratorHack to control index calculation move slice window
template <typename SrcMoveSliceWindowIteratorHack>
__device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& step,
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(
src_desc, step, src_move_slice_window_iterator_hack);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
}
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseDynamicTensorSliceTransfer_v3<ThreadSliceLengths,
DstInMemOp,
SrcData,
DstData,
SrcDesc,
DstDesc,
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVector,
DstScalarPerVector,
SrcScalarStrideInVector,
DstScalarStrideInVector,
SrcAddressSpace,
DstAddressSpace,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
#endif
......@@ -95,26 +95,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
}
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c,
index_t n_in_c)
{
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t MPerLevel1Cluster =
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
constexpr index_t NPerLevel1Cluster =
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
index_t m_repeat = m_in_c / MPerThreadSubC;
index_t n_repeat = n_in_c / NPerThreadSubC;
index_t m_in_sub_c = m_in_c % MPerThreadSubC;
index_t n_in_sub_c = n_in_c % NPerThreadSubC;
return MatrixIndex{m_repeat * MPerLevel1Cluster + m_in_sub_c,
n_repeat * NPerLevel1Cluster + n_in_sub_c};
}
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void
Run_naive(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
......@@ -336,9 +316,14 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
static_if<MRepeat == 2 && NRepeat == 2>{}([&](auto) {
if constexpr(MRepeat == 2 && NRepeat == 2)
{
Run_pipelined_2x2(p_a_block, p_b_block, p_c_thread);
}).Else([&](auto) { Run_naive(p_a_block, p_b_block, p_c_thread); });
}
else
{
Run_naive(p_a_block, p_b_block, p_c_thread);
}
#else
Run_naive(p_a_block, p_b_block, p_c_thread);
#endif
......
#ifndef CK_BLOCKWISE_GEMM_V2_HPP
#define CK_BLOCKWISE_GEMM_V2_HPP
#include "common_header.hpp"
#include "threadwise_gemm_v2.hpp"
namespace ck {
// blockwise GEMM: C[M, N] += transpose(A[K, M]) * B[K, N]
// A and B are visable to the whole block, C is distributed among each thread
// If following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster,
// MLevel1ThreadCluster, NLevel1ThreadCluster
template <index_t BlockSize,
typename BlockMatrixA,
typename BlockMatrixB,
typename ThreadMatrixC,
index_t MPerThreadSubC,
index_t NPerThreadSubC,
index_t KPerThreadLoop,
index_t MLevel0ThreadCluster,
index_t NLevel0ThreadCluster,
index_t MLevel1ThreadCluster,
index_t NLevel1ThreadCluster,
index_t ThreadGemmADataPerRead_M,
index_t ThreadGemmBDataPerRead_N>
struct BlockwiseGemm_km_kn_m0m1n0n1_v1
{
struct MatrixIndex
{
index_t row;
index_t col;
};
index_t mMyThreadOffsetA;
index_t mMyThreadOffsetB;
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v1()
{
static_assert(BlockMatrixA::IsKnownAtCompileTime() &&
BlockMatrixB::IsKnownAtCompileTime() &&
ThreadMatrixC::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr index_t ThreadPerLevel1Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster *
MLevel1ThreadCluster * NLevel1ThreadCluster;
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0),
"wrong! K dimension not consistent\n");
constexpr index_t M = BlockMatrixA{}.GetLength(I1); // A is transposed
constexpr index_t N = BlockMatrixB{}.GetLength(I1);
static_assert(M % (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster) == 0 &&
N % (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster) == 0,
"wrong! Cannot evenly divide work among\n");
static_assert(ThreadMatrixC{}.GetLength(I0) == GetThreadMatrixCLengths()[I0] &&
ThreadMatrixC{}.GetLength(I1) == GetThreadMatrixCLengths()[I1],
"wrong! ThreadMatrixC lengths is wrong");
auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
mMyThreadOffsetA = BlockMatrixA{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.row));
mMyThreadOffsetB = BlockMatrixB{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.col));
}
__device__ static constexpr auto GetThreadMatrixCLengths()
{
constexpr auto I1 = Number<1>{};
constexpr index_t M = BlockMatrixA{}.GetLength(I1); // A is transposed
constexpr index_t N = BlockMatrixB{}.GetLength(I1);
constexpr index_t MRepeat =
M / (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster);
constexpr index_t NRepeat =
N / (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster);
return Sequence<MRepeat * MPerThreadSubC, NRepeat * NPerThreadSubC>{};
}
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
{
constexpr index_t ThreadPerLevel0Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster;
index_t level1_id = thread_id / ThreadPerLevel0Cluster;
index_t level1_m_id = level1_id / NLevel1ThreadCluster;
index_t level1_n_id = level1_id % NLevel1ThreadCluster;
index_t level0_id = thread_id % ThreadPerLevel0Cluster;
index_t level0_m_id = level0_id / NLevel0ThreadCluster;
index_t level0_n_id = level0_id % NLevel0ThreadCluster;
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0ThreadCluster;
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0ThreadCluster;
return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
}
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void
Run_naive(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr auto K = a_block_mtx.GetLength(I0);
constexpr auto MPerThread = c_thread_mtx.GetLength(I0);
constexpr auto NPerThread = c_thread_mtx.GetLength(I1);
constexpr index_t MPerLevel1Cluster =
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
constexpr index_t NPerLevel1Cluster =
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
// thread A, B for GEMM
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
Number<KPerThreadLoop>{}, Number<MPerThread>{});
constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
Number<KPerThreadLoop>{}, Number<NPerThread>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy_v2<BlockMatrixA,
decltype(a_thread_mtx),
KPerThreadLoop,
MPerThreadSubC,
ThreadGemmADataPerRead_M>{};
constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy_v2<BlockMatrixB,
decltype(b_thread_mtx),
KPerThreadLoop,
NPerThreadSubC,
ThreadGemmBDataPerRead_N>{};
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<decltype(a_thread_mtx),
decltype(b_thread_mtx),
decltype(c_thread_mtx)>{};
#pragma unroll
// loop over k
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
{
#pragma unroll
// read A
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
a_thread_copy.Run(p_a_block +
a_block_mtx.CalculateOffset(
make_tuple(k_begin, m_repeat * MPerLevel1Cluster)) +
mMyThreadOffsetA,
p_a_thread + a_thread_mtx.CalculateOffset(
make_tuple(0, m_repeat * MPerThreadSubC)));
}
#pragma unroll
// read B
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
b_thread_copy.Run(p_b_block +
b_block_mtx.CalculateOffset(
make_tuple(k_begin, n_repeat * NPerLevel1Cluster)) +
mMyThreadOffsetB,
p_b_thread + b_thread_mtx.CalculateOffset(
make_tuple(0, n_repeat * NPerThreadSubC)));
}
// C += A * B
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
}
}
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void
Run_pipelined_2x2(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr auto K = a_block_mtx.GetLength(I0);
constexpr auto MPerThread = c_thread_mtx.GetLength(I0);
constexpr auto NPerThread = c_thread_mtx.GetLength(I1);
constexpr index_t MPerLevel1Cluster =
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
constexpr index_t NPerLevel1Cluster =
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
static_assert(MRepeat == 2 && NRepeat == 2,
"wrong! inline asm cannot deal with this GEMM config yet");
// thread A, B
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<MPerThread>{}));
constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<NPerThread>{}));
// thread A-sub, B-sub
constexpr auto a_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}),
make_tuple(Number<MPerThread>{}, Number<1>{}));
constexpr auto b_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}),
make_tuple(Number<NPerThread>{}, Number<1>{}));
constexpr auto c_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2(
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}),
make_tuple(Number<NPerThread>{}, Number<1>{}));
FloatA p_a_thread[a_thread_mtx.GetElementSpaceSize()];
FloatB p_b_thread[b_thread_mtx.GetElementSpaceSize()];
constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy_v2<BlockMatrixA,
decltype(a_thread_mtx),
KPerThreadLoop,
MPerThreadSubC,
ThreadGemmADataPerRead_M>{};
constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy_v2<BlockMatrixB,
decltype(b_thread_mtx),
KPerThreadLoop,
NPerThreadSubC,
ThreadGemmBDataPerRead_N>{};
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<decltype(a_thread_sub_mtx),
decltype(b_thread_sub_mtx),
decltype(c_thread_sub_mtx)>{};
const FloatA* p_a_block_off = p_a_block + mMyThreadOffsetA;
const FloatB* p_b_block_off = p_b_block + mMyThreadOffsetB;
// read A_sub_0
a_thread_copy.Run(p_a_block_off, p_a_thread);
// read B_sub_0
b_thread_copy.Run(p_b_block_off, p_b_thread);
// read B_sub_1
b_thread_copy.Run(p_b_block_off +
b_block_mtx.CalculateOffset(make_tuple(0, NPerLevel1Cluster)),
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)));
// read A_sub_1
a_thread_copy.Run(p_a_block_off +
a_block_mtx.CalculateOffset(make_tuple(0, MPerLevel1Cluster)),
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)));
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(
p_a_thread,
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)));
#pragma unroll
// loop over rest of k
for(index_t k = KPerThreadLoop; k < K; k += KPerThreadLoop)
{
// read A_sub_0
a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(make_tuple(k, 0)),
p_a_thread);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread,
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, 0)));
// read B_sub_0
b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset(make_tuple(k, 0)),
p_b_thread);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread +
c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC)));
// read B_sub_1
b_thread_copy.Run(
p_b_block_off + b_block_mtx.CalculateOffset(make_tuple(k, NPerLevel1Cluster)),
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)));
// read A_sub_1
a_thread_copy.Run(
p_a_block_off + a_block_mtx.CalculateOffset(make_tuple(k, MPerLevel1Cluster)),
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)));
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(
p_a_thread,
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)));
}
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread,
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, 0)));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC)));
}
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void Run(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr index_t MPerThread = ThreadMatrixC{}.GetLength(I0);
constexpr index_t NPerThread = ThreadMatrixC{}.GetLength(I1);
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
if constexpr(MRepeat == 2 && NRepeat == 2)
{
Run_pipelined_2x2(p_a_block, p_b_block, p_c_thread);
}
else
{
Run_naive(p_a_block, p_b_block, p_c_thread);
}
#else
Run_naive(p_a_block, p_b_block, p_c_thread);
#endif
}
};
} // namespace ck
#endif
#ifndef CK_BLOCKWISE_GEMM_V3_HPP
#define CK_BLOCKWISE_GEMM_V3_HPP
#include "common_header.hpp"
#include "threadwise_gemm_v3.hpp"
namespace ck {
// blockwise GEMM: C[M, N] += transpose(A[K, M]) * B[K, N]
// A and B are visable to the whole block, C is distributed among each thread
// If following number are power of 2, index calculation shall be greatly reduced:
// KPerThread, HPerThread, MLevel0ThreadCluster, NLevel0ThreadCluster,
// MLevel1ThreadCluster, NLevel1ThreadCluster
template <index_t BlockSize,
typename BlockMatrixA,
typename BlockMatrixB,
typename ThreadMatrixC,
index_t KPerThread,
index_t HPerThread,
index_t WPerThread,
index_t EPerThreadLoop,
index_t ThreadGemmADataPerRead_K,
index_t ThreadGemmBDataPerRead_W>
struct BlockwiseGemm_km_kn_m0m1n0n1_v3
{
struct MatrixIndex
{
index_t k;
index_t h;
index_t w;
};
index_t mMyThreadOffsetA;
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v3()
{
static_assert(BlockMatrixA::IsKnownAtCompileTime() &&
BlockMatrixB::IsKnownAtCompileTime() &&
ThreadMatrixC::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0),
"wrong! K dimension not consistent\n");
constexpr index_t K = BlockMatrixA{}.GetLength(I1); // A is transposed
constexpr index_t N = BlockMatrixB{}.GetLength(I1);
constexpr index_t H = BlockMatrixB{}.GetLength(I2);
constexpr index_t W = BlockMatrixB{}.GetLength(I3);
static_assert(K % KPerThread == 0 && H % HPerThread == 0 && W % WPerThread == 0,
"wrong! Cannot evenly divide work among\n");
constexpr auto KThreadCluster = K / KPerThread;
constexpr auto HThreadCluster = H / HPerThread;
constexpr auto WThreadCluster = W / WPerThread;
static_assert(BlockSize == KThreadCluster * HThreadCluster * WThreadCluster,
"wrong! wrong blocksize\n");
auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
mMyThreadOffsetA =
BlockMatrixA{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.k * KPerThread));
}
__device__ static constexpr auto GetThreadMatrixCLengths()
{
return Sequence<KPerThread, 1, HPerThread, WPerThread>{};
}
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
{
constexpr index_t H = BlockMatrixB{}.GetLength(Number<2>{});
constexpr index_t W = BlockMatrixB{}.GetLength(Number<3>{});
constexpr auto num_w_threads = W / WPerThread;
constexpr auto num_h_threads = H / HPerThread;
constexpr auto num_hw_threads = num_w_threads * num_h_threads;
index_t k_thread_id = thread_id / num_hw_threads;
index_t hw_thread_id = thread_id % num_hw_threads;
index_t h_thread_id = hw_thread_id / num_w_threads;
index_t w_thread_id = hw_thread_id % num_w_threads;
return MatrixIndex{k_thread_id, h_thread_id, w_thread_id};
}
template <typename SrcDesc,
typename DstDesc,
index_t NSliceRow,
index_t NSliceCol,
index_t DataPerAccess>
struct ThreadwiseSliceCopy_a
{
template <typename Data>
__device__ static void Run(const Data* p_src, Data* p_dst)
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
using vector_t = typename vector_type<Data, DataPerAccess>::type;
static_for<0, NSliceRow, 1>{}([&](auto i) {
static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) {
constexpr auto src_offset = SrcDesc{}.CalculateOffset(make_tuple(i, j));
constexpr auto dst_offset = DstDesc{}.CalculateOffset(make_tuple(i, j));
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
*reinterpret_cast<const vector_t*>(&p_src[src_offset]);
});
});
}
};
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void
Run_naive(const FloatA* p_a_block, const FloatB* p_b_thread, FloatC* p_c_thread) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto EPerBlock = a_block_mtx.GetLength(I0);
constexpr auto KPerThreadSubC = 4;
static_assert(KPerThread % KPerThreadSubC == 0, "");
static_assert(HPerThread % 2 == 0, "");
static_assert(WPerThread % 2 == 0, "");
// thread A, B for GEMM
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadSubC>{}));
constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<EPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
constexpr auto c_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThreadSubC>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
FloatA p_a_thread[a_thread_mtx.GetElementSpaceSize()];
constexpr auto a_thread_copy = ThreadwiseSliceCopy_a<BlockMatrixA,
decltype(a_thread_mtx),
EPerThreadLoop,
KPerThreadSubC,
ThreadGemmADataPerRead_K>{};
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3<decltype(a_thread_mtx),
decltype(b_thread_mtx),
decltype(c_thread_mtx)>{};
// loop over k
#pragma unroll
for(index_t e_begin = 0; e_begin < EPerBlock; e_begin += EPerThreadLoop)
{
#pragma unroll
for(index_t k_begin = 0; k_begin < KPerThread; k_begin += KPerThreadSubC)
{
a_thread_copy.Run(p_a_block +
a_block_mtx.CalculateOffset(make_tuple(e_begin, k_begin)) +
mMyThreadOffsetA,
p_a_thread);
for(index_t h_begin = 0; h_begin < HPerThread; h_begin += 2)
{
for(index_t w_begin = 0; w_begin < WPerThread; w_begin += 2)
{
threadwise_gemm.Run(p_a_thread,
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(
e_begin, 0, h_begin, w_begin)),
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(
k_begin, 0, h_begin, w_begin)));
}
}
}
}
}
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void Run(const FloatA* p_a_block, const FloatB* p_b_thread, FloatC* p_c_thread) const
{
Run_naive(p_a_block, p_b_thread, p_c_thread);
}
};
} // namespace ck
#endif
......@@ -5,6 +5,7 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_coordinate.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
namespace ck {
......@@ -68,9 +69,9 @@ struct BlockwiseGenericTensorSliceCopy_v4
const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{};
mThreadwiseLoad.SetSrcSliceOrigin(src_block_slice_origin + thread_data_id_begin);
mThreadwiseLoad.SetDstSliceOrigin(make_zero_array<index_t, nDim>());
mThreadwiseLoad.SetDstSliceOrigin(make_zero_multi_index<nDim>());
mThreadwiseStore.SetSrcSliceOrigin(make_zero_array<index_t, nDim>());
mThreadwiseStore.SetSrcSliceOrigin(make_zero_multi_index<nDim>());
mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_id_begin);
}
}
......
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_HPP
#define CK_GRIDWISE_DYNAMIC_GEMM_HPP
#include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "blockwise_gemm_v2.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N,
typename BBlockTransferThreadClusterLengths_K_N,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGlobalIteratorHacks,
typename BGlobalIteratorHacks,
typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
Number<BBlockTransferDstScalarPerVector_N>{},
Number<MPerThread>{},
Number<NPerThread>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
}
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_k_n_global_desc,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global,
FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
const auto K = a_k_m_global_desc.GetLength(I0);
const auto M = a_k_m_global_desc.GetLength(I1);
const auto N = b_k_n_global_desc.GetLength(I1);
// divide block work by [M, N]
#if 0
const auto m_block_work_num = M / Number<MPerBlock>{};
const auto n_block_work_num = N / Number<NPerBlock>{};
const index_t m_block_work_id = get_block_1d_id() / n_block_work_num;
const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num;
#else
// Hack: this force result into SGPR
const index_t m_block_work_num = __builtin_amdgcn_readfirstlane(M / MPerBlock);
const index_t n_block_work_num = __builtin_amdgcn_readfirstlane(N / NPerBlock);
const index_t m_block_work_id =
__builtin_amdgcn_readfirstlane(get_block_1d_id() / n_block_work_num);
const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num;
#endif
const index_t m_block_data_on_global = m_block_work_id * MPerBlock;
const index_t n_block_data_on_global = n_block_work_id * NPerBlock;
// lds max alignment
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
Number<BBlockTransferDstScalarPerVector_N>{},
Number<MPerThread>{},
Number<NPerThread>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, MPerBlock>,
ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_k_m_global_desc),
decltype(a_k_m_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 1>,
ABlockTransferSrcVectorDim,
1,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M,
AddressSpace::Global,
AddressSpace::Lds,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_k_m_global_desc,
make_multi_index(0, m_block_data_on_global),
a_k_m_block_desc,
make_multi_index(0, 0));
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, NPerBlock>,
BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 1>,
BBlockTransferSrcVectorDim,
1,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N,
AddressSpace::Global,
AddressSpace::Lds,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_k_n_global_desc,
make_multi_index(0, n_block_data_on_global),
b_k_n_block_desc,
make_multi_index(0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 &&
NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0,
"wrong!");
constexpr index_t MRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
constexpr index_t NRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_m0m1_n0n1_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat * MPerThread>{}, Number<NRepeat * NPerThread>{}));
const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v1<BlockSize,
decltype(a_k_m_block_desc),
decltype(b_k_n_block_desc),
decltype(c_m0m1_n0n1_thread_desc),
MPerThread,
NPerThread,
KPerThread,
MLevel0Cluster,
NLevel0Cluster,
MLevel1Cluster,
NLevel1Cluster,
MPerThread,
NPerThread>{};
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output
FloatAcc p_c_thread[c_m0m1_n0n1_thread_desc.GetElementSpaceSize()];
// zero out threadwise output
threadwise_matrix_set_zero_v2(c_m0m1_n0n1_thread_desc, p_c_thread);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{};
constexpr auto b_k_n_global_iterator_hacks = BGlobalIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k_m_global_move_slice_window_iterator_hack =
AGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_k_n_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{};
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double);
}
if constexpr(HasMainKBlockLoop)
{
FloatAB* p_a_block_even = p_a_block_double;
FloatAB* p_b_block_even = p_b_block_double;
FloatAB* p_a_block_odd = p_a_block_double + a_block_space_size;
FloatAB* p_b_block_odd = p_b_block_double + b_block_space_size;
index_t k_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_odd);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_odd);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_even);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_even);
k_block_data_begin += 2 * KPerBlock;
} while(k_block_data_begin < K - 2 * KPerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double + b_block_space_size);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double + a_block_space_size,
p_b_block_double + b_block_space_size,
p_c_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
}
// output: register to global memory
{
constexpr auto M1 = Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{};
constexpr auto N1 = Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{};
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
constexpr auto c_m0_m1_n0_n1_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<MRepeat>{},
Number<MPerThread>{},
Number<NRepeat>{},
Number<NPerThread>{}));
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t m_thread_data_on_global =
m_block_data_on_global + c_thread_mtx_on_block.row;
const index_t n_thread_data_on_global =
n_block_data_on_global + c_thread_mtx_on_block.col;
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
constexpr auto tmp = make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc),
Sequence<MRepeat, MPerThread, NRepeat, NPerThread>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation,
1,
true>(c_m0_m1_n0_n1_global_desc,
make_multi_index(m_thread_data_on_global / M1,
m_thread_data_on_global % M1,
n_thread_data_on_global / N1,
n_thread_data_on_global % N1))
.Run(c_m0_m1_n0_n1_thread_desc,
make_tuple(I0, I0, I0, I0),
p_c_thread,
c_m0_m1_n0_n1_global_desc,
p_c_global,
c_m0_m1_n0_n1_global_tensor_iterator_hacks);
}
}
// pass tensor descriptor by reference
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_k_n_global_desc,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
Run(a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
// pass tensor descriptors by pointers
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc* p_a_k_m_global_desc,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc* p_b_k_n_global_desc,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc* p_c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
const auto a_k_m_global_desc = *p_a_k_m_global_desc;
const auto b_k_n_global_desc = *p_b_k_n_global_desc;
const auto c_m0_m1_n0_n1_global_desc = *p_c_m0_m1_n0_n1_global_desc;
Run(a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
// pass tensor descriptors by void*
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const void* p_a_k_m_global_desc,
const FloatAB* __restrict__ p_a_global,
const void* p_b_k_n_global_desc,
const FloatAB* __restrict__ p_b_global,
const void* p_c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
const auto a_k_m_global_desc = *reinterpret_cast<const AGlobalDesc*>(p_a_k_m_global_desc);
const auto b_k_n_global_desc = *reinterpret_cast<const BGlobalDesc*>(p_b_k_n_global_desc);
const auto c_m0_m1_n0_n1_global_desc =
*reinterpret_cast<const CGlobalDesc*>(p_c_m0_m1_n0_n1_global_desc);
Run(a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_V2_HPP
#define CK_GRIDWISE_DYNAMIC_GEMM_V2_HPP
#include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "blockwise_gemm_v3.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
index_t KPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
index_t EPerBlock,
index_t KPerThread,
index_t HoPerThread,
index_t WoPerThread,
index_t EPerThread,
typename ABlockTransferThreadSliceLengths_E_K,
typename ABlockTransferThreadClusterLengths_E_K,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGlobalIteratorHacks,
typename BGlobalIteratorHacks,
typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_km_kn_mn_v3
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto E = EPerBlock * 3 * 3;
constexpr auto max_lds_align =
math::lcm(Number<ABlockTransferDstScalarPerVector_K>{}, Number<KPerBlock>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_e_k_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_e_k_desc.GetElementSpaceSize(), max_lds_align);
return a_block_space_size * sizeof(FloatAB);
}
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_e_k_global_desc,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global,
FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto E = EPerBlock * 3 * 3;
// const auto E = a_e_k_global_desc.GetLength(I0);
const auto K = a_e_k_global_desc.GetLength(I1);
const auto N = b_e_n_ho_wo_global_desc.GetLength(I1);
const auto Ho = b_e_n_ho_wo_global_desc.GetLength(I2);
const auto Wo = b_e_n_ho_wo_global_desc.GetLength(I3);
// divide block work by [M, N]
#if 0
const auto k_block_work_num = K / Number<KPerBlock>{};
const auto ho_block_work_num = Ho / Number<HoPerBlock>{};
const auto wo_block_work_num = Wo / Number<WoPerBlock>{};
const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num;
const index_t k_block_work_id = get_block_1d_id() / hwo_block_work_num;
const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num;
const index_t ho_block_work_id = hwo_block_work_id / wo_block_work_num;
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
#else
// Hack: this force result into SGPR
const index_t k_block_work_num = __builtin_amdgcn_readfirstlane(K / KPerBlock);
const index_t ho_block_work_num = __builtin_amdgcn_readfirstlane(Ho / HoPerBlock);
const index_t wo_block_work_num = __builtin_amdgcn_readfirstlane(Wo / WoPerBlock);
const index_t hwo_block_work_num = ho_block_work_num * wo_block_work_num;
const index_t k_block_work_id =
__builtin_amdgcn_readfirstlane(get_block_1d_id() / hwo_block_work_num);
const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num;
const index_t ho_block_work_id =
__builtin_amdgcn_readfirstlane(hwo_block_work_id / wo_block_work_num);
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
#endif
// lds max alignment
constexpr auto max_lds_align =
math::lcm(Number<ABlockTransferDstScalarPerVector_K>{}, Number<KPerBlock>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_e_k_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<EPerBlock>{}, Number<KPerBlock>{}), max_lds_align);
constexpr auto a_e_k_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_e_n_ho_wo_block_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<EPerBlock>{}, Number<1>{}, Number<HoPerBlock>{}, Number<WoPerBlock>{}));
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k_n_ho_wo_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize,
decltype(a_e_k_block_desc),
decltype(b_e_n_ho_wo_block_desc),
decltype(c_k_n_ho_wo_thread_desc),
KPerThread,
HoPerThread,
WoPerThread,
EPerThread,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K>{};
auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const auto k_thread_id = c_thread_mtx_index.k;
const auto ho_thread_id = c_thread_mtx_index.h;
const auto wo_thread_id = c_thread_mtx_index.w;
const index_t k_block_data_on_global = k_block_work_id * KPerBlock;
const index_t ho_block_data_on_global = ho_block_work_id * HoPerBlock;
const index_t wo_block_data_on_global = wo_block_work_id * WoPerBlock;
const index_t ho_thread_data_on_global =
ho_block_data_on_global + ho_thread_id * HoPerThread;
const index_t wo_thread_data_on_global =
wo_block_data_on_global + wo_thread_id * WoPerThread;
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<E, KPerBlock>,
ABlockTransferThreadSliceLengths_E_K,
ABlockTransferThreadClusterLengths_E_K,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_e_k_global_desc),
decltype(a_e_k_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 1>,
ABlockTransferSrcVectorDim,
1,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K,
AddressSpace::Global,
AddressSpace::Lds,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_e_k_global_desc,
make_multi_index(0, k_block_data_on_global),
a_e_k_desc,
make_multi_index(0, 0));
constexpr auto b_e_n_ho_wo_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
auto b_threadwise_transfer = ThreadwiseDynamicTensorSliceTransfer_v2<
FloatAB,
FloatAB,
decltype(b_e_n_ho_wo_global_desc),
decltype(b_e_n_ho_wo_thread_desc),
Sequence<EPerBlock, 1, HoPerThread, WoPerThread>,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
AddressSpace::Global,
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
1,
true>(b_e_n_ho_wo_global_desc,
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
FloatAB* p_a_block = p_shared_block;
// register allocation for output
FloatAcc p_c_thread[c_k_n_ho_wo_thread_desc.GetElementSpaceSize()];
// zero out threadwise output
threadwise_matrix_set_zero_v3(c_k_n_ho_wo_thread_desc, p_c_thread);
constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_e_k_global_iterator_hacks = AGlobalIteratorHacks{};
constexpr auto b_e_n_ho_wo_global_iterator_hacks = BGlobalIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_e_k_global_move_slice_window_iterator_hack =
AGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_thread_space_size = b_e_n_ho_wo_thread_desc.GetElementSpaceSize();
FloatAB p_b_thread[b_thread_space_size * 2];
FloatAB* p_b_thread_double = p_b_thread;
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(a_e_k_global_desc, p_a_global, a_e_k_global_iterator_hacks);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
p_b_thread_double,
b_e_n_ho_wo_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_e_k_desc, p_a_block);
}
__syncthreads();
index_t b_block_data_begin = 0;
#if 1
if constexpr(HasMainKBlockLoop)
{
FloatAB* p_b_thread_even = p_b_thread_double;
FloatAB* p_b_thread_odd = p_b_thread_double + b_thread_space_size;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
p_b_thread_odd,
b_e_n_ho_wo_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)),
p_b_thread_even,
p_c_thread);
b_block_data_begin += EPerBlock;
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
p_b_thread_even,
b_e_n_ho_wo_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)),
p_b_thread_odd,
p_c_thread);
b_block_data_begin += EPerBlock;
} while(b_block_data_begin < E - 2 * EPerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
p_b_thread_double + b_thread_space_size,
b_e_n_ho_wo_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(
p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)),
p_b_thread_double,
p_c_thread);
b_block_data_begin += EPerBlock;
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)),
p_b_thread_double + b_thread_space_size,
p_c_thread);
}
else // if has 1 iteration left
{
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)),
p_b_thread_double,
p_c_thread);
}
#endif
#if 1
// output: register to global memory
{
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
const index_t k_thread_data_on_global =
k_block_data_on_global + k_thread_id * KPerThread;
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_k_n_ho_wo_thread_desc),
decltype(c_k_n_ho_wo_global_desc),
Sequence<KPerThread, 1, HoPerThread, WoPerThread>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation,
1,
true>(
c_k_n_ho_wo_global_desc,
make_multi_index(
k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global))
.Run(c_k_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
p_c_thread,
c_k_n_ho_wo_global_desc,
p_c_global,
c_k_n_ho_wo_global_tensor_iterator_hacks);
}
#endif
}
// pass tensor descriptor by reference
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_e_k_global_desc,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
Run(a_e_k_global_desc,
p_a_global,
b_e_n_ho_wo_global_desc,
p_b_global,
c_k_n_ho_wo_global_desc,
p_c_global,
p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
// pass tensor descriptors by their pointers
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc* p_a_e_k_global_desc,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc* p_b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc* p_c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
const auto a_e_k_global_desc = *p_a_e_k_global_desc;
const auto b_e_n_ho_wo_global_desc = *p_b_e_n_ho_wo_global_desc;
const auto c_k_n_ho_wo_global_desc = *p_c_k_n_ho_wo_global_desc;
Run(a_e_k_global_desc,
p_a_global,
b_e_n_ho_wo_global_desc,
p_b_global,
c_k_n_ho_wo_global_desc,
p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
// pass tensor descriptors by void*
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const void* p_a_e_k_global_desc,
const FloatAB* __restrict__ p_a_global,
const void* p_b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global,
const void* p_c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
const auto a_e_k_global_desc = *reinterpret_cast<const AGlobalDesc*>(p_a_e_k_global_desc);
const auto b_e_n_ho_wo_global_desc =
*reinterpret_cast<const BGlobalDesc*>(p_b_e_n_ho_wo_global_desc);
const auto c_k_n_ho_wo_global_desc =
*reinterpret_cast<const CGlobalDesc*>(p_c_k_n_ho_wo_global_desc);
Run(a_e_k_global_desc,
p_a_global,
b_e_n_ho_wo_global_desc,
p_b_global,
c_k_n_ho_wo_global_desc,
p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
};
} // namespace ck
#endif
......@@ -68,13 +68,13 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space =
constexpr index_t a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t b_block_space =
constexpr index_t b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
return 2 * (a_block_space + b_block_space) * sizeof(Float);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float);
}
__device__ void Run(const Float* __restrict__ p_a_global,
......@@ -116,8 +116,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t m_block_data_on_global = block_work_id[0] * MPerBlock;
const index_t n_block_data_on_global = block_work_id[1] * NPerBlock;
const index_t m_block_data_on_global = block_work_id[Number<0>{}] * MPerBlock;
const index_t n_block_data_on_global = block_work_id[Number<1>{}] * NPerBlock;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
......@@ -143,7 +143,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{0, m_block_data_on_global}, {0, 0});
make_multi_index(0, m_block_data_on_global), make_multi_index(0, 0));
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
......@@ -169,7 +169,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{0, n_block_data_on_global}, {0, 0});
make_multi_index(0, n_block_data_on_global), make_multi_index(0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
......@@ -209,14 +209,14 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
ThreadGemmBThreadCopySrcDataPerRead_N>{};
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space =
constexpr index_t a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t b_block_space =
constexpr index_t b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
Float* p_a_block_double = p_shared_block;
Float* p_b_block_double = p_shared_block + 2 * a_block_space;
Float* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
......@@ -230,47 +230,55 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
b_blockwise_copy.Run(p_b_global, p_b_block_double);
}
constexpr auto a_block_slice_copy_steps = Sequence<KPerBlock, 0>{};
constexpr auto b_block_slice_copy_steps = Sequence<KPerBlock, 0>{};
constexpr auto a_block_slice_copy_step = Sequence<KPerBlock, 0>{};
constexpr auto b_block_slice_copy_step = Sequence<KPerBlock, 0>{};
Float* p_a_block_even = p_a_block_double;
Float* p_b_block_even = p_b_block_double;
Float* p_a_block_odd = p_a_block_double + a_block_space_size;
Float* p_b_block_odd = p_b_block_double + b_block_space_size;
// LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
for(index_t k_block_data_begin = 0; k_block_data_begin < K - 2 * KPerBlock;
k_block_data_begin += 2 * KPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
Float* p_a_block_now =
even_loop ? p_a_block_double : p_a_block_double + a_block_space;
Float* p_b_block_now =
even_loop ? p_b_block_double : p_b_block_double + b_block_space;
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True);
Float* p_a_block_next =
even_loop ? p_a_block_double + a_block_space : p_a_block_double;
Float* p_b_block_next =
even_loop ? p_b_block_double + b_block_space : p_b_block_double;
__syncthreads();
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread);
__syncthreads();
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_odd);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_odd);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
__syncthreads();
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next);
}
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_even);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_even);
}
// LDS double buffer: tail
......@@ -282,8 +290,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True);
__syncthreads();
......@@ -296,15 +304,16 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer,
p_a_block_double + a_block_space);
p_a_block_double + a_block_space_size);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer,
p_b_block_double + b_block_space);
p_b_block_double + b_block_space_size);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
p_a_block_double + a_block_space, p_b_block_double + b_block_space, p_c_thread);
blockwise_gemm.Run(p_a_block_double + a_block_space_size,
p_b_block_double + b_block_space_size,
p_c_thread);
}
else // if has 1 iteration left
{
......@@ -355,11 +364,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation>(
{0, 0, 0, 0},
{m_thread_data_on_global / M1,
m_thread_data_on_global % M1,
n_thread_data_on_global / N1,
n_thread_data_on_global % N1})
make_multi_index(0, 0, 0, 0),
make_multi_index(m_thread_data_on_global / M1,
m_thread_data_on_global % M1,
n_thread_data_on_global / N1,
n_thread_data_on_global % N1))
.Run(p_c_thread, p_c_global);
}
}
......@@ -433,13 +442,13 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space =
constexpr index_t a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t b_block_space =
constexpr index_t b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
return 2 * (a_block_space + b_block_space) * sizeof(Float);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float);
}
__device__ void Run(const Float* __restrict__ p_a_global,
......@@ -447,21 +456,23 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
Float* __restrict__ p_c_global,
Float* __restrict__ p_shared_block) const
{
constexpr auto True = integral_constant<bool, true>{};
constexpr auto False = integral_constant<bool, false>{};
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto False = integral_constant<bool, false>{};
constexpr auto a_k0_k1_k2_m_global_desc = AGlobalDesc{};
constexpr auto b_k0_k1_k2_n_global_desc = BGlobalDesc{};
constexpr auto c_m_n_global_desc = CGlobalDesc{};
constexpr auto K0 = a_k0_k1_k2_m_global_desc.GetLengths()[0];
constexpr auto K1 = a_k0_k1_k2_m_global_desc.GetLengths()[1];
constexpr auto K = a_k0_k1_k2_m_global_desc.GetLengths()[2];
constexpr auto M = c_m_n_global_desc.GetLengths()[0];
constexpr auto N = c_m_n_global_desc.GetLengths()[1];
constexpr auto K0 = a_k0_k1_k2_m_global_desc.GetLengths()[I0];
constexpr auto K1 = a_k0_k1_k2_m_global_desc.GetLengths()[I1];
constexpr auto K = a_k0_k1_k2_m_global_desc.GetLengths()[I2];
constexpr auto M = c_m_n_global_desc.GetLengths()[I0];
constexpr auto N = c_m_n_global_desc.GetLengths()[I1];
// don't do anything if K == 0
if(K == 0)
......@@ -487,8 +498,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t m_block_data_on_global = block_work_id[0] * MPerBlock;
const index_t n_block_data_on_global = block_work_id[1] * NPerBlock;
const index_t m_block_data_on_global = block_work_id[I0] * MPerBlock;
const index_t n_block_data_on_global = block_work_id[I1] * NPerBlock;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
......@@ -514,7 +525,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{0, 0, 0, m_block_data_on_global}, {0, 0, 0, 0});
make_multi_index(0, 0, 0, m_block_data_on_global), make_multi_index(0, 0, 0, 0));
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
......@@ -540,7 +551,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{0, 0, 0, n_block_data_on_global}, {0, 0, 0, 0});
make_multi_index(0, 0, 0, n_block_data_on_global), make_multi_index(0, 0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
......@@ -582,14 +593,14 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
ThreadGemmBThreadCopySrcDataPerRead_N>{};
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space =
constexpr index_t a_block_space_size =
math::integer_least_multiple(a_k0_k1_k2_m_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t b_block_space =
constexpr index_t b_block_space_size =
math::integer_least_multiple(b_k0_k1_k2_n_block_desc.GetElementSpace(), max_lds_align);
Float* p_a_block_double = p_shared_block;
Float* p_b_block_double = p_shared_block + 2 * a_block_space;
Float* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
......@@ -601,15 +612,14 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
{
for(index_t k1 = 0; k1 < K1; ++k1)
{
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.Run(p_a_global, p_a_block_double);
b_blockwise_copy.Run(p_b_global, p_b_block_double);
}
constexpr auto a_block_slice_copy_steps = Sequence<0, 0, KPerBlock, 0>{};
constexpr auto b_block_slice_copy_steps = Sequence<0, 0, KPerBlock, 0>{};
constexpr auto a_block_slice_copy_step = Sequence<0, 0, KPerBlock, 0>{};
constexpr auto b_block_slice_copy_step = Sequence<0, 0, KPerBlock, 0>{};
// LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
......@@ -621,20 +631,20 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
const bool even_loop = (iloop % 2 == 0);
Float* p_a_block_now =
even_loop ? p_a_block_double : p_a_block_double + a_block_space;
even_loop ? p_a_block_double : p_a_block_double + a_block_space_size;
Float* p_b_block_now =
even_loop ? p_b_block_double : p_b_block_double + b_block_space;
even_loop ? p_b_block_double : p_b_block_double + b_block_space_size;
Float* p_a_block_next =
even_loop ? p_a_block_double + a_block_space : p_a_block_double;
even_loop ? p_a_block_double + a_block_space_size : p_a_block_double;
Float* p_b_block_next =
even_loop ? p_b_block_double + b_block_space : p_b_block_double;
even_loop ? p_b_block_double + b_block_space_size : p_b_block_double;
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True);
__syncthreads();
......@@ -660,8 +670,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True);
__syncthreads();
......@@ -673,16 +683,16 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer,
p_a_block_double + a_block_space);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer,
p_b_block_double + b_block_space);
a_blockwise_copy.RunStoreThreadBuffer(
p_a_thread_buffer, p_a_block_double + a_block_space_size);
b_blockwise_copy.RunStoreThreadBuffer(
p_b_thread_buffer, p_b_block_double + b_block_space_size);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double + a_block_space,
p_b_block_double + b_block_space,
blockwise_gemm.Run(p_a_block_double + a_block_space_size,
p_b_block_double + b_block_space_size,
p_c_thread);
}
else // if has 1 iteration left
......@@ -750,11 +760,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v2
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation>(
{0, 0, 0, 0},
{m_thread_data_on_global / M1,
m_thread_data_on_global % M1,
n_thread_data_on_global / N1,
n_thread_data_on_global % N1})
make_multi_index(0, 0, 0, 0),
make_multi_index(m_thread_data_on_global / M1,
m_thread_data_on_global % M1,
n_thread_data_on_global / N1,
n_thread_data_on_global % N1))
.Run(p_c_thread, p_c_global);
}
}
......
#ifndef CK_GRIDWISE_TENSOR_CONTRACTION_HPP
#define CK_GRIDWISE_TENSOR_CONTRACTION_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
namespace ck {
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccFloat,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockLengths,
index_t KPerBlock,
InMemoryDataOperation CGlobalMemoryDataOperation>
struct GridwiseTensorContraction_v1
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() {}
__device__ void Run(const Float* __restrict__ p_a_global,
const Float* __restrict__ p_b_global,
Float* __restrict__ p_c_global,
Float* __restrict__ p_shared_block) const
{
/// \todo sanity-check on AGlobalDesc, BGlboalDesc, CGlobalDesc length consisitency
/// \todo santiy-check on CBlockLengtsh
constexpr auto True = integral_constant<bool, true>{};
constexpr auto a_global_desc = AGlobalDesc{};
constexpr auto b_global_desc = BGlobalDesc{};
constexpr auto c_global_desc = CGlobalDesc{};
constexpr auto K = a_global_desc.GetLengths()[0];
// don't do anything if K == 0
if(K == 0)
{
return;
}
// lds max alignment
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
BBlockCopyDstDataPerWrite_N,
ThreadGemmAThreadCopySrcDataPerRead_M,
ThreadGemmBThreadCopySrcDataPerRead_N);
// divide block work by [M, N]
static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t MBlockWork = M / MPerBlock;
constexpr index_t NBlockWork = N / NPerBlock;
constexpr auto block_work_desc =
make_cluster_descriptor(Sequence<MBlockWork, NBlockWork>{});
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t m_block_data_on_global = block_work_id[0] * MPerBlock;
const index_t n_block_data_on_global = block_work_id[1] * NPerBlock;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, MPerBlock>{}, Number<max_lds_align>{});
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
AGlobalDesc,
decltype(a_block_desc),
decltype(a_k_m_block_desc.GetLengths()),
ABlockCopyThreadSliceLengths_K_M,
ABlockCopyThreadClusterLengths_K_M,
ABlockCopyThreadClusterArrangeOrder,
ABlockCopySrcAccessOrder,
Sequence<0, 1>,
ABlockCopySrcVectorReadDim,
1,
ABlockCopySrcDataPerRead,
ABlockCopyDstDataPerWrite_M,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{0, m_block_data_on_global}, {0, 0});
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc),
decltype(b_k_n_block_desc.GetLengths()),
BBlockCopyThreadSliceLengths_K_N,
BBlockCopyThreadClusterLengths_K_N,
BBlockCopyThreadClusterArrangeOrder,
BBlockCopySrcAccessOrder,
Sequence<0, 1>,
BBlockCopySrcVectorReadDim,
1,
BBlockCopySrcDataPerRead,
BBlockCopyDstDataPerWrite_N,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>(
{0, n_block_data_on_global}, {0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr auto a_k_m_block_mtx_desc = make_ConstantMatrixDescriptor(a_k_m_block_desc);
constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(b_k_n_block_desc);
// sanity check
static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 &&
NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0,
"wrong!");
constexpr index_t GemmMRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
constexpr index_t GemmNRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * MPerThread>{}, Number<GemmNRepeat * NPerThread>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_k_m_block_mtx_desc),
decltype(b_k_n_block_mtx_desc),
decltype(c_m0m1_n0n1_thread_mtx_desc),
MPerThread,
NPerThread,
MLevel0Cluster,
NLevel0Cluster,
MLevel1Cluster,
NLevel1Cluster,
KPerThread,
ThreadGemmAThreadCopySrcDataPerRead_M,
ThreadGemmBThreadCopySrcDataPerRead_N>{};
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t b_block_space =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
Float* p_a_block_double = p_shared_block;
Float* p_b_block_double = p_shared_block + 2 * a_block_space;
// register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_m0m1_n0n1_thread_mtx_desc, p_c_thread);
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.Run(p_a_global, p_a_block_double);
b_blockwise_copy.Run(p_b_global, p_b_block_double);
}
constexpr auto a_block_slice_copy_steps = Sequence<KPerBlock, 0>{};
constexpr auto b_block_slice_copy_steps = Sequence<KPerBlock, 0>{};
// LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
k_block_data_begin += 2 * KPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_a_block_now =
even_loop ? p_a_block_double : p_a_block_double + a_block_space;
Float* p_b_block_now =
even_loop ? p_b_block_double : p_b_block_double + b_block_space;
Float* p_a_block_next =
even_loop ? p_a_block_double + a_block_space : p_a_block_double;
Float* p_b_block_next =
even_loop ? p_b_block_double + b_block_space : p_b_block_double;
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next);
}
}
// LDS double buffer: tail
{
constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0);
if(has_two_iteration_left) // if has 2 iteration left
{
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer,
p_a_block_double + a_block_space);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer,
p_b_block_double + b_block_space);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
p_a_block_double + a_block_space, p_b_block_double + b_block_space, p_c_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
}
}
// input: register to global memory
{
constexpr index_t M1 = MPerThread * MLevel0Cluster * MLevel1Cluster;
constexpr index_t M0 = M / M1;
constexpr index_t N1 = NPerThread * NLevel0Cluster * NLevel1Cluster;
constexpr index_t N0 = N / N1;
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
constexpr auto c_m0_m1_n0_n1_thread_desc = make_native_tensor_descriptor_packed(
Sequence<GemmMRepeat, MPerThread, GemmNRepeat, NPerThread>{});
constexpr auto c_m0_m1_n0_n1_global_desc = transform_tensor_descriptor(
c_m_n_global_desc,
make_tuple(UnMerge<Sequence<M0, M1>>{}, UnMerge<Sequence<N0, N1>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t m_thread_data_on_global =
m_block_data_on_global + c_thread_mtx_on_block.row;
const index_t n_thread_data_on_global =
n_block_data_on_global + c_thread_mtx_on_block.col;
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc),
decltype(c_m0_m1_n0_n1_thread_desc.GetLengths()),
CThreadCopySrcDstAccessOrder,
CThreadCopySrcDstVectorReadWriteDim,
1,
CThreadCopyDstDataPerWrite,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation>(
{0, 0, 0, 0},
{m_thread_data_on_global / M1,
m_thread_data_on_global % M1,
n_thread_data_on_global / N1,
n_thread_data_on_global % N1})
.Run(p_c_thread, p_c_global);
}
}
__device__ void Run(const Float* __restrict__ p_a_global,
const Float* __restrict__ p_b_global,
Float* __restrict__ p_c_global) const
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
__shared__ Float p_shared_block[shared_block_size];
Run(p_a_global, p_b_global, p_c_global, p_shared_block);
}
};
} // namespace ck
#endif
#ifndef CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP
#define CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
namespace ck {
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor
template <index_t VectorDim, index_t ScalarPerVector>
struct lambda_scalar_per_access
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
return (i == VectorDim) ? ScalarPerVector : 1;
}
};
template <index_t VectorDim>
struct lambda_scalar_step_in_vector
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
return (i == VectorDim) ? 1 : 0;
}
};
// this version is less likely to have scratch memory issue, due to:
// 1. It does not keep reference to tensor descriptor
// 2. It does not construct new tensor coordinate for this->Run()
// Assume src_slice_origin_idx is 0
// TODO: support non-zero src_slice_oring_idx
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SliceLengths,
typename DimAccessOrder,
index_t DstVectorDim,
index_t DstScalarPerVector,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
InMemoryDataOperation DstInMemOp,
index_t DstScalarStrideInVector,
bool DstResetCoordinateAfterRun,
typename std::enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseDynamicTensorSliceTransfer_v1r3
{
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using DstCoord = decltype(make_dynamic_tensor_coordinate(DstDesc{}, Index{}));
using DstCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(DstDesc{}, Index{}));
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v1r3(
const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
: dst_slice_origin_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx))
{
static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
}
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
{
dst_slice_origin_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx);
}
template <typename SrcSliceOriginIdx, typename DstIteratorHacks>
__device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&,
const SrcData* p_src,
const DstDesc& dst_desc,
DstData* p_dst,
const DstIteratorHacks& dst_iterator_hacks)
{
static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
static_assert(
is_known_at_compile_time<remove_cv_t<remove_reference_t<SrcSliceOriginIdx>>>::value,
"wrong! SrcSliceOrigin need to known at compile-time");
// SrcDesc and src_slice_origin_idx are known at compile-time
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{};
constexpr auto src_slice_origin_idx = SrcSliceOriginIdx{};
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence(
lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_scalar_step_in_vector =
generate_sequence(lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dim_access_order = DimAccessOrder{};
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
// make forward iterators
const auto dst_forward_iterators = generate_tuple(
[&](auto i) {
Index forward_step;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
});
return make_dynamic_tensor_coordinate_iterator(
dst_desc, forward_step, dst_iterator_hacks[I0][i]);
},
Number<nDim>{});
// make backward iterators
const auto dst_backward_iterators = generate_tuple(
[&](auto i) {
Index backward_step;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
});
return make_dynamic_tensor_coordinate_iterator(
dst_desc, backward_step, dst_iterator_hacks[I1][i]);
},
Number<nDim>{});
// loop over tensor and copy
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep;
forward_sweep(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_idx[I0];
static_for<0, i, 1>{}([&](auto j) {
tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j];
});
forward_sweep(i) = tmp % 2 == 0;
});
return forward_sweep;
}();
// calculate dst data index
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i]
? ordered_access_idx[i]
: ordered_access_lengths[i] - 1 - ordered_access_idx[i];
});
auto dst_data_idx = container_reorder_given_old2new(ordered_idx, dim_access_order) *
dst_scalar_per_access;
return dst_data_idx;
}();
// copy data
vector_type<DstData, DstScalarPerVector> dst_vector;
using dst_vector_t = typename vector_type<DstData, DstScalarPerVector>::type;
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset =
src_desc.CalculateOffset(to_multi_index(src_slice_origin_idx) + dst_data_idx +
i * dst_scalar_step_in_vector);
dst_vector.Scalars()(i) = type_convert<DstData>{}(p_src[Number<src_offset>{}]);
});
const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
dst_desc, dst_slice_origin_coord_);
if constexpr(SrcAddressSpace == AddressSpace::Vgpr &&
DstAddressSpace == AddressSpace::Global)
{
#if CK_USE_AMD_BUFFER_ADDRESSING
amd_buffer_store_v2<DstData, DstScalarPerVector>(
dst_vector.Vector(),
p_dst,
dst_slice_origin_coord_.GetOffset(),
is_dst_valid,
dst_desc.GetElementSpaceSize());
#else
if(is_dst_valid)
{
*reinterpret_cast<dst_vector_t*>(
&(p_dst[dst_slice_origin_coord_.GetOffset()])) = dst_vector.Vector();
}
#endif
}
else
{
if(is_dst_valid)
{
*reinterpret_cast<dst_vector_t*>(
&(p_dst[dst_slice_origin_coord_.GetOffset()])) = dst_vector.Vector();
}
}
constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1;
});
});
return move_on_dim;
}
();
// move
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_dynamic_tensor_coordinate(dst_desc,
dst_slice_origin_coord_,
dst_forward_iterators[dim_access_order[i]]);
}
else
{
move_dynamic_tensor_coordinate(dst_desc,
dst_slice_origin_coord_,
dst_backward_iterators[dim_access_order[i]]);
}
}
});
});
// move dst coordinate back to slice origin (or not)
if constexpr(DstResetCoordinateAfterRun)
{
const auto dst_reset_iterator =
make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep());
move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_coord_, dst_reset_iterator);
}
}
__device__ void Run(const SrcData* p_src, const DstDesc& dst_desc, DstData* p_dst)
{
constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
constexpr auto dst_iterator_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
Run(p_src, dst_desc, p_dst, dst_iterator_hacks);
}
__device__ static constexpr auto GetDstCoordinateResetStep()
{
constexpr auto I0 = Number<0>{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence(
lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dim_access_order = DimAccessOrder{};
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep;
forward_sweep(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_lengths[I0] - 1;
static_for<0, i, 1>{}([&](auto j) {
tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1;
});
forward_sweep(i) = tmp % 2 == 0;
});
return forward_sweep;
}();
// calculate dst data index after last iteration in Run(), if it has not being reset by
// RunWrite()
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0;
});
auto dst_data_idx = container_reorder_given_old2new(ordered_idx, dim_access_order) *
dst_scalar_per_access;
return dst_data_idx;
}();
//
constexpr auto reset_dst_data_step = [&]() {
Index reset_dst_data_step;
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step(i) = -dst_data_idx[i]; });
return reset_dst_data_step;
}();
return reset_dst_data_step;
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Index& dst_slice_origin_step_idx)
{
// if dst coord was not reset by Run(), then need to adjust the step here
const auto adjusted_step_idx =
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step =
make_dynamic_tensor_coordinate_iterator(dst_desc, adjusted_step_idx);
move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_coord_, adjusted_step);
}
private:
DstCoord dst_slice_origin_coord_;
}; // namespace ck
// this version is less likely to have scratch memory issue, due to:
// 1. It does not keep reference to tensor descriptor
// 2. It does not construct new tensor coordinate for this->Run()
// Assume dst_slice_origin_idx is 0
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SliceLengths,
typename DimAccessOrder,
index_t SrcVectorDim,
index_t SrcScalarPerVector,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
index_t SrcScalarStrideInVector,
bool SrcResetCoordinateAfterRun,
typename std::enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseDynamicTensorSliceTransfer_v2
{
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{}));
using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{}));
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v2(const SrcDesc& src_desc,
const Index& src_slice_origin_idx)
: src_slice_origin_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx))
{
static_assert(DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
}
__device__ void SetDstSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{
src_slice_origin_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx);
}
template <typename DstSliceOriginIdx, typename SrcIteratorHacks>
__device__ void Run(const SrcDesc& src_desc,
const SrcData* p_src,
const DstDesc&,
const DstSliceOriginIdx&,
DstData* p_dst,
const SrcIteratorHacks& src_iterator_hacks)
{
static_assert(DstDesc::IsKnownAtCompileTime(),
"wrong! DstDesc need to known at compile-time");
static_assert(
is_known_at_compile_time<remove_cv_t<remove_reference_t<DstSliceOriginIdx>>>::value,
"wrong! DstSliceOrigin need to known at compile-time");
// DstDesc and dst_slice_origin_idx are known at compile-time
constexpr auto dst_desc = remove_cv_t<remove_reference_t<DstDesc>>{};
constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{};
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_scalar_step_in_vector =
generate_sequence(lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto dim_access_order = DimAccessOrder{};
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
// make forward iterators
const auto src_forward_iterators = generate_tuple(
[&](auto i) {
Index forward_step;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
});
return make_dynamic_tensor_coordinate_iterator(
src_desc, forward_step, src_iterator_hacks[I0][i]);
},
Number<nDim>{});
// make backward iterators
const auto src_backward_iterators = generate_tuple(
[&](auto i) {
Index backward_step;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
});
return make_dynamic_tensor_coordinate_iterator(
src_desc, backward_step, src_iterator_hacks[I1][i]);
},
Number<nDim>{});
// loop over tensor and copy
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep;
forward_sweep(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_idx[I0];
static_for<0, i, 1>{}([&](auto j) {
tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j];
});
forward_sweep(i) = tmp % 2 == 0;
});
return forward_sweep;
}();
// calculate src data index
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i]
? ordered_access_idx[i]
: ordered_access_lengths[i] - 1 - ordered_access_idx[i];
});
auto src_data_idx = container_reorder_given_old2new(ordered_idx, dim_access_order) *
src_scalar_per_access;
return src_data_idx;
}();
// copy data
static_assert(DstAddressSpace == AddressSpace::Vgpr, "wrong! hardcode for vgpr dst");
vector_type<SrcData, SrcScalarPerVector> src_vector;
using src_vector_t = typename vector_type<SrcData, SrcScalarPerVector>::type;
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_coord_);
if constexpr(SrcAddressSpace == AddressSpace::Global)
{
#if CK_USE_AMD_BUFFER_ADDRESSING
src_vector.Vector() = amd_buffer_load_v2<SrcData, SrcScalarPerVector>(
p_src,
src_slice_origin_coord_.GetOffset(),
is_src_valid,
src_desc.GetElementSpaceSize());
#else
src_vector.Vector() = is_src_valid
? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0};
#endif
}
else
{
src_vector.Vector() = is_src_valid
? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0};
}
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset =
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
i * src_scalar_step_in_vector);
p_dst[Number<dst_offset>{}] = src_vector.Scalars()[i];
});
constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1;
});
});
return move_on_dim;
}
();
// move
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_dynamic_tensor_coordinate(src_desc,
src_slice_origin_coord_,
src_forward_iterators[dim_access_order[i]]);
}
else
{
move_dynamic_tensor_coordinate(src_desc,
src_slice_origin_coord_,
src_backward_iterators[dim_access_order[i]]);
}
}
});
});
// move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun)
{
const auto src_reset_iterator =
make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep());
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_coord_, src_reset_iterator);
}
}
__device__ void Run(const SrcDesc& src_desc, const SrcData* p_src, DstData* p_dst)
{
constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
constexpr auto src_iterator_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
Run(src_desc, p_src, p_dst, src_iterator_hacks);
}
__device__ static constexpr auto GetSrcCoordinateResetStep()
{
constexpr auto I0 = Number<0>{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto dim_access_order = DimAccessOrder{};
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep;
forward_sweep(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_lengths[I0] - 1;
static_for<0, i, 1>{}([&](auto j) {
tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1;
});
forward_sweep(i) = tmp % 2 == 0;
});
return forward_sweep;
}();
// calculate src data index after last iteration in Run(), if it has not being reset by
// RunWrite()
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0;
});
auto src_data_idx = container_reorder_given_old2new(ordered_idx, dim_access_order) *
src_scalar_per_access;
return src_data_idx;
}();
//
constexpr auto reset_src_data_step = [&]() {
Index reset_src_data_step;
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step(i) = -src_data_idx[i]; });
return reset_src_data_step;
}();
return reset_src_data_step;
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx)
{
// if src coord was not reset by Run(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step =
make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx);
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_coord_, adjusted_step);
}
private:
SrcCoord src_slice_origin_coord_;
}; // namespace ck
// this version does following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions
// 1. It does not keep reference to tensor descriptor
// 2. It does not construct new tensor coordinate for this->Run()
// 3. It does not use pointer for VGPR thread buffer
// 4. It calculate offset for thread buffer directly, instead of moving the coordinate
template <typename SliceLengths,
InMemoryDataOperation DstInMemOp,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t DstScalarPerVector,
index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
// RunRead(), will be fused with MoveSrcSliceWindow to
// save addr computation
bool DstResetCoordinateAfterRun> // control whether to move back dst coordinate after each
// RunWrite(), will be fused with MoveDstSliceWindow to
// save addr computation
struct ThreadwiseDynamicTensorSliceTransfer_v3
{
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_dynamic_tensor_coordinate(DstDesc{}, Index{}));
using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{}));
using DstCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(DstDesc{}, Index{}));
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v3(const SrcDesc& src_desc,
const Index& src_slice_origin,
const DstDesc& dst_desc,
const Index& dst_slice_origin)
: src_slice_origin_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin)),
dst_slice_origin_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin))
{
static_assert(SrcAddressSpace == AddressSpace::Global or
SrcAddressSpace == AddressSpace::Lds,
"wrong!");
static_assert(DstAddressSpace == AddressSpace::Global or
DstAddressSpace == AddressSpace::Lds,
"wrong!");
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{
src_slice_origin_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx);
}
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
{
dst_slice_origin_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx);
}
template <typename SrcIteratorHacks>
__device__ void RunRead(const SrcDesc& src_desc,
const SrcData* p_src,
const SrcIteratorHacks& src_iterator_hacks)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_scalar_step_in_vector =
generate_sequence(lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// make forward iterators
const auto src_forward_iterators = generate_tuple(
[&](auto i) {
Index forward_step;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
});
return make_dynamic_tensor_coordinate_iterator(
src_desc, forward_step, src_iterator_hacks[I0][i]);
},
Number<nDim>{});
// make backward iterators
const auto src_backward_iterators = generate_tuple(
[&](auto i) {
Index backward_step;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
});
return make_dynamic_tensor_coordinate_iterator(
src_desc, backward_step, src_iterator_hacks[I1][i]);
},
Number<nDim>{});
// loop over tensor and copy
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep;
forward_sweep(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_idx[I0];
static_for<0, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
});
forward_sweep(i) = tmp % 2 == 0;
});
return forward_sweep;
}();
// calculate src data index
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
: ordered_src_access_lengths[i] - 1 -
ordered_src_access_idx[i];
});
auto src_data_idx =
container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_scalar_per_access;
return src_data_idx;
}();
// copy data
vector_type<SrcData, SrcScalarPerVector> src_vector;
using src_vector_t = typename vector_type<SrcData, SrcScalarPerVector>::type;
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_coord_);
if constexpr(SrcAddressSpace == AddressSpace::Global)
{
#if CK_USE_AMD_BUFFER_ADDRESSING
src_vector.Vector() = amd_buffer_load_v2<SrcData, SrcScalarPerVector>(
p_src,
src_slice_origin_coord_.GetOffset(),
is_src_valid,
src_desc.GetElementSpaceSize());
#else
src_vector.Vector() = is_src_valid
? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0};
#endif
}
else
{
src_vector.Vector() = is_src_valid
? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0};
}
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t buffer_offset =
buffer_desc_.CalculateOffset(src_data_idx + i * src_scalar_step_in_vector);
buffer_(Number<buffer_offset>{}) = src_vector.Scalars()[i];
});
constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim(i) &=
ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
});
});
return move_on_dim;
}
();
// move
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_dynamic_tensor_coordinate(
src_desc,
src_slice_origin_coord_,
src_forward_iterators[src_dim_access_order[i]]);
}
else
{
move_dynamic_tensor_coordinate(
src_desc,
src_slice_origin_coord_,
src_backward_iterators[src_dim_access_order[i]]);
}
}
});
});
// move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun)
{
const auto src_reset_iterator =
make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep());
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_coord_, src_reset_iterator);
}
}
template <typename DstIteratorHacks>
__device__ void
RunWrite(const DstDesc& dst_desc, DstData* p_dst, const DstIteratorHacks& dst_iterator_hacks)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// src scalar per access on each dim
// TODO: don't use this
constexpr auto dst_scalar_per_access = generate_sequence(
lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_scalar_step_in_vector =
generate_sequence(lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// make forward iterators
const auto dst_forward_iterators = generate_tuple(
[&](auto i) {
Index forward_step;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
});
const auto forward_iterator = make_dynamic_tensor_coordinate_iterator(
dst_desc, forward_step, dst_iterator_hacks[I0][i]);
return forward_iterator;
},
Number<nDim>{});
// make backward iterators
const auto dst_backward_iterators = generate_tuple(
[&](auto i) {
Index backward_step;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
});
const auto backward_iterator = make_dynamic_tensor_coordinate_iterator(
dst_desc, backward_step, dst_iterator_hacks[I1][i]);
return backward_iterator;
},
Number<nDim>{});
// loop over tensor and copy
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep;
forward_sweep(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_idx[I0];
static_for<0, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
});
forward_sweep(i) = tmp % 2 == 0;
});
return forward_sweep;
}();
// calculate dst data index
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i]
: ordered_dst_access_lengths[i] - 1 -
ordered_dst_access_idx[i];
});
auto dst_data_idx =
container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_scalar_per_access;
return dst_data_idx;
}();
// copy data
// hardcoding for ds_write
// TODO refactor transfer_data() to encapsulate this
static_assert(DstAddressSpace == AddressSpace::Lds &&
DstInMemOp == InMemoryDataOperation::Set,
"wrong! hardcoded for ds_write");
vector_type<DstData, DstScalarPerVector> dst_vector;
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t buffer_offset =
buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector);
dst_vector.Scalars()(i) = buffer_[Number<buffer_offset>{}];
});
using DstVectorType = typename vector_type<DstData, DstScalarPerVector>::type;
*reinterpret_cast<DstVectorType*>(p_dst + dst_slice_origin_coord_.GetOffset()) =
dst_vector.Vector();
constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim(i) &=
ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
});
});
return move_on_dim;
}
();
// move
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_dynamic_tensor_coordinate(
dst_desc,
dst_slice_origin_coord_,
dst_forward_iterators[dst_dim_access_order[i]]);
}
else
{
move_dynamic_tensor_coordinate(
dst_desc,
dst_slice_origin_coord_,
dst_backward_iterators[dst_dim_access_order[i]]);
}
}
});
});
// move dst coordinate back to slice origin (or not)
if constexpr(DstResetCoordinateAfterRun)
{
const auto dst_reset_iterator =
make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep());
move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_coord_, dst_reset_iterator);
}
}
__device__ void RunRead(const SrcDesc& src_desc, const SrcData* p_src)
{
constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
constexpr auto src_iterator_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
RunRead(src_desc, p_src, src_iterator_hacks);
}
__device__ void RunWrite(const DstDesc& dst_desc, DstData* p_dst)
{
constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
constexpr auto dst_iterator_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
RunWrite(dst_desc, p_dst, dst_iterator_hacks);
}
__device__ static constexpr auto GetSrcCoordinateResetStep()
{
constexpr auto I0 = Number<0>{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep;
forward_sweep(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_lengths[I0] - 1;
static_for<0, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
});
forward_sweep(i) = tmp % 2 == 0;
});
return forward_sweep;
}();
// calculate src data index after last iteration in RunRead(), if it has not being reset by
// RunRead()
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
});
auto src_data_idx = container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_scalar_per_access;
return src_data_idx;
}();
//
constexpr auto reset_src_data_step = [&]() {
Index reset_src_data_step;
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step(i) = -src_data_idx[i]; });
return reset_src_data_step;
}();
return reset_src_data_step;
}
__device__ static constexpr auto GetDstCoordinateResetStep()
{
constexpr auto I0 = Number<0>{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence(
lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep;
forward_sweep(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_lengths[I0] - 1;
static_for<0, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
});
forward_sweep(i) = tmp % 2 == 0;
});
return forward_sweep;
}();
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
// RunWrite()
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
});
auto dst_data_idx = container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_scalar_per_access;
return dst_data_idx;
}();
//
constexpr auto reset_dst_data_step = [&]() {
Index reset_dst_data_step;
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step(i) = -dst_data_idx[i]; });
return reset_dst_data_step;
}();
return reset_dst_data_step;
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step =
make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx);
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_coord_, adjusted_step);
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template <typename SrcMoveSliceWindowIteratorHack>
__device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx,
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_dynamic_tensor_coordinate_iterator(
src_desc, adjusted_step_idx, src_move_slice_window_iterator_hack);
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_coord_, adjusted_step);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Index& dst_slice_origin_step_idx)
{
// if dst coord was not reset by RunWrite(), then need to adjust the step here
const auto adjusted_step_idx =
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step =
make_dynamic_tensor_coordinate_iterator(dst_desc, adjusted_step_idx);
move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_coord_, adjusted_step);
}
private:
static constexpr auto buffer_desc_ =
make_dynamic_naive_tensor_descriptor_packed_v2(sequence_to_tuple_of_number(SliceLengths{}));
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
StaticallyIndexedArray<SrcData, buffer_size_> buffer_;
SrcCoord src_slice_origin_coord_;
DstCoord dst_slice_origin_coord_;
};
} // namespace ck
#endif
......@@ -39,7 +39,7 @@ struct ThreadwiseMatrixSliceCopy
template <typename Data>
__device__ static void Run(const Data* p_src, Data* p_dst)
{
using vector_t = typename vector_type<Data, DataPerAccess>::MemoryType;
using vector_t = typename vector_type<Data, DataPerAccess>::type;
for(index_t i = 0; i < NSliceRow; ++i)
{
......@@ -153,9 +153,8 @@ struct ThreadwiseGemmTransANormalBNormalC
(is_same<FloatA, half2_t>{} && is_same<FloatB, half2_t>{}) ||
(is_same<FloatA, half4_t>{} && is_same<FloatB, half4_t>{}));
static_if<has_amd_asm>{}([&](auto fwd) {
Run_amd_asm(p_a, p_b, fwd(p_c));
}).Else([&](auto) { Run_source(p_a, p_b, p_c); });
static_if<has_amd_asm>{}([&](auto fwd) { Run_amd_asm(p_a, p_b, fwd(p_c)); })
.Else([&](auto) { Run_source(p_a, p_b, p_c); });
#else
Run_source(p_a, p_b, p_c);
#endif
......
#ifndef CK_THREADWISE_GEMM_V2_HPP
#define CK_THREADWISE_GEMM_V2_HPP
#include "common_header.hpp"
#include "math.hpp"
namespace ck {
template <typename Float, typename Desc>
__device__ void threadwise_matrix_set_zero_v2(Desc, Float* __restrict__ p_thread)
{
static_assert(Desc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto desc = Desc{};
constexpr auto M = desc.GetLength(I0);
constexpr auto N = desc.GetLength(I1);
static_for<0, M, 1>{}([&](auto i) {
static_for<0, N, 1>{}([&](auto j) {
constexpr auto offset = desc.CalculateOffset(make_tuple(i, j));
p_thread[offset] = Float(0);
});
});
}
template <typename SrcDesc,
typename DstDesc,
index_t NSliceRow,
index_t NSliceCol,
index_t DataPerAccess>
struct ThreadwiseMatrixSliceCopy_v2
{
template <typename Data>
__device__ static void Run(const Data* p_src, Data* p_dst)
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
using vector_t = typename vector_type<Data, DataPerAccess>::type;
static_for<0, NSliceRow, 1>{}([&](auto i) {
static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) {
constexpr auto src_offset = SrcDesc{}.CalculateOffset(make_tuple(i, j));
constexpr auto dst_offset = DstDesc{}.CalculateOffset(make_tuple(i, j));
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
*reinterpret_cast<const vector_t*>(&p_src[src_offset]);
});
});
}
};
// C[M, N] += transpose(A[K, M]) * B[K, N]
// Element of matrix can be vectorized data
template <typename ADesc,
typename BDesc,
typename CDesc,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseGemm_km_kn_mn_v1
{
template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run_source(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto M = CDesc{}.GetLength(I0);
constexpr auto N = CDesc{}.GetLength(I1);
constexpr auto K = ADesc{}.GetLength(I0);
static_for<0, K, 1>{}([&](auto k) {
static_for<0, M, 1>{}([&](auto m) {
static_for<0, N, 1>{}([&](auto n) {
constexpr auto a_offset = ADesc{}.CalculateOffset(make_tuple(k, m));
constexpr auto b_offset = BDesc{}.CalculateOffset(make_tuple(k, n));
constexpr auto c_offset = CDesc{}.CalculateOffset(make_tuple(m, n));
p_c[c_offset] +=
inner_product_with_conversion<FloatC>{}(p_a[a_offset], p_b[b_offset]);
});
});
});
}
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run_amd_asm(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto M = CDesc{}.GetLength(I0);
constexpr auto N = CDesc{}.GetLength(I1);
constexpr auto K = ADesc{}.GetLength(I0);
static_assert(N == 4 || N == 2, "wrong! this config not supported by asm yet");
static_for<0, K, 1>{}([&](auto k) {
static_for<0, M, 1>{}([&](auto m) {
constexpr auto a_offset = ADesc{}.CalculateOffset(make_tuple(k, m));
if constexpr(N == 2)
{
constexpr auto b_offset_0 = BDesc{}.CalculateOffset(make_tuple(k, I0));
constexpr auto b_offset_1 = BDesc{}.CalculateOffset(make_tuple(k, I1));
constexpr auto c_offset_0 = CDesc{}.CalculateOffset(make_tuple(m, I0));
constexpr auto c_offset_1 = CDesc{}.CalculateOffset(make_tuple(m, I1));
amd_assembly_outer_product_1x2(p_a[a_offset],
p_b[b_offset_0],
p_b[b_offset_1],
p_c[c_offset_0],
p_c[c_offset_1]);
}
else if constexpr(N == 4)
{
constexpr auto b_offset_0 = BDesc{}.CalculateOffset(make_tuple(k, I0));
constexpr auto b_offset_1 = BDesc{}.CalculateOffset(make_tuple(k, I1));
constexpr auto b_offset_2 = BDesc{}.CalculateOffset(make_tuple(k, I2));
constexpr auto b_offset_3 = BDesc{}.CalculateOffset(make_tuple(k, I3));
constexpr auto c_offset_0 = CDesc{}.CalculateOffset(make_tuple(m, I0));
constexpr auto c_offset_1 = CDesc{}.CalculateOffset(make_tuple(m, I1));
constexpr auto c_offset_2 = CDesc{}.CalculateOffset(make_tuple(m, I2));
constexpr auto c_offset_3 = CDesc{}.CalculateOffset(make_tuple(m, I3));
amd_assembly_outer_product_1x4(p_a[a_offset],
p_b[b_offset_0],
p_b[b_offset_1],
p_b[b_offset_2],
p_b[b_offset_3],
p_c[c_offset_0],
p_c[c_offset_1],
p_c[c_offset_2],
p_c[c_offset_3]);
}
});
});
}
#endif
template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
Run_amd_asm(p_a, p_b, p_c);
#else
Run_source(p_a, p_b, p_c);
#endif
}
};
} // namespace ck
#endif
#ifndef CK_THREADWISE_GEMM_V3_HPP
#define CK_THREADWISE_GEMM_V3_HPP
#include "common_header.hpp"
#include "math.hpp"
namespace ck {
template <typename Float, typename Desc>
__device__ void threadwise_matrix_set_zero_v3(Desc, Float* __restrict__ p_thread)
{
static_assert(Desc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time");
constexpr auto I0 = Number<0>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto desc = Desc{};
constexpr auto K = desc.GetLength(I0);
constexpr auto H = desc.GetLength(I2);
constexpr auto W = desc.GetLength(I3);
static_for<0, K, 1>{}([&](auto i) {
static_for<0, H, 1>{}([&](auto j) {
static_for<0, W, 1>{}([&](auto k) {
constexpr auto offset = desc.CalculateOffset(make_tuple(i, 0, j, k));
p_thread[offset] = Float(0);
});
});
});
}
// C[M, N] += transpose(A[K, M]) * B[K, N]
// Element of matrix can be vectorized data
template <typename ADesc,
typename BDesc,
typename CDesc,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseGemm_km_kn_mn_v3
{
template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run_source(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
// constexpr auto H = BDesc{}.GetLength(I2);
// constexpr auto W = BDesc{}.GetLength(I3);
constexpr auto H = 2;
constexpr auto W = 2;
constexpr auto E = ADesc{}.GetLength(I0);
constexpr auto K = ADesc{}.GetLength(I1);
static_for<0, E, 1>{}([&](auto e) {
static_for<0, K, 1>{}([&](auto k) {
constexpr auto a_offset = ADesc{}.CalculateOffset(make_tuple(e, k));
if constexpr(H == 2 && W == 2)
{
constexpr auto b_offset_0 = BDesc{}.CalculateOffset(make_tuple(e, 0, 0, 0));
constexpr auto b_offset_1 = BDesc{}.CalculateOffset(make_tuple(e, 0, 0, 1));
constexpr auto b_offset_2 = BDesc{}.CalculateOffset(make_tuple(e, 0, 1, 0));
constexpr auto b_offset_3 = BDesc{}.CalculateOffset(make_tuple(e, 0, 1, 1));
constexpr auto c_offset_0 = CDesc{}.CalculateOffset(make_tuple(k, 0, 0, 0));
constexpr auto c_offset_1 = CDesc{}.CalculateOffset(make_tuple(k, 0, 0, 1));
constexpr auto c_offset_2 = CDesc{}.CalculateOffset(make_tuple(k, 0, 1, 0));
constexpr auto c_offset_3 = CDesc{}.CalculateOffset(make_tuple(k, 0, 1, 1));
amd_assembly_outer_product_1x4(p_a[a_offset],
p_b[b_offset_0],
p_b[b_offset_1],
p_b[b_offset_2],
p_b[b_offset_3],
p_c[c_offset_0],
p_c[c_offset_1],
p_c[c_offset_2],
p_c[c_offset_3]);
}
else if constexpr(H == 4 && W == 1)
{
constexpr auto b_offset_0 = BDesc{}.CalculateOffset(make_tuple(e, 0, 0, 0));
constexpr auto b_offset_1 = BDesc{}.CalculateOffset(make_tuple(e, 0, 1, 0));
constexpr auto b_offset_2 = BDesc{}.CalculateOffset(make_tuple(e, 0, 2, 0));
constexpr auto b_offset_3 = BDesc{}.CalculateOffset(make_tuple(e, 0, 3, 0));
constexpr auto c_offset_0 = CDesc{}.CalculateOffset(make_tuple(k, 0, 0, 0));
constexpr auto c_offset_1 = CDesc{}.CalculateOffset(make_tuple(k, 0, 1, 0));
constexpr auto c_offset_2 = CDesc{}.CalculateOffset(make_tuple(k, 0, 2, 0));
constexpr auto c_offset_3 = CDesc{}.CalculateOffset(make_tuple(k, 0, 3, 0));
amd_assembly_outer_product_1x4(p_a[a_offset],
p_b[b_offset_0],
p_b[b_offset_1],
p_b[b_offset_2],
p_b[b_offset_3],
p_c[c_offset_0],
p_c[c_offset_1],
p_c[c_offset_2],
p_c[c_offset_3]);
}
else
{
static_for<0, H, 1>{}([&](auto h) {
static_for<0, W, 1>{}([&](auto w) {
constexpr auto b_offset =
BDesc{}.CalculateOffset(make_tuple(e, 0, h, w));
constexpr auto c_offset =
CDesc{}.CalculateOffset(make_tuple(k, 0, h, w));
p_c[c_offset] += inner_product_with_conversion<FloatC>{}(p_a[a_offset],
p_b[b_offset]);
});
});
}
});
});
}
template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{
Run_source(p_a, p_b, p_c);
}
};
} // namespace ck
#endif
......@@ -54,8 +54,8 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
}
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v4r2()
: ThreadwiseGenericTensorSliceCopy_v4r2(make_zero_array<index_t, nDim>(),
make_zero_array<index_t, nDim>())
: ThreadwiseGenericTensorSliceCopy_v4r2(make_zero_multi_index<nDim>(),
make_zero_multi_index<nDim>())
{
}
......@@ -82,113 +82,104 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
constexpr auto long_vector_access_lengths = SliceLengths::Modify(
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
ford<decltype(long_vector_access_lengths), SrcDstDimAccessOrder>{}([&](
auto long_vector_access_id) {
// data id w.r.t slicing-window
auto long_vector_data_begin_id = long_vector_access_id;
long_vector_data_begin_id(vector_access_dim) =
long_vector_size * long_vector_access_id[vector_access_dim];
// buffer to hold a src long-vector
SrcData p_src_long_vector[long_vector_size];
#if 1
// zero out buffer
for(index_t i = 0; i < long_vector_size; ++i)
{
p_src_long_vector[i] = 0;
}
#endif
// load data from src to the long-vector buffer
for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i)
{
auto scalar_id = make_zero_array<index_t, nDim>();
scalar_id(vector_access_dim) = i * src_data_per_access;
const index_t buffer_offset = i * src_data_per_access;
const auto src_coord = mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id);
// Check src data's valid mapping situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
transfer_data<SrcData,
SrcDataPerRead,
SrcAddressSpace,
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
SrcDataStride,
1>(p_src,
src_coord.GetOffset(),
src_coord.IsOffsetValidAssumingUpperIndexIsValid(),
SrcDesc::GetElementSpace(),
p_src_long_vector,
buffer_offset,
true,
long_vector_size);
}
// SrcData to DstData conversion
DstData p_dst_long_vector[long_vector_size];
for(index_t i = 0; i < long_vector_size; ++i)
{
p_dst_long_vector[i] = type_convert<DstData>{}(p_src_long_vector[i]);
}
// store data from the long-vector buffer to dst
for(index_t i = 0; i < long_vector_size / dst_data_per_access; ++i)
{
auto scalar_id = make_zero_array<index_t, nDim>();
scalar_id(vector_access_dim) = i * dst_data_per_access;
const index_t buffer_offset = i * dst_data_per_access;
const auto dst_coord = mDstSliceOrigin + (long_vector_data_begin_id + scalar_id);
// Check dst data's valid mapping situation, only check the first data in this dst
// vector. It's user's responsiblity to make sure all data in the dst vector
// has the valid/invalid mapping situation
transfer_data<DstData,
DstDataPerWrite,
AddressSpace::Vgpr,
DstAddressSpace,
DstInMemOp,
1,
DstDataStride>(p_dst_long_vector,
buffer_offset,
true,
long_vector_size,
p_dst,
dst_coord.GetOffset(),
dst_coord.IsOffsetValidAssumingUpperIndexIsValid(),
DstDesc::GetElementSpace());
}
});
ford<decltype(long_vector_access_lengths), SrcDstDimAccessOrder>{}(
[&](auto long_vector_access_id) {
// data id w.r.t slicing-window
auto long_vector_data_begin_id = long_vector_access_id;
long_vector_data_begin_id(vector_access_dim) =
long_vector_size * long_vector_access_id[vector_access_dim];
// buffer to hold a src long-vector
SrcData p_src_long_vector[long_vector_size];
// load data from src to the long-vector buffer
static_for<0, long_vector_size / src_data_per_access, 1>{}([&](auto i) {
auto scalar_id = make_zero_multi_index<nDim>();
scalar_id(vector_access_dim) = i * src_data_per_access;
const index_t buffer_offset = i * src_data_per_access;
const auto src_coord =
mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id);
// Check src data's valid mapping situation, only check the first data in this
// src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
transfer_data<SrcData,
SrcDataPerRead,
SrcAddressSpace,
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
SrcDataStride,
1>(p_src,
src_coord.GetOffset(),
src_coord.IsOffsetValidAssumingUpperIndexIsValid(),
SrcDesc::GetElementSpace(),
p_src_long_vector,
buffer_offset,
true,
long_vector_size);
});
// SrcData to DstData conversion
DstData p_dst_long_vector[long_vector_size];
static_for<0, long_vector_size, 1>{}([&](auto i) {
p_dst_long_vector[i] = type_convert<DstData>{}(p_src_long_vector[i]);
});
// store data from the long-vector buffer to dst
static_for<0, long_vector_size / dst_data_per_access, 1>{}([&](auto i) {
auto scalar_id = make_zero_multi_index<nDim>();
scalar_id(vector_access_dim) = i * dst_data_per_access;
const index_t buffer_offset = i * dst_data_per_access;
const auto dst_coord =
mDstSliceOrigin + (long_vector_data_begin_id + scalar_id);
// Check dst data's valid mapping situation, only check the first data in this
// dst
// vector. It's user's responsiblity to make sure all data in the dst vector
// has the valid/invalid mapping situation
transfer_data<DstData,
DstDataPerWrite,
AddressSpace::Vgpr,
DstAddressSpace,
DstInMemOp,
1,
DstDataStride>(p_dst_long_vector,
buffer_offset,
true,
long_vector_size,
p_dst,
dst_coord.GetOffset(),
dst_coord.IsOffsetValidAssumingUpperIndexIsValid(),
DstDesc::GetElementSpace());
});
});
}
template <typename T, bool PositiveDirection>
__device__ void MoveSrcSliceWindow(const T& step_sizes_,
integral_constant<bool, PositiveDirection>)
{
const auto step_sizes = to_array(step_sizes_);
const auto step_sizes = to_multi_index(step_sizes_);
static_if<PositiveDirection>{}([&](auto) {
mSrcSliceOrigin += to_array(step_sizes);
}).Else([&](auto) { mSrcSliceOrigin -= step_sizes; });
static_if<PositiveDirection>{}([&](auto) { mSrcSliceOrigin += to_multi_index(step_sizes); })
.Else([&](auto) { mSrcSliceOrigin -= step_sizes; });
}
template <typename T, bool PositiveDirection>
__device__ void MoveDstSliceWindow(const T& step_sizes_,
integral_constant<bool, PositiveDirection>)
{
const auto step_sizes = to_array(step_sizes_);
const auto step_sizes = to_multi_index(step_sizes_);
static_if<PositiveDirection>{}([&](auto) {
mDstSliceOrigin += step_sizes;
}).Else([&](auto) { mDstSliceOrigin -= step_sizes; });
static_if<PositiveDirection>{}([&](auto) { mDstSliceOrigin += step_sizes; })
.Else([&](auto) { mDstSliceOrigin -= step_sizes; });
}
private:
......
......@@ -2,20 +2,10 @@
#define CK_AMD_BUFFER_ADDRESSING_HPP
#include "float_type.hpp"
#include "amd_buffer_addressing_v2.hpp"
namespace ck {
// For 128 bit SGPRs to supply resource constant in buffer instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
template <typename T>
union BufferResourceConstant
{
int32x4_t data;
T* address[2];
int32_t range[4];
int32_t config[4];
};
__device__ float __llvm_amdgcn_buffer_load_f32(int32x4_t srsrc,
index_t vindex,
index_t offset,
......@@ -35,44 +25,17 @@ __llvm_amdgcn_buffer_load_f32x4(int32x4_t srsrc,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.load.v4f32");
__device__ half_t __llvm_amdgcn_buffer_load_f16(int32x4_t srsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.load.f16");
__device__ half2_t __llvm_amdgcn_buffer_load_f16x2(int32x4_t srsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.load.v2f16");
__device__ half4_t __llvm_amdgcn_buffer_load_f16x4(int32x4_t srsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.load.v4f16");
__device__ ushort __llvm_amdgcn_buffer_load_bf16(int32x4_t srsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.load.bf16");
__device__ ushort2_t
__llvm_amdgcn_buffer_load_bf16x2(int32x4_t srsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.load.v2bf16");
__device__ ushort4_t
__llvm_amdgcn_buffer_load_bf16x4(int32x4_t srsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.load.v4bf16");
__device__ half_t
__llvm_amdgcn_raw_buffer_load_f16(int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16");
__device__ ushort
__llvm_amdgcn_raw_buffer_load_bf16(int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.bf16");
__device__ void __llvm_amdgcn_buffer_store_f32(float vdata,
int32x4_t srsrc,
......@@ -95,67 +58,43 @@ __device__ void __llvm_amdgcn_buffer_store_f32x4(float4_t vdata,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.store.v4f32");
__device__ void __llvm_amdgcn_buffer_store_f16(half_t vdata,
int32x4_t srsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.store.f16");
__device__ void __llvm_amdgcn_buffer_store_f16x2(half2_t vdata,
int32x4_t srsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.store.v2f16");
__device__ void __llvm_amdgcn_buffer_store_f16x4(half4_t vdata,
int32x4_t srsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.store.v4f16");
__device__ void __llvm_amdgcn_buffer_store_bf16(ushort vdata,
int32x4_t srsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.store.bf16");
__device__ void
__llvm_amdgcn_buffer_store_bf16x2(ushort2_t vdata,
int32x4_t srsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.store.v2bf16");
__llvm_amdgcn_raw_buffer_store_f16(half_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16");
__device__ void
__llvm_amdgcn_buffer_store_bf16x4(ushort4_t vdata,
int32x4_t srsrc,
index_t vindex,
index_t offset,
bool glc,
bool slc) __asm("llvm.amdgcn.buffer.store.v4bf16");
__llvm_amdgcn_raw_buffer_store_bf16(ushort vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.bf16");
#if CK_USE_AMD_BUFFER_ATOMIC_FADD
#if CK_HIP_VERSION_FLAT >= 3010020405
// starting ROCm-3.10, the return type becomes float
__device__ float
#else
__device__ void
#endif
__llvm_amdgcn_buffer_atomic_add_f32(float vdata,
int32x4_t srsrc,
int32x4_t rsrc,
index_t vindex,
index_t offset,
bool slc) __asm("llvm.amdgcn.buffer.atomic.fadd.f32");
#endif
// buffer_load requires:
// 1) p_src_thread must be in global memory space, p_dst_thread must be vgpr
// 2) p_src_thread to be a wavewise pointer.
// 1) p_src_wave must be in global memory space
// 2) p_src_wave to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t VectorSize>
__device__ typename vector_type<T, VectorSize>::MemoryType
amd_buffer_load(const T* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_elemenst_space);
__device__ typename vector_type<T, VectorSize>::type amd_buffer_load(const T* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_elemenst_space);
// buffer_store requires:
// 1) p_src_thread must be in vgpr space, p_dst_thread must be global memory
......@@ -185,36 +124,27 @@ __device__ float amd_buffer_load<float, 1>(const float* p_src_wave,
bool src_thread_data_valid,
index_t src_data_range)
{
BufferResourceConstant<float> src_wave_buffer_resource;
BufferResource<float> src_wave_buffer_resource;
// wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<float*>(p_src_wave);
// wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(float);
// wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = 0x00027000;
src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
#if 1 // debug
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
return __llvm_amdgcn_buffer_load_f32(src_wave_buffer_resource.data,
0,
src_thread_data_valid ? src_thread_addr_offset
: 0xffffffff,
false,
false);
#else
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
return __llvm_amdgcn_buffer_load_f32(
src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false);
#endif
#else
return src_thread_data_valid
? __llvm_amdgcn_buffer_load_f32(
src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false)
: 0;
float tmp = __llvm_amdgcn_buffer_load_f32(
src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false);
return src_thread_data_valid ? tmp : float(0);
#endif
}
......@@ -224,29 +154,27 @@ __device__ float2_t amd_buffer_load<float, 2>(const float* p_src_wave,
bool src_thread_data_valid,
index_t src_data_range)
{
BufferResourceConstant<float> src_wave_buffer_resource;
BufferResource<float> src_wave_buffer_resource;
// wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<float*>(p_src_wave);
// wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(float);
// wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = 0x00027000;
src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
return __llvm_amdgcn_buffer_load_f32x2(src_wave_buffer_resource.data,
0,
src_thread_data_valid ? src_thread_addr_offset
: 0xffffffff,
false,
false);
#else
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
return __llvm_amdgcn_buffer_load_f32x2(
src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false);
#else
float2_t tmp = __llvm_amdgcn_buffer_load_f32x2(
src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false);
return src_thread_data_valid ? tmp : float2_t(0);
#endif
}
......@@ -256,29 +184,27 @@ __device__ float4_t amd_buffer_load<float, 4>(const float* p_src_wave,
bool src_thread_data_valid,
index_t src_data_range)
{
BufferResourceConstant<float> src_wave_buffer_resource;
BufferResource<float> src_wave_buffer_resource;
// wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<float*>(p_src_wave);
// wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(float);
// wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = 0x00027000;
src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
return __llvm_amdgcn_buffer_load_f32x4(src_wave_buffer_resource.data,
0,
src_thread_data_valid ? src_thread_addr_offset
: 0xffffffff,
false,
false);
#else
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
return __llvm_amdgcn_buffer_load_f32x4(
src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false);
#else
float4_t tmp = __llvm_amdgcn_buffer_load_f32x4(
src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false);
return src_thread_data_valid ? tmp : float4_t(0);
#endif
}
......@@ -288,33 +214,32 @@ __device__ half_t amd_buffer_load<half_t, 1>(const half_t* p_src_wave,
bool src_thread_data_valid,
index_t src_data_range)
{
BufferResourceConstant<half_t> src_wave_buffer_resource;
BufferResource<half_t> src_wave_buffer_resource;
// wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<half_t*>(p_src_wave);
// wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(half_t);
// wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = 0x00027000;
src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
#if !CK_WORKAROUND_SWDEV_231101
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t);
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
return __llvm_amdgcn_buffer_load_f16(src_wave_buffer_resource.data,
0,
src_thread_data_valid ? src_thread_addr_offset
: 0xffffffff,
false,
false);
#else
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
return __llvm_amdgcn_buffer_load_f16(
src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false);
#endif
// current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and
// everything is passed to Voffset
return __llvm_amdgcn_raw_buffer_load_f16(
src_wave_buffer_resource.data, src_addr_shift + src_thread_addr_offset, 0, 0);
#else
return src_thread_data_valid ? p_src_wave[src_thread_data_offset] : 0;
half_t zero(0);
// current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and
// everything is passed to Voffset
return src_thread_data_valid ? __llvm_amdgcn_raw_buffer_load_f16(
src_wave_buffer_resource.data, src_thread_addr_offset, 0, 0)
: zero;
#endif
}
......@@ -324,32 +249,32 @@ __device__ half2_t amd_buffer_load<half_t, 2>(const half_t* p_src_wave,
bool src_thread_data_valid,
index_t src_data_range)
{
BufferResourceConstant<half_t> src_wave_buffer_resource;
BufferResource<half_t> src_wave_buffer_resource;
// wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<half_t*>(p_src_wave);
// wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(half_t);
// wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = 0x00027000;
src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t);
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
float dst_out_tmp =
__llvm_amdgcn_buffer_load_f32(src_wave_buffer_resource.data,
0,
src_thread_data_valid ? src_thread_addr_offset : 0xffffffff,
false,
false);
#else
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
float dst_out_tmp = __llvm_amdgcn_buffer_load_f32(
src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false);
#endif
return *reinterpret_cast<half2_t*>(&dst_out_tmp);
#else
half2_t zeros(0);
float dst_out_tmp = __llvm_amdgcn_buffer_load_f32(
src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false);
return src_thread_data_valid ? *reinterpret_cast<half2_t*>(&dst_out_tmp) : zeros;
#endif
}
template <>
......@@ -358,32 +283,32 @@ __device__ half4_t amd_buffer_load<half_t, 4>(const half_t* p_src_wave,
bool src_thread_data_valid,
index_t src_data_range)
{
BufferResourceConstant<half_t> src_wave_buffer_resource;
BufferResource<half_t> src_wave_buffer_resource;
// wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<half_t*>(p_src_wave);
// wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(half_t);
// wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = 0x00027000;
src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t);
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
float2_t dst_out_tmp =
__llvm_amdgcn_buffer_load_f32x2(src_wave_buffer_resource.data,
0,
src_thread_data_valid ? src_thread_addr_offset : 0xffffffff,
false,
false);
#else
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
float2_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x2(
src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false);
#endif
return *reinterpret_cast<half4_t*>(&dst_out_tmp);
#else
half4_t zeros(0);
float2_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x2(
src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false);
return src_thread_data_valid ? *reinterpret_cast<half4_t*>(&dst_out_tmp) : zeros;
#endif
}
template <>
......@@ -392,32 +317,32 @@ __device__ half8_t amd_buffer_load<half_t, 8>(const half_t* p_src_wave,
bool src_thread_data_valid,
index_t src_data_range)
{
BufferResourceConstant<half_t> src_wave_buffer_resource;
BufferResource<half_t> src_wave_buffer_resource;
// wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<half_t*>(p_src_wave);
// wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(half_t);
// wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = 0x00027000;
src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t);
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
float4_t dst_out_tmp =
__llvm_amdgcn_buffer_load_f32x4(src_wave_buffer_resource.data,
0,
src_thread_data_valid ? src_thread_addr_offset : 0xffffffff,
false,
false);
#else
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
float4_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x4(
src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false);
#endif
return *reinterpret_cast<half8_t*>(&dst_out_tmp);
#else
half8_t zeros(0);
float4_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x4(
src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false);
return src_thread_data_valid ? *reinterpret_cast<half8_t*>(&dst_out_tmp) : zeros;
#endif
}
template <>
......@@ -426,34 +351,32 @@ __device__ ushort amd_buffer_load<ushort, 1>(const ushort* p_src_wave,
bool src_thread_data_valid,
index_t src_data_range)
{
BufferResourceConstant<ushort> src_wave_buffer_resource;
BufferResource<ushort> src_wave_buffer_resource;
// wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<ushort*>(p_src_wave);
// wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(ushort);
// wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = 0x00027000;
src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
#if !CK_WORKAROUND_SWDEV_231101
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort);
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
return __llvm_amdgcn_buffer_load_bf16(src_wave_buffer_resource.data,
0,
src_thread_data_valid ? src_thread_addr_offset
: 0xffffffff,
false,
false);
#else
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
return __llvm_amdgcn_buffer_load_bf16(
src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false);
#endif
// current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and
// everything is passed to Voffset
return __llvm_amdgcn_raw_buffer_load_bf16(
src_wave_buffer_resource.data, src_addr_shift + src_thread_addr_offset, 0, 0);
#else
return src_thread_data_valid ? p_src_wave[src_thread_data_offset] : 0;
ushort zero(0);
// current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and
// everything is passed to Voffset
return src_thread_data_valid ? __llvm_amdgcn_raw_buffer_load_bf16(
src_wave_buffer_resource.data, src_thread_addr_offset, 0, 0)
: zero;
#endif
}
......@@ -463,32 +386,32 @@ __device__ ushort2_t amd_buffer_load<ushort, 2>(const ushort* p_src_wave,
bool src_thread_data_valid,
index_t src_data_range)
{
BufferResourceConstant<ushort> src_wave_buffer_resource;
BufferResource<ushort> src_wave_buffer_resource;
// wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<ushort*>(p_src_wave);
// wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(ushort);
// wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = 0x00027000;
src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort);
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
float dst_out_tmp =
__llvm_amdgcn_buffer_load_f32(src_wave_buffer_resource.data,
0,
src_thread_data_valid ? src_thread_addr_offset : 0xffffffff,
false,
false);
#else
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
float dst_out_tmp = __llvm_amdgcn_buffer_load_f32(
src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false);
#endif
return *reinterpret_cast<ushort2_t*>(&dst_out_tmp);
#else
ushort2_t zeros(0);
float dst_out_tmp = __llvm_amdgcn_buffer_load_f32(
src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false);
return src_thread_data_valid ? *reinterpret_cast<ushort2_t*>(&dst_out_tmp) : zeros;
#endif
}
template <>
......@@ -497,32 +420,32 @@ __device__ ushort4_t amd_buffer_load<ushort, 4>(const ushort* p_src_wave,
bool src_thread_data_valid,
index_t src_data_range)
{
BufferResourceConstant<ushort> src_wave_buffer_resource;
BufferResource<ushort> src_wave_buffer_resource;
// wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<ushort*>(p_src_wave);
// wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(ushort);
// wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = 0x00027000;
src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort);
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
float2_t dst_out_tmp =
__llvm_amdgcn_buffer_load_f32x2(src_wave_buffer_resource.data,
0,
src_thread_data_valid ? src_thread_addr_offset : 0xffffffff,
false,
false);
#else
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
float2_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x2(
src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false);
#endif
return *reinterpret_cast<ushort4_t*>(&dst_out_tmp);
#else
ushort4_t zeros(0);
float2_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x2(
src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false);
return src_thread_data_valid ? *reinterpret_cast<ushort4_t*>(&dst_out_tmp) : zeros;
#endif
}
template <>
......@@ -531,32 +454,32 @@ __device__ ushort8_t amd_buffer_load<ushort, 8>(const ushort* p_src_wave,
bool src_thread_data_valid,
index_t src_data_range)
{
BufferResourceConstant<ushort> src_wave_buffer_resource;
BufferResource<ushort> src_wave_buffer_resource;
// wavewise base address (64 bit)
src_wave_buffer_resource.address[0] = const_cast<ushort*>(p_src_wave);
// wavewise range (32 bit)
src_wave_buffer_resource.range[2] = src_data_range * sizeof(ushort);
// wavewise setting (32 bit)
src_wave_buffer_resource.config[3] = 0x00027000;
src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort);
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
float4_t dst_out_tmp =
__llvm_amdgcn_buffer_load_f32x4(src_wave_buffer_resource.data,
0,
src_thread_data_valid ? src_thread_addr_offset : 0xffffffff,
false,
false);
#else
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
float4_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x4(
src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false);
#endif
return *reinterpret_cast<ushort8_t*>(&dst_out_tmp);
#else
ushort8_t zeros(0);
float4_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x4(
src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false);
return src_thread_data_valid ? *reinterpret_cast<ushort8_t*>(&dst_out_tmp) : zeros;
#endif
}
template <>
......@@ -566,26 +489,18 @@ __device__ void amd_buffer_store<float, 1>(const float* p_src_thread,
bool dst_thread_data_valid,
index_t dst_data_range)
{
BufferResourceConstant<float> dst_wave_buffer_resource;
BufferResource<float> dst_wave_buffer_resource;
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float);
// wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = 0x00027000;
dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
#if 1 // debug
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
__llvm_amdgcn_buffer_store_f32(*p_src_thread,
dst_wave_buffer_resource.data,
0,
dst_thread_data_valid ? dst_thread_addr_offset : 0xffffffff,
false,
false);
#else
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_buffer_store_f32(*p_src_thread,
......@@ -594,7 +509,6 @@ __device__ void amd_buffer_store<float, 1>(const float* p_src_thread,
dst_addr_shift + dst_thread_addr_offset,
false,
false);
#endif
#else
if(dst_thread_data_valid)
{
......@@ -611,25 +525,18 @@ __device__ void amd_buffer_store<float, 2>(const float* p_src_thread,
bool dst_thread_data_valid,
index_t dst_data_range)
{
BufferResourceConstant<float> dst_wave_buffer_resource;
BufferResource<float> dst_wave_buffer_resource;
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float);
// wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = 0x00027000;
dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
__llvm_amdgcn_buffer_store_f32x2(*reinterpret_cast<const float2_t*>(p_src_thread),
dst_wave_buffer_resource.data,
0,
dst_thread_data_valid ? dst_thread_addr_offset : 0xffffffff,
false,
false);
#else
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_buffer_store_f32x2(*reinterpret_cast<const float2_t*>(p_src_thread),
......@@ -638,6 +545,16 @@ __device__ void amd_buffer_store<float, 2>(const float* p_src_thread,
dst_addr_shift + dst_thread_addr_offset,
false,
false);
#else
if(dst_thread_data_valid)
{
__llvm_amdgcn_buffer_store_f32x2(*reinterpret_cast<const float2_t*>(p_src_thread),
dst_wave_buffer_resource.data,
0,
dst_thread_addr_offset,
false,
false);
}
#endif
}
......@@ -648,25 +565,18 @@ __device__ void amd_buffer_store<float, 4>(const float* p_src_thread,
bool dst_thread_data_valid,
index_t dst_data_range)
{
BufferResourceConstant<float> dst_wave_buffer_resource;
BufferResource<float> dst_wave_buffer_resource;
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float);
// wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = 0x00027000;
dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
__llvm_amdgcn_buffer_store_f32x4(*reinterpret_cast<const float4_t*>(p_src_thread),
dst_wave_buffer_resource.data,
0,
dst_thread_data_valid ? dst_thread_addr_offset : 0xffffffff,
false,
false);
#else
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_buffer_store_f32x4(*reinterpret_cast<const float4_t*>(p_src_thread),
......@@ -675,6 +585,16 @@ __device__ void amd_buffer_store<float, 4>(const float* p_src_thread,
dst_addr_shift + dst_thread_addr_offset,
false,
false);
#else
if(dst_thread_data_valid)
{
__llvm_amdgcn_buffer_store_f32x4(*reinterpret_cast<const float4_t*>(p_src_thread),
dst_wave_buffer_resource.data,
0,
dst_thread_addr_offset,
false,
false);
}
#endif
}
......@@ -685,40 +605,34 @@ __device__ void amd_buffer_store<half_t, 1>(const half_t* p_src_thread,
bool dst_thread_data_valid,
index_t dst_data_range)
{
BufferResourceConstant<half_t> dst_wave_buffer_resource;
BufferResource<half_t> dst_wave_buffer_resource;
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(half_t);
// wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = 0x00027000;
dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
#if !CK_WORKAROUND_SWDEV_231101
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t);
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
__llvm_amdgcn_buffer_store_f16(*p_src_thread,
dst_wave_buffer_resource.data,
0,
dst_thread_data_valid ? dst_thread_addr_offset : 0xffffffff,
false,
false);
#else
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_buffer_store_f16(*p_src_thread,
dst_wave_buffer_resource.data,
0,
dst_addr_shift + dst_thread_addr_offset,
false,
false);
#endif
// current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and
// everything is passed to Voffset
__llvm_amdgcn_raw_buffer_store_f16(*p_src_thread,
dst_wave_buffer_resource.data,
dst_addr_shift + dst_thread_addr_offset,
0,
0);
#else
if(dst_thread_data_valid)
{
p_dst_wave[dst_thread_data_offset] = *p_src_thread;
// current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and
// everything is passed to Voffset
__llvm_amdgcn_raw_buffer_store_f16(
*p_src_thread, dst_wave_buffer_resource.data, dst_thread_addr_offset, 0, 0);
}
#endif
}
......@@ -730,27 +644,20 @@ __device__ void amd_buffer_store<half_t, 2>(const half_t* p_src_thread,
bool dst_thread_data_valid,
index_t dst_data_range)
{
BufferResourceConstant<half_t> dst_wave_buffer_resource;
BufferResource<half_t> dst_wave_buffer_resource;
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(half_t);
// wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = 0x00027000;
dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t);
const float* p_src_tmp = reinterpret_cast<const float*>(p_src_thread);
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
__llvm_amdgcn_buffer_store_f32(*p_src_tmp,
dst_wave_buffer_resource.data,
0,
dst_thread_data_valid ? dst_thread_addr_offset : 0xffffffff,
false,
false);
#else
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_buffer_store_f32(*p_src_tmp,
......@@ -759,6 +666,12 @@ __device__ void amd_buffer_store<half_t, 2>(const half_t* p_src_thread,
dst_addr_shift + dst_thread_addr_offset,
false,
false);
#else
if(dst_thread_data_valid)
{
__llvm_amdgcn_buffer_store_f32(
*p_src_tmp, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false, false);
}
#endif
}
......@@ -769,27 +682,20 @@ __device__ void amd_buffer_store<half_t, 4>(const half_t* p_src_thread,
bool dst_thread_data_valid,
index_t dst_data_range)
{
BufferResourceConstant<half_t> dst_wave_buffer_resource;
BufferResource<half_t> dst_wave_buffer_resource;
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(half_t);
// wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = 0x00027000;
dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t);
const float2_t* p_src_tmp = reinterpret_cast<const float2_t*>(p_src_thread);
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
__llvm_amdgcn_buffer_store_f32x2(*p_src_tmp,
dst_wave_buffer_resource.data,
0,
dst_thread_data_valid ? dst_thread_addr_offset : 0xffffffff,
false,
false);
#else
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_buffer_store_f32x2(*p_src_tmp,
......@@ -798,6 +704,12 @@ __device__ void amd_buffer_store<half_t, 4>(const half_t* p_src_thread,
dst_addr_shift + dst_thread_addr_offset,
false,
false);
#else
if(dst_thread_data_valid)
{
__llvm_amdgcn_buffer_store_f32x2(
*p_src_tmp, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false, false);
}
#endif
}
......@@ -808,27 +720,20 @@ __device__ void amd_buffer_store<half_t, 8>(const half_t* p_src_thread,
bool dst_thread_data_valid,
index_t dst_data_range)
{
BufferResourceConstant<half_t> dst_wave_buffer_resource;
BufferResource<half_t> dst_wave_buffer_resource;
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(half_t);
// wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = 0x00027000;
dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t);
const float4_t* p_src_tmp = reinterpret_cast<const float4_t*>(p_src_thread);
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
__llvm_amdgcn_buffer_store_f32x4(*p_src_tmp,
dst_wave_buffer_resource.data,
0,
dst_thread_data_valid ? dst_thread_addr_offset : 0xffffffff,
false,
false);
#else
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_buffer_store_f32x4(*p_src_tmp,
......@@ -837,6 +742,12 @@ __device__ void amd_buffer_store<half_t, 8>(const half_t* p_src_thread,
dst_addr_shift + dst_thread_addr_offset,
false,
false);
#else
if(dst_thread_data_valid)
{
__llvm_amdgcn_buffer_store_f32x4(
*p_src_tmp, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false, false);
}
#endif
}
......@@ -847,40 +758,30 @@ __device__ void amd_buffer_store<ushort, 1>(const ushort* p_src_thread,
bool dst_thread_data_valid,
index_t dst_data_range)
{
BufferResourceConstant<ushort> dst_wave_buffer_resource;
BufferResource<ushort> dst_wave_buffer_resource;
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(ushort);
// wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = 0x00027000;
dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
#if !CK_WORKAROUND_SWDEV_231101
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort);
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
__llvm_amdgcn_buffer_store_bf16(*p_src_thread,
dst_wave_buffer_resource.data,
0,
dst_thread_data_valid ? dst_thread_addr_offset : 0xffffffff,
false,
false);
#else
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_buffer_store_bf16(*p_src_thread,
dst_wave_buffer_resource.data,
0,
dst_addr_shift + dst_thread_addr_offset,
false,
false);
#endif
__llvm_amdgcn_raw_buffer_store_bf16(*p_src_thread,
dst_wave_buffer_resource.data,
dst_addr_shift + dst_thread_addr_offset,
0,
0);
#else
if(dst_thread_data_valid)
{
p_dst_wave[dst_thread_data_offset] = *p_src_thread;
__llvm_amdgcn_raw_buffer_store_bf16(
*p_src_thread, dst_wave_buffer_resource.data, dst_thread_addr_offset, 0, 0);
}
#endif
}
......@@ -892,27 +793,20 @@ __device__ void amd_buffer_store<ushort, 2>(const ushort* p_src_thread,
bool dst_thread_data_valid,
index_t dst_data_range)
{
BufferResourceConstant<ushort> dst_wave_buffer_resource;
BufferResource<ushort> dst_wave_buffer_resource;
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(ushort);
// wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = 0x00027000;
dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort);
const float* p_src_tmp = reinterpret_cast<const float*>(p_src_thread);
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
__llvm_amdgcn_buffer_store_f32(*p_src_tmp,
dst_wave_buffer_resource.data,
0,
dst_thread_data_valid ? dst_thread_addr_offset : 0xffffffff,
false,
false);
#else
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_buffer_store_f32(*p_src_tmp,
......@@ -921,6 +815,12 @@ __device__ void amd_buffer_store<ushort, 2>(const ushort* p_src_thread,
dst_addr_shift + dst_thread_addr_offset,
false,
false);
#else
if(dst_thread_data_valid)
{
__llvm_amdgcn_buffer_store_f32(
*p_src_tmp, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false, false);
}
#endif
}
......@@ -931,27 +831,20 @@ __device__ void amd_buffer_store<ushort, 4>(const ushort* p_src_thread,
bool dst_thread_data_valid,
index_t dst_data_range)
{
BufferResourceConstant<ushort> dst_wave_buffer_resource;
BufferResource<ushort> dst_wave_buffer_resource;
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(ushort);
// wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = 0x00027000;
dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort);
const float2_t* p_src_tmp = reinterpret_cast<const float2_t*>(p_src_thread);
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
__llvm_amdgcn_buffer_store_f32x2(*p_src_tmp,
dst_wave_buffer_resource.data,
0,
dst_thread_data_valid ? dst_thread_addr_offset : 0xffffffff,
false,
false);
#else
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_buffer_store_f32x2(*p_src_tmp,
......@@ -960,6 +853,12 @@ __device__ void amd_buffer_store<ushort, 4>(const ushort* p_src_thread,
dst_addr_shift + dst_thread_addr_offset,
false,
false);
#else
if(dst_thread_data_valid)
{
__llvm_amdgcn_buffer_store_f32x2(
*p_src_tmp, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false, false);
}
#endif
}
......@@ -970,27 +869,20 @@ __device__ void amd_buffer_store<ushort, 8>(const ushort* p_src_thread,
bool dst_thread_data_valid,
index_t dst_data_range)
{
BufferResourceConstant<ushort> dst_wave_buffer_resource;
BufferResource<ushort> dst_wave_buffer_resource;
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(ushort);
// wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = 0x00027000;
dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort);
const float4_t* p_src_tmp = reinterpret_cast<const float4_t*>(p_src_thread);
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
__llvm_amdgcn_buffer_store_f32x4(*p_src_tmp,
dst_wave_buffer_resource.data,
0,
dst_thread_data_valid ? dst_thread_addr_offset : 0xffffffff,
false,
false);
#else
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_buffer_store_f32x4(*p_src_tmp,
......@@ -999,9 +891,16 @@ __device__ void amd_buffer_store<ushort, 8>(const ushort* p_src_thread,
dst_addr_shift + dst_thread_addr_offset,
false,
false);
#else
if(dst_thread_data_valid)
{
__llvm_amdgcn_buffer_store_f32x4(
*p_src_tmp, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false, false);
}
#endif
}
#if CK_USE_AMD_BUFFER_ATOMIC_FADD
template <>
__device__ void amd_buffer_atomic_add<float, 1>(const float* p_src_thread,
float* p_dst_wave,
......@@ -1009,24 +908,18 @@ __device__ void amd_buffer_atomic_add<float, 1>(const float* p_src_thread,
bool dst_thread_data_valid,
index_t dst_data_range)
{
BufferResourceConstant<float> dst_wave_buffer_resource;
BufferResource<float> dst_wave_buffer_resource;
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float);
// wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = 0x00027000;
dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
__llvm_amdgcn_buffer_atomic_add_f32(*p_src_thread,
dst_wave_buffer_resource.data,
0,
dst_thread_data_valid ? dst_thread_addr_offset : 0xffffffff,
false);
#else
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
__llvm_amdgcn_buffer_atomic_add_f32(*p_src_thread,
......@@ -1034,6 +927,12 @@ __device__ void amd_buffer_atomic_add<float, 1>(const float* p_src_thread,
0,
dst_addr_shift + dst_thread_addr_offset,
false);
#else
if(dst_thread_data_valid)
{
__llvm_amdgcn_buffer_atomic_add_f32(
*p_src_thread, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false);
}
#endif
}
......@@ -1044,28 +943,18 @@ __device__ void amd_buffer_atomic_add<float, 2>(const float* p_src_thread,
bool dst_thread_data_valid,
index_t dst_data_range)
{
BufferResourceConstant<float> dst_wave_buffer_resource;
BufferResource<float> dst_wave_buffer_resource;
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range;
// wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = 0x00027000;
dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
for(index_t i = 0; i < 2; ++i)
{
__llvm_amdgcn_buffer_atomic_add_f32(
p_src_thread[i],
dst_wave_buffer_resource.data,
0,
dst_thread_data_valid ? (dst_thread_addr_offset + i * sizeof(float)) : 0xffffffff,
false);
}
#else
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
for(index_t i = 0; i < 2; ++i)
......@@ -1077,6 +966,18 @@ __device__ void amd_buffer_atomic_add<float, 2>(const float* p_src_thread,
i * sizeof(float),
false);
}
#else
if(dst_thread_data_valid)
{
for(index_t i = 0; i < 2; ++i)
{
__llvm_amdgcn_buffer_atomic_add_f32(p_src_thread[i],
dst_wave_buffer_resource.data,
0,
dst_thread_addr_offset + i * sizeof(float),
false);
}
}
#endif
}
......@@ -1087,28 +988,18 @@ __device__ void amd_buffer_atomic_add<float, 4>(const float* p_src_thread,
bool dst_thread_data_valid,
index_t dst_data_range)
{
BufferResourceConstant<float> dst_wave_buffer_resource;
BufferResource<float> dst_wave_buffer_resource;
// wavewise base address (64 bit)
dst_wave_buffer_resource.address[0] = p_dst_wave;
// wavewise range (32 bit)
dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float);
// wavewise setting (32 bit)
dst_wave_buffer_resource.config[3] = 0x00027000;
dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
for(index_t i = 0; i < 4; ++i)
{
__llvm_amdgcn_buffer_atomic_add_f32(
p_src_thread[i],
dst_wave_buffer_resource.data,
0,
dst_thread_data_valid ? (dst_thread_addr_offset + i * sizeof(float)) : 0xffffffff,
false);
}
#else
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
for(index_t i = 0; i < 4; ++i)
......@@ -1120,8 +1011,21 @@ __device__ void amd_buffer_atomic_add<float, 4>(const float* p_src_thread,
i * sizeof(float),
false);
}
#else
if(dst_thread_data_valid)
{
for(index_t i = 0; i < 4; ++i)
{
__llvm_amdgcn_buffer_atomic_add_f32(p_src_thread[i],
dst_wave_buffer_resource.data,
0,
dst_thread_addr_offset + i * sizeof(float),
false);
}
}
#endif
}
#endif // CK_USE_AMD_BUFFER_ATOMIC_FADD
} // namespace ck
#endif
#ifndef CK_AMD_BUFFER_ADDRESSING_V2_HPP
#define CK_AMD_BUFFER_ADDRESSING_V2_HPP
#include "float_type.hpp"
namespace ck {
template <typename T>
union BufferResource
{
// 128 bit SGPRs to supply buffer resource in buffer instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
int32x4_t data;
T* address[2];
int32_t range[4];
int32_t config[4];
};
template <typename T>
__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t data_space_size)
{
BufferResource<T> wave_buffer_resource;
// wavewise base address (64 bit)
wave_buffer_resource.address[0] = const_cast<remove_cv_t<T>*>(p_wave);
// wavewise range (32 bit)
wave_buffer_resource.range[2] = data_space_size * sizeof(T);
// wavewise setting (32 bit)
wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
return wave_buffer_resource.data;
}
// load
__device__ int8_t
__llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8");
__device__ int16_t
__llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
__device__ int32_t
__llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
__device__ int32x2_t
__llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32");
__device__ int32x4_t
__llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
__device__ float
__llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32");
__device__ float2_t
__llvm_amdgcn_raw_buffer_load_fp32x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32");
__device__ float4_t
__llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32");
// store
__device__ void
__llvm_amdgcn_raw_buffer_store_i8(int8_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8");
__device__ void
__llvm_amdgcn_raw_buffer_store_i16(int16_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
__device__ void
__llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32");
__device__ void
__llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32");
__device__ void
__llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32");
__device__ void
__llvm_amdgcn_raw_buffer_store_fp32(float vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32");
__device__ void
__llvm_amdgcn_raw_buffer_store_fp32x2(float2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32");
__device__ void
__llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
template <typename T, index_t N>
__device__ typename vector_type<T, N>::type
amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset)
{
static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)),
"wrong! not implemented");
if constexpr(is_same<T, float>::value)
{
if constexpr(N == 1)
{
return __llvm_amdgcn_raw_buffer_load_fp32(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 2)
{
return __llvm_amdgcn_raw_buffer_load_fp32x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 4)
{
return __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 8)
{
vector_type<float, 8> tmp;
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset, 4 * sizeof(float), 0);
return tmp.Vector();
}
}
else if constexpr(is_same<T, int32_t>::value)
{
if constexpr(N == 1)
{
return __llvm_amdgcn_raw_buffer_load_i32(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 2)
{
return __llvm_amdgcn_raw_buffer_load_i32x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 4)
{
return __llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 8)
{
vector_type<int32_t, 8> tmp;
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, 4 * sizeof(int32_t), 0);
return tmp.Vector();
}
}
}
template <typename T, index_t N>
__device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
{
static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4)),
"wrong! not implemented");
if constexpr(is_same<T, float>::value)
{
if constexpr(N == 1)
{
__llvm_amdgcn_raw_buffer_store_fp32(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
__llvm_amdgcn_raw_buffer_store_fp32x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
__llvm_amdgcn_raw_buffer_store_fp32x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
}
else if constexpr(is_same<T, int32_t>::value)
{
if constexpr(N == 1)
{
__llvm_amdgcn_raw_buffer_store_i32(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
__llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
__llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
}
else if constexpr(is_same<T, int8_t>::value)
{
if constexpr(N == 1)
{
__llvm_amdgcn_raw_buffer_store_i8(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
__llvm_amdgcn_raw_buffer_store_i16(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
__llvm_amdgcn_raw_buffer_store_i32(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
}
}
// buffer_load requires:
// 1) p_src_wave must be in global memory space
// 2) p_src_wave to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t N>
__device__ typename vector_type<T, N>::type amd_buffer_load_v2(const T* p_src_wave,
index_t src_thread_data_offset,
bool src_thread_data_valid,
index_t src_element_space)
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space);
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(T);
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
return amd_buffer_load_impl_v2<T, N>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#else
using vector_t = typename vector_type<T, N>::type;
vector_t tmp =
amd_buffer_load_impl_v2<T, N>(src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_data_valid ? tmp : vector_t(0);
#endif
}
// buffer_store requires:
// 1) p_dst_wave must be global memory
// 2) p_dst_wave to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t N>
__device__ void amd_buffer_store_v2(const typename vector_type<T, N>::type src_thread_data,
T* p_dst_wave,
const index_t dst_thread_data_offset,
const bool dst_thread_data_valid,
const index_t dst_element_space)
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space);
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(T);
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
amd_buffer_store_impl_v2<T, N>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else
if(dst_thread_data_valid)
{
amd_buffer_store_impl_v2<T, N>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
#endif
}
} // namespace ck
#endif
......@@ -5,21 +5,44 @@
namespace ck {
// outer-product: c[i,j] += inner_product(a[i], b[j])
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
{
#if CK_USE_AMD_V_FMAC_F32
asm volatile("\n \
v_fmac_f32 %0, %2, %3 \n \
v_fmac_f32 %1, %2, %4 \n \
"
: "=v"(c0), "=v"(c1)
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
#else
asm volatile("\n \
v_mac_f32 %0, %2, %3 \n \
v_mac_f32 %1, %2, %4 \n \
"
: "=v"(c0), "=v"(c1)
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
#endif
}
// outer-product: c[i,j] += inner_product(a[i], b[j])
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(
float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3)
{
#if CK_USE_AMD_V_FMAC_F32
asm volatile("\n \
v_fmac_f32 %0, %4, %5 \n \
v_fmac_f32 %1, %4, %6 \n \
v_fmac_f32 %2, %4, %7 \n \
v_fmac_f32 %3, %4, %8 \n \
"
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
#else
asm volatile("\n \
v_mac_f32 %0, %4, %5 \n \
v_mac_f32 %1, %4, %6 \n \
......@@ -28,9 +51,11 @@ __device__ void amd_assembly_outer_product_1x4(
"
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
#endif
}
// outer-product: c[i,j] += inner_product(a[i], b[j])
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void
amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, float& c1)
{
......@@ -38,15 +63,12 @@ amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, flo
v_dot2_f32_f16 %0, %2, %3, %0\n \
v_dot2_f32_f16 %1, %2, %4, %1\n \
"
: "=v"(c0), "=v"(c1) // Dest registers
: "v"(a), // 1st Src register for 1 half2 registers
"v"(b0), // 2nd Src register
"v"(b1),
"0"(c0), // 3rd Src register
"1"(c1));
: "=v"(c0), "=v"(c1)
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
}
// outer-product: c[i,j] += inner_product(a[i], b[j])
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void
amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1)
{
......@@ -61,18 +83,21 @@ amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, flo
v_dot2_f32_f16 %0, %3, %5, %0\n \
v_dot2_f32_f16 %1, %3, %7, %1\n \
"
: "=v"(c0), "=v"(c1) // Dest registers
: "=v"(c0), "=v"(c1)
: "v"(p_a_half2[0]),
"v"(p_a_half2[1]), // 1st Src registers for 2 half2 registers
"v"(p_a_half2[1]),
"v"(p_b0_half2[0]),
"v"(p_b0_half2[1]),
"v"(p_b1_half2[0]),
"v"(p_b1_half2[1]), // 2nd Src registers for 2 half2 registers
"v"(p_b1_half2[1]),
"0"(c0),
"1"(c1)); // 3rd Src Acc registers for 2 half2 registers
"1"(c1));
}
// outer-product: c[i,j] += inner_product(a[i], b[j])
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(half2_t a,
half2_t b0,
half2_t b1,
......@@ -89,19 +114,14 @@ __device__ void amd_assembly_outer_product_1x4(half2_t a,
v_dot2_f32_f16 %2, %4, %7, %2\n \
v_dot2_f32_f16 %3, %4, %8, %3\n \
"
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) // Dest registers
: "v"(a), // 1st Src register for 1 half2 registers
"v"(b0), // 2nd Src register
"v"(b1),
"v"(b2),
"v"(b3),
"0"(c0), // 3rd Src register
"1"(c1),
"2"(c2),
"3"(c3));
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
}
// outer-product: c[i,j] += inner_product(a[i], b[j])
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(half4_t a,
half4_t b0,
half4_t b1,
......@@ -129,21 +149,70 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a,
v_dot2_f32_f16 %2, %5, %11, %2\n \
v_dot2_f32_f16 %3, %5, %13, %3\n \
"
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) // Dest registers
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(p_a_half2[0]),
"v"(p_a_half2[1]), // 1st Src registers for 2 half2 registers
"v"(p_a_half2[1]),
"v"(p_b0_half2[0]),
"v"(p_b0_half2[1]),
"v"(p_b1_half2[0]),
"v"(p_b1_half2[1]), // 2nd Src registers for 2 half2 registers
"v"(p_b1_half2[1]),
"v"(p_b2_half2[0]),
"v"(p_b2_half2[1]),
"v"(p_b3_half2[0]),
"v"(p_b3_half2[1]), // 2nd Src registers for 2 half2 registers
"v"(p_b3_half2[1]),
"0"(c0),
"1"(c1),
"2"(c2),
"3"(c3)); // 3rd Src Acc registers for 2 half2 registers
"3"(c3));
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void
amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0, int32_t& c1)
{
#if 1
asm volatile("\n \
v_dot4_i32_i8 %0, %2, %3, %0\n \
v_dot4_i32_i8 %1, %2, %4, %1\n \
"
: "=v"(c0), "=v"(c1)
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
#else
c0 = __builtin_amdgcn_sdot4(a, b0, c0, false);
c1 = __builtin_amdgcn_sdot4(a, b1, c1, false);
#endif
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(int8x4_t a,
int8x4_t b0,
int8x4_t b1,
int8x4_t b2,
int8x4_t b3,
int32_t& c0,
int32_t& c1,
int32_t& c2,
int32_t& c3)
{
#if 1
asm volatile("\n \
v_dot4_i32_i8 %0, %4, %5, %0\n \
v_dot4_i32_i8 %1, %4, %6, %1\n \
v_dot4_i32_i8 %2, %4, %7, %2\n \
v_dot4_i32_i8 %3, %4, %8, %3\n \
"
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
#else
c0 = __builtin_amdgcn_sdot4(a, b0, c0, false);
c1 = __builtin_amdgcn_sdot4(a, b1, c1, false);
c2 = __builtin_amdgcn_sdot4(a, b2, c2, false);
c3 = __builtin_amdgcn_sdot4(a, b3, c3, false);
#endif
}
} // namespace ck
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment