Commit 61ac0866 authored by Chao Liu's avatar Chao Liu
Browse files

tune for 1x1

parent abf75ac0
...@@ -391,7 +391,7 @@ int main() ...@@ -391,7 +391,7 @@ int main()
constexpr unsigned HPad = 0; constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0; constexpr unsigned WPad = 0;
#elif 1 #elif 0
// 3x3, 34x34 // 3x3, 34x34
constexpr unsigned N = 64; constexpr unsigned N = 64;
constexpr unsigned C = 256; constexpr unsigned C = 256;
...@@ -490,7 +490,7 @@ int main() ...@@ -490,7 +490,7 @@ int main()
constexpr unsigned HPad = 1; constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1; constexpr unsigned WPad = 1;
#elif 0 #elif 1
// 1x1 filter, 28x28 image // 1x1 filter, 28x28 image
constexpr unsigned N = 16; constexpr unsigned N = 16;
constexpr unsigned C = 256; constexpr unsigned C = 256;
...@@ -582,7 +582,7 @@ int main() ...@@ -582,7 +582,7 @@ int main()
wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#endif #endif
unsigned nrepeat = 100; unsigned nrepeat = 200;
#if 1 #if 1
#if 0 #if 0
...@@ -593,11 +593,11 @@ int main() ...@@ -593,11 +593,11 @@ int main()
device_implicit_gemm_convolution_1_nchw_kcsr device_implicit_gemm_convolution_1_nchw_kcsr
#elif 0 #elif 0
device_implicit_gemm_convolution_1_nchw_srck_nkhw device_implicit_gemm_convolution_1_nchw_srck_nkhw
#elif 1 #elif 0
device_implicit_gemm_convolution_1_chwn_csrk_khwn device_implicit_gemm_convolution_1_chwn_csrk_khwn
#elif 0 #elif 0
device_implicit_gemm_convolution_2_cnhw_srck_knhw device_implicit_gemm_convolution_2_cnhw_srck_knhw
#elif 0 #elif 1
device_implicit_gemm_convolution_2_cnhw_csrk_knhw device_implicit_gemm_convolution_2_cnhw_csrk_knhw
#endif #endif
(in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device, nrepeat); (in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device, nrepeat);
......
...@@ -67,7 +67,7 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc, ...@@ -67,7 +67,7 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
Tensor<T> out_knhw(make_TensorDescriptor(out_knhw_desc)); Tensor<T> out_knhw(make_TensorDescriptor(out_knhw_desc));
#if 1 #if 0
// 3x3, 34x34 // 3x3, 34x34
constexpr unsigned BPerBlock = 128; constexpr unsigned BPerBlock = 128;
constexpr unsigned KPerBlock = 64; constexpr unsigned KPerBlock = 64;
...@@ -78,9 +78,9 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc, ...@@ -78,9 +78,9 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
constexpr unsigned GemmMPerThreadSubC = 4; constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4; constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 8; constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2; constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 1; constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 8; constexpr unsigned GemmNLevel1Cluster = 8;
constexpr unsigned GemmKPerThreadLoop = 1; constexpr unsigned GemmKPerThreadLoop = 1;
...@@ -98,7 +98,7 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc, ...@@ -98,7 +98,7 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
constexpr unsigned BlockSize = 128; constexpr unsigned BlockSize = 128;
#elif 0 #elif 0
// 1x1, 28x28 // 1x1, 28x28, 64 threads
constexpr unsigned BPerBlock = 64; constexpr unsigned BPerBlock = 64;
constexpr unsigned KPerBlock = 64; constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 8; constexpr unsigned CPerBlock = 8;
...@@ -108,9 +108,9 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc, ...@@ -108,9 +108,9 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
constexpr unsigned GemmMPerThreadSubC = 4; constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4; constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 8; constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2; constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 1; constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4; constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1; constexpr unsigned GemmKPerThreadLoop = 1;
...@@ -128,7 +128,37 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc, ...@@ -128,7 +128,37 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
constexpr unsigned BlockSize = 64; constexpr unsigned BlockSize = 64;
#elif 1 #elif 1
// 1x1, 28x28 try // 1x1, 28x28, 128 threads
constexpr unsigned BPerBlock = 64;
constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 8;
constexpr unsigned BPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 4;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
#elif 1
// 1x1, 28x28, 256 thread
constexpr unsigned BPerBlock = 128; constexpr unsigned BPerBlock = 128;
constexpr unsigned KPerBlock = 128; constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 8; constexpr unsigned CPerBlock = 8;
...@@ -138,9 +168,9 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc, ...@@ -138,9 +168,9 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
constexpr unsigned GemmMPerThreadSubC = 4; constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4; constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 8; constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 4; constexpr unsigned GemmNLevel0Cluster = 4;
constexpr unsigned GemmMLevel1Cluster = 2; constexpr unsigned GemmMLevel1Cluster = 4;
constexpr unsigned GemmNLevel1Cluster = 4; constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1; constexpr unsigned GemmKPerThreadLoop = 1;
......
...@@ -70,7 +70,7 @@ template <unsigned BlockSize, ...@@ -70,7 +70,7 @@ template <unsigned BlockSize,
class F> class F>
__device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src( __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
SrcDesc, SrcDesc,
Float* const __restrict__ p_src, const Float* __restrict__ p_src,
DstDesc, DstDesc,
Float* __restrict__ p_dst, Float* __restrict__ p_dst,
SrcOpLengths, SrcOpLengths,
...@@ -149,7 +149,7 @@ template <unsigned BlockSize, ...@@ -149,7 +149,7 @@ template <unsigned BlockSize,
class DstFromSrcReorder> class DstFromSrcReorder>
__device__ void __device__ void
blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc, blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
Float* const __restrict__ p_src, const Float* __restrict__ p_src,
DstDesc, DstDesc,
Float* __restrict__ p_dst, Float* __restrict__ p_dst,
SrcOpLengths, SrcOpLengths,
...@@ -164,7 +164,7 @@ blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc, ...@@ -164,7 +164,7 @@ blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
template <unsigned BlockSize, class Float, class SrcDesc, class DstDesc, class SrcOpLengths> template <unsigned BlockSize, class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
struct Blockwise2dTensorCopy1 struct Blockwise2dTensorCopy1
{ {
__device__ void Run(Float* const __restrict__ p_src, Float* __restrict__ p_dst) const __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
{ {
constexpr auto dst_from_src_reorder = Sequence<0, 1>{}; constexpr auto dst_from_src_reorder = Sequence<0, 1>{};
...@@ -199,7 +199,7 @@ struct Blockwise2dTensorCopy2 ...@@ -199,7 +199,7 @@ struct Blockwise2dTensorCopy2
mThreadId1 = get_thread_local_1d_id() - mThreadId0 * ThreadPerDim1; mThreadId1 = get_thread_local_1d_id() - mThreadId0 * ThreadPerDim1;
} }
__device__ void Run(Float* const __restrict__ p_src, Float* __restrict__ p_dst) const __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
{ {
static_assert(is_same<Float, float>::value, "wrong! only support float!\n"); static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
...@@ -253,7 +253,7 @@ struct Blockwise2dTensorCopy2 ...@@ -253,7 +253,7 @@ struct Blockwise2dTensorCopy2
const unsigned dindex = dst_desc.Get1dIndex(did0, did1); const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
*(reinterpret_cast<Float4*>(p_dst + dindex)) = *(reinterpret_cast<Float4*>(p_dst + dindex)) =
*(reinterpret_cast<Float4*>(p_src + sindex)); *(reinterpret_cast<const Float4*>(p_src + sindex));
} }
// v2 // v2
...@@ -266,7 +266,7 @@ struct Blockwise2dTensorCopy2 ...@@ -266,7 +266,7 @@ struct Blockwise2dTensorCopy2
const unsigned dindex = dst_desc.Get1dIndex(did0, did1); const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
*(reinterpret_cast<Float2*>(p_dst + dindex)) = *(reinterpret_cast<Float2*>(p_dst + dindex)) =
*(reinterpret_cast<Float2*>(p_src + sindex)); *(reinterpret_cast<const Float2*>(p_src + sindex));
} }
// v1 // v1
...@@ -314,7 +314,7 @@ struct Blockwise2dTensorCopy2 ...@@ -314,7 +314,7 @@ struct Blockwise2dTensorCopy2
const unsigned dindex = dst_desc.Get1dIndex(did0, did1); const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
*(reinterpret_cast<Float4*>(p_dst + dindex)) = *(reinterpret_cast<Float4*>(p_dst + dindex)) =
*(reinterpret_cast<Float4*>(p_src + sindex)); *(reinterpret_cast<const Float4*>(p_src + sindex));
} }
// v2 // v2
...@@ -327,7 +327,7 @@ struct Blockwise2dTensorCopy2 ...@@ -327,7 +327,7 @@ struct Blockwise2dTensorCopy2
const unsigned dindex = dst_desc.Get1dIndex(did0, did1); const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
*(reinterpret_cast<Float2*>(p_dst + dindex)) = *(reinterpret_cast<Float2*>(p_dst + dindex)) =
*(reinterpret_cast<Float2*>(p_src + sindex)); *(reinterpret_cast<const Float2*>(p_src + sindex));
} }
// v1 // v1
...@@ -422,7 +422,7 @@ struct Blockwise2dTensorCopy3 ...@@ -422,7 +422,7 @@ struct Blockwise2dTensorCopy3
mDstMyThreadOffset = DstDesc{}.Get1dIndex(thread_id_d0, thread_id_d1 * DataPerRead); mDstMyThreadOffset = DstDesc{}.Get1dIndex(thread_id_d0, thread_id_d1 * DataPerRead);
} }
__device__ void Run(Float* const __restrict__ p_src, Float* __restrict__ p_dst) const __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
{ {
static_assert(is_same<Float, float>::value, "wrong! only support float!\n"); static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
...@@ -463,13 +463,13 @@ struct Blockwise2dTensorCopy3 ...@@ -463,13 +463,13 @@ struct Blockwise2dTensorCopy3
else if(DataPerRead == 2) else if(DataPerRead == 2)
{ {
*(reinterpret_cast<Float2*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) = *(reinterpret_cast<Float2*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
*(reinterpret_cast<Float2*>(p_src + mSrcMyThreadOffset + *(reinterpret_cast<const Float2*>(p_src + mSrcMyThreadOffset +
iloop * src_loop_stride)); iloop * src_loop_stride));
} }
else if(DataPerRead == 4) else if(DataPerRead == 4)
{ {
*(reinterpret_cast<Float4*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) = *(reinterpret_cast<Float4*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
*(reinterpret_cast<Float4*>(p_src + mSrcMyThreadOffset + *(reinterpret_cast<const Float4*>(p_src + mSrcMyThreadOffset +
iloop * src_loop_stride)); iloop * src_loop_stride));
} }
else else
...@@ -495,14 +495,14 @@ struct Blockwise2dTensorCopy3 ...@@ -495,14 +495,14 @@ struct Blockwise2dTensorCopy3
{ {
*(reinterpret_cast<Float2*>(p_dst + mDstMyThreadOffset + *(reinterpret_cast<Float2*>(p_dst + mDstMyThreadOffset +
nloop_d0 * dst_loop_stride)) = nloop_d0 * dst_loop_stride)) =
*(reinterpret_cast<Float2*>(p_src + mSrcMyThreadOffset + *(reinterpret_cast<const Float2*>(p_src + mSrcMyThreadOffset +
nloop_d0 * src_loop_stride)); nloop_d0 * src_loop_stride));
} }
else if(DataPerRead == 4) else if(DataPerRead == 4)
{ {
*(reinterpret_cast<Float4*>(p_dst + mDstMyThreadOffset + *(reinterpret_cast<Float4*>(p_dst + mDstMyThreadOffset +
nloop_d0 * dst_loop_stride)) = nloop_d0 * dst_loop_stride)) =
*(reinterpret_cast<Float4*>(p_src + mSrcMyThreadOffset + *(reinterpret_cast<const Float4*>(p_src + mSrcMyThreadOffset +
nloop_d0 * src_loop_stride)); nloop_d0 * src_loop_stride));
} }
else else
......
...@@ -29,8 +29,8 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC ...@@ -29,8 +29,8 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
__device__ Blockwise1dStridedBatchedGemmBlockABlockBThreadC() __device__ Blockwise1dStridedBatchedGemmBlockABlockBThreadC()
{ {
const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile constexpr auto a_block_mtx = BlockMatrixA{};
const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile constexpr auto b_block_mtx = BlockMatrixB{};
const auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id()); const auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
...@@ -66,8 +66,8 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC ...@@ -66,8 +66,8 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
if(TransA && (!TransB) && (!TransC)) if(TransA && (!TransB) && (!TransC))
{ {
const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile constexpr auto a_block_mtx = BlockMatrixA{};
const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile constexpr auto b_block_mtx = BlockMatrixB{};
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(), static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
"wrong! k dimension not consistent!"); "wrong! k dimension not consistent!");
...@@ -75,7 +75,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC ...@@ -75,7 +75,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
constexpr unsigned MPerBlock = a_block_mtx.NCol(); constexpr unsigned MPerBlock = a_block_mtx.NCol();
constexpr unsigned NPerBlock = b_block_mtx.NCol(); constexpr unsigned NPerBlock = b_block_mtx.NCol();
const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile constexpr auto c_thread_mtx = ThreadMatrixC{};
// divide thread work // divide thread work
constexpr unsigned MPerThread = c_thread_mtx.NRow(); constexpr unsigned MPerThread = c_thread_mtx.NRow();
...@@ -117,9 +117,9 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC ...@@ -117,9 +117,9 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
} }
template <class FloatA, class FloatB, class FloatC, class Accumulator> template <class FloatA, class FloatB, class FloatC, class Accumulator>
__device__ void Run(FloatA* const p_a_block, __device__ void Run(const FloatA* __restrict__ p_a_block,
FloatB* const p_b_block, const FloatB* __restrict__ p_b_block,
FloatC* p_c_thread, FloatC* __restrict__ p_c_thread,
Accumulator f_accum) const Accumulator f_accum) const
{ {
if(TransA && (!TransB) && (!TransC)) if(TransA && (!TransB) && (!TransC))
...@@ -243,8 +243,8 @@ struct BlockwiseGemmBlockABlockBThreadC ...@@ -243,8 +243,8 @@ struct BlockwiseGemmBlockABlockBThreadC
__device__ BlockwiseGemmBlockABlockBThreadC() __device__ BlockwiseGemmBlockABlockBThreadC()
{ {
const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile constexpr auto a_block_mtx = BlockMatrixA{};
const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile constexpr auto b_block_mtx = BlockMatrixB{};
const auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id()); const auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
...@@ -278,8 +278,8 @@ struct BlockwiseGemmBlockABlockBThreadC ...@@ -278,8 +278,8 @@ struct BlockwiseGemmBlockABlockBThreadC
if(TransA && (!TransB) && (!TransC)) if(TransA && (!TransB) && (!TransC))
{ {
constexpr auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile constexpr auto b_block_mtx = BlockMatrixB{};
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(), static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
"wrong! k dimension not consistent!"); "wrong! k dimension not consistent!");
...@@ -287,7 +287,7 @@ struct BlockwiseGemmBlockABlockBThreadC ...@@ -287,7 +287,7 @@ struct BlockwiseGemmBlockABlockBThreadC
constexpr unsigned MPerBlock = a_block_mtx.NCol(); constexpr unsigned MPerBlock = a_block_mtx.NCol();
constexpr unsigned NPerBlock = b_block_mtx.NCol(); constexpr unsigned NPerBlock = b_block_mtx.NCol();
constexpr auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile constexpr auto c_thread_mtx = ThreadMatrixC{};
// divide thread work // divide thread work
constexpr unsigned MPerThread = c_thread_mtx.NRow(); constexpr unsigned MPerThread = c_thread_mtx.NRow();
...@@ -367,9 +367,9 @@ struct BlockwiseGemmBlockABlockBThreadC ...@@ -367,9 +367,9 @@ struct BlockwiseGemmBlockABlockBThreadC
} }
template <class FloatA, class FloatB, class FloatC, class Accumulator> template <class FloatA, class FloatB, class FloatC, class Accumulator>
__device__ void Run(FloatA* const p_a_block, __device__ void Run(const FloatA* __restrict__ p_a_block,
FloatB* const p_b_block, const FloatB* __restrict__ p_b_block,
FloatC* p_c_thread, FloatC* __restrict__ p_c_thread,
Accumulator f_accum) const Accumulator f_accum) const
{ {
if(TransA && (!TransB) && (!TransC)) if(TransA && (!TransB) && (!TransC))
...@@ -459,9 +459,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -459,9 +459,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n"); static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile constexpr auto a_block_mtx = BlockMatrixA{};
const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile constexpr auto b_block_mtx = BlockMatrixB{};
const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile constexpr auto c_thread_mtx = ThreadMatrixC{};
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(), static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
"wrong! K dimension not consistent\n"); "wrong! K dimension not consistent\n");
...@@ -529,7 +529,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -529,7 +529,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(unsigned m_in_c, __device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(unsigned m_in_c,
unsigned n_in_c) unsigned n_in_c)
{ {
const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned MPerThread = c_thread_mtx.NRow(); constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol(); constexpr unsigned NPerThread = c_thread_mtx.NCol();
...@@ -551,17 +551,17 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -551,17 +551,17 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
} }
template <class FloatA, class FloatB, class FloatC, class Accumulator> template <class FloatA, class FloatB, class FloatC, class Accumulator>
__device__ void Run(FloatA* const p_a_block, __device__ void Run(const FloatA* __restrict__ p_a_block,
FloatB* const p_b_block, const FloatB* __restrict__ p_b_block,
FloatC* p_c_thread, FloatC* __restrict__ p_c_thread,
Accumulator f_accum) const Accumulator f_accum) const
{ {
constexpr auto True = integral_constant<bool, true>{}; constexpr auto True = integral_constant<bool, true>{};
constexpr auto False = integral_constant<bool, false>{}; constexpr auto False = integral_constant<bool, false>{};
const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile constexpr auto a_block_mtx = BlockMatrixA{};
const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile constexpr auto b_block_mtx = BlockMatrixB{};
const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned M = a_block_mtx.NCol(); constexpr unsigned M = a_block_mtx.NCol();
constexpr unsigned N = b_block_mtx.NCol(); constexpr unsigned N = b_block_mtx.NCol();
...@@ -571,22 +571,18 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -571,22 +571,18 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr unsigned NPerThread = c_thread_mtx.NCol(); constexpr unsigned NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM // thread A, B for GEMM
const auto a_thread_mtx = make_ConstantMatrixDescriptor( constexpr auto a_thread_mtx =
Number<KPerThreadLoop>{}, Number<MPerThread>{}); // constexpr doesn't compile make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
const auto b_thread_mtx = make_ConstantMatrixDescriptor( constexpr auto b_thread_mtx =
Number<KPerThreadLoop>{}, Number<NPerThread>{}); // constexpr doesn't compile make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
// thread A-sub, B-sub for copy // thread A-sub, B-sub for copy
const auto a_thread_sub_mtx = constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
Number<MPerThreadSubC>{},
Number<MPerThread>{}); // constexpr doesn't compile
const auto b_thread_sub_mtx = constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
Number<NPerThreadSubC>{},
Number<NPerThread>{}); // constexpr doesn't compile
FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
...@@ -606,12 +602,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -606,12 +602,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
// copy A-sub to form A // copy A-sub to form A
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ {
threadwise_matrix_copy(a_block_mtx, threadwise_matrix_copy(
p_a_block + mMyThreadOffsetA + a_block_mtx,
k_begin * a_block_mtx.RowStride() + p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) +
m_repeat * MPerLevel1Cluster, mMyThreadOffsetA,
a_thread_sub_mtx, a_thread_mtx,
p_a_thread + m_repeat * MPerThreadSubC, p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths()); a_thread_sub_mtx.GetLengths());
} }
...@@ -619,12 +615,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -619,12 +615,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
// copy B-sub to form B // copy B-sub to form B
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ {
threadwise_matrix_copy(b_block_mtx, threadwise_matrix_copy(
p_b_block + mMyThreadOffsetB + b_block_mtx,
k_begin * b_block_mtx.RowStride() + p_b_block + b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) +
n_repeat * NPerLevel1Cluster, mMyThreadOffsetB,
b_thread_sub_mtx, b_thread_mtx,
p_b_thread + n_repeat * NPerThreadSubC, p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths()); b_thread_sub_mtx.GetLengths());
} }
...@@ -778,4 +774,144 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -778,4 +774,144 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
f_accum); f_accum);
} }
} }
template <class FloatA, class FloatB, class FloatC, class Accumulator>
__device__ void Run_v2(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block,
FloatC* __restrict__ p_c_thread,
Accumulator f_accum) const
{
constexpr auto True = integral_constant<bool, true>{};
constexpr auto False = integral_constant<bool, false>{};
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned M = a_block_mtx.NCol();
constexpr unsigned N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
// thread A-sub, B-sub, C-sub
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
constexpr auto c_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
// thread A, B
constexpr auto a_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
constexpr auto b_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
#pragma unroll
// loop over k
for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
{
// C-sub(s) in first row-wise subblock of C
{
// copy first A-sub
threadwise_matrix_copy(a_block_mtx,
p_a_block + a_block_mtx.Get1dIndex(k_begin, 0) +
mMyThreadOffsetA,
a_thread_mtx,
p_a_thread,
a_thread_sub_mtx.GetLengths());
// copy first B-sub
threadwise_matrix_copy(b_block_mtx,
p_b_block + b_block_mtx.Get1dIndex(k_begin, 0) +
mMyThreadOffsetB,
b_thread_mtx,
p_b_thread,
b_thread_sub_mtx.GetLengths());
// do first sub GEMM
threadwise_gemm(a_thread_sub_mtx,
True,
p_a_thread,
b_thread_sub_mtx,
False,
p_b_thread,
c_thread_sub_mtx,
False,
p_c_thread,
f_accum);
#pragma unroll
// copy next B-sub, and do GEMM
for(unsigned n_repeat = 1; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
b_block_mtx,
p_b_block + b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) +
mMyThreadOffsetB,
b_thread_mtx,
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths());
threadwise_gemm(
a_thread_sub_mtx,
True,
p_a_thread,
b_thread_sub_mtx,
False,
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
c_thread_sub_mtx,
False,
p_c_thread + c_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
f_accum);
}
#pragma unroll
// loop over rest of row-wise subblock
// all B-sub(s) has been copied, so only A-sub(s) need to be copied
for(unsigned m_repeat = 1; m_repeat < MRepeat; ++m_repeat)
{
// copy a A-sub
threadwise_matrix_copy(
a_block_mtx,
p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) +
mMyThreadOffsetA,
a_thread_mtx,
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths());
// do some GEMMs
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_gemm(
a_thread_sub_mtx,
True,
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
b_thread_sub_mtx,
False,
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
c_thread_sub_mtx,
False,
p_c_thread + c_thread_mtx.Get1dIndex(m_repeat * MPerThreadSubC,
n_repeat * NPerThreadSubC),
f_accum);
}
}
}
}
}
}; };
...@@ -201,11 +201,11 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc, ...@@ -201,11 +201,11 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc,
// set threadwise output tensor to 0 // set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_hkwn_thread_desc, p_out_thread); threadwise_4d_tensor_set_zero(out_hkwn_thread_desc, p_out_thread);
Float* p_in_global_block_begin = const Float* p_in_global_block_begin =
p_in_global + in_chwn_global_desc.Get1dIndex( p_in_global + in_chwn_global_desc.Get1dIndex(
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin); 0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
Float* p_wei_global_block_begin = const Float* p_wei_global_block_begin =
p_wei_global + wei_csrk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); p_wei_global + wei_csrk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
...@@ -213,15 +213,11 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc, ...@@ -213,15 +213,11 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc,
p_wei_global_block_begin += CPerBlock * wei_csrk_global_desc.GetStride(I0), p_wei_global_block_begin += CPerBlock * wei_csrk_global_desc.GetStride(I0),
__syncthreads()) __syncthreads())
{ {
#if 1
// input: global mem to LDS, // input: global mem to LDS,
blockwise_in_copy.Run(p_in_global_block_begin, p_in_block); blockwise_in_copy.Run(p_in_global_block_begin, p_in_block);
#endif
#if 1
// weight: global mem to LDS, // weight: global mem to LDS,
blockwise_wei_copy.Run(p_wei_global_block_begin, p_wei_block); blockwise_wei_copy.Run(p_wei_global_block_begin, p_wei_block);
#endif
__syncthreads(); __syncthreads();
......
...@@ -36,11 +36,11 @@ template <unsigned GridSize, ...@@ -36,11 +36,11 @@ template <unsigned GridSize,
unsigned WeiBlockCopyDataPerRead> unsigned WeiBlockCopyDataPerRead>
__global__ void __global__ void
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc, gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
Float* const __restrict__ p_in_global, const Float* const __restrict__ p_in_global,
WeiGlobalDesc, WeiGlobalDesc,
Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
OutGlobalDesc, OutGlobalDesc,
Float* __restrict__ p_out_global) Float* const __restrict__ p_out_global)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -228,10 +228,10 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc, ...@@ -228,10 +228,10 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
// set threadwise output tensor to 0 // set threadwise output tensor to 0
threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread); threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread);
Float* p_in_global_block_offset = const Float* p_in_global_block_offset =
p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin); p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin);
Float* p_wei_global_block_offset = const Float* p_wei_global_block_offset =
p_wei_global + wei_csrk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); p_wei_global + wei_csrk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
...@@ -256,7 +256,9 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc, ...@@ -256,7 +256,9 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
#if 1 #if 1
blockwise_gemm.Run blockwise_gemm.Run
#else #elif 0
blockwise_gemm.Run_v2
#elif 0
blockwise_gemm.Run_RegisterDoubleBuffer blockwise_gemm.Run_RegisterDoubleBuffer
#endif #endif
(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0), (p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0),
......
...@@ -123,7 +123,7 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_b ...@@ -123,7 +123,7 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_b
decltype(in_cb_global_desc), decltype(in_cb_global_desc),
decltype(in_cb_block_desc), decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths())>{}; decltype(in_cb_block_desc.GetLengths())>{};
#elif 1 #elif 0
const auto blockwise_in_copy = Blockwise2dTensorCopy2<BlockSize, const auto blockwise_in_copy = Blockwise2dTensorCopy2<BlockSize,
Float, Float,
decltype(in_cb_global_desc), decltype(in_cb_global_desc),
...@@ -149,7 +149,7 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_b ...@@ -149,7 +149,7 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_b
decltype(wei_ek_global_desc), decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc), decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths())>{}; decltype(wei_ek_block_desc.GetLengths())>{};
#elif 1 #elif 0
const auto blockwise_wei_copy = Blockwise2dTensorCopy2<BlockSize, const auto blockwise_wei_copy = Blockwise2dTensorCopy2<BlockSize,
Float, Float,
decltype(wei_ek_global_desc), decltype(wei_ek_global_desc),
...@@ -226,10 +226,10 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_b ...@@ -226,10 +226,10 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_b
__shared__ Float p_in_block_1[max_align * ((in_block_size + max_align - 1) / max_align)]; __shared__ Float p_in_block_1[max_align * ((in_block_size + max_align - 1) / max_align)];
__shared__ Float p_wei_block_1[max_align * ((wei_block_size + max_align - 1) / max_align)]; __shared__ Float p_wei_block_1[max_align * ((wei_block_size + max_align - 1) / max_align)];
Float* p_in_global_block_offset = const Float* p_in_global_block_offset =
p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin); p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin);
Float* p_wei_global_block_offset = const Float* p_wei_global_block_offset =
p_wei_global + wei_csrk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); p_wei_global + wei_csrk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
// preload data into LDS // preload data into LDS
...@@ -272,7 +272,7 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_b ...@@ -272,7 +272,7 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_b
for(unsigned r = 0; r < R; ++r) for(unsigned r = 0; r < R; ++r)
{ {
auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
#if 0 #if 1
blockwise_gemm.Run blockwise_gemm.Run
#else #else
blockwise_gemm.Run_RegisterDoubleBuffer blockwise_gemm.Run_RegisterDoubleBuffer
......
#pragma once #pragma once
template <class Float, class SrcMatrix, class DstMatrix, unsigned NRow, unsigned NCol> template <class Float, class SrcMatrix, class DstMatrix, unsigned NRow, unsigned NCol>
__device__ void __device__ void threadwise_matrix_copy(SrcMatrix,
threadwise_matrix_copy(SrcMatrix, Float* const p_src, DstMatrix, Float* p_dst, Sequence<NRow, NCol>) const Float* __restrict__ p_src,
DstMatrix,
Float* __restrict__ p_dst,
Sequence<NRow, NCol>)
{ {
const auto src_mtx = SrcMatrix{}; // constexpr doesn't compile constexpr auto src_mtx = SrcMatrix{};
const auto dst_mtx = DstMatrix{}; // constexpr doesn't compile constexpr auto dst_mtx = DstMatrix{};
for(unsigned i = 0; i < NRow; ++i) for(unsigned i = 0; i < NRow; ++i)
{ {
...@@ -31,30 +34,30 @@ template <class MatrixA, ...@@ -31,30 +34,30 @@ template <class MatrixA,
class Accumulator> class Accumulator>
__device__ void threadwise_gemm(MatrixA, __device__ void threadwise_gemm(MatrixA,
integral_constant<bool, TransA>, integral_constant<bool, TransA>,
FloatA* const p_a_thread, const FloatA* __restrict__ p_a_thread,
MatrixB, MatrixB,
integral_constant<bool, TransB>, integral_constant<bool, TransB>,
FloatB* const p_b_thread, const FloatB* __restrict__ p_b_thread,
MatrixC, MatrixC,
integral_constant<bool, TransC>, integral_constant<bool, TransC>,
FloatC* p_c_thread, FloatC* __restrict__ p_c_thread,
Accumulator f_accum) Accumulator f_accum)
{ {
if(TransA && (!TransB) && (!TransC)) if(TransA && (!TransB) && (!TransC))
{ {
const auto a_mtx = MatrixA{}; // constexpr doesn't compile constexpr auto a_mtx = MatrixA{};
const auto b_mtx = MatrixB{}; // constexpr doesn't compile constexpr auto b_mtx = MatrixB{};
const auto c_mtx = MatrixC{}; // constexpr doesn't compile constexpr auto c_mtx = MatrixC{};
constexpr unsigned M = c_mtx.NRow(); constexpr unsigned M = c_mtx.NRow();
constexpr unsigned N = c_mtx.NCol(); constexpr unsigned N = c_mtx.NCol();
constexpr unsigned K = a_mtx.NRow(); // A is transposed constexpr unsigned K = a_mtx.NRow(); // A is transposed
for(unsigned k = 0; k < K; ++k)
{
for(unsigned i = 0; i < M; ++i) for(unsigned i = 0; i < M; ++i)
{ {
for(unsigned j = 0; j < N; ++j) for(unsigned j = 0; j < N; ++j)
{
for(unsigned k = 0; k < K; ++k)
{ {
const unsigned aindex = a_mtx.Get1dIndex(k, i); // A is transposed const unsigned aindex = a_mtx.Get1dIndex(k, i); // A is transposed
const unsigned bindex = b_mtx.Get1dIndex(k, j); const unsigned bindex = b_mtx.Get1dIndex(k, j);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment