Unverified Commit 78b987fb authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Use DynamicBuffer instead of raw pointer (#32)

* Use DynamicBuffer to hold raw pointer (to global and LDS memory)

* add workaround for compiler issue (inefficient ISA) of ds_write for int8x4, int8x8, int8x16
parent 01055d95
...@@ -146,16 +146,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -146,16 +146,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>, remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>, remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
integral_constant<bool, true>, true,
integral_constant<bool, true>>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -163,28 +163,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -163,28 +163,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
a_k_m_global_desc,
p_a_global, p_a_global,
b_k_n_global_desc,
p_b_global, p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global, p_c_global,
c_block_cluster_desc, a_k_m_global_desc,
integral_constant<bool, true>{}, b_k_n_global_desc,
integral_constant<bool, true>{}); c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>, remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>, remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
integral_constant<bool, true>, true,
integral_constant<bool, false>>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -192,28 +190,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -192,28 +190,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
a_k_m_global_desc,
p_a_global, p_a_global,
b_k_n_global_desc,
p_b_global, p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global, p_c_global,
c_block_cluster_desc, a_k_m_global_desc,
integral_constant<bool, true>{}, b_k_n_global_desc,
integral_constant<bool, false>{}); c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>, remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>, remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
integral_constant<bool, false>, false,
integral_constant<bool, true>>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -221,28 +217,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -221,28 +217,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
a_k_m_global_desc,
p_a_global, p_a_global,
b_k_n_global_desc,
p_b_global, p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global, p_c_global,
c_block_cluster_desc, a_k_m_global_desc,
integral_constant<bool, false>{}, b_k_n_global_desc,
integral_constant<bool, true>{}); c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
} }
else else
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>, remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>, remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
integral_constant<bool, false>, false,
integral_constant<bool, false>>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -250,15 +244,13 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -250,15 +244,13 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
a_k_m_global_desc,
p_a_global, p_a_global,
b_k_n_global_desc,
p_b_global, p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global, p_c_global,
c_block_cluster_desc, a_k_m_global_desc,
integral_constant<bool, false>{}, b_k_n_global_desc,
integral_constant<bool, false>{}); c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
} }
return ave_time; return ave_time;
...@@ -277,13 +269,13 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -277,13 +269,13 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>, remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>, remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
true, true,
true>; true>;
...@@ -295,23 +287,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -295,23 +287,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
p_a_global, p_a_global,
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
p_b_global, p_b_global,
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
p_c_global, p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
remove_reference_t<AGlobalDesc>,
FloatAB, FloatAB,
remove_reference_t<BGlobalDesc>,
FloatAB, FloatAB,
remove_reference_t<CGlobalDesc>,
FloatC, FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
true, true,
false>; false>;
...@@ -323,23 +315,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -323,23 +315,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
p_a_global, p_a_global,
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
p_b_global, p_b_global,
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
p_c_global, p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
remove_reference_t<AGlobalDesc>,
FloatAB, FloatAB,
remove_reference_t<BGlobalDesc>,
FloatAB, FloatAB,
remove_reference_t<CGlobalDesc>,
FloatC, FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
false, false,
true>; true>;
...@@ -351,23 +343,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -351,23 +343,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
p_a_global, p_a_global,
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
p_b_global, p_b_global,
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
p_c_global, p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
} }
else else
{ {
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
remove_reference_t<AGlobalDesc>,
FloatAB, FloatAB,
remove_reference_t<BGlobalDesc>,
FloatAB, FloatAB,
remove_reference_t<CGlobalDesc>,
FloatC, FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
false, false,
false>; false>;
...@@ -379,12 +371,12 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -379,12 +371,12 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
p_a_global, p_a_global,
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
p_b_global, p_b_global,
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
p_c_global, p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
} }
......
...@@ -29,8 +29,6 @@ template <index_t BlockSize, ...@@ -29,8 +29,6 @@ template <index_t BlockSize,
index_t DstVectorDim, index_t DstVectorDim,
index_t SrcScalarPerVector, index_t SrcScalarPerVector,
index_t DstScalarPerVector, index_t DstScalarPerVector,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
index_t SrcScalarStrideInVector, index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector, index_t DstScalarStrideInVector,
index_t ThreadTransferSrcResetCoordinateAfterRun, index_t ThreadTransferSrcResetCoordinateAfterRun,
...@@ -79,24 +77,25 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 ...@@ -79,24 +77,25 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
} }
} }
template <typename SrcIteratorHacks> template <typename SrcBuffer, typename SrcIteratorHacks>
__device__ void RunRead(const SrcDesc& src_desc, __device__ void RunRead(const SrcDesc& src_desc,
const SrcData* p_src, const SrcBuffer& src_buf,
const SrcIteratorHacks& src_iterator_hacks) const SrcIteratorHacks& src_iterator_hacks)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.RunRead(src_desc, p_src, src_iterator_hacks); threadwise_transfer_.RunRead(src_desc, src_buf, src_iterator_hacks);
} }
} }
__device__ void RunWrite(const DstDesc& dst_desc, DstData* p_dst) template <typename DstBuffer>
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.RunWrite(dst_desc, p_dst); threadwise_transfer_.RunWrite(dst_desc, dst_buf);
} }
} }
...@@ -152,8 +151,6 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 ...@@ -152,8 +151,6 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
DstScalarPerVector, DstScalarPerVector,
SrcScalarStrideInVector, SrcScalarStrideInVector,
DstScalarStrideInVector, DstScalarStrideInVector,
SrcAddressSpace,
DstAddressSpace,
ThreadTransferSrcResetCoordinateAfterRun, ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>; ThreadTransferDstResetCoordinateAfterRun>;
......
...@@ -115,8 +115,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1 ...@@ -115,8 +115,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
{ {
auto a_thread_buf = make_static_buffer<FloatA>(a_thread_desc_.GetElementSpaceSize()); auto a_thread_buf =
auto b_thread_buf = make_static_buffer<FloatB>(b_thread_desc_.GetElementSpaceSize()); make_static_buffer<AddressSpace::Vgpr, FloatA>(a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf =
make_static_buffer<AddressSpace::Vgpr, FloatB>(b_thread_desc_.GetElementSpaceSize());
constexpr auto threadwise_gemm = constexpr auto threadwise_gemm =
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA, ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA,
...@@ -176,8 +178,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1 ...@@ -176,8 +178,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
2, 2,
AThreadCopyScalarPerVector_M1, AThreadCopyScalarPerVector_M1,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>; 1>;
using BThreadCopy = using BThreadCopy =
...@@ -189,8 +189,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1 ...@@ -189,8 +189,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
2, 2,
BThreadCopyScalarPerVector_N1, BThreadCopyScalarPerVector_N1,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>; 1>;
CIndex c_thread_origin_data_idx_; CIndex c_thread_origin_data_idx_;
...@@ -211,6 +209,8 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1 ...@@ -211,6 +209,8 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
// 3. C: // 3. C:
// 1. CThreadDesc is known at compile-time // 1. CThreadDesc is known at compile-time
// 2. CThreadBuffer is StaticBuffer // 2. CThreadBuffer is StaticBuffer
// Also assume:
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
template <index_t BlockSize, template <index_t BlockSize,
typename FloatA, typename FloatA,
typename FloatB, typename FloatB,
...@@ -312,8 +312,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2 ...@@ -312,8 +312,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
{ {
auto a_thread_buf = make_static_buffer<FloatA>(a_thread_desc_.GetElementSpaceSize()); auto a_thread_buf =
auto b_thread_buf = make_static_buffer<FloatB>(b_thread_desc_.GetElementSpaceSize()); make_static_buffer<AddressSpace::Vgpr, FloatA>(a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf =
make_static_buffer<AddressSpace::Vgpr, FloatB>(b_thread_desc_.GetElementSpaceSize());
constexpr auto threadwise_gemm = constexpr auto threadwise_gemm =
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA, ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA,
...@@ -481,8 +483,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2 ...@@ -481,8 +483,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
2, 2,
AThreadCopyScalarPerVector_M1, AThreadCopyScalarPerVector_M1,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>; 1>;
using BThreadCopy = using BThreadCopy =
...@@ -494,8 +494,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2 ...@@ -494,8 +494,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
2, 2,
BThreadCopyScalarPerVector_N1, BThreadCopyScalarPerVector_N1,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>; 1>;
CIndex c_thread_origin_data_idx_; CIndex c_thread_origin_data_idx_;
......
...@@ -49,8 +49,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -49,8 +49,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
ThreadGemmADataPerRead_K, ThreadGemmADataPerRead_K,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>; 1>;
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v3() __device__ BlockwiseGemm_km_kn_m0m1n0n1_v3()
...@@ -140,7 +138,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -140,7 +138,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
static_assert(WPerThread % WoPerThreadSubC == 0, ""); static_assert(WPerThread % WoPerThreadSubC == 0, "");
// thread A buffer for GEMM // thread A buffer for GEMM
StaticBuffer<FloatA, a_thread_mtx_.GetElementSpaceSize()> a_thread_buf; StaticBuffer<AddressSpace::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize()> a_thread_buf;
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3<FloatA, constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3<FloatA,
FloatB, FloatB,
......
...@@ -14,54 +14,62 @@ namespace ck { ...@@ -14,54 +14,62 @@ namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename AGlobalDesc,
typename FloatA, typename FloatA,
typename BGlobalDesc,
typename FloatB, typename FloatB,
typename CGlobalDesc,
typename FloatC, typename FloatC,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc, typename CBlockClusterDesc,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void kernel_dynamic_gemm_v1(const AGlobalDesc a_k_m_global_desc, __global__ void
const FloatA* __restrict__ p_a_global, #if CK_USE_LAUNCH_BOUNDS
const BGlobalDesc b_k_n_global_desc, __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
const FloatB* __restrict__ p_b_global, #endif
const CGlobalDesc c_m0_m1_n0_n1_global_desc, kernel_dynamic_gemm_v1(const FloatA* __restrict__ p_a_global,
FloatC* __restrict__ p_c_global, const FloatB* __restrict__ p_b_global,
const CBlockClusterDesc c_block_cluster_desc) FloatC* __restrict__ p_c_global,
const AGlobalDesc a_k_m_global_desc,
const BGlobalDesc b_k_n_global_desc,
const CGlobalDesc c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc c_block_cluster_desc)
{ {
GridwiseGemm{}.Run(a_k_m_global_desc, GridwiseGemm::Run(p_a_global,
p_a_global, p_b_global,
b_k_n_global_desc, p_c_global,
p_b_global, a_k_m_global_desc,
c_m0_m1_n0_n1_global_desc, b_k_n_global_desc,
p_c_global, c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc, c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by __CONSTANT__ void pointer // pass tensor descriptor by __CONSTANT__ void pointer
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to // __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization // non-modifiable parameter address space, so compiler can enable corresponding optimization
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename AGlobalDesc,
typename FloatA, typename FloatA,
typename BGlobalDesc,
typename FloatB, typename FloatB,
typename CGlobalDesc,
typename FloatC, typename FloatC,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc, typename CBlockClusterDesc,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void kernel_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_desc, __global__ void
const FloatA* __restrict__ p_a_global, #if CK_USE_LAUNCH_BOUNDS
const void __CONSTANT__* p_b_k_n_global_desc, __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
const FloatB* __restrict__ p_b_global, #endif
const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc, kernel_dynamic_gemm_v1(const FloatA* __restrict__ p_a_global,
FloatC* __restrict__ p_c_global, const FloatB* __restrict__ p_b_global,
const void __CONSTANT__* p_c_block_cluster_desc) FloatC* __restrict__ p_c_global,
const void __CONSTANT__* p_a_k_m_global_desc,
const void __CONSTANT__* p_b_k_n_global_desc,
const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc,
const void __CONSTANT__* p_c_block_cluster_desc)
{ {
// first cast void __CONSTANT__ void* to void* // first cast void __CONSTANT__ void* to void*
// second cast void* to Desc* // second cast void* to Desc*
...@@ -76,15 +84,15 @@ __global__ void kernel_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_d ...@@ -76,15 +84,15 @@ __global__ void kernel_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_d
const auto c_block_cluster_desc = const auto c_block_cluster_desc =
*reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc); *reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc);
GridwiseGemm{}.Run(a_k_m_global_desc, GridwiseGemm::Run(p_a_global,
p_a_global, p_b_global,
b_k_n_global_desc, p_c_global,
p_b_global, a_k_m_global_desc,
c_m0_m1_n0_n1_global_desc, b_k_n_global_desc,
p_c_global, c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc, c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
#endif #endif
...@@ -161,22 +169,29 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -161,22 +169,29 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
} }
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc, __device__ static void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_a_global, const FloatAB* __restrict__ p_b_global,
const BGlobalDesc& b_k_n_global_desc, FloatC* __restrict__ p_c_global,
const FloatAB* __restrict__ p_b_global, const AGlobalDesc& a_k_m_global_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc, const BGlobalDesc& b_k_n_global_desc,
FloatC* __restrict__ p_c_global, const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc& c_block_cluster_desc, const CBlockClusterDesc& c_block_cluster_desc,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_a_global, a_k_m_global_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_global, b_k_n_global_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_global, c_m0_m1_n0_n1_global_desc.GetElementSpaceSize());
const auto K = a_k_m_global_desc.GetLength(I0); const auto K = a_k_m_global_desc.GetLength(I0);
const auto M = a_k_m_global_desc.GetLength(I1); const auto M = a_k_m_global_desc.GetLength(I1);
const auto N = b_k_n_global_desc.GetLength(I1); const auto N = b_k_n_global_desc.GetLength(I1);
...@@ -226,8 +241,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -226,8 +241,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
1, 1,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M, ABlockTransferDstScalarPerVector_M,
AddressSpace::Global,
AddressSpace::Lds,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
...@@ -255,8 +268,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -255,8 +268,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
1, 1,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N, BBlockTransferDstScalarPerVector_N,
AddressSpace::Global,
AddressSpace::Lds,
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
...@@ -331,8 +342,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -331,8 +342,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size; FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output // register allocation for output
auto c_thread_buf = auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
make_static_buffer<FloatAcc>(c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize()); c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc, ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
decltype(c_m0_m1_n0_n1_thread_desc), decltype(c_m0_m1_n0_n1_thread_desc),
...@@ -353,25 +364,23 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -353,25 +364,23 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
constexpr auto b_k_n_global_move_slice_window_iterator_hack = constexpr auto b_k_n_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{}; BGlobalMoveSliceWindowIteratorHacks{};
FloatAB* p_a_block_even = p_a_block_double; auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
FloatAB* p_b_block_even = p_b_block_double; p_a_block_double, a_k_m_block_desc.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
FloatAB* p_a_block_odd = p_a_block_double + a_block_space_size; p_b_block_double, b_k_n_block_desc.GetElementSpaceSize());
FloatAB* p_b_block_odd = p_b_block_double + b_block_space_size;
auto a_block_even_buf = make_dynamic_buffer(p_a_block_even); auto a_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
auto b_block_even_buf = make_dynamic_buffer(p_b_block_even); p_a_block_double + a_block_space_size, a_k_m_block_desc.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
auto a_block_odd_buf = make_dynamic_buffer(p_a_block_odd); p_b_block_double + b_block_space_size, b_k_n_block_desc.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer(p_b_block_odd);
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_blockwise_copy.RunRead(a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); b_blockwise_copy.RunRead(b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double); a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double); b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_even_buf);
} }
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
...@@ -394,16 +403,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -394,16 +403,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_odd); a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_odd); b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_odd_buf);
// odd iteration // odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
...@@ -417,16 +426,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -417,16 +426,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf); blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_even); a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_even); b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_even_buf);
k_block_data_begin += 2 * KPerBlock; k_block_data_begin += 2 * KPerBlock;
} while(k_block_data_begin < K - 2 * KPerBlock); } while(k_block_data_begin < K - 2 * KPerBlock);
...@@ -445,15 +454,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -445,15 +454,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_blockwise_copy.RunRead(a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); b_blockwise_copy.RunRead(b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS // LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size); a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double + b_block_space_size); b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_odd_buf);
__syncthreads(); __syncthreads();
...@@ -488,8 +497,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -488,8 +497,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>{ true>{
...@@ -502,32 +509,32 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -502,32 +509,32 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
c_m0_m1_n0_n1_global_desc, c_m0_m1_n0_n1_global_desc,
p_c_global, c_global_buf,
c_m0_m1_n0_n1_global_tensor_iterator_hacks); c_m0_m1_n0_n1_global_tensor_iterator_hacks);
} }
} }
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc, __device__ static void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_a_global, const FloatAB* __restrict__ p_b_global,
const BGlobalDesc& b_k_n_global_desc, FloatC* __restrict__ p_c_global,
const FloatAB* __restrict__ p_b_global, const AGlobalDesc& a_k_m_global_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc, const BGlobalDesc& b_k_n_global_desc,
FloatC* __restrict__ p_c_global, const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc& c_block_cluster_desc, const CBlockClusterDesc& c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>)
{ {
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB); constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size]; __shared__ FloatAB p_shared_block[shared_block_size];
Run(a_k_m_global_desc, Run(p_a_global,
p_a_global,
b_k_n_global_desc,
p_b_global, p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global, p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc, c_block_cluster_desc,
p_shared_block, p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
......
...@@ -84,6 +84,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -84,6 +84,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_a_global, a_e_k_global_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize());
constexpr auto E = EPerBlock * 3 * 3; constexpr auto E = EPerBlock * 3 * 3;
// const auto E = a_e_k_global_desc.GetLength(I0); // const auto E = a_e_k_global_desc.GetLength(I0);
...@@ -192,8 +199,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -192,8 +199,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
1, 1,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K, ABlockTransferDstScalarPerVector_K,
AddressSpace::Global,
AddressSpace::Lds,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
...@@ -216,19 +221,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -216,19 +221,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
AddressSpace::Global,
AddressSpace::Vgpr,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
1, 1,
true>(b_e_n_ho_wo_global_desc, true>(b_e_n_ho_wo_global_desc,
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global)); make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
FloatAB* p_a_block = p_shared_block; auto a_block_buf = make_dynamic_buffer<AddressSpace::Lds>(p_shared_block,
a_e_k_desc.GetElementSpaceSize());
auto a_block_buf = make_dynamic_buffer(p_a_block);
// register allocation for output // register allocation for output
StaticBuffer<FloatAcc, c_k_n_ho_wo_thread_desc.GetElementSpaceSize()> c_thread_buf; StaticBuffer<AddressSpace::Vgpr, FloatAcc, c_k_n_ho_wo_thread_desc.GetElementSpaceSize()>
c_thread_buf;
// initialize output thread tensor // initialize output thread tensor
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc, ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
...@@ -250,21 +253,21 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -250,21 +253,21 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
BGlobalMoveSliceWindowIteratorHacks{}; BGlobalMoveSliceWindowIteratorHacks{};
// double regsiter buffer for b // double regsiter buffer for b
StaticBuffer<FloatAB, b_e_n_ho_wo_thread_desc.GetElementSpaceSize()> b_thread_even_buf, StaticBuffer<AddressSpace::Vgpr, FloatAB, b_e_n_ho_wo_thread_desc.GetElementSpaceSize()>
b_thread_odd_buf; b_thread_even_buf, b_thread_odd_buf;
// LDS double buffer: preload data // LDS double buffer: preload data
{ {
a_blockwise_copy.RunRead(a_e_k_global_desc, p_a_global, a_e_k_global_iterator_hacks); a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_iterator_hacks);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global, b_global_buf,
b_e_n_ho_wo_thread_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_even_buf, b_thread_even_buf,
b_e_n_ho_wo_global_iterator_hacks); b_e_n_ho_wo_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_e_k_desc, p_a_block); a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf);
} }
__syncthreads(); __syncthreads();
...@@ -282,7 +285,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -282,7 +285,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
b_thread_slice_copy_step); b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global, b_global_buf,
b_e_n_ho_wo_thread_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_odd_buf, b_thread_odd_buf,
...@@ -298,7 +301,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -298,7 +301,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
b_thread_slice_copy_step); b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global, b_global_buf,
b_e_n_ho_wo_thread_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_even_buf, b_thread_even_buf,
...@@ -321,7 +324,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -321,7 +324,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
b_thread_slice_copy_step); b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
p_b_global, b_global_buf,
b_e_n_ho_wo_thread_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_odd_buf, b_thread_odd_buf,
...@@ -358,8 +361,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -358,8 +361,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>( true>(
...@@ -370,7 +371,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -370,7 +371,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
c_k_n_ho_wo_global_desc, c_k_n_ho_wo_global_desc,
p_c_global, c_global_buf,
c_k_n_ho_wo_global_tensor_iterator_hacks); c_k_n_ho_wo_global_tensor_iterator_hacks);
} }
} }
......
...@@ -323,7 +323,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -323,7 +323,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return __llvm_amdgcn_raw_buffer_load_i8x2( return __llvm_amdgcn_raw_buffer_load_i8x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
#else #else
...@@ -335,7 +335,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -335,7 +335,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return __llvm_amdgcn_raw_buffer_load_i8x4( return __llvm_amdgcn_raw_buffer_load_i8x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
#else #else
...@@ -347,7 +347,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -347,7 +347,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type<int8_t, 8> tmp; vector_type<int8_t, 8> tmp;
tmp.AsType<int8x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4( tmp.AsType<int8x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4(
...@@ -369,7 +369,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, ...@@ -369,7 +369,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
} }
else if constexpr(N == 16) else if constexpr(N == 16)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type<int8_t, 16> tmp; vector_type<int8_t, 16> tmp;
tmp.AsType<int8x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4( tmp.AsType<int8x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4(
...@@ -483,7 +483,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type ...@@ -483,7 +483,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
__llvm_amdgcn_raw_buffer_store_i8x2(src_thread_data, __llvm_amdgcn_raw_buffer_store_i8x2(src_thread_data,
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
...@@ -499,7 +499,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type ...@@ -499,7 +499,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
#if !CK_WORKAROUND_SWDEV_XXXXXX #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
__llvm_amdgcn_raw_buffer_store_i8x4(src_thread_data, __llvm_amdgcn_raw_buffer_store_i8x4(src_thread_data,
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
......
#ifndef CK_BUFFER_HPP
#define CK_BUFFER_HPP
#include "statically_indexed_array.hpp"
namespace ck {
template <typename T, index_t N>
struct StaticBuffer : public StaticallyIndexedArray<T, N>
{
using type = T;
using base = StaticallyIndexedArray<T, N>;
__host__ __device__ constexpr StaticBuffer() : base{} {}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
};
template <typename T, index_t N>
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
{
return StaticBuffer<T, N>{};
}
template <typename T>
struct DynamicBuffer
{
using type = T;
T* p_data_;
__host__ __device__ constexpr DynamicBuffer(T* p_data) : p_data_{p_data} {}
__host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; }
__host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }
template <typename X,
typename std::enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false>
__host__ __device__ constexpr const auto Get(index_t i) const
{
return *reinterpret_cast<const X*>(&p_data_[i]);
}
template <typename X,
typename std::enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false>
__host__ __device__ void Set(index_t i, const X& x)
{
*reinterpret_cast<X*>(&p_data_[i]) = x;
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return false; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
};
template <typename T>
__host__ __device__ constexpr auto make_dynamic_buffer(T* p)
{
return DynamicBuffer<T>{p};
}
} // namespace ck
#endif
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
#include "container_element_picker.hpp" #include "container_element_picker.hpp"
#include "data_type.hpp" #include "data_type.hpp"
#include "float_type.hpp" #include "float_type.hpp"
#include "buffer.hpp"
#include "functional.hpp" #include "functional.hpp"
#include "functional2.hpp" #include "functional2.hpp"
#include "functional3.hpp" #include "functional3.hpp"
...@@ -25,6 +24,8 @@ ...@@ -25,6 +24,8 @@
#include "type.hpp" #include "type.hpp"
#include "utility.hpp" #include "utility.hpp"
#include "magic_division.hpp" #include "magic_division.hpp"
#include "static_buffer.hpp"
#include "dynamic_buffer.hpp"
#if CK_USE_AMD_INLINE_ASM #if CK_USE_AMD_INLINE_ASM
#include "amd_inline_asm.hpp" #include "amd_inline_asm.hpp"
......
...@@ -143,8 +143,13 @@ ...@@ -143,8 +143,13 @@
#endif #endif
// workaround for compiler crash when using buffer load/store for i8 // workaround for compiler crash when using buffer load/store for i8
#ifndef CK_WORKAROUND_SWDEV_XXXXXX #ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
#define CK_WORKAROUND_SWDEV_XXXXXX 1 #define CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE 1
#endif
// workaround for compiler crash when using buffer load/store for i8
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif #endif
namespace ck { namespace ck {
...@@ -154,6 +159,7 @@ enum AddressSpace ...@@ -154,6 +159,7 @@ enum AddressSpace
Generic, Generic,
Global, Global,
Lds, Lds,
Sgpr,
Vgpr Vgpr
}; };
......
#ifndef CK_DYNAMIC_BUFFER_HPP
#define CK_DYNAMIC_BUFFER_HPP
namespace ck {
#include "amd_buffer_addressing_v2.hpp"
template <AddressSpace BufferAddressSpace, typename T, typename ElementSpaceSize>
struct DynamicBuffer
{
using type = T;
T* p_data_;
ElementSpaceSize element_space_size_;
__host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
: p_data_{p_data}, element_space_size_{element_space_size}
{
}
__host__ __device__ static constexpr AddressSpace GetAddressSpace()
{
return BufferAddressSpace;
}
__host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; }
__host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }
template <typename X,
typename std::enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false>
__host__ __device__ constexpr const auto Get(index_t i, bool is_valid_offset) const
{
// X contains multiple T
constexpr index_t scalar_per_t_vector =
scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;
constexpr index_t scalar_per_x_vector =
scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T");
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
if constexpr(GetAddressSpace() == AddressSpace::Global)
{
#if CK_USE_AMD_BUFFER_ADDRESSING
return amd_buffer_load_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>(
p_data_, i, is_valid_offset, element_space_size_);
#else
return is_valid_offset ? *reinterpret_cast<const X*>(&p_data_[i]) : X{0};
#endif
}
else
{
return is_valid_offset ? *reinterpret_cast<const X*>(&p_data_[i]) : X{0};
}
}
template <typename X,
typename std::enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false>
__host__ __device__ void Set(index_t i, bool is_valid_offset, const X& x)
{
// X contains multiple T
constexpr index_t scalar_per_t_vector =
scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;
constexpr index_t scalar_per_x_vector =
scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T");
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
if constexpr(GetAddressSpace() == AddressSpace::Global)
{
#if CK_USE_AMD_BUFFER_ADDRESSING
amd_buffer_store_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>(
x, p_data_, i, is_valid_offset, element_space_size_);
#else
if(is_valid_offset)
{
*reinterpret_cast<X*>(&p_data_[i]) = x;
}
#endif
}
else if constexpr(GetAddressSpace() == AddressSpace::Lds)
{
if(is_valid_offset)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
*reinterpret_cast<X*>(&p_data_[i]) = x;
#else
// HACK: compiler would lower IR "store<i8, 16> address_space(3)" into inefficient
// ISA, so I try to let compiler emit use IR "store<i32, 4>" which would be lower to
// ds_write_b128
// TODO: remove this after compiler fix
if constexpr(is_same<typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type,
int8_t>::value)
{
static_assert(
(is_same<remove_cv_t<remove_reference_t<T>>, int8x4_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) ||
(is_same<remove_cv_t<remove_reference_t<T>>, int8x8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value) ||
(is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value),
"wrong! not implemented for this combination, please add implementation");
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8x4_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*reinterpret_cast<int32_t*>(&p_data_[i]) =
*reinterpret_cast<const int32_t*>(&x);
}
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8x8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*reinterpret_cast<int32x2_t*>(&p_data_[i]) =
*reinterpret_cast<const int32x2_t*>(&x);
}
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*reinterpret_cast<int32x4_t*>(&p_data_[i]) =
*reinterpret_cast<const int32x4_t*>(&x);
}
}
else
{
*reinterpret_cast<X*>(&p_data_[i]) = x;
}
#endif
}
}
else
{
if(is_valid_offset)
{
*reinterpret_cast<X*>(&p_data_[i]) = x;
}
}
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return false; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
};
template <AddressSpace BufferAddressSpace = AddressSpace::Generic,
typename T,
typename ElementSpaceSize>
__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
{
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize>{p, element_space_size};
}
} // namespace ck
#endif
#ifndef CK_STATIC_BUFFER_HPP
#define CK_STATIC_BUFFER_HPP
#include "statically_indexed_array.hpp"
namespace ck {
template <AddressSpace BufferAddressSpace, typename T, index_t N>
struct StaticBuffer : public StaticallyIndexedArray<T, N>
{
using type = T;
using base = StaticallyIndexedArray<T, N>;
__host__ __device__ constexpr StaticBuffer() : base{} {}
__host__ __device__ static constexpr AddressSpace GetAddressSpace()
{
return BufferAddressSpace;
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
};
template <AddressSpace BufferAddressSpace = AddressSpace::Generic, typename T, index_t N>
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
{
return StaticBuffer<BufferAddressSpace, T, N>{};
}
} // namespace ck
#endif
...@@ -63,12 +63,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -63,12 +63,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
#if 0 #if 0
// run-time variables // run-time variables
const auto in_n_c_hi_wi_desc = const auto in_n_c0_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths())); make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, C0, Hi, Wi));
const auto wei_k_c_y_x_desc = const auto wei_k_c0_y_x_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(WeiDesc::GetLengths())); make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(K, C0, Y, X));
const auto out_n_k_ho_wo_desc = const auto out_n_k0_ho_wo_k1_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(OutDesc::GetLengths())); make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, K0, Ho, Wo, K1));
const auto conv_strides = to_multi_index(ConvStrides{}); const auto conv_strides = to_multi_index(ConvStrides{});
const auto conv_dilations = to_multi_index(ConvDilations{}); const auto conv_dilations = to_multi_index(ConvDilations{});
......
...@@ -48,8 +48,8 @@ int main(int argc, char* argv[]) ...@@ -48,8 +48,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
...@@ -62,8 +62,8 @@ int main(int argc, char* argv[]) ...@@ -62,8 +62,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
...@@ -92,7 +92,7 @@ int main(int argc, char* argv[]) ...@@ -92,7 +92,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 1 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
constexpr index_t HI = 540; constexpr index_t HI = 540;
...@@ -630,7 +630,7 @@ int main(int argc, char* argv[]) ...@@ -630,7 +630,7 @@ int main(int argc, char* argv[])
print_array("ConvStrides", to_multi_index(ConvStrides{})); print_array("ConvStrides", to_multi_index(ConvStrides{}));
print_array("ConvDilations", to_multi_index(ConvDilations{})); print_array("ConvDilations", to_multi_index(ConvDilations{}));
#if 0 #if 1
using in_data_t = float; using in_data_t = float;
constexpr index_t in_vector_size = 1; constexpr index_t in_vector_size = 1;
using acc_data_t = float; using acc_data_t = float;
...@@ -740,7 +740,7 @@ int main(int argc, char* argv[]) ...@@ -740,7 +740,7 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 0 #elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t, device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
in_vector_size, in_vector_size,
acc_data_t, acc_data_t,
......
...@@ -10,13 +10,14 @@ cmake ...@@ -10,13 +10,14 @@ cmake
-D CMAKE_INSTALL_PREFIX=${MY_PROJECT_INSTALL} \ -D CMAKE_INSTALL_PREFIX=${MY_PROJECT_INSTALL} \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D DEVICE_BACKEND="AMD" \ -D DEVICE_BACKEND="AMD" \
-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx1030 -gline-tables-only -save-temps=$CWD -ftemplate-backtrace-limit=0" \ -D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx1030 -gline-tables-only -save-temps=$CWD -ftemplate-backtrace-limit=0" \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_PREFIX_PATH="/opt/rocm" \ -D CMAKE_PREFIX_PATH="/opt/rocm" \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
${MY_PROJECT_SOURCE} ${MY_PROJECT_SOURCE}
#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -save-temps=$CWD" \ #-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx1030 -gline-tables-only -save-temps=$CWD -ftemplate-backtrace-limit=0" \
#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx1030 -gline-tables-only -save-temps=$CWD -ftemplate-backtrace-limit=0 -mllvm -print-before=amdgpu-codegenprepare -mllvm -print-module-scope" \
#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -gline-tables-only -save-temps=$CWD" \ #-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -gline-tables-only -save-temps=$CWD" \
#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-spill-vgpr-to-agpr=0" \ #-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-spill-vgpr-to-agpr=0" \
#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -save-temps=$CWD" \ #-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -save-temps=$CWD" \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment