"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "c20aabc3b1af609e633964ea0e50b790347d6b06"
Commit 510b3a21 authored by Chao Liu's avatar Chao Liu
Browse files

move AddressSpace info from copy operator into DynamicBuffer and StaticBuffer

parent aac345ab
...@@ -29,8 +29,6 @@ template <index_t BlockSize, ...@@ -29,8 +29,6 @@ template <index_t BlockSize,
index_t DstVectorDim, index_t DstVectorDim,
index_t SrcScalarPerVector, index_t SrcScalarPerVector,
index_t DstScalarPerVector, index_t DstScalarPerVector,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
index_t SrcScalarStrideInVector, index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector, index_t DstScalarStrideInVector,
index_t ThreadTransferSrcResetCoordinateAfterRun, index_t ThreadTransferSrcResetCoordinateAfterRun,
...@@ -153,8 +151,6 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 ...@@ -153,8 +151,6 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
DstScalarPerVector, DstScalarPerVector,
SrcScalarStrideInVector, SrcScalarStrideInVector,
DstScalarStrideInVector, DstScalarStrideInVector,
SrcAddressSpace,
DstAddressSpace,
ThreadTransferSrcResetCoordinateAfterRun, ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>; ThreadTransferDstResetCoordinateAfterRun>;
......
...@@ -115,8 +115,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1 ...@@ -115,8 +115,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
{ {
auto a_thread_buf = make_static_buffer<FloatA>(a_thread_desc_.GetElementSpaceSize()); auto a_thread_buf =
auto b_thread_buf = make_static_buffer<FloatB>(b_thread_desc_.GetElementSpaceSize()); make_static_buffer<AddressSpace::Vgpr, FloatA>(a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf =
make_static_buffer<AddressSpace::Vgpr, FloatB>(b_thread_desc_.GetElementSpaceSize());
constexpr auto threadwise_gemm = constexpr auto threadwise_gemm =
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA, ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA,
...@@ -176,8 +178,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1 ...@@ -176,8 +178,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
2, 2,
AThreadCopyScalarPerVector_M1, AThreadCopyScalarPerVector_M1,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>; 1>;
using BThreadCopy = using BThreadCopy =
...@@ -189,8 +189,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1 ...@@ -189,8 +189,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
2, 2,
BThreadCopyScalarPerVector_N1, BThreadCopyScalarPerVector_N1,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>; 1>;
CIndex c_thread_origin_data_idx_; CIndex c_thread_origin_data_idx_;
...@@ -211,6 +209,8 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1 ...@@ -211,6 +209,8 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
// 3. C: // 3. C:
// 1. CThreadDesc is known at compile-time // 1. CThreadDesc is known at compile-time
// 2. CThreadBuffer is StaticBuffer // 2. CThreadBuffer is StaticBuffer
// Also assume:
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
template <index_t BlockSize, template <index_t BlockSize,
typename FloatA, typename FloatA,
typename FloatB, typename FloatB,
...@@ -312,8 +312,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2 ...@@ -312,8 +312,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
{ {
auto a_thread_buf = make_static_buffer<FloatA>(a_thread_desc_.GetElementSpaceSize()); auto a_thread_buf =
auto b_thread_buf = make_static_buffer<FloatB>(b_thread_desc_.GetElementSpaceSize()); make_static_buffer<AddressSpace::Vgpr, FloatA>(a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf =
make_static_buffer<AddressSpace::Vgpr, FloatB>(b_thread_desc_.GetElementSpaceSize());
constexpr auto threadwise_gemm = constexpr auto threadwise_gemm =
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA, ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA,
...@@ -481,8 +483,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2 ...@@ -481,8 +483,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
2, 2,
AThreadCopyScalarPerVector_M1, AThreadCopyScalarPerVector_M1,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>; 1>;
using BThreadCopy = using BThreadCopy =
...@@ -494,8 +494,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2 ...@@ -494,8 +494,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
2, 2,
BThreadCopyScalarPerVector_N1, BThreadCopyScalarPerVector_N1,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>; 1>;
CIndex c_thread_origin_data_idx_; CIndex c_thread_origin_data_idx_;
......
...@@ -49,8 +49,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -49,8 +49,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
ThreadGemmADataPerRead_K, ThreadGemmADataPerRead_K,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>; 1>;
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v3() __device__ BlockwiseGemm_km_kn_m0m1n0n1_v3()
...@@ -140,7 +138,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -140,7 +138,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
static_assert(WPerThread % WoPerThreadSubC == 0, ""); static_assert(WPerThread % WoPerThreadSubC == 0, "");
// thread A buffer for GEMM // thread A buffer for GEMM
StaticBuffer<FloatA, a_thread_mtx_.GetElementSpaceSize()> a_thread_buf; StaticBuffer<AddressSpace::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize()> a_thread_buf;
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3<FloatA, constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3<FloatA,
FloatB, FloatB,
......
...@@ -35,8 +35,7 @@ __global__ void ...@@ -35,8 +35,7 @@ __global__ void
const CGlobalDesc c_m0_m1_n0_n1_global_desc, const CGlobalDesc c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc c_block_cluster_desc) const CBlockClusterDesc c_block_cluster_desc)
{ {
GridwiseGemm::Run( GridwiseGemm::Run(p_a_global,
p_a_global,
p_b_global, p_b_global,
p_c_global, p_c_global,
a_k_m_global_desc, a_k_m_global_desc,
...@@ -85,8 +84,7 @@ __global__ void ...@@ -85,8 +84,7 @@ __global__ void
const auto c_block_cluster_desc = const auto c_block_cluster_desc =
*reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc); *reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc);
GridwiseGemm::Run( GridwiseGemm::Run(p_a_global,
p_a_global,
p_b_global, p_b_global,
p_c_global, p_c_global,
a_k_m_global_desc, a_k_m_global_desc,
...@@ -171,8 +169,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -171,8 +169,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
} }
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void Run( __device__ static void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
const AGlobalDesc& a_k_m_global_desc, const AGlobalDesc& a_k_m_global_desc,
...@@ -188,9 +185,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -188,9 +185,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer(p_a_global); const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(p_a_global);
const auto b_global_buf = make_dynamic_buffer(p_b_global); const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(p_b_global);
auto c_global_buf = make_dynamic_buffer(p_c_global); auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>(p_c_global);
const auto K = a_k_m_global_desc.GetLength(I0); const auto K = a_k_m_global_desc.GetLength(I0);
const auto M = a_k_m_global_desc.GetLength(I1); const auto M = a_k_m_global_desc.GetLength(I1);
...@@ -241,8 +238,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -241,8 +238,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
1, 1,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M, ABlockTransferDstScalarPerVector_M,
AddressSpace::Global,
AddressSpace::Lds,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
...@@ -270,8 +265,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -270,8 +265,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
1, 1,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N, BBlockTransferDstScalarPerVector_N,
AddressSpace::Global,
AddressSpace::Lds,
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
...@@ -346,8 +339,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -346,8 +339,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size; FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output // register allocation for output
auto c_thread_buf = auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
make_static_buffer<FloatAcc>(c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize()); c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc, ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
decltype(c_m0_m1_n0_n1_thread_desc), decltype(c_m0_m1_n0_n1_thread_desc),
...@@ -368,11 +361,13 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -368,11 +361,13 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
constexpr auto b_k_n_global_move_slice_window_iterator_hack = constexpr auto b_k_n_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{}; BGlobalMoveSliceWindowIteratorHacks{};
auto a_block_even_buf = make_dynamic_buffer(p_a_block_double); auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(p_a_block_double);
auto b_block_even_buf = make_dynamic_buffer(p_b_block_double); auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(p_b_block_double);
auto a_block_odd_buf = make_dynamic_buffer(p_a_block_double + a_block_space_size); auto a_block_odd_buf =
auto b_block_odd_buf = make_dynamic_buffer(p_b_block_double + b_block_space_size); make_dynamic_buffer<AddressSpace::Lds>(p_a_block_double + a_block_space_size);
auto b_block_odd_buf =
make_dynamic_buffer<AddressSpace::Lds>(p_b_block_double + b_block_space_size);
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
...@@ -497,8 +492,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -497,8 +492,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>{ true>{
...@@ -517,8 +510,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -517,8 +510,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
} }
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void Run( __device__ static void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
const AGlobalDesc& a_k_m_global_desc, const AGlobalDesc& a_k_m_global_desc,
...@@ -532,8 +524,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -532,8 +524,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
__shared__ FloatAB p_shared_block[shared_block_size]; __shared__ FloatAB p_shared_block[shared_block_size];
Run( Run(p_a_global,
p_a_global,
p_b_global, p_b_global,
p_c_global, p_c_global,
a_k_m_global_desc, a_k_m_global_desc,
......
...@@ -84,9 +84,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -84,9 +84,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer(p_a_global); const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(p_a_global);
const auto b_global_buf = make_dynamic_buffer(p_b_global); const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(p_b_global);
auto c_global_buf = make_dynamic_buffer(p_c_global); auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>(p_c_global);
constexpr auto E = EPerBlock * 3 * 3; constexpr auto E = EPerBlock * 3 * 3;
...@@ -196,8 +196,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -196,8 +196,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
1, 1,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K, ABlockTransferDstScalarPerVector_K,
AddressSpace::Global,
AddressSpace::Lds,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
...@@ -220,19 +218,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -220,19 +218,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
AddressSpace::Global,
AddressSpace::Vgpr,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
1, 1,
true>(b_e_n_ho_wo_global_desc, true>(b_e_n_ho_wo_global_desc,
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global)); make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
FloatAB* p_a_block = p_shared_block; auto a_block_buf = make_dynamic_buffer<AddressSpace::Lds>(p_shared_block);
auto a_block_buf = make_dynamic_buffer(p_a_block);
// register allocation for output // register allocation for output
StaticBuffer<FloatAcc, c_k_n_ho_wo_thread_desc.GetElementSpaceSize()> c_thread_buf; StaticBuffer<AddressSpace::Vgpr, FloatAcc, c_k_n_ho_wo_thread_desc.GetElementSpaceSize()>
c_thread_buf;
// initialize output thread tensor // initialize output thread tensor
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc, ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
...@@ -254,8 +249,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -254,8 +249,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
BGlobalMoveSliceWindowIteratorHacks{}; BGlobalMoveSliceWindowIteratorHacks{};
// double regsiter buffer for b // double regsiter buffer for b
StaticBuffer<FloatAB, b_e_n_ho_wo_thread_desc.GetElementSpaceSize()> b_thread_even_buf, StaticBuffer<AddressSpace::Vgpr, FloatAB, b_e_n_ho_wo_thread_desc.GetElementSpaceSize()>
b_thread_odd_buf; b_thread_even_buf, b_thread_odd_buf;
// LDS double buffer: preload data // LDS double buffer: preload data
{ {
...@@ -362,8 +357,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -362,8 +357,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>( true>(
......
...@@ -54,8 +54,6 @@ template <typename SrcData, ...@@ -54,8 +54,6 @@ template <typename SrcData,
typename DimAccessOrder, typename DimAccessOrder,
index_t DstVectorDim, index_t DstVectorDim,
index_t DstScalarPerVector, index_t DstScalarPerVector,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
InMemoryDataOperation DstInMemOp, InMemoryDataOperation DstInMemOp,
index_t DstScalarStrideInVector, index_t DstScalarStrideInVector,
bool DstResetCoordinateAfterRun, bool DstResetCoordinateAfterRun,
...@@ -211,8 +209,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -211,8 +209,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
dst_desc, dst_slice_origin_coord_); dst_desc, dst_slice_origin_coord_);
if constexpr(SrcAddressSpace == AddressSpace::Vgpr && if constexpr(SrcBuffer::GetAddressSpace() == AddressSpace::Vgpr &&
DstAddressSpace == AddressSpace::Global) DstBuffer::GetAddressSpace() == AddressSpace::Global)
{ {
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ADDRESSING
amd_buffer_store_v2<DstData, DstScalarPerVector>( amd_buffer_store_v2<DstData, DstScalarPerVector>(
...@@ -403,8 +401,6 @@ template <typename SrcData, ...@@ -403,8 +401,6 @@ template <typename SrcData,
typename DimAccessOrder, typename DimAccessOrder,
index_t SrcVectorDim, index_t SrcVectorDim,
index_t SrcScalarPerVector, index_t SrcScalarPerVector,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
index_t SrcScalarStrideInVector, index_t SrcScalarStrideInVector,
bool SrcResetCoordinateAfterRun, bool SrcResetCoordinateAfterRun,
typename std::enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false> typename std::enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
...@@ -541,8 +537,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -541,8 +537,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
}(); }();
// copy data // copy data
static_assert(DstAddressSpace == AddressSpace::Vgpr, "wrong! hardcode for vgpr dst");
typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector; typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector;
using src_vector_t = using src_vector_t =
...@@ -551,7 +545,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -551,7 +545,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_coord_); src_desc, src_slice_origin_coord_);
if constexpr(SrcAddressSpace == AddressSpace::Global) if constexpr(SrcBuffer::GetAddressSpace() == AddressSpace::Global)
{ {
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ADDRESSING
src_vector.template AsType<src_vector_t>()(Number<0>{}) = src_vector.template AsType<src_vector_t>()(Number<0>{}) =
...@@ -748,8 +742,6 @@ template <typename SliceLengths, ...@@ -748,8 +742,6 @@ template <typename SliceLengths,
index_t DstScalarPerVector, index_t DstScalarPerVector,
index_t SrcScalarStrideInVector, index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector, index_t DstScalarStrideInVector,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
// RunRead(), will be fused with MoveSrcSliceWindow to // RunRead(), will be fused with MoveSrcSliceWindow to
// save addr computation // save addr computation
...@@ -774,13 +766,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -774,13 +766,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
: src_slice_origin_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin)), : src_slice_origin_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin)),
dst_slice_origin_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin)) dst_slice_origin_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin))
{ {
static_assert(SrcAddressSpace == AddressSpace::Global or
SrcAddressSpace == AddressSpace::Lds,
"wrong!");
static_assert(DstAddressSpace == AddressSpace::Global or
DstAddressSpace == AddressSpace::Lds,
"wrong!");
// TODO: fix this // TODO: fix this
static_assert(is_same<SrcData, DstData>::value, static_assert(is_same<SrcData, DstData>::value,
"wrong! current implementation assume SrcData and DstData are same type"); "wrong! current implementation assume SrcData and DstData are same type");
...@@ -801,6 +786,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -801,6 +786,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
const SrcBuffer& src_buf, const SrcBuffer& src_buf,
const SrcIteratorHacks& src_iterator_hacks) const SrcIteratorHacks& src_iterator_hacks)
{ {
static_assert(SrcBuffer::GetAddressSpace() == AddressSpace::Global or
SrcBuffer::GetAddressSpace() == AddressSpace::Lds,
"wrong!");
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>, static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
remove_cv_t<remove_reference_t<SrcData>>>::value, remove_cv_t<remove_reference_t<SrcData>>>::value,
"wrong! SrcBuffer and SrcData data type are inconsistent"); "wrong! SrcBuffer and SrcData data type are inconsistent");
...@@ -897,7 +886,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -897,7 +886,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_coord_); src_desc, src_slice_origin_coord_);
if constexpr(SrcAddressSpace == AddressSpace::Global) if constexpr(SrcBuffer::GetAddressSpace() == AddressSpace::Global)
{ {
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ADDRESSING
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) = src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
...@@ -983,6 +972,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -983,6 +972,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
DstBuffer& dst_buf, DstBuffer& dst_buf,
const DstIteratorHacks& dst_iterator_hacks) const DstIteratorHacks& dst_iterator_hacks)
{ {
static_assert(DstBuffer::GetAddressSpace() == AddressSpace::Global or
DstBuffer::GetAddressSpace() == AddressSpace::Lds,
"wrong!");
static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>, static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>,
remove_cv_t<remove_reference_t<DstData>>>::value, remove_cv_t<remove_reference_t<DstData>>>::value,
"wrong! SrcBuffer or DstBuffer data type is wrong"); "wrong! SrcBuffer or DstBuffer data type is wrong");
...@@ -1078,7 +1071,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1078,7 +1071,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// copy data // copy data
// hardcoding for ds_write // hardcoding for ds_write
// TODO refactor transfer_data() to encapsulate this // TODO refactor transfer_data() to encapsulate this
static_assert(DstAddressSpace == AddressSpace::Lds && static_assert(DstBuffer::GetAddressSpace() == AddressSpace::Lds &&
DstInMemOp == InMemoryDataOperation::Set, DstInMemOp == InMemoryDataOperation::Set,
"wrong! hardcoded for ds_write"); "wrong! hardcoded for ds_write");
...@@ -1356,7 +1349,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1356,7 +1349,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
StaticBuffer<SrcData, buffer_size_> buffer_; StaticBuffer<AddressSpace::Vgpr, SrcData, buffer_size_> buffer_;
SrcCoord src_slice_origin_coord_; SrcCoord src_slice_origin_coord_;
DstCoord dst_slice_origin_coord_; DstCoord dst_slice_origin_coord_;
...@@ -1384,8 +1377,6 @@ template < ...@@ -1384,8 +1377,6 @@ template <
typename DimAccessOrder, typename DimAccessOrder,
index_t SrcVectorDim, index_t SrcVectorDim,
index_t SrcScalarPerVector, index_t SrcScalarPerVector,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
index_t SrcScalarStrideInVector, index_t SrcScalarStrideInVector,
typename std::enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), typename std::enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
......
...@@ -8,7 +8,8 @@ ...@@ -8,7 +8,8 @@
#include "container_element_picker.hpp" #include "container_element_picker.hpp"
#include "data_type.hpp" #include "data_type.hpp"
#include "float_type.hpp" #include "float_type.hpp"
#include "buffer.hpp" #include "static_buffer.hpp"
#include "dynamic_buffer.hpp"
#include "functional.hpp" #include "functional.hpp"
#include "functional2.hpp" #include "functional2.hpp"
#include "functional3.hpp" #include "functional3.hpp"
......
...@@ -159,6 +159,7 @@ enum AddressSpace ...@@ -159,6 +159,7 @@ enum AddressSpace
Generic, Generic,
Global, Global,
Lds, Lds,
Sgpr,
Vgpr Vgpr
}; };
......
#ifndef CK_BUFFER_HPP #ifndef CK_DYNAMIC_BUFFER_HPP
#define CK_BUFFER_HPP #define CK_DYNAMIC_BUFFER_HPP
#include "statically_indexed_array.hpp"
namespace ck { namespace ck {
template <typename T, index_t N> template <AddressSpace BufferAddressSpace, typename T>
struct StaticBuffer : public StaticallyIndexedArray<T, N>
{
using type = T;
using base = StaticallyIndexedArray<T, N>;
__host__ __device__ constexpr StaticBuffer() : base{} {}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
};
template <typename T, index_t N>
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
{
return StaticBuffer<T, N>{};
}
template <typename T>
struct DynamicBuffer struct DynamicBuffer
{ {
using type = T; using type = T;
...@@ -33,6 +12,11 @@ struct DynamicBuffer ...@@ -33,6 +12,11 @@ struct DynamicBuffer
__host__ __device__ constexpr DynamicBuffer(T* p_data) : p_data_{p_data} {} __host__ __device__ constexpr DynamicBuffer(T* p_data) : p_data_{p_data} {}
__host__ __device__ static constexpr AddressSpace GetAddressSpace()
{
return BufferAddressSpace;
}
__host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; } __host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; }
__host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; } __host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }
...@@ -91,10 +75,10 @@ struct DynamicBuffer ...@@ -91,10 +75,10 @@ struct DynamicBuffer
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
}; };
template <typename T> template <AddressSpace BufferAddressSpace = AddressSpace::Generic, typename T>
__host__ __device__ constexpr auto make_dynamic_buffer(T* p) __host__ __device__ constexpr auto make_dynamic_buffer(T* p)
{ {
return DynamicBuffer<T>{p}; return DynamicBuffer<BufferAddressSpace, T>{p};
} }
} // namespace ck } // namespace ck
......
#ifndef CK_STATIC_BUFFER_HPP
#define CK_STATIC_BUFFER_HPP
#include "statically_indexed_array.hpp"
namespace ck {
template <AddressSpace BufferAddressSpace, typename T, index_t N>
struct StaticBuffer : public StaticallyIndexedArray<T, N>
{
using type = T;
using base = StaticallyIndexedArray<T, N>;
__host__ __device__ constexpr StaticBuffer() : base{} {}
__host__ __device__ static constexpr AddressSpace GetAddressSpace()
{
return BufferAddressSpace;
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
};
template <AddressSpace BufferAddressSpace = AddressSpace::Generic, typename T, index_t N>
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
{
return StaticBuffer<BufferAddressSpace, T, N>{};
}
} // namespace ck
#endif
...@@ -92,7 +92,7 @@ int main(int argc, char* argv[]) ...@@ -92,7 +92,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 0 #elif 1
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
constexpr index_t HI = 540; constexpr index_t HI = 540;
...@@ -630,7 +630,7 @@ int main(int argc, char* argv[]) ...@@ -630,7 +630,7 @@ int main(int argc, char* argv[])
print_array("ConvStrides", to_multi_index(ConvStrides{})); print_array("ConvStrides", to_multi_index(ConvStrides{}));
print_array("ConvDilations", to_multi_index(ConvDilations{})); print_array("ConvDilations", to_multi_index(ConvDilations{}));
#if 1 #if 0
using in_data_t = float; using in_data_t = float;
constexpr index_t in_vector_size = 1; constexpr index_t in_vector_size = 1;
using acc_data_t = float; using acc_data_t = float;
...@@ -740,7 +740,7 @@ int main(int argc, char* argv[]) ...@@ -740,7 +740,7 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 1 #elif 0
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t, device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
in_vector_size, in_vector_size,
acc_data_t, acc_data_t,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment