Commit cfc80c01 authored by ltqin's avatar ltqin
Browse files

Merge branch 'develop' into ck_conv_bww_fp16

parents 69ea9ad9 6d4450ef
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "blockwise_tensor_slice_transfer_v4r1.hpp" #include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "blockwise_tensor_slice_transfer_v6r3.hpp" #include "blockwise_tensor_slice_transfer_v6r3.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
namespace ck { namespace ck {
...@@ -24,7 +25,7 @@ template <typename GridwiseGemm, ...@@ -24,7 +25,7 @@ template <typename GridwiseGemm,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainKBlockLoop> bool HasMainK0BlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -50,7 +51,7 @@ __global__ void ...@@ -50,7 +51,7 @@ __global__ void
{ {
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainK0BlockLoop>(
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
...@@ -109,7 +110,8 @@ template < ...@@ -109,7 +110,8 @@ template <
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl> index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
index_t NumPrefetch = 1>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -242,6 +244,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -242,6 +244,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false; return false;
// check NumPrefetch
if constexpr(NumPrefetch == 1)
{
// 1-stage prefetch always supported
}
else if constexpr(NumPrefetch == 2)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if(!((K0 / K0PerBlock) % 2 == 0))
{
return false;
}
}
else
{
return false;
}
// check M01, N01 // check M01, N01
constexpr auto M1 = Number<MPerBlock>{}; constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{}; constexpr auto N1 = Number<NPerBlock>{};
...@@ -267,9 +288,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -267,9 +288,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
return grid_size; return grid_size;
} }
// TODO move this function into GEMM-pipeline class
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{ {
const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1;
return has_main_k0_block_loop; return has_main_k0_block_loop;
} }
...@@ -303,7 +325,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -303,7 +325,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0); const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1); const auto N = c_grid_desc_m_n.GetLength(I1);
...@@ -324,17 +346,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -324,17 +346,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
const auto c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor = const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3>{}), make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto c_blockid_to_m0_n0_block_cluster_adaptor = const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor); cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return c_blockid_to_m0_n0_block_cluster_adaptor; return cblockid_to_m0_n0_block_cluster_adaptor;
} }
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype( remove_cvref_t<decltype(
...@@ -351,9 +373,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -351,9 +373,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
C1GridDesc_M_N{}))>; C1GridDesc_M_N{}))>;
using Block2CTileMap = remove_cvref_t<decltype(MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>; using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
template <bool HasMainKBlockLoop> template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void __device__ static void
Run(const FloatAB* __restrict__ p_a_grid, Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
...@@ -509,51 +532,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -509,51 +532,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
// preload data into LDS // gridwise GEMM pipeline
{ const auto gridwise_gemm_pipeline =
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_k0_m_k1)>,
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); remove_cvref_t<decltype(a_block_desc_k0_m_k1)>,
remove_cvref_t<decltype(a_blockwise_copy)>,
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); remove_cvref_t<decltype(a_grid_buf)>,
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); remove_cvref_t<decltype(a_block_buf)>,
} remove_cvref_t<decltype(a_block_slice_copy_step)>,
remove_cvref_t<decltype(b_grid_desc_k0_n_k1)>,
// Initialize C remove_cvref_t<decltype(b_block_desc_k0_n_k1)>,
c_thread_buf.Clear(); remove_cvref_t<decltype(b_blockwise_copy)>,
remove_cvref_t<decltype(b_grid_buf)>,
// main body remove_cvref_t<decltype(b_block_buf)>,
if constexpr(HasMainKBlockLoop) remove_cvref_t<decltype(b_block_slice_copy_step)>,
{ remove_cvref_t<decltype(blockwise_gemm)>,
index_t k0_block_data_begin = 0; remove_cvref_t<decltype(c_thread_buf)>,
NumPrefetch,
do HasMainK0BlockLoop>{};
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step);
gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1,
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); a_block_desc_k0_m_k1,
a_blockwise_copy,
block_sync_lds(); a_grid_buf,
a_block_buf,
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); a_block_slice_copy_step,
b_grid_desc_k0_n_k1,
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); b_block_desc_k0_n_k1,
b_blockwise_copy,
block_sync_lds(); b_grid_buf,
b_block_buf,
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); b_block_slice_copy_step,
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); blockwise_gemm,
c_thread_buf,
k0_block_data_begin += K0PerBlock; K0BlockMainLoop);
} while(k0_block_data_begin < (K0 - K0PerBlock));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
// shuffle C and write out // shuffle C and write out
{ {
......
...@@ -64,9 +64,10 @@ template <typename SliceLengths, ...@@ -64,9 +64,10 @@ template <typename SliceLengths,
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
// RunRead(), will be fused with MoveSrcSliceWindow to // RunRead(), will be fused with MoveSrcSliceWindow to
// save addr computation // save addr computation
bool DstResetCoordinateAfterRun> // control whether to move back dst coordinate after each bool DstResetCoordinateAfterRun, // control whether to move back dst coordinate after each
// RunWrite(), will be fused with MoveDstSliceWindow to // RunWrite(), will be fused with MoveDstSliceWindow to
// save addr computation // save addr computation
index_t NumThreadScratch = 1>
struct ThreadwiseTensorSliceTransfer_v3r1 struct ThreadwiseTensorSliceTransfer_v3r1
{ {
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
...@@ -78,6 +79,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -78,6 +79,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{};
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1( __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(
const SrcDesc& src_desc, const SrcDesc& src_desc,
const Index& src_slice_origin, const Index& src_slice_origin,
...@@ -102,9 +105,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -102,9 +105,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
} }
template <typename SrcBuffer, typename SrcStepHacks> template <typename SrcBuffer, index_t ThreadScratchId = 0>
__device__ void __device__ void RunRead(const SrcDesc& src_desc,
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) const SrcBuffer& src_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{ {
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
...@@ -114,9 +118,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -114,9 +118,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value, is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value,
"wrong! SrcBuffer and SrcData data type are inconsistent"); "wrong! SrcBuffer and SrcData data type are inconsistent");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// scalar per access on each dim // scalar per access on each dim
// TODO: don't use lambda_scalar_per_access // TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence( constexpr auto src_scalar_per_access = generate_sequence(
...@@ -138,8 +139,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -138,8 +139,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
}); });
return make_tensor_coordinate_step( return make_tensor_coordinate_step(src_desc, forward_step_idx);
src_desc, forward_step_idx, src_step_hacks[I0][i]);
}, },
Number<nDim>{}); Number<nDim>{});
...@@ -152,8 +152,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -152,8 +152,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
}); });
return make_tensor_coordinate_step( return make_tensor_coordinate_step(src_desc, backward_step_idx);
src_desc, backward_step_idx, src_step_hacks[I1][i]);
}, },
Number<nDim>{}); Number<nDim>{});
...@@ -215,8 +214,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -215,8 +214,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
}); });
// copy data from src_vector_container into src_thread_scratch_ // copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_.template SetAsType<src_vector_t>( src_thread_scratch_tuple_(thread_scratch_id)
src_data_idx_seq, src_vector_container.template AsType<src_vector_t>()[I0]); .template SetAsType<src_vector_t>(
src_data_idx_seq, src_vector_container.template AsType<src_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
{ {
...@@ -263,12 +263,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -263,12 +263,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1
} }
} }
__device__ void TransferDataFromSrcThreadScratchToDstThreadScratch() template <index_t ThreadScratchId>
__device__ void
TransferDataFromSrcThreadScratchToDstThreadScratch(Number<ThreadScratchId> thread_scratch_id)
{ {
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE #if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
static_ford<SliceLengths>{}([&](auto idx) { static_ford<SliceLengths>{}([&](auto idx) {
// convert from SrcData to DstData here // convert from SrcData to DstData here
dst_thread_scratch_(idx) = type_convert<DstData>(src_thread_scratch_[idx]); dst_thread_scratch_(idx) =
type_convert<DstData>(src_thread_scratch_tuple[thread_scratch_id][idx]);
}); });
#else #else
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
...@@ -318,7 +321,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -318,7 +321,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
const auto src_vector_refs = generate_tie( const auto src_vector_refs = generate_tie(
[&](auto i) -> const src_vector_t& { [&](auto i) -> const src_vector_t& {
// i increment corresponds to movement in DstVectorDim // i increment corresponds to movement in DstVectorDim
return src_thread_scratch_.GetVectorTypeReference( return src_thread_scratch_tuple_[thread_scratch_id].GetVectorTypeReference(
data_idx_seq + i * dst_scalar_step_in_vector); data_idx_seq + i * dst_scalar_step_in_vector);
}, },
Number<num_src_vector>{}); Number<num_src_vector>{});
...@@ -342,19 +345,21 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -342,19 +345,21 @@ struct ThreadwiseTensorSliceTransfer_v3r1
{ {
static_ford<SliceLengths>{}([&](auto idx) { static_ford<SliceLengths>{}([&](auto idx) {
// convert from SrcData to DstData here // convert from SrcData to DstData here
dst_thread_scratch_(idx) = type_convert<DstData>(src_thread_scratch_[idx]); dst_thread_scratch_(idx) =
type_convert<DstData>(src_thread_scratch_tuple_[thread_scratch_id][idx]);
}); });
} }
#endif #endif
} }
template <typename DstBuffer, typename DstStepHacks> template <typename DstBuffer, index_t ThreadScratchId = 0>
__device__ void __device__ void RunWrite(const DstDesc& dst_desc,
RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks) DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{ {
// if there is transpose, it's done here // if there is transpose, it's done here
// TODO move this elsewhere // TODO move this elsewhere
TransferDataFromSrcThreadScratchToDstThreadScratch(); TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id);
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
...@@ -364,9 +369,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -364,9 +369,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value, is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
"wrong! SrcBuffer or DstBuffer data type is wrong"); "wrong! SrcBuffer or DstBuffer data type is wrong");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// src scalar per access on each dim // src scalar per access on each dim
// TODO: don't use this // TODO: don't use this
constexpr auto dst_scalar_per_access = generate_sequence( constexpr auto dst_scalar_per_access = generate_sequence(
...@@ -388,8 +390,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -388,8 +390,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
}); });
return make_tensor_coordinate_step( return make_tensor_coordinate_step(dst_desc, forward_step_idx);
dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
}, },
Number<nDim>{}); Number<nDim>{});
...@@ -402,8 +403,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -402,8 +403,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
}); });
return make_tensor_coordinate_step( return make_tensor_coordinate_step(dst_desc, backward_step_idx);
dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
}, },
Number<nDim>{}); Number<nDim>{});
...@@ -515,39 +515,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -515,39 +515,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
} }
} }
template <typename SrcBuffer>
__device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
{
constexpr index_t ntransform_src = remove_cvref_t<SrcDesc>::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
constexpr auto src_step_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
RunRead(src_desc, src_buf, src_step_hacks);
}
template <typename DstBuffer>
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
{
// TODO: why need remove_cvref_t ?
constexpr index_t ntransform_dst = remove_cvref_t<DstDesc>::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
constexpr auto dst_step_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
RunWrite(dst_desc, dst_buf, dst_step_hacks);
}
__device__ static constexpr auto GetSrcCoordinateResetStep() __device__ static constexpr auto GetSrcCoordinateResetStep()
{ {
constexpr auto I0 = Number<0>{};
// scalar per access on each dim // scalar per access on each dim
// TODO: don't use lambda_scalar_per_access // TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence( constexpr auto src_scalar_per_access = generate_sequence(
...@@ -606,8 +575,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -606,8 +575,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
__device__ static constexpr auto GetDstCoordinateResetStep() __device__ static constexpr auto GetDstCoordinateResetStep()
{ {
constexpr auto I0 = Number<0>{};
// scalar per access on each dim // scalar per access on each dim
// TODO: don't use lambda_scalar_per_access // TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence( constexpr auto dst_scalar_per_access = generate_sequence(
...@@ -679,25 +646,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -679,25 +646,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
move_tensor_coordinate(src_desc, src_coord_, adjusted_step); move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
} }
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template <typename SrcMoveSliceWindowStepHack>
__device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx,
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(
src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, __device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Index& dst_slice_origin_step_idx) const Index& dst_slice_origin_step_idx)
...@@ -815,19 +763,21 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -815,19 +763,21 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum_t::Vgpr, using SrcThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum_t::Vgpr,
SrcData, SrcData,
SrcScalarPerVector, SrcScalarPerVector,
decltype(src_thread_scratch_desc_), decltype(src_thread_scratch_desc_),
true> true>;
src_thread_scratch_;
using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum_t::Vgpr,
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum_t::Vgpr, DstData,
DstData, DstScalarPerVector,
DstScalarPerVector, decltype(dst_thread_scratch_desc_),
decltype(dst_thread_scratch_desc_), true>;
true>
dst_thread_scratch_; StaticallyIndexedArray<SrcThreadScratch, NumThreadScratch> src_thread_scratch_tuple_;
DstThreadScratch dst_thread_scratch_;
SrcCoord src_coord_; SrcCoord src_coord_;
DstCoord dst_coord_; DstCoord dst_coord_;
......
...@@ -920,10 +920,10 @@ __device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::typ ...@@ -920,10 +920,10 @@ __device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::typ
// It is user's responsibility to make sure that is true. // It is user's responsibility to make sure that is true.
template <typename T, index_t N> template <typename T, index_t N>
__device__ typename vector_type_maker<T, N>::type::type __device__ typename vector_type_maker<T, N>::type::type
amd_buffer_load_invalid_element_return_return_zero(const T* p_src_wave, amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
index_t src_thread_element_offset, index_t src_thread_element_offset,
bool src_thread_element_valid, bool src_thread_element_valid,
index_t src_element_space_size) index_t src_element_space_size)
{ {
const int32x4_t src_wave_buffer_resource = const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space_size); make_wave_buffer_resource(p_src_wave, src_element_space_size);
......
...@@ -49,7 +49,7 @@ template <typename X, typename... Xs> ...@@ -49,7 +49,7 @@ template <typename X, typename... Xs>
__host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs) __host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs)
{ {
using data_type = remove_cvref_t<X>; using data_type = remove_cvref_t<X>;
return Array<data_type, sizeof...(Xs) + 1>{{std::forward<X>(x), std::forward<Xs>(xs)...}}; return Array<data_type, sizeof...(Xs) + 1>{std::forward<X>(x), std::forward<Xs>(xs)...};
} }
// make empty array // make empty array
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include "transpose_vectors.hpp" #include "transpose_vectors.hpp"
#include "inner_product.hpp" #include "inner_product.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "debug.hpp"
// TODO: remove this // TODO: remove this
#if CK_USE_AMD_INLINE_ASM #if CK_USE_AMD_INLINE_ASM
......
This diff is collapsed.
...@@ -56,7 +56,7 @@ struct DynamicBuffer ...@@ -56,7 +56,7 @@ struct DynamicBuffer
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T"); "wrong! X need to be multiple T");
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_LOAD
bool constexpr use_amd_buffer_addressing = true; bool constexpr use_amd_buffer_addressing = true;
#else #else
bool constexpr use_amd_buffer_addressing = false; bool constexpr use_amd_buffer_addressing = false;
...@@ -68,8 +68,7 @@ struct DynamicBuffer ...@@ -68,8 +68,7 @@ struct DynamicBuffer
if constexpr(InvalidElementUseNumericalZeroValue) if constexpr(InvalidElementUseNumericalZeroValue)
{ {
return amd_buffer_load_invalid_element_return_return_zero<remove_cvref_t<T>, return amd_buffer_load_invalid_element_return_zero<remove_cvref_t<T>, t_per_x>(
t_per_x>(
p_data_, i, is_valid_element, element_space_size_); p_data_, i, is_valid_element, element_space_size_);
} }
else else
...@@ -125,7 +124,7 @@ struct DynamicBuffer ...@@ -125,7 +124,7 @@ struct DynamicBuffer
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global) if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global)
{ {
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_STORE
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store<remove_cvref_t<T>, t_per_x>( amd_buffer_store<remove_cvref_t<T>, t_per_x>(
...@@ -291,7 +290,7 @@ struct DynamicBuffer ...@@ -291,7 +290,7 @@ struct DynamicBuffer
static_assert(GetAddressSpace() == AddressSpaceEnum_t::Global, "only support global mem"); static_assert(GetAddressSpace() == AddressSpaceEnum_t::Global, "only support global mem");
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ATOMIC_ADD
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>( amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
......
...@@ -13,5 +13,38 @@ struct integral_constant ...@@ -13,5 +13,38 @@ struct integral_constant
__host__ __device__ constexpr value_type operator()() const noexcept { return value; } __host__ __device__ constexpr value_type operator()() const noexcept { return value; }
}; };
template <typename TX, TX X, typename TY, TY Y>
__host__ __device__ constexpr auto operator+(integral_constant<TX, X>, integral_constant<TY, Y>)
{
return integral_constant<decltype(X + Y), X + Y>{};
}
template <typename TX, TX X, typename TY, TY Y>
__host__ __device__ constexpr auto operator-(integral_constant<TX, X>, integral_constant<TY, Y>)
{
static_assert(Y <= X, "wrong!");
return integral_constant<decltype(X - Y), X - Y>{};
}
template <typename TX, TX X, typename TY, TY Y>
__host__ __device__ constexpr auto operator*(integral_constant<TX, X>, integral_constant<TY, Y>)
{
return integral_constant<decltype(X * Y), X * Y>{};
}
template <typename TX, TX X, typename TY, TY Y>
__host__ __device__ constexpr auto operator/(integral_constant<TX, X>, integral_constant<TY, Y>)
{
static_assert(Y > 0, "wrong!");
return integral_constant<decltype(X / Y), X / Y>{};
}
template <typename TX, TX X, typename TY, TY Y>
__host__ __device__ constexpr auto operator%(integral_constant<TX, X>, integral_constant<TY, Y>)
{
static_assert(Y > 0, "wrong!");
return integral_constant<decltype(X % Y), X % Y>{};
}
} // namespace ck } // namespace ck
#endif #endif
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment