Unverified Commit 78b987fb authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Use DynamicBuffer instead of raw pointer (#32)

* Use DynamicBuffer to hold raw pointer (to global and LDS memory)

* add workaround for compiler issue (inefficient ISA) of ds_write for int8x4, int8x8, int8x16
parent 01055d95
...@@ -146,16 +146,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -146,16 +146,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>, remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>, remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
integral_constant<bool, true>, true,
integral_constant<bool, true>>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -163,28 +163,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -163,28 +163,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
a_k_m_global_desc,
p_a_global, p_a_global,
b_k_n_global_desc,
p_b_global, p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global, p_c_global,
c_block_cluster_desc, a_k_m_global_desc,
integral_constant<bool, true>{}, b_k_n_global_desc,
integral_constant<bool, true>{}); c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>, remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>, remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
integral_constant<bool, true>, true,
integral_constant<bool, false>>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -192,28 +190,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -192,28 +190,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
a_k_m_global_desc,
p_a_global, p_a_global,
b_k_n_global_desc,
p_b_global, p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global, p_c_global,
c_block_cluster_desc, a_k_m_global_desc,
integral_constant<bool, true>{}, b_k_n_global_desc,
integral_constant<bool, false>{}); c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>, remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>, remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
integral_constant<bool, false>, false,
integral_constant<bool, true>>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -221,28 +217,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -221,28 +217,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
a_k_m_global_desc,
p_a_global, p_a_global,
b_k_n_global_desc,
p_b_global, p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global, p_c_global,
c_block_cluster_desc, a_k_m_global_desc,
integral_constant<bool, false>{}, b_k_n_global_desc,
integral_constant<bool, true>{}); c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
} }
else else
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>, remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>, remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
integral_constant<bool, false>, false,
integral_constant<bool, false>>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -250,15 +244,13 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -250,15 +244,13 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
a_k_m_global_desc,
p_a_global, p_a_global,
b_k_n_global_desc,
p_b_global, p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global, p_c_global,
c_block_cluster_desc, a_k_m_global_desc,
integral_constant<bool, false>{}, b_k_n_global_desc,
integral_constant<bool, false>{}); c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
} }
return ave_time; return ave_time;
...@@ -277,13 +269,13 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -277,13 +269,13 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>, remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>, remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
true, true,
true>; true>;
...@@ -295,23 +287,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -295,23 +287,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
p_a_global, p_a_global,
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
p_b_global, p_b_global,
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
p_c_global, p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
remove_reference_t<AGlobalDesc>,
FloatAB, FloatAB,
remove_reference_t<BGlobalDesc>,
FloatAB, FloatAB,
remove_reference_t<CGlobalDesc>,
FloatC, FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
true, true,
false>; false>;
...@@ -323,23 +315,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -323,23 +315,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
p_a_global, p_a_global,
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
p_b_global, p_b_global,
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
p_c_global, p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
remove_reference_t<AGlobalDesc>,
FloatAB, FloatAB,
remove_reference_t<BGlobalDesc>,
FloatAB, FloatAB,
remove_reference_t<CGlobalDesc>,
FloatC, FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
false, false,
true>; true>;
...@@ -351,23 +343,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -351,23 +343,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
p_a_global, p_a_global,
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
p_b_global, p_b_global,
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
p_c_global, p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
} }
else else
{ {
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
remove_reference_t<AGlobalDesc>,
FloatAB, FloatAB,
remove_reference_t<BGlobalDesc>,
FloatAB, FloatAB,
remove_reference_t<CGlobalDesc>,
FloatC, FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
false, false,
false>; false>;
...@@ -379,12 +371,12 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -379,12 +371,12 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
p_a_global, p_a_global,
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
p_b_global, p_b_global,
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
p_c_global, p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
} }
......
...@@ -29,8 +29,6 @@ template <index_t BlockSize, ...@@ -29,8 +29,6 @@ template <index_t BlockSize,
index_t DstVectorDim, index_t DstVectorDim,
index_t SrcScalarPerVector, index_t SrcScalarPerVector,
index_t DstScalarPerVector, index_t DstScalarPerVector,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
index_t SrcScalarStrideInVector, index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector, index_t DstScalarStrideInVector,
index_t ThreadTransferSrcResetCoordinateAfterRun, index_t ThreadTransferSrcResetCoordinateAfterRun,
...@@ -79,24 +77,25 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 ...@@ -79,24 +77,25 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
} }
} }
template <typename SrcIteratorHacks> template <typename SrcBuffer, typename SrcIteratorHacks>
__device__ void RunRead(const SrcDesc& src_desc, __device__ void RunRead(const SrcDesc& src_desc,
const SrcData* p_src, const SrcBuffer& src_buf,
const SrcIteratorHacks& src_iterator_hacks) const SrcIteratorHacks& src_iterator_hacks)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.RunRead(src_desc, p_src, src_iterator_hacks); threadwise_transfer_.RunRead(src_desc, src_buf, src_iterator_hacks);
} }
} }
__device__ void RunWrite(const DstDesc& dst_desc, DstData* p_dst) template <typename DstBuffer>
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.RunWrite(dst_desc, p_dst); threadwise_transfer_.RunWrite(dst_desc, dst_buf);
} }
} }
...@@ -152,8 +151,6 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 ...@@ -152,8 +151,6 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
DstScalarPerVector, DstScalarPerVector,
SrcScalarStrideInVector, SrcScalarStrideInVector,
DstScalarStrideInVector, DstScalarStrideInVector,
SrcAddressSpace,
DstAddressSpace,
ThreadTransferSrcResetCoordinateAfterRun, ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>; ThreadTransferDstResetCoordinateAfterRun>;
......
...@@ -115,8 +115,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1 ...@@ -115,8 +115,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
{ {
auto a_thread_buf = make_static_buffer<FloatA>(a_thread_desc_.GetElementSpaceSize()); auto a_thread_buf =
auto b_thread_buf = make_static_buffer<FloatB>(b_thread_desc_.GetElementSpaceSize()); make_static_buffer<AddressSpace::Vgpr, FloatA>(a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf =
make_static_buffer<AddressSpace::Vgpr, FloatB>(b_thread_desc_.GetElementSpaceSize());
constexpr auto threadwise_gemm = constexpr auto threadwise_gemm =
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA, ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA,
...@@ -176,8 +178,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1 ...@@ -176,8 +178,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
2, 2,
AThreadCopyScalarPerVector_M1, AThreadCopyScalarPerVector_M1,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>; 1>;
using BThreadCopy = using BThreadCopy =
...@@ -189,8 +189,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1 ...@@ -189,8 +189,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
2, 2,
BThreadCopyScalarPerVector_N1, BThreadCopyScalarPerVector_N1,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>; 1>;
CIndex c_thread_origin_data_idx_; CIndex c_thread_origin_data_idx_;
...@@ -211,6 +209,8 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1 ...@@ -211,6 +209,8 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
// 3. C: // 3. C:
// 1. CThreadDesc is known at compile-time // 1. CThreadDesc is known at compile-time
// 2. CThreadBuffer is StaticBuffer // 2. CThreadBuffer is StaticBuffer
// Also assume:
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
template <index_t BlockSize, template <index_t BlockSize,
typename FloatA, typename FloatA,
typename FloatB, typename FloatB,
...@@ -312,8 +312,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2 ...@@ -312,8 +312,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
{ {
auto a_thread_buf = make_static_buffer<FloatA>(a_thread_desc_.GetElementSpaceSize()); auto a_thread_buf =
auto b_thread_buf = make_static_buffer<FloatB>(b_thread_desc_.GetElementSpaceSize()); make_static_buffer<AddressSpace::Vgpr, FloatA>(a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf =
make_static_buffer<AddressSpace::Vgpr, FloatB>(b_thread_desc_.GetElementSpaceSize());
constexpr auto threadwise_gemm = constexpr auto threadwise_gemm =
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA, ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA,
...@@ -481,8 +483,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2 ...@@ -481,8 +483,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
2, 2,
AThreadCopyScalarPerVector_M1, AThreadCopyScalarPerVector_M1,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>; 1>;
using BThreadCopy = using BThreadCopy =
...@@ -494,8 +494,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2 ...@@ -494,8 +494,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
2, 2,
BThreadCopyScalarPerVector_N1, BThreadCopyScalarPerVector_N1,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>; 1>;
CIndex c_thread_origin_data_idx_; CIndex c_thread_origin_data_idx_;
......
...@@ -49,8 +49,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -49,8 +49,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
ThreadGemmADataPerRead_K, ThreadGemmADataPerRead_K,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>; 1>;
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v3() __device__ BlockwiseGemm_km_kn_m0m1n0n1_v3()
...@@ -140,7 +138,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -140,7 +138,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
static_assert(WPerThread % WoPerThreadSubC == 0, ""); static_assert(WPerThread % WoPerThreadSubC == 0, "");
// thread A buffer for GEMM // thread A buffer for GEMM
StaticBuffer<FloatA, a_thread_mtx_.GetElementSpaceSize()> a_thread_buf; StaticBuffer<AddressSpace::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize()> a_thread_buf;
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3<FloatA, constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3<FloatA,
FloatB, FloatB,
......
...@@ -14,54 +14,62 @@ namespace ck { ...@@ -14,54 +14,62 @@ namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename AGlobalDesc,
typename FloatA, typename FloatA,
typename BGlobalDesc,
typename FloatB, typename FloatB,
typename CGlobalDesc,
typename FloatC, typename FloatC,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc, typename CBlockClusterDesc,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void kernel_dynamic_gemm_v1(const AGlobalDesc a_k_m_global_desc, __global__ void
const FloatA* __restrict__ p_a_global, #if CK_USE_LAUNCH_BOUNDS
const BGlobalDesc b_k_n_global_desc, __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
const FloatB* __restrict__ p_b_global, #endif
const CGlobalDesc c_m0_m1_n0_n1_global_desc, kernel_dynamic_gemm_v1(const FloatA* __restrict__ p_a_global,
FloatC* __restrict__ p_c_global, const FloatB* __restrict__ p_b_global,
const CBlockClusterDesc c_block_cluster_desc) FloatC* __restrict__ p_c_global,
const AGlobalDesc a_k_m_global_desc,
const BGlobalDesc b_k_n_global_desc,
const CGlobalDesc c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc c_block_cluster_desc)
{ {
GridwiseGemm{}.Run(a_k_m_global_desc, GridwiseGemm::Run(p_a_global,
p_a_global, p_b_global,
b_k_n_global_desc, p_c_global,
p_b_global, a_k_m_global_desc,
c_m0_m1_n0_n1_global_desc, b_k_n_global_desc,
p_c_global, c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc, c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by __CONSTANT__ void pointer // pass tensor descriptor by __CONSTANT__ void pointer
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to // __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization // non-modifiable parameter address space, so compiler can enable corresponding optimization
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename AGlobalDesc,
typename FloatA, typename FloatA,
typename BGlobalDesc,
typename FloatB, typename FloatB,
typename CGlobalDesc,
typename FloatC, typename FloatC,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc, typename CBlockClusterDesc,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void kernel_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_desc, __global__ void
const FloatA* __restrict__ p_a_global, #if CK_USE_LAUNCH_BOUNDS
const void __CONSTANT__* p_b_k_n_global_desc, __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
const FloatB* __restrict__ p_b_global, #endif
const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc, kernel_dynamic_gemm_v1(const FloatA* __restrict__ p_a_global,
FloatC* __restrict__ p_c_global, const FloatB* __restrict__ p_b_global,
const void __CONSTANT__* p_c_block_cluster_desc) FloatC* __restrict__ p_c_global,
const void __CONSTANT__* p_a_k_m_global_desc,
const void __CONSTANT__* p_b_k_n_global_desc,
const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc,
const void __CONSTANT__* p_c_block_cluster_desc)
{ {
// first cast void __CONSTANT__ void* to void* // first cast void __CONSTANT__ void* to void*
// second cast void* to Desc* // second cast void* to Desc*
...@@ -76,15 +84,15 @@ __global__ void kernel_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_d ...@@ -76,15 +84,15 @@ __global__ void kernel_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_d
const auto c_block_cluster_desc = const auto c_block_cluster_desc =
*reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc); *reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc);
GridwiseGemm{}.Run(a_k_m_global_desc, GridwiseGemm::Run(p_a_global,
p_a_global, p_b_global,
b_k_n_global_desc, p_c_global,
p_b_global, a_k_m_global_desc,
c_m0_m1_n0_n1_global_desc, b_k_n_global_desc,
p_c_global, c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc, c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
#endif #endif
...@@ -161,22 +169,29 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -161,22 +169,29 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
} }
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc, __device__ static void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_a_global, const FloatAB* __restrict__ p_b_global,
const BGlobalDesc& b_k_n_global_desc, FloatC* __restrict__ p_c_global,
const FloatAB* __restrict__ p_b_global, const AGlobalDesc& a_k_m_global_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc, const BGlobalDesc& b_k_n_global_desc,
FloatC* __restrict__ p_c_global, const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc& c_block_cluster_desc, const CBlockClusterDesc& c_block_cluster_desc,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_a_global, a_k_m_global_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_global, b_k_n_global_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_global, c_m0_m1_n0_n1_global_desc.GetElementSpaceSize());
const auto K = a_k_m_global_desc.GetLength(I0); const auto K = a_k_m_global_desc.GetLength(I0);
const auto M = a_k_m_global_desc.GetLength(I1); const auto M = a_k_m_global_desc.GetLength(I1);
const auto N = b_k_n_global_desc.GetLength(I1); const auto N = b_k_n_global_desc.GetLength(I1);
...@@ -226,8 +241,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -226,8 +241,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
1, 1,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M, ABlockTransferDstScalarPerVector_M,
AddressSpace::Global,
AddressSpace::Lds,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
...@@ -255,8 +268,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -255,8 +268,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
1, 1,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N, BBlockTransferDstScalarPerVector_N,
AddressSpace::Global,
AddressSpace::Lds,
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
...@@ -331,8 +342,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -331,8 +342,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size; FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output // register allocation for output
auto c_thread_buf = auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
make_static_buffer<FloatAcc>(c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize()); c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc, ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
decltype(c_m0_m1_n0_n1_thread_desc), decltype(c_m0_m1_n0_n1_thread_desc),
...@@ -353,25 +364,23 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -353,25 +364,23 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
constexpr auto b_k_n_global_move_slice_window_iterator_hack = constexpr auto b_k_n_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{}; BGlobalMoveSliceWindowIteratorHacks{};
FloatAB* p_a_block_even = p_a_block_double; auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
FloatAB* p_b_block_even = p_b_block_double; p_a_block_double, a_k_m_block_desc.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
FloatAB* p_a_block_odd = p_a_block_double + a_block_space_size; p_b_block_double, b_k_n_block_desc.GetElementSpaceSize());
FloatAB* p_b_block_odd = p_b_block_double + b_block_space_size;
auto a_block_even_buf = make_dynamic_buffer(p_a_block_even); auto a_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
auto b_block_even_buf = make_dynamic_buffer(p_b_block_even); p_a_block_double + a_block_space_size, a_k_m_block_desc.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
auto a_block_odd_buf = make_dynamic_buffer(p_a_block_odd); p_b_block_double + b_block_space_size, b_k_n_block_desc.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer(p_b_block_odd);
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_blockwise_copy.RunRead(a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); b_blockwise_copy.RunRead(b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double); a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double); b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_even_buf);
} }
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
...@@ -394,16 +403,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -394,16 +403,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_odd); a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_odd); b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_odd_buf);
// odd iteration // odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
...@@ -417,16 +426,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -417,16 +426,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf); blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_even); a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_even); b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_even_buf);
k_block_data_begin += 2 * KPerBlock; k_block_data_begin += 2 * KPerBlock;
} while(k_block_data_begin < K - 2 * KPerBlock); } while(k_block_data_begin < K - 2 * KPerBlock);
...@@ -445,15 +454,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -445,15 +454,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_blockwise_copy.RunRead(a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); b_blockwise_copy.RunRead(b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS // LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size); a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double + b_block_space_size); b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_odd_buf);
__syncthreads(); __syncthreads();
...@@ -488,8 +497,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -488,8 +497,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>{ true>{
...@@ -502,32 +509,32 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -502,32 +509,32 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
c_m0_m1_n0_n1_global_desc, c_m0_m1_n0_n1_global_desc,
p_c_global, c_global_buf,
c_m0_m1_n0_n1_global_tensor_iterator_hacks); c_m0_m1_n0_n1_global_tensor_iterator_hacks);
} }
} }
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc, __device__ static void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_a_global, const FloatAB* __restrict__ p_b_global,
const BGlobalDesc& b_k_n_global_desc, FloatC* __restrict__ p_c_global,
const FloatAB* __restrict__ p_b_global, const AGlobalDesc& a_k_m_global_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc, const BGlobalDesc& b_k_n_global_desc,
FloatC* __restrict__ p_c_global, const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc& c_block_cluster_desc, const CBlockClusterDesc& c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>)
{ {
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB); constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size]; __shared__ FloatAB p_shared_block[shared_block_size];
Run(a_k_m_global_desc, Run(p_a_global,
p_a_global,
b_k_n_global_desc,
p_b_global, p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global, p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc, c_block_cluster_desc,
p_shared_block, p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
......
...@@ -84,6 +84,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -84,6 +84,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_a_global, a_e_k_global_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize());
constexpr auto E = EPerBlock * 3 * 3; constexpr auto E = EPerBlock * 3 * 3;
// const auto E = a_e_k_global_desc.GetLength(I0); // const auto E = a_e_k_global_desc.GetLength(I0);
...@@ -192,8 +199,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -192,8 +199,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
1, 1,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K, ABlockTransferDstScalarPerVector_K,
AddressSpace::Global,
AddressSpace::Lds,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
...@@ -216,19 +221,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -216,19 +221,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
AddressSpace::Global,
AddressSpace::Vgpr,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
1, 1,
true>(b_e_n_ho_wo_global_desc, true>(b_e_n_ho_wo_global_desc,
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global)); make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
FloatAB* p_a_block = p_shared_block; auto a_block_buf = make_dynamic_buffer<AddressSpace::Lds>(p_shared_block,
a_e_k_desc.GetElementSpaceSize());
auto a_block_buf = make_dynamic_buffer(p_a_block);
// register allocation for output // register allocation for output
StaticBuffer<FloatAcc, c_k_n_ho_wo_thread_desc.GetElementSpaceSize()> c_thread_buf; StaticBuffer<AddressSpace::Vgpr, FloatAcc, c_k_n_ho_wo_thread_desc.GetElementSpaceSize()>
c_thread_buf;
// initialize output thread tensor // initialize output thread tensor
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc, ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
...@@ -250,21 +253,21 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -250,21 +253,21 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
BGlobalMoveSliceWindowIteratorHacks{}; BGlobalMoveSliceWindowIteratorHacks{};
// double regsiter buffer for b // double regsiter buffer for b
StaticBuffer<FloatAB, b_e_n_ho_wo_thread_desc.GetElementSpaceSize()> b_thread_even_buf, StaticBuffer<AddressSpace::Vgpr, FloatAB, b_e_n_ho_wo_thread_desc.GetElementSpaceSize()>
b_thread_odd_buf; b_thread_even_buf, b_thread_odd_buf;
// LDS double buffer: preload data // LDS double buffer: preload data
{ {
a_blockwise_copy.RunRead(a_e_k_global_desc, p_a_global, a_e_k_global_iterator_hacks); a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_iterator_hacks);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global, b_global_buf,
b_e_n_ho_wo_thread_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_even_buf, b_thread_even_buf,
b_e_n_ho_wo_global_iterator_hacks); b_e_n_ho_wo_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_e_k_desc, p_a_block); a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf);
} }
__syncthreads(); __syncthreads();
...@@ -282,7 +285,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -282,7 +285,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
b_thread_slice_copy_step); b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global, b_global_buf,
b_e_n_ho_wo_thread_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_odd_buf, b_thread_odd_buf,
...@@ -298,7 +301,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -298,7 +301,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
b_thread_slice_copy_step); b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global, b_global_buf,
b_e_n_ho_wo_thread_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_even_buf, b_thread_even_buf,
...@@ -321,7 +324,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -321,7 +324,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
b_thread_slice_copy_step); b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global, b_global_buf,
b_e_n_ho_wo_thread_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_odd_buf, b_thread_odd_buf,
...@@ -358,8 +361,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -358,8 +361,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>( true>(
...@@ -370,7 +371,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -370,7 +371,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
c_k_n_ho_wo_global_desc, c_k_n_ho_wo_global_desc,
p_c_global, c_global_buf,
c_k_n_ho_wo_global_tensor_iterator_hacks); c_k_n_ho_wo_global_tensor_iterator_hacks);
} }
} }
......
...@@ -54,8 +54,6 @@ template <typename SrcData, ...@@ -54,8 +54,6 @@ template <typename SrcData,
typename DimAccessOrder, typename DimAccessOrder,
index_t DstVectorDim, index_t DstVectorDim,
index_t DstScalarPerVector, index_t DstScalarPerVector,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
InMemoryDataOperation DstInMemOp, InMemoryDataOperation DstInMemOp,
index_t DstScalarStrideInVector, index_t DstScalarStrideInVector,
bool DstResetCoordinateAfterRun, bool DstResetCoordinateAfterRun,
...@@ -72,7 +70,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -72,7 +70,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v1r3( __device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v1r3(
const DstDesc& dst_desc, const Index& dst_slice_origin_idx) const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
: dst_slice_origin_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx)) : dst_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx))
{ {
static_assert(SrcDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time"); "wrong! SrcDesc need to known at compile-time");
...@@ -80,15 +78,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -80,15 +78,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
{ {
dst_slice_origin_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx); dst_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx);
} }
template <typename SrcSliceOriginIdx, typename SrcBuffer, typename DstIteratorHacks> template <typename SrcSliceOriginIdx,
typename SrcBuffer,
typename DstBuffer,
typename DstIteratorHacks>
__device__ void Run(const SrcDesc&, __device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&, const SrcSliceOriginIdx&,
const SrcBuffer& src_buf, const SrcBuffer& src_buf,
const DstDesc& dst_desc, const DstDesc& dst_desc,
DstData* p_dst, DstBuffer& dst_buf,
const DstIteratorHacks& dst_iterator_hacks) const DstIteratorHacks& dst_iterator_hacks)
{ {
static_assert(SrcDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime(),
...@@ -191,12 +192,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -191,12 +192,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
return dst_data_idx; return dst_data_idx;
}(); }();
// copy data
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector; typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
using dst_vector_t = using dst_vector_t =
typename vector_type_maker<DstData, DstScalarPerVector>::type::type; typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
// copy data from src_buf into dst_vector
static_for<0, DstScalarPerVector, 1>{}([&](auto i) { static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector); src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector);
...@@ -205,37 +206,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -205,37 +206,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
type_convert<DstData>{}(src_buf[Number<src_offset>{}]); type_convert<DstData>{}(src_buf[Number<src_offset>{}]);
}); });
const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_dst_valid =
dst_desc, dst_slice_origin_coord_); coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
if constexpr(SrcAddressSpace == AddressSpace::Vgpr && // copy data from dst_vector into dst_buf
DstAddressSpace == AddressSpace::Global) dst_buf.template Set<dst_vector_t>(
{ dst_coord_.GetOffset(),
#if CK_USE_AMD_BUFFER_ADDRESSING is_dst_valid,
amd_buffer_store_v2<DstData, DstScalarPerVector>( dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
dst_vector.template AsType<dst_vector_t>()(Number<0>{}),
p_dst,
dst_slice_origin_coord_.GetOffset(),
is_dst_valid,
dst_desc.GetElementSpaceSize());
#else
if(is_dst_valid)
{
*reinterpret_cast<dst_vector_t*>(
&(p_dst[dst_slice_origin_coord_.GetOffset()])) =
dst_vector.template AsType<dst_vector_t>()[Number<0>{}];
}
#endif
}
else
{
if(is_dst_valid)
{
*reinterpret_cast<dst_vector_t*>(
&(p_dst[dst_slice_origin_coord_.GetOffset()])) =
dst_vector.template AsType<dst_vector_t>()[Number<0>{}];
}
}
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
{ {
...@@ -259,15 +237,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -259,15 +237,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
{ {
if constexpr(forward_sweep[i]) if constexpr(forward_sweep[i])
{ {
move_dynamic_tensor_coordinate(dst_desc, move_dynamic_tensor_coordinate(
dst_slice_origin_coord_, dst_desc, dst_coord_, dst_forward_iterators[dim_access_order[i]]);
dst_forward_iterators[dim_access_order[i]]);
} }
else else
{ {
move_dynamic_tensor_coordinate(dst_desc, move_dynamic_tensor_coordinate(
dst_slice_origin_coord_, dst_desc, dst_coord_, dst_backward_iterators[dim_access_order[i]]);
dst_backward_iterators[dim_access_order[i]]);
} }
} }
}); });
...@@ -279,11 +255,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -279,11 +255,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
const auto dst_reset_iterator = const auto dst_reset_iterator =
make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep()); make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep());
move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_coord_, dst_reset_iterator); move_dynamic_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator);
} }
} }
__device__ void Run(const SrcData* p_src, const DstDesc& dst_desc, DstData* p_dst) template <typename SrcSliceOriginIdx, typename SrcBuffer, typename DstBuffer>
__device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{ {
constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform(); constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform();
...@@ -293,7 +274,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -293,7 +274,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}), make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{})); generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
Run(p_src, dst_desc, p_dst, dst_iterator_hacks); Run(SrcDesc{}, SrcSliceOriginIdx{}, src_buf, dst_desc, dst_buf, dst_iterator_hacks);
} }
__device__ static constexpr auto GetDstCoordinateResetStep() __device__ static constexpr auto GetDstCoordinateResetStep()
...@@ -371,18 +352,22 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -371,18 +352,22 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
const auto adjusted_step = const auto adjusted_step =
make_dynamic_tensor_coordinate_iterator(dst_desc, adjusted_step_idx); make_dynamic_tensor_coordinate_iterator(dst_desc, adjusted_step_idx);
move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_coord_, adjusted_step); move_dynamic_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
} }
private: private:
DstCoord dst_slice_origin_coord_; DstCoord dst_coord_;
}; // namespace ck }; // namespace ck
// Assume: // Assume:
// 1. src_desc is not known at compile-time // 1. src:
// 2. dst_desc is known at compile-time // 1. SrcDesc is not known at compile-time
// 3. src_slice_origin_idx is not known at compile-time // 2. SrcBuffer is DynamicBuffer
// 4. dst_slice_origin_idx is known at compile-time and it's 0 // 3. src_slice_origin_idx is not known at compile-time
// 2. dst:
// 1. DstDesc is known at compile-time
// 2. DstBuffer is StaticBuffer
// 3. dst_slice_origin_idx is known at compile-time
template <typename SrcData, template <typename SrcData,
typename DstData, typename DstData,
typename SrcDesc, typename SrcDesc,
...@@ -391,8 +376,6 @@ template <typename SrcData, ...@@ -391,8 +376,6 @@ template <typename SrcData,
typename DimAccessOrder, typename DimAccessOrder,
index_t SrcVectorDim, index_t SrcVectorDim,
index_t SrcScalarPerVector, index_t SrcScalarPerVector,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
index_t SrcScalarStrideInVector, index_t SrcScalarStrideInVector,
bool SrcResetCoordinateAfterRun, bool SrcResetCoordinateAfterRun,
typename std::enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false> typename std::enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
...@@ -408,7 +391,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -408,7 +391,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v2(const SrcDesc& src_desc, __device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v2(const SrcDesc& src_desc,
const Index& src_slice_origin_idx) const Index& src_slice_origin_idx)
: src_slice_origin_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx)) : src_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx))
{ {
static_assert(DstDesc::IsKnownAtCompileTime(), static_assert(DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time"); "wrong! SrcDesc need to known at compile-time");
...@@ -416,12 +399,15 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -416,12 +399,15 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
__device__ void SetDstSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) __device__ void SetDstSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{ {
src_slice_origin_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx); src_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx);
} }
template <typename DstBuffer, typename DstSliceOriginIdx, typename SrcIteratorHacks> template <typename SrcBuffer,
typename DstBuffer,
typename DstSliceOriginIdx,
typename SrcIteratorHacks>
__device__ void Run(const SrcDesc& src_desc, __device__ void Run(const SrcDesc& src_desc,
const SrcData* p_src, const SrcBuffer& src_buf,
const DstDesc&, const DstDesc&,
const DstSliceOriginIdx&, const DstSliceOriginIdx&,
DstBuffer& dst_buf, DstBuffer& dst_buf,
...@@ -525,41 +511,19 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -525,41 +511,19 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
return src_data_idx; return src_data_idx;
}(); }();
// copy data
static_assert(DstAddressSpace == AddressSpace::Vgpr, "wrong! hardcode for vgpr dst");
typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector; typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector;
using src_vector_t = using src_vector_t =
typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type; typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_src_valid =
src_desc, src_slice_origin_coord_); coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
if constexpr(SrcAddressSpace == AddressSpace::Global) // copy data from src_buf into src_vector
{ src_vector.template AsType<src_vector_t>()(Number<0>{}) =
#if CK_USE_AMD_BUFFER_ADDRESSING src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid);
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
amd_buffer_load_v2<SrcData, SrcScalarPerVector>(
p_src,
src_slice_origin_coord_.GetOffset(),
is_src_valid,
src_desc.GetElementSpaceSize());
#else
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid ? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0};
#endif
}
else
{
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid ? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0};
}
// copy data from src_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = constexpr index_t dst_offset =
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx + dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
...@@ -590,15 +554,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -590,15 +554,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
{ {
if constexpr(forward_sweep[i]) if constexpr(forward_sweep[i])
{ {
move_dynamic_tensor_coordinate(src_desc, move_dynamic_tensor_coordinate(
src_slice_origin_coord_, src_desc, src_coord_, src_forward_iterators[dim_access_order[i]]);
src_forward_iterators[dim_access_order[i]]);
} }
else else
{ {
move_dynamic_tensor_coordinate(src_desc, move_dynamic_tensor_coordinate(
src_slice_origin_coord_, src_desc, src_coord_, src_backward_iterators[dim_access_order[i]]);
src_backward_iterators[dim_access_order[i]]);
} }
} }
}); });
...@@ -610,13 +572,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -610,13 +572,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
const auto src_reset_iterator = const auto src_reset_iterator =
make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep()); make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep());
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_coord_, src_reset_iterator); move_dynamic_tensor_coordinate(src_desc, src_coord_, src_reset_iterator);
} }
} }
template <typename DstBuffer, typename DstSliceOriginIdx> template <typename SrcBuffer, typename DstBuffer, typename DstSliceOriginIdx>
__device__ void Run(const SrcDesc& src_desc, __device__ void Run(const SrcDesc& src_desc,
const SrcData* p_src, const SrcBuffer& src_buf,
const DstDesc&, const DstDesc&,
const DstSliceOriginIdx&, const DstSliceOriginIdx&,
DstBuffer& dst_buf) DstBuffer& dst_buf)
...@@ -629,7 +591,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -629,7 +591,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}), make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{})); generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
Run(src_desc, p_src, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_iterator_hacks); Run(src_desc, src_buf, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_iterator_hacks);
} }
__device__ static constexpr auto GetSrcCoordinateResetStep() __device__ static constexpr auto GetSrcCoordinateResetStep()
...@@ -707,17 +669,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -707,17 +669,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
const auto adjusted_step = const auto adjusted_step =
make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx); make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx);
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_coord_, adjusted_step); move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step);
} }
private: private:
SrcCoord src_slice_origin_coord_; SrcCoord src_coord_;
}; // namespace ck }; // namespace ck
// Assume: // Assume:
// 1. src_desc and dst_desc are not known at compile-time // 1. src_desc and dst_desc are not known at compile-time
// 2. src_slice_origin and dst_slice_origin are not known at compile-time, // 2. SrcBuffer and DstBuffer are DynamicBuffer
// 3. Use thread buffer // 3. src_slice_origin and dst_slice_origin are not known at compile-time,
// 4. Use thread buffer
template <typename SliceLengths, template <typename SliceLengths,
InMemoryDataOperation DstInMemOp, InMemoryDataOperation DstInMemOp,
typename SrcData, typename SrcData,
...@@ -732,8 +695,6 @@ template <typename SliceLengths, ...@@ -732,8 +695,6 @@ template <typename SliceLengths,
index_t DstScalarPerVector, index_t DstScalarPerVector,
index_t SrcScalarStrideInVector, index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector, index_t DstScalarStrideInVector,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
// RunRead(), will be fused with MoveSrcSliceWindow to // RunRead(), will be fused with MoveSrcSliceWindow to
// save addr computation // save addr computation
...@@ -755,16 +716,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -755,16 +716,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
const Index& src_slice_origin, const Index& src_slice_origin,
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Index& dst_slice_origin) const Index& dst_slice_origin)
: src_slice_origin_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin)), : src_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin)),
dst_slice_origin_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin)) dst_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin))
{ {
static_assert(SrcAddressSpace == AddressSpace::Global or
SrcAddressSpace == AddressSpace::Lds,
"wrong!");
static_assert(DstAddressSpace == AddressSpace::Global or
DstAddressSpace == AddressSpace::Lds,
"wrong!");
// TODO: fix this // TODO: fix this
static_assert(is_same<SrcData, DstData>::value, static_assert(is_same<SrcData, DstData>::value,
"wrong! current implementation assume SrcData and DstData are same type"); "wrong! current implementation assume SrcData and DstData are same type");
...@@ -772,19 +726,27 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -772,19 +726,27 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{ {
src_slice_origin_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx); src_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx);
} }
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
{ {
dst_slice_origin_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx); dst_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx);
} }
template <typename SrcIteratorHacks> template <typename SrcBuffer, typename SrcIteratorHacks>
__device__ void RunRead(const SrcDesc& src_desc, __device__ void RunRead(const SrcDesc& src_desc,
const SrcData* p_src, const SrcBuffer& src_buf,
const SrcIteratorHacks& src_iterator_hacks) const SrcIteratorHacks& src_iterator_hacks)
{ {
static_assert(SrcBuffer::GetAddressSpace() == AddressSpace::Global or
SrcBuffer::GetAddressSpace() == AddressSpace::Lds,
"wrong!");
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
remove_cv_t<remove_reference_t<SrcData>>>::value,
"wrong! SrcBuffer and SrcData data type are inconsistent");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -869,37 +831,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -869,37 +831,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
return src_data_idx; return src_data_idx;
}(); }();
// copy data from src_buf to src_tmp_vector
vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector; vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector;
using src_vector_t = typename decltype(src_tmp_vector)::type; using src_vector_t = typename decltype(src_tmp_vector)::type;
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_src_valid =
src_desc, src_slice_origin_coord_); coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
if constexpr(SrcAddressSpace == AddressSpace::Global) // copy data from src_buf to src_tmp_vector
{ src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
#if CK_USE_AMD_BUFFER_ADDRESSING src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid);
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
amd_buffer_load_v2<SrcData, SrcScalarPerVector>(
p_src,
src_slice_origin_coord_.GetOffset(),
is_src_valid,
src_desc.GetElementSpaceSize());
#else
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid ? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0};
#endif
}
else
{
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid ? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0};
}
// copy data from src_tmp_vector to buffer_ // copy data from src_tmp_vector to buffer_
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
...@@ -933,16 +874,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -933,16 +874,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
if constexpr(forward_sweep[i]) if constexpr(forward_sweep[i])
{ {
move_dynamic_tensor_coordinate( move_dynamic_tensor_coordinate(
src_desc, src_desc, src_coord_, src_forward_iterators[src_dim_access_order[i]]);
src_slice_origin_coord_,
src_forward_iterators[src_dim_access_order[i]]);
} }
else else
{ {
move_dynamic_tensor_coordinate( move_dynamic_tensor_coordinate(
src_desc, src_desc, src_coord_, src_backward_iterators[src_dim_access_order[i]]);
src_slice_origin_coord_,
src_backward_iterators[src_dim_access_order[i]]);
} }
} }
}); });
...@@ -954,14 +891,23 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -954,14 +891,23 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
const auto src_reset_iterator = const auto src_reset_iterator =
make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep()); make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep());
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_coord_, src_reset_iterator); move_dynamic_tensor_coordinate(src_desc, src_coord_, src_reset_iterator);
} }
} }
template <typename DstIteratorHacks> template <typename DstBuffer, typename DstIteratorHacks>
__device__ void __device__ void RunWrite(const DstDesc& dst_desc,
RunWrite(const DstDesc& dst_desc, DstData* p_dst, const DstIteratorHacks& dst_iterator_hacks) DstBuffer& dst_buf,
const DstIteratorHacks& dst_iterator_hacks)
{ {
static_assert(DstBuffer::GetAddressSpace() == AddressSpace::Global or
DstBuffer::GetAddressSpace() == AddressSpace::Lds,
"wrong!");
static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>,
remove_cv_t<remove_reference_t<DstData>>>::value,
"wrong! SrcBuffer or DstBuffer data type is wrong");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -1050,13 +996,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1050,13 +996,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
return dst_data_idx; return dst_data_idx;
}(); }();
// copy data
// hardcoding for ds_write
// TODO refactor transfer_data() to encapsulate this
static_assert(DstAddressSpace == AddressSpace::Lds &&
DstInMemOp == InMemoryDataOperation::Set,
"wrong! hardcoded for ds_write");
vector_type_maker_t<DstData, DstScalarPerVector> dst_tmp_vector; vector_type_maker_t<DstData, DstScalarPerVector> dst_tmp_vector;
// copy data from buffer_ to dst_tmp_vector // copy data from buffer_ to dst_tmp_vector
...@@ -1070,8 +1009,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1070,8 +1009,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
using dst_vector_t = typename decltype(dst_tmp_vector)::type; using dst_vector_t = typename decltype(dst_tmp_vector)::type;
// copy data from dst_tmp_vector to dst_buf // copy data from dst_tmp_vector to dst_buf
*reinterpret_cast<dst_vector_t*>(p_dst + dst_slice_origin_coord_.GetOffset()) = const bool is_dst_valid =
dst_tmp_vector.template AsType<dst_vector_t>()[Number<0>{}]; coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_tmp_vector.template AsType<dst_vector_t>()[Number<0>{}]);
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
{ {
...@@ -1097,16 +1041,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1097,16 +1041,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
if constexpr(forward_sweep[i]) if constexpr(forward_sweep[i])
{ {
move_dynamic_tensor_coordinate( move_dynamic_tensor_coordinate(
dst_desc, dst_desc, dst_coord_, dst_forward_iterators[dst_dim_access_order[i]]);
dst_slice_origin_coord_,
dst_forward_iterators[dst_dim_access_order[i]]);
} }
else else
{ {
move_dynamic_tensor_coordinate( move_dynamic_tensor_coordinate(
dst_desc, dst_desc, dst_coord_, dst_backward_iterators[dst_dim_access_order[i]]);
dst_slice_origin_coord_,
dst_backward_iterators[dst_dim_access_order[i]]);
} }
} }
}); });
...@@ -1118,11 +1058,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1118,11 +1058,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
const auto dst_reset_iterator = const auto dst_reset_iterator =
make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep()); make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep());
move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_coord_, dst_reset_iterator); move_dynamic_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator);
} }
} }
__device__ void RunRead(const SrcDesc& src_desc, const SrcData* p_src) template <typename SrcBuffer>
__device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
{ {
constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform(); constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform();
...@@ -1132,10 +1073,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1132,10 +1073,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}), make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{})); generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
RunRead(src_desc, p_src, src_iterator_hacks); RunRead(src_desc, src_buf, src_iterator_hacks);
} }
__device__ void RunWrite(const DstDesc& dst_desc, DstData* p_dst) template <typename DstBuffer>
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
{ {
constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform(); constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform();
...@@ -1145,7 +1087,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1145,7 +1087,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}), make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{})); generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
RunWrite(dst_desc, p_dst, dst_iterator_hacks); RunWrite(dst_desc, dst_buf, dst_iterator_hacks);
} }
__device__ static constexpr auto GetSrcCoordinateResetStep() __device__ static constexpr auto GetSrcCoordinateResetStep()
...@@ -1285,7 +1227,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1285,7 +1227,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
const auto adjusted_step = const auto adjusted_step =
make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx); make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx);
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_coord_, adjusted_step); move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step);
} }
// src_slice_origin_step_idx need to be known at compile-time, for performance reason // src_slice_origin_step_idx need to be known at compile-time, for performance reason
...@@ -1304,7 +1246,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1304,7 +1246,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
const auto adjusted_step = make_dynamic_tensor_coordinate_iterator( const auto adjusted_step = make_dynamic_tensor_coordinate_iterator(
src_desc, adjusted_step_idx, src_move_slice_window_iterator_hack); src_desc, adjusted_step_idx, src_move_slice_window_iterator_hack);
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_coord_, adjusted_step); move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step);
} }
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, __device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
...@@ -1319,7 +1261,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1319,7 +1261,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
const auto adjusted_step = const auto adjusted_step =
make_dynamic_tensor_coordinate_iterator(dst_desc, adjusted_step_idx); make_dynamic_tensor_coordinate_iterator(dst_desc, adjusted_step_idx);
move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_coord_, adjusted_step); move_dynamic_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
} }
private: private:
...@@ -1328,10 +1270,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1328,10 +1270,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
StaticBuffer<SrcData, buffer_size_> buffer_; StaticBuffer<AddressSpace::Vgpr, SrcData, buffer_size_> buffer_;
SrcCoord src_slice_origin_coord_; SrcCoord src_coord_;
DstCoord dst_slice_origin_coord_; DstCoord dst_coord_;
}; };
// Assume: // Assume:
...@@ -1356,8 +1298,6 @@ template < ...@@ -1356,8 +1298,6 @@ template <
typename DimAccessOrder, typename DimAccessOrder,
index_t SrcVectorDim, index_t SrcVectorDim,
index_t SrcScalarPerVector, index_t SrcScalarPerVector,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
index_t SrcScalarStrideInVector, index_t SrcScalarStrideInVector,
typename std::enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), typename std::enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
...@@ -1480,7 +1420,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1480,7 +1420,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
move_dynamic_tensor_coordinate( move_dynamic_tensor_coordinate(
src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator); src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator);
// copy data from src_buf into src_tmp_buffer
vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector; vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector;
using src_vector_t = typename decltype(src_tmp_vector)::type; using src_vector_t = typename decltype(src_tmp_vector)::type;
...@@ -1488,9 +1427,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1488,9 +1427,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_data_coord); src_desc, src_data_coord);
// copy data from src_buf into src_tmp_vector
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) = src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid ? src_buf.template Get<src_vector_t>(src_data_coord.GetOffset()) src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(), is_src_valid);
: src_vector_t{0};
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData) // DstData)
......
...@@ -323,7 +323,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -323,7 +323,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return __llvm_amdgcn_raw_buffer_load_i8x2( return __llvm_amdgcn_raw_buffer_load_i8x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
#else #else
...@@ -335,7 +335,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -335,7 +335,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return __llvm_amdgcn_raw_buffer_load_i8x4( return __llvm_amdgcn_raw_buffer_load_i8x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
#else #else
...@@ -347,7 +347,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -347,7 +347,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type<int8_t, 8> tmp; vector_type<int8_t, 8> tmp;
tmp.AsType<int8x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4( tmp.AsType<int8x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4(
...@@ -369,7 +369,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -369,7 +369,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
} }
else if constexpr(N == 16) else if constexpr(N == 16)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type<int8_t, 16> tmp; vector_type<int8_t, 16> tmp;
tmp.AsType<int8x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4( tmp.AsType<int8x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4(
...@@ -483,7 +483,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type ...@@ -483,7 +483,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
__llvm_amdgcn_raw_buffer_store_i8x2(src_thread_data, __llvm_amdgcn_raw_buffer_store_i8x2(src_thread_data,
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
...@@ -499,7 +499,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type ...@@ -499,7 +499,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
__llvm_amdgcn_raw_buffer_store_i8x4(src_thread_data, __llvm_amdgcn_raw_buffer_store_i8x4(src_thread_data,
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
......
#ifndef CK_BUFFER_HPP
#define CK_BUFFER_HPP
#include "statically_indexed_array.hpp"
namespace ck {
template <typename T, index_t N>
struct StaticBuffer : public StaticallyIndexedArray<T, N>
{
using type = T;
using base = StaticallyIndexedArray<T, N>;
__host__ __device__ constexpr StaticBuffer() : base{} {}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
};
template <typename T, index_t N>
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
{
return StaticBuffer<T, N>{};
}
template <typename T>
struct DynamicBuffer
{
using type = T;
T* p_data_;
__host__ __device__ constexpr DynamicBuffer(T* p_data) : p_data_{p_data} {}
__host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; }
__host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }
template <typename X,
typename std::enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false>
__host__ __device__ constexpr const auto Get(index_t i) const
{
return *reinterpret_cast<const X*>(&p_data_[i]);
}
template <typename X,
typename std::enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false>
__host__ __device__ void Set(index_t i, const X& x)
{
*reinterpret_cast<X*>(&p_data_[i]) = x;
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return false; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
};
template <typename T>
__host__ __device__ constexpr auto make_dynamic_buffer(T* p)
{
return DynamicBuffer<T>{p};
}
} // namespace ck
#endif
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
#include "container_element_picker.hpp" #include "container_element_picker.hpp"
#include "data_type.hpp" #include "data_type.hpp"
#include "float_type.hpp" #include "float_type.hpp"
#include "buffer.hpp"
#include "functional.hpp" #include "functional.hpp"
#include "functional2.hpp" #include "functional2.hpp"
#include "functional3.hpp" #include "functional3.hpp"
...@@ -25,6 +24,8 @@ ...@@ -25,6 +24,8 @@
#include "type.hpp" #include "type.hpp"
#include "utility.hpp" #include "utility.hpp"
#include "magic_division.hpp" #include "magic_division.hpp"
#include "static_buffer.hpp"
#include "dynamic_buffer.hpp"
#if CK_USE_AMD_INLINE_ASM #if CK_USE_AMD_INLINE_ASM
#include "amd_inline_asm.hpp" #include "amd_inline_asm.hpp"
......
...@@ -143,8 +143,13 @@ ...@@ -143,8 +143,13 @@
#endif #endif
// workaround for compiler crash when using buffer load/store for i8 // workaround for compiler crash when using buffer load/store for i8
#ifndef CK_WORKAROUND_SWDEV_XXXXXX #ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
#define CK_WORKAROUND_SWDEV_XXXXXX 1 #define CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE 1
#endif
// workaround for compiler crash when using buffer load/store for i8
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif #endif
namespace ck { namespace ck {
...@@ -154,6 +159,7 @@ enum AddressSpace ...@@ -154,6 +159,7 @@ enum AddressSpace
Generic, Generic,
Global, Global,
Lds, Lds,
Sgpr,
Vgpr Vgpr
}; };
......
#ifndef CK_DYNAMIC_BUFFER_HPP
#define CK_DYNAMIC_BUFFER_HPP
namespace ck {
#include "amd_buffer_addressing_v2.hpp"
template <AddressSpace BufferAddressSpace, typename T, typename ElementSpaceSize>
struct DynamicBuffer
{
using type = T;
T* p_data_;
ElementSpaceSize element_space_size_;
__host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
: p_data_{p_data}, element_space_size_{element_space_size}
{
}
__host__ __device__ static constexpr AddressSpace GetAddressSpace()
{
return BufferAddressSpace;
}
__host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; }
__host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }
template <typename X,
typename std::enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false>
__host__ __device__ constexpr const auto Get(index_t i, bool is_valid_offset) const
{
// X contains multiple T
constexpr index_t scalar_per_t_vector =
scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;
constexpr index_t scalar_per_x_vector =
scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T");
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
if constexpr(GetAddressSpace() == AddressSpace::Global)
{
#if CK_USE_AMD_BUFFER_ADDRESSING
return amd_buffer_load_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>(
p_data_, i, is_valid_offset, element_space_size_);
#else
return is_valid_offset ? *reinterpret_cast<const X*>(&p_data_[i]) : X{0};
#endif
}
else
{
return is_valid_offset ? *reinterpret_cast<const X*>(&p_data_[i]) : X{0};
}
}
template <typename X,
typename std::enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false>
__host__ __device__ void Set(index_t i, bool is_valid_offset, const X& x)
{
// X contains multiple T
constexpr index_t scalar_per_t_vector =
scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;
constexpr index_t scalar_per_x_vector =
scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T");
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
if constexpr(GetAddressSpace() == AddressSpace::Global)
{
#if CK_USE_AMD_BUFFER_ADDRESSING
amd_buffer_store_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>(
x, p_data_, i, is_valid_offset, element_space_size_);
#else
if(is_valid_offset)
{
*reinterpret_cast<X*>(&p_data_[i]) = x;
}
#endif
}
else if constexpr(GetAddressSpace() == AddressSpace::Lds)
{
if(is_valid_offset)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
*reinterpret_cast<X*>(&p_data_[i]) = x;
#else
// HACK: compiler would lower IR "store<i8, 16> address_space(3)" into inefficient
// ISA, so I try to let compiler emit use IR "store<i32, 4>" which would be lower to
// ds_write_b128
// TODO: remove this after compiler fix
if constexpr(is_same<typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type,
int8_t>::value)
{
static_assert(
(is_same<remove_cv_t<remove_reference_t<T>>, int8x4_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) ||
(is_same<remove_cv_t<remove_reference_t<T>>, int8x8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value) ||
(is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value),
"wrong! not implemented for this combination, please add implementation");
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8x4_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*reinterpret_cast<int32_t*>(&p_data_[i]) =
*reinterpret_cast<const int32_t*>(&x);
}
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8x8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*reinterpret_cast<int32x2_t*>(&p_data_[i]) =
*reinterpret_cast<const int32x2_t*>(&x);
}
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*reinterpret_cast<int32x4_t*>(&p_data_[i]) =
*reinterpret_cast<const int32x4_t*>(&x);
}
}
else
{
*reinterpret_cast<X*>(&p_data_[i]) = x;
}
#endif
}
}
else
{
if(is_valid_offset)
{
*reinterpret_cast<X*>(&p_data_[i]) = x;
}
}
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return false; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
};
template <AddressSpace BufferAddressSpace = AddressSpace::Generic,
typename T,
typename ElementSpaceSize>
__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
{
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize>{p, element_space_size};
}
} // namespace ck
#endif
#ifndef CK_STATIC_BUFFER_HPP
#define CK_STATIC_BUFFER_HPP
#include "statically_indexed_array.hpp"
namespace ck {
template <AddressSpace BufferAddressSpace, typename T, index_t N>
struct StaticBuffer : public StaticallyIndexedArray<T, N>
{
using type = T;
using base = StaticallyIndexedArray<T, N>;
__host__ __device__ constexpr StaticBuffer() : base{} {}
__host__ __device__ static constexpr AddressSpace GetAddressSpace()
{
return BufferAddressSpace;
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
};
template <AddressSpace BufferAddressSpace = AddressSpace::Generic, typename T, index_t N>
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
{
return StaticBuffer<BufferAddressSpace, T, N>{};
}
} // namespace ck
#endif
...@@ -63,12 +63,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -63,12 +63,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
#if 0 #if 0
// run-time variables // run-time variables
const auto in_n_c_hi_wi_desc = const auto in_n_c0_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths())); make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, C0, Hi, Wi));
const auto wei_k_c_y_x_desc = const auto wei_k_c0_y_x_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(WeiDesc::GetLengths())); make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(K, C0, Y, X));
const auto out_n_k_ho_wo_desc = const auto out_n_k0_ho_wo_k1_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(OutDesc::GetLengths())); make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, K0, Ho, Wo, K1));
const auto conv_strides = to_multi_index(ConvStrides{}); const auto conv_strides = to_multi_index(ConvStrides{});
const auto conv_dilations = to_multi_index(ConvDilations{}); const auto conv_dilations = to_multi_index(ConvDilations{});
......
...@@ -48,8 +48,8 @@ int main(int argc, char* argv[]) ...@@ -48,8 +48,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
...@@ -62,8 +62,8 @@ int main(int argc, char* argv[]) ...@@ -62,8 +62,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
...@@ -92,7 +92,7 @@ int main(int argc, char* argv[]) ...@@ -92,7 +92,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 1 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
constexpr index_t HI = 540; constexpr index_t HI = 540;
...@@ -630,7 +630,7 @@ int main(int argc, char* argv[]) ...@@ -630,7 +630,7 @@ int main(int argc, char* argv[])
print_array("ConvStrides", to_multi_index(ConvStrides{})); print_array("ConvStrides", to_multi_index(ConvStrides{}));
print_array("ConvDilations", to_multi_index(ConvDilations{})); print_array("ConvDilations", to_multi_index(ConvDilations{}));
#if 0 #if 1
using in_data_t = float; using in_data_t = float;
constexpr index_t in_vector_size = 1; constexpr index_t in_vector_size = 1;
using acc_data_t = float; using acc_data_t = float;
...@@ -740,7 +740,7 @@ int main(int argc, char* argv[]) ...@@ -740,7 +740,7 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 0 #elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t, device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
in_vector_size, in_vector_size,
acc_data_t, acc_data_t,
......
...@@ -10,13 +10,14 @@ cmake ...@@ -10,13 +10,14 @@ cmake
-D CMAKE_INSTALL_PREFIX=${MY_PROJECT_INSTALL} \ -D CMAKE_INSTALL_PREFIX=${MY_PROJECT_INSTALL} \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D DEVICE_BACKEND="AMD" \ -D DEVICE_BACKEND="AMD" \
-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx1030 -gline-tables-only -save-temps=$CWD -ftemplate-backtrace-limit=0" \ -D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx1030 -gline-tables-only -save-temps=$CWD -ftemplate-backtrace-limit=0" \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_PREFIX_PATH="/opt/rocm" \ -D CMAKE_PREFIX_PATH="/opt/rocm" \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
${MY_PROJECT_SOURCE} ${MY_PROJECT_SOURCE}
#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -save-temps=$CWD" \ #-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx1030 -gline-tables-only -save-temps=$CWD -ftemplate-backtrace-limit=0" \
#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx1030 -gline-tables-only -save-temps=$CWD -ftemplate-backtrace-limit=0 -mllvm -print-before=amdgpu-codegenprepare -mllvm -print-module-scope" \
#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -gline-tables-only -save-temps=$CWD" \ #-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -gline-tables-only -save-temps=$CWD" \
#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-spill-vgpr-to-agpr=0" \ #-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-spill-vgpr-to-agpr=0" \
#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -save-temps=$CWD" \ #-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -save-temps=$CWD" \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment