"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "22114959da03cab6aedfdf43bdb88e8a9bba75a1"
Commit 17f3d2d4 authored by Chao Liu's avatar Chao Liu
Browse files

refactor ConstantTensorDescriptor and functional

parent a2cf803c
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "device_direct_convolution_2_nchw_kcyx_nkhw.hpp" #include "device_direct_convolution_2_nchw_kcyx_nkhw.hpp"
//#include "device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp" //#include "device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp" #include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp"
#include "device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp"
//#include "device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp" //#include "device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp"
#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp" #include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp"
...@@ -48,13 +49,10 @@ struct GeneratorTensor_3 ...@@ -48,13 +49,10 @@ struct GeneratorTensor_3
#if 0 #if 0
auto f_acc = std::plus<index_t>{}; auto f_acc = std::plus<index_t>{};
#else #else
auto f_acc = [](auto a, auto b){ return 10*a + b;}; auto f_acc = [](auto a, auto b) { return 10 * a + b; };
#endif #endif
return std::accumulate(dims.begin(), return std::accumulate(dims.begin(), dims.end(), index_t(0), f_acc);
dims.end(),
index_t(0),
f_acc);
} }
}; };
...@@ -505,7 +503,7 @@ int main(int argc, char* argv[]) ...@@ -505,7 +503,7 @@ int main(int argc, char* argv[])
constexpr index_t C = 256; constexpr index_t C = 256;
constexpr index_t HI = 28; constexpr index_t HI = 28;
constexpr index_t WI = 28; constexpr index_t WI = 28;
constexpr index_t K = 512; constexpr index_t K = 128;
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
...@@ -666,6 +664,8 @@ int main(int argc, char* argv[]) ...@@ -666,6 +664,8 @@ int main(int argc, char* argv[])
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif 1 #elif 1
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
#elif 0
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
#elif 0 #elif 0
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
#endif #endif
......
...@@ -14,5 +14,7 @@ struct Array ...@@ -14,5 +14,7 @@ struct Array
{ {
} }
__host__ __device__ TData operator[](index_t i) const { return mData[i]; } __host__ __device__ const TData& operator[](index_t i) const { return mData[i]; }
__host__ __device__ TData& operator[](index_t i) { return mData[i]; }
}; };
...@@ -115,46 +115,27 @@ struct ConstantTensorDescriptor ...@@ -115,46 +115,27 @@ struct ConstantTensorDescriptor
static_assert(Lengths::nDim == Strides::nDim, "nDim not consistent"); static_assert(Lengths::nDim == Strides::nDim, "nDim not consistent");
} }
__host__ __device__ constexpr index_t GetDimension() const { return nDim; } __host__ __device__ static constexpr index_t GetDimension() { return nDim; }
__host__ __device__ constexpr Lengths GetLengths() const { return Lengths{}; } __host__ __device__ static constexpr Lengths GetLengths() { return Lengths{}; }
__host__ __device__ constexpr Strides GetStrides() const { return Strides{}; } __host__ __device__ static constexpr Strides GetStrides() { return Strides{}; }
template <index_t I> template <index_t I>
__host__ __device__ constexpr index_t GetLength(Number<I>) const __host__ __device__ static constexpr index_t GetLength(Number<I>)
{ {
return Lengths{}.Get(Number<I>{}); return Lengths{}.Get(Number<I>{});
} }
template <index_t I> template <index_t I>
__host__ __device__ constexpr index_t GetStride(Number<I>) const __host__ __device__ static constexpr index_t GetStride(Number<I>)
{ {
return Strides{}.Get(Number<I>{}); return Strides{}.Get(Number<I>{});
} }
// c++14 doesn't support constexpr lambdas, has to use this trick instead __host__ __device__ static constexpr index_t GetElementSize()
struct GetElementSize_f
{
template <class IDim>
__host__ __device__ constexpr index_t operator()(IDim idim) const
{ {
return Type{}.GetLength(idim); return accumulate_on_sequence(Lengths{}, mod_conv::multiplies<index_t>{}, Number<1>{});
}
};
__host__ __device__ constexpr index_t GetElementSize() const
{
// c++14 doesn't support constexpr lambdas, has to use this trick instead
struct multiply
{
__host__ __device__ constexpr index_t operator()(index_t a, index_t b) const
{
return a * b;
}
};
return static_const_reduce_n<nDim>{}(GetElementSize_f{}, multiply{});
} }
// c++14 doesn't support constexpr lambdas, has to use this trick instead // c++14 doesn't support constexpr lambdas, has to use this trick instead
...@@ -168,25 +149,16 @@ struct ConstantTensorDescriptor ...@@ -168,25 +149,16 @@ struct ConstantTensorDescriptor
}; };
template <class Align = Number<1>> template <class Align = Number<1>>
__host__ __device__ constexpr index_t GetElementSpace(Align align = Align{}) const __host__ __device__ static constexpr index_t GetElementSpace(Align align = Align{})
{ {
// c++14 doesn't support constexpr lambdas, has to use this trick instead
struct add
{
__host__ __device__ constexpr index_t operator()(index_t a, index_t b) const
{
return a + b;
}
};
index_t element_space_unaligned = index_t element_space_unaligned =
static_const_reduce_n<nDim>{}(GetElementSpace_f{}, add{}) + 1; static_const_reduce_n<nDim>{}(GetElementSpace_f{}, mod_conv::plus<index_t>{}) + 1;
return align.Get() * ((element_space_unaligned + align.Get() - 1) / align.Get()); return align.Get() * ((element_space_unaligned + align.Get() - 1) / align.Get());
} }
template <class... Is> template <class... Is>
__host__ __device__ index_t Get1dIndex(Is... is) const __host__ __device__ static index_t Get1dIndex(Is... is)
{ {
static_assert(sizeof...(Is) == nDim, "number of multi-index is wrong"); static_assert(sizeof...(Is) == nDim, "number of multi-index is wrong");
...@@ -194,7 +166,7 @@ struct ConstantTensorDescriptor ...@@ -194,7 +166,7 @@ struct ConstantTensorDescriptor
index_t id = 0; index_t id = 0;
static_loop_n<nDim>{}([&](auto IDim) { static_for<0, nDim, 1>{}([&](auto IDim) {
constexpr index_t idim = IDim.Get(); constexpr index_t idim = IDim.Get();
#if DEVICE_BACKEND_HIP #if DEVICE_BACKEND_HIP
id += __mul24(multi_id[idim], GetStride(IDim)); id += __mul24(multi_id[idim], GetStride(IDim));
...@@ -206,16 +178,25 @@ struct ConstantTensorDescriptor ...@@ -206,16 +178,25 @@ struct ConstantTensorDescriptor
return id; return id;
} }
__host__ __device__ constexpr auto Condense() const __host__ __device__ static Array<index_t, nDim> GetMultiIndex(index_t id)
{ {
constexpr auto default_strides = calculate_default_strides(Lengths{}); Array<index_t, nDim> multi_id;
return ConstantTensorDescriptor<Lengths, decltype(default_strides)>{};
static_for<0, nDim - 1, 1>{}([&](auto IDim) {
constexpr index_t idim = IDim.Get();
multi_id[idim] = id / GetStride(IDim);
id -= multi_id[idim] * GetStride(IDim);
});
multi_id[nDim - 1] = id / GetStride(Number<nDim - 1>{});
return multi_id;
} }
template <index_t IDim, index_t NVector> __host__ __device__ static constexpr auto Condense()
__host__ __device__ constexpr auto Vectorize(Number<IDim>, Number<NVector>) const
{ {
assert(false); // not implemented constexpr auto default_strides = calculate_default_strides(Lengths{});
return ConstantTensorDescriptor<Lengths, decltype(default_strides)>{};
} }
}; };
......
...@@ -17,6 +17,8 @@ struct Sequence ...@@ -17,6 +17,8 @@ struct Sequence
return mData[I]; return mData[I];
} }
__host__ __device__ index_t operator[](index_t i) const { return mData[i]; }
// this is ugly, only for nDIm = 4 // this is ugly, only for nDIm = 4
template <index_t I0, index_t I1, index_t I2, index_t I3> template <index_t I0, index_t I1, index_t I2, index_t I3>
__host__ __device__ constexpr auto ReorderByGetNewFromOld(Sequence<I0, I1, I2, I3>) const __host__ __device__ constexpr auto ReorderByGetNewFromOld(Sequence<I0, I1, I2, I3>) const
...@@ -90,3 +92,21 @@ __host__ __device__ constexpr auto Sequence<Is...>::PopBack() const ...@@ -90,3 +92,21 @@ __host__ __device__ constexpr auto Sequence<Is...>::PopBack() const
{ {
return sequence_pop_back(Type{}); return sequence_pop_back(Type{});
} }
template <class Seq>
struct accumulate_on_sequence_f
{
template <class IDim>
__host__ __device__ constexpr index_t operator()(IDim) const
{
return Seq{}.Get(IDim{});
}
};
template <class Seq, class Reduce, index_t I>
__host__ __device__ constexpr index_t accumulate_on_sequence(Seq, Reduce, Number<I>)
{
constexpr index_t a =
static_const_reduce_n<Seq::nDim>{}(accumulate_on_sequence_f<Seq>{}, Reduce{});
return Reduce{}(a, I);
}
...@@ -211,8 +211,7 @@ struct Blockwise2dTensorCopy1 ...@@ -211,8 +211,7 @@ struct Blockwise2dTensorCopy1
constexpr index_t read_per_d1 = integer_divide_ceil(L1, DataPerRead); constexpr index_t read_per_d1 = integer_divide_ceil(L1, DataPerRead);
constexpr auto ref_desc = constexpr auto ref_desc = make_ConstantTensorDescriptor(Sequence<L0, read_per_d1>{});
make_ConstantTensorDescriptor(Sequence<L0, read_per_d1>{});
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize; constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
...@@ -225,10 +224,8 @@ struct Blockwise2dTensorCopy1 ...@@ -225,10 +224,8 @@ struct Blockwise2dTensorCopy1
did[1] = is / ref_desc.GetStride(I1); did[1] = is / ref_desc.GetStride(I1);
const index_t src_index = const index_t src_index = src_desc.Get1dIndex(did[0], did[1] * DataPerRead);
src_desc.Get1dIndex(did[0], did[1] * DataPerRead); const index_t dst_index = dst_desc.Get1dIndex(did[0], did[1] * DataPerRead);
const index_t dst_index =
dst_desc.Get1dIndex(did[0], did[1] * DataPerRead);
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) = *(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
*(reinterpret_cast<const vector_t*>(p_src + src_index)); *(reinterpret_cast<const vector_t*>(p_src + src_index));
......
...@@ -54,8 +54,7 @@ struct Blockwise3dTensorCopy1 ...@@ -54,8 +54,7 @@ struct Blockwise3dTensorCopy1
constexpr index_t read_per_d2 = integer_divide_ceil(L2, DataPerRead); constexpr index_t read_per_d2 = integer_divide_ceil(L2, DataPerRead);
constexpr auto ref_desc = constexpr auto ref_desc = make_ConstantTensorDescriptor(Sequence<L0, L1, read_per_d2>{});
make_ConstantTensorDescriptor(Sequence<L0, L1, read_per_d2>{});
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize; constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
...@@ -72,10 +71,8 @@ struct Blockwise3dTensorCopy1 ...@@ -72,10 +71,8 @@ struct Blockwise3dTensorCopy1
did[2] = is / ref_desc.GetStride(I2); did[2] = is / ref_desc.GetStride(I2);
const index_t src_index = const index_t src_index = src_desc.Get1dIndex(did[0], did[1], did[2] * DataPerRead);
src_desc.Get1dIndex(did[0], did[1], did[2] * DataPerRead); const index_t dst_index = dst_desc.Get1dIndex(did[0], did[1], did[2] * DataPerRead);
const index_t dst_index =
dst_desc.Get1dIndex(did[0], did[1], did[2] * DataPerRead);
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) = *(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
*(reinterpret_cast<const vector_t*>(p_src + src_index)); *(reinterpret_cast<const vector_t*>(p_src + src_index));
......
...@@ -340,8 +340,7 @@ struct BlockwiseChwnTensorCopyPadded ...@@ -340,8 +340,7 @@ struct BlockwiseChwnTensorCopyPadded
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize; constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
const Float* p_src_tmp = const Float* p_src_tmp =
p_src + p_src + src_desc.Get1dIndex(c_block_data_begin,
src_desc.Get1dIndex(c_block_data_begin,
(ho_block_data_begin + h_block_pad_low) - h_global_pad_low, (ho_block_data_begin + h_block_pad_low) - h_global_pad_low,
(wo_block_data_begin + w_block_pad_low) - w_global_pad_low, (wo_block_data_begin + w_block_pad_low) - w_global_pad_low,
n_block_data_begin); n_block_data_begin);
...@@ -494,7 +493,7 @@ struct Blockwise4dTensorCopy3 ...@@ -494,7 +493,7 @@ struct Blockwise4dTensorCopy3
"wrrong! BlockSize is not big enough for ThreadPerDims!"); "wrrong! BlockSize is not big enough for ThreadPerDims!");
constexpr index_t num_active_thread = constexpr index_t num_active_thread =
thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3; accumulate_on_sequence(ThreadPerDims{}, mod_conv::multiplies<index_t>{}, Number<1>{});
if(BlockSize > num_active_thread) if(BlockSize > num_active_thread)
{ {
...@@ -504,19 +503,18 @@ struct Blockwise4dTensorCopy3 ...@@ -504,19 +503,18 @@ struct Blockwise4dTensorCopy3
} }
} }
const index_t thread_id_d0 = constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor(ThreadPerDims{});
get_thread_local_1d_id() / (thread_per_d1 * thread_per_d2 * thread_per_d3); const auto thread_multi_id = thread_cluster_desc.GetMultiIndex(get_thread_local_1d_id());
index_t itmp = get_thread_local_1d_id() -
thread_id_d0 * (thread_per_d1 * thread_per_d2 * thread_per_d3);
const index_t thread_id_d1 = itmp / (thread_per_d2 * thread_per_d3);
itmp -= thread_id_d1 * (thread_per_d2 * thread_per_d3);
const index_t thread_id_d2 = itmp / thread_per_d3;
const index_t thread_id_d3 = itmp - thread_id_d2 * thread_per_d3;
mSrcMyThreadOffset = SrcDesc{}.Get1dIndex( mSrcMyThreadOffset = SrcDesc{}.Get1dIndex(thread_multi_id[0],
thread_id_d0, thread_id_d1, thread_id_d2, thread_id_d3 * DataPerRead); thread_multi_id[1],
mDstMyThreadOffset = DstDesc{}.Get1dIndex( thread_multi_id[2],
thread_id_d0, thread_id_d1, thread_id_d2, thread_id_d3 * DataPerRead); thread_multi_id[3] * DataPerRead);
mDstMyThreadOffset = DstDesc{}.Get1dIndex(thread_multi_id[0],
thread_multi_id[1],
thread_multi_id[2],
thread_multi_id[3] * DataPerRead);
} }
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
...@@ -745,3 +743,113 @@ struct Blockwise4dTensorCopy3 ...@@ -745,3 +743,113 @@ struct Blockwise4dTensorCopy3
} }
} }
}; };
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class SrcOpLengths,
class DstFromSrcReorder>
struct Blockwise4dTensorCopyReorder1
{
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
{
auto f_copy = [](const Float& src, Float& dst) { dst = src; };
blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src<BlockSize>(
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
}
};
#if 0
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class SrcLengths,
class SrcSubLengths,
class SrcThreadPerDims,
class DstFromSrcReorder,
index_t DataPerRead,
index_t DataPerWrite>
struct Blockwise4dTensorCopyReorder3
{
index_t mSrcMyThreadOffset;
index_t mDstMyThreadOffset;
__device__ Blockwise4dTensorCopyReorder3()
{
constexpr index_t nDim = SrcDesc{}.GetDimension();
static_assert(DstDesc{}.GetDimension() == nDim && SrcOpLengths::nDim == nDim &&
SrcOpThreadPerDims::nDim == nDim && DstFromSrcReorder::nDim == nDim,
"wrong! nDim is not consistent\n");
// Src
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
"wrong! only support DataPerRead == 1, 2 or 4!\n");
static_assert(DataPerRead == 1 || SrcDesc{}.GetStride(Number<nDim-1>{}) == 1,
"wrong! only support src.stride(nDim-1) == 1 if DataPerRead > 1!\n");
static_assert(
SrcDesc{}.GetStride(Number<nDim-2>{}) % DataPerRead == 0,
"wrong! src.stride(nDim-2) should be multiple of DataPerRead to keep alignment");
static_assert(SrcSubLengths{}.Get(Number<nDim-1>{}) % DataPerRead == 0, "wrong! SrcSubLengths[nDim-1] % DataPerRead != 0\n");
static_loop<nDim-1>([](auto I){
constexpr index_t src_len = SrcLengths{}.Get(I);
constexpr index_t src_sub_len = SrcSubLengths{}.Get(I);
constexpr index_t thread_per_dim = SrcThreadPerDims{}.Get(I);
static_assert(src_len % (src_sub_len * thread_per_dim) == 0,
"wrong! cannot evenly divide tensor lengths");
});
constexpr index_t num_active_thread = accumulate_on_sequence(SrcOpThreadPerDims{}, mod_conv::multiplies<index_t>{}, Number<1>{});
static_assert(BlockSize >= num_active_thread,
"wrong! BlockSize is not big enough for ThreadPerDims!");
if(BlockSize > num_active_thread)
{
if(get_thread_local_1d_id() >= num_active_thread)
{
return;
}
}
const auto thread_multi_id = SrcOpThreadPerDims::GetMultiIndex(get_thread_local_1d_id());
const index_t thread_id_d0 =
get_thread_local_1d_id() / (thread_per_d1 * thread_per_d2 * thread_per_d3);
index_t itmp = get_thread_local_1d_id() -
thread_id_d0 * (thread_per_d1 * thread_per_d2 * thread_per_d3);
const index_t thread_id_d1 = itmp / (thread_per_d2 * thread_per_d3);
itmp -= thread_id_d1 * (thread_per_d2 * thread_per_d3);
const index_t thread_id_d2 = itmp / thread_per_d3;
const index_t thread_id_d3 = itmp - thread_id_d2 * thread_per_d3;
mSrcMyThreadOffset = SrcDesc{}.Get1dIndex(
thread_id_d0, thread_id_d1, thread_id_d2, thread_id_d3 * DataPerRead);
}
__device__ static constexpr index_t GetRegisterClipboardSize()
{
static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
}
__device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
Float* __restrict__ p_clipboard) const
{
}
__device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
Float* __restrict__ p_dst) const
{
}
};
#endif
...@@ -393,8 +393,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -393,8 +393,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
{ {
threadwise_matrix_copy( threadwise_matrix_copy(
c_thread_sub_mtx, c_thread_sub_mtx,
p_c_thread + p_c_thread + c_thread_sub_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster,
c_thread_sub_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster,
n_repeat * NPerLevel1Cluster), n_repeat * NPerLevel1Cluster),
c_block_mtx, c_block_mtx,
p_c_block + p_c_block +
......
...@@ -93,8 +93,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc, ...@@ -93,8 +93,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
Float p_out_thread[out_thread_desc.GetElementSpace()]; Float p_out_thread[out_thread_desc.GetElementSpace()];
threadwise_4d_tensor_copy(out_block_desc, threadwise_4d_tensor_copy(out_block_desc,
p_out_block + p_out_block + out_block_desc.Get1dIndex(n_thread_data_begin,
out_block_desc.Get1dIndex(n_thread_data_begin,
k_thread_data_begin, k_thread_data_begin,
ho_thread_data_begin, ho_thread_data_begin,
wo_thread_data_begin), wo_thread_data_begin),
...@@ -108,8 +107,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc, ...@@ -108,8 +107,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
// threadwise convolution // threadwise convolution
threadwise_direct_convolution_2( threadwise_direct_convolution_2(
in_thread_block_desc, in_thread_block_desc,
p_in_block + p_in_block + in_block_desc.Get1dIndex(n_thread_data_begin,
in_block_desc.Get1dIndex(n_thread_data_begin,
c_thread_data_begin, c_thread_data_begin,
hi_thread_data_begin, hi_thread_data_begin,
wi_thread_data_begin), wi_thread_data_begin),
...@@ -124,8 +122,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc, ...@@ -124,8 +122,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
threadwise_4d_tensor_copy(out_thread_desc, threadwise_4d_tensor_copy(out_thread_desc,
p_out_thread, p_out_thread,
out_block_desc, out_block_desc,
p_out_block + p_out_block + out_block_desc.Get1dIndex(n_thread_data_begin,
out_block_desc.Get1dIndex(n_thread_data_begin,
k_thread_data_begin, k_thread_data_begin,
ho_thread_data_begin, ho_thread_data_begin,
wo_thread_data_begin), wo_thread_data_begin),
......
#pragma once #pragma once
#include "constant_integral.hip.hpp" #include "constant_integral.hip.hpp"
template <index_t NLoop> template <index_t Iter, index_t Remaining, index_t Increment>
struct static_loop_n struct static_for_impl
{ {
template <class F> template <class F>
__host__ __device__ void operator()(F f) const __host__ __device__ void operator()(F f) const
{ {
static_assert(NLoop > 1, "out-of-range"); static_assert(Remaining % Increment == 0, "wrong! Remaining % Increment != 0");
static_assert(Increment <= Remaining, "will go out-of-range");
f(Number<NLoop - 1>{}); f(Number<Iter>{});
static_loop_n<NLoop - 1>{}(f); static_for_impl<Iter + Increment, Remaining - Increment, Increment>{}(f);
} }
}; };
template <> template <index_t Iter, index_t Increment>
struct static_loop_n<1> struct static_for_impl<Iter, 0, Increment>
{
template <class F>
__host__ __device__ void operator()(F) const
{
// do nothing
return;
}
};
template <index_t NBegin, index_t NEnd, index_t Increment>
struct static_for
{ {
template <class F> template <class F>
__host__ __device__ void operator()(F f) const __host__ __device__ void operator()(F f) const
{ {
f(Number<0>{}); static_assert(NBegin < NEnd, "Wrong! we should have NBegin < NEnd");
static_assert((NEnd - NBegin) % Increment == 0,
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
static_for_impl<NBegin, NEnd - NBegin, Increment>{}(f);
} }
}; };
...@@ -55,3 +70,18 @@ __host__ __device__ constexpr auto unpacker(F f) ...@@ -55,3 +70,18 @@ __host__ __device__ constexpr auto unpacker(F f)
return [=](auto xs_array){ f(xs...); }; return [=](auto xs_array){ f(xs...); };
} }
#endif #endif
namespace mod_conv {
template <class T>
struct multiplies
{
__host__ __device__ constexpr T operator()(T a, T b) const { return a * b; }
};
template <class T>
struct plus
{
__host__ __device__ constexpr T operator()(T a, T b) const { return a + b; }
};
} // namespace mod_conv
...@@ -99,8 +99,8 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn ...@@ -99,8 +99,8 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
// tensor view of blockwise input and weight in LDS // tensor view of blockwise input and weight in LDS
// be careful of alignment // be careful of alignment
constexpr index_t max_align = constexpr index_t max_align = mod_conv::max(
mod_conv::max(InBlockCopyDataPerRead, WeiBlockCopyDataPerRead, GemmDataPerReadA, GemmDataPerReadB); InBlockCopyDataPerRead, WeiBlockCopyDataPerRead, GemmDataPerReadA, GemmDataPerReadB);
constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{}, Number<max_align>{}); Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{}, Number<max_align>{});
...@@ -135,7 +135,6 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn ...@@ -135,7 +135,6 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
InBlockCopyDataPerRead>{}; InBlockCopyDataPerRead>{};
#endif #endif
// blockwise wei copy // blockwise wei copy
// format is [CPerBlock*Y*X,KPerBlock] // format is [CPerBlock*Y*X,KPerBlock]
const auto blockwise_wei_copy = const auto blockwise_wei_copy =
...@@ -202,8 +201,7 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn ...@@ -202,8 +201,7 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
threadwise_4d_tensor_set_zero(out_khwn_thread_desc, p_out_thread); threadwise_4d_tensor_set_zero(out_khwn_thread_desc, p_out_thread);
const Float* p_in_global_block_offset = const Float* p_in_global_block_offset =
p_in_global + p_in_global + in_chwn_global_desc.Get1dIndex(
in_chwn_global_desc.Get1dIndex(
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin); 0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
const Float* p_wei_global_block_offset = const Float* p_wei_global_block_offset =
...@@ -323,12 +321,11 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn ...@@ -323,12 +321,11 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
} }
#endif #endif
threadwise_10d_tensor_copy( threadwise_10d_tensor_copy(out_10d_thread_desc,
out_10d_thread_desc,
p_out_thread, p_out_thread,
out_10d_global_desc, out_10d_global_desc,
p_out_global + p_out_global + out_khwn_global_desc.Get1dIndex(
out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
......
...@@ -190,8 +190,7 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn_lds_double_buffer ...@@ -190,8 +190,7 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn_lds_double_buffer
__shared__ Float p_wei_block_double[2 * wei_block_space]; __shared__ Float p_wei_block_double[2 * wei_block_space];
const Float* p_in_global_block_offset = const Float* p_in_global_block_offset =
p_in_global + p_in_global + in_chwn_global_desc.Get1dIndex(
in_chwn_global_desc.Get1dIndex(
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin); 0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
const Float* p_wei_global_block_offset = const Float* p_wei_global_block_offset =
...@@ -393,12 +392,11 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn_lds_double_buffer ...@@ -393,12 +392,11 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn_lds_double_buffer
} }
#endif #endif
threadwise_10d_tensor_copy( threadwise_10d_tensor_copy(out_10d_thread_desc,
out_10d_thread_desc,
p_out_thread, p_out_thread,
out_10d_global_desc, out_10d_global_desc,
p_out_global + p_out_global + out_khwn_global_desc.Get1dIndex(
out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
......
...@@ -101,8 +101,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn ...@@ -101,8 +101,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
// LDS tensor view // LDS tensor view
// be careful of alignment // be careful of alignment
constexpr index_t max_align = constexpr index_t max_align = mod_conv::max(
mod_conv::max(InBlockCopyDataPerRead, WeiBlockCopyDataPerRead, GemmDataPerReadA, GemmDataPerReadB); InBlockCopyDataPerRead, WeiBlockCopyDataPerRead, GemmDataPerReadA, GemmDataPerReadB);
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, HoPerBlock, WiPerBlock, NPerBlock>{}, Number<max_align>{}); Sequence<CPerBlock, HoPerBlock, WiPerBlock, NPerBlock>{}, Number<max_align>{});
...@@ -116,8 +116,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn ...@@ -116,8 +116,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
// blockwise copy // blockwise copy
// input: format is [C, Hi, Wi, N] // input: format is [C, Hi, Wi, N]
const auto blockwise_in_copy =
#if 0 #if 0
const auto blockwise_in_copy =
Blockwise4dTensorCopy1<BlockSize, Blockwise4dTensorCopy1<BlockSize,
Float, Float,
decltype(in_c_h_w_n_global_desc), decltype(in_c_h_w_n_global_desc),
...@@ -125,6 +125,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn ...@@ -125,6 +125,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
decltype(in_c_h_w_n_block_desc.GetLengths()), decltype(in_c_h_w_n_block_desc.GetLengths()),
InBlockCopyDataPerRead>{}; InBlockCopyDataPerRead>{};
#else #else
const auto blockwise_in_copy =
Blockwise4dTensorCopy3<BlockSize, Blockwise4dTensorCopy3<BlockSize,
Float, Float,
decltype(in_c_h_w_n_global_desc), decltype(in_c_h_w_n_global_desc),
...@@ -150,10 +151,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn ...@@ -150,10 +151,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
// A_matrix[C,K] is a sub-matrix of wei_block[C,K] // A_matrix[C,K] is a sub-matrix of wei_block[C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N] // C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr auto a_c_k_block_mtx_desc = constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
make_ConstantMatrixDescriptor(Number<CPerBlock>{}, Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_x_k_block_desc.GetStride(I0)>{});
Number<KPerBlock>{},
Number<wei_c_x_k_block_desc.GetStride(I0)>{});
constexpr auto b_c_wn_block_mtx_desc = constexpr auto b_c_wn_block_mtx_desc =
make_ConstantMatrixDescriptor(Number<CPerBlock>{}, make_ConstantMatrixDescriptor(Number<CPerBlock>{},
...@@ -187,8 +186,10 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn ...@@ -187,8 +186,10 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
GemmDataPerReadB>{}; GemmDataPerReadB>{};
// LDS: be careful of alignment // LDS: be careful of alignment
constexpr index_t in_block_space = in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{}); constexpr index_t in_block_space =
constexpr index_t wei_block_space = wei_c_x_k_block_desc.GetElementSpace(Number<max_align>{}); in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
constexpr index_t wei_block_space =
wei_c_x_k_block_desc.GetElementSpace(Number<max_align>{});
__shared__ Float p_in_block[in_block_space]; __shared__ Float p_in_block[in_block_space];
__shared__ Float p_wei_block[wei_block_space]; __shared__ Float p_wei_block[wei_block_space];
...@@ -213,8 +214,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn ...@@ -213,8 +214,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread); threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
const Float* p_in_global_block_offset = const Float* p_in_global_block_offset =
p_in_global + p_in_global + in_c_h_w_n_global_desc.Get1dIndex(
in_c_h_w_n_global_desc.Get1dIndex(
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin); 0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
const Float* p_wei_global_block_offset = const Float* p_wei_global_block_offset =
...@@ -239,9 +239,9 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn ...@@ -239,9 +239,9 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
for(index_t x = 0; x < X; ++x) for(index_t x = 0; x < X; ++x)
{ {
blockwise_batch_gemm.Run(p_wei_block + wei_c_x_k_block_desc.Get1dIndex(0, x, 0), blockwise_batch_gemm.Run(p_wei_block + wei_c_x_k_block_desc.Get1dIndex(0, x, 0),
p_in_block + in_c_h_w_n_block_desc.Get1dIndex(0, 0, x, 0), p_in_block +
in_c_h_w_n_block_desc.Get1dIndex(0, 0, x, 0),
p_out_thread); p_out_thread);
} }
__syncthreads(); __syncthreads();
...@@ -321,12 +321,11 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn ...@@ -321,12 +321,11 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
} }
#endif #endif
threadwise_10d_tensor_copy( threadwise_10d_tensor_copy(out_10d_thread_desc,
out_10d_thread_desc,
p_out_thread, p_out_thread,
out_10d_global_desc, out_10d_global_desc,
p_out_global + p_out_global + out_k_h_w_n_global_desc.Get1dIndex(
out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
......
...@@ -365,12 +365,11 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer ...@@ -365,12 +365,11 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
constexpr auto out_kb_global_desc = make_ConstantTensorDescriptor(Sequence<K, B>{}); constexpr auto out_kb_global_desc = make_ConstantTensorDescriptor(Sequence<K, B>{});
threadwise_6d_tensor_copy( threadwise_6d_tensor_copy(out_6d_thread_desc,
out_6d_thread_desc,
p_out_thread, p_out_thread,
out_6d_global_desc, out_6d_global_desc,
p_out_global + p_out_global + out_kb_global_desc.Get1dIndex(
out_kb_global_desc.Get1dIndex(k_thread_data_begin, b_thread_data_begin), k_thread_data_begin, b_thread_data_begin),
out_6d_thread_desc.GetLengths(), out_6d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite>{}); Number<OutThreadCopyDataPerWrite>{});
} }
......
...@@ -113,8 +113,7 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_ ...@@ -113,8 +113,7 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
c_block_work_begin += CPerBlock) c_block_work_begin += CPerBlock)
{ {
// copy input tensor to LDS // copy input tensor to LDS
blockwise_in_copy.Run(p_in_global + blockwise_in_copy.Run(p_in_global + in_global_desc.Get1dIndex(n_block_work_begin,
in_global_desc.Get1dIndex(n_block_work_begin,
c_block_work_begin, c_block_work_begin,
hi_block_work_begin, hi_block_work_begin,
wi_block_work_begin), wi_block_work_begin),
...@@ -144,9 +143,9 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_ ...@@ -144,9 +143,9 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
} }
// copy output tensor from LDS to device mem // copy output tensor from LDS to device mem
blockwise_out_copy.Run( blockwise_out_copy.Run(p_out_block,
p_out_block, p_out_global + out_global_desc.Get1dIndex(n_block_work_begin,
p_out_global + k_block_work_begin,
out_global_desc.Get1dIndex( ho_block_work_begin,
n_block_work_begin, k_block_work_begin, ho_block_work_begin, wo_block_work_begin)); wo_block_work_begin));
} }
...@@ -175,17 +175,15 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -175,17 +175,15 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
c_block_data_begin += CPerBlock, __syncthreads()) c_block_data_begin += CPerBlock, __syncthreads())
{ {
// copy input tensor to LDS // copy input tensor to LDS
blockwise_in_copy.Run(p_in_global + blockwise_in_copy.Run(p_in_global + in_nchw_global_desc.Get1dIndex(n_block_data_begin,
in_nchw_global_desc.Get1dIndex(n_block_data_begin,
c_block_data_begin, c_block_data_begin,
hi_block_data_begin, hi_block_data_begin,
wi_block_data_begin), wi_block_data_begin),
p_in_block); p_in_block);
// copy weight tensor to LDS // copy weight tensor to LDS
blockwise_wei_copy.Run( blockwise_wei_copy.Run(p_wei_global + wei_kcyx_global_desc.Get1dIndex(
p_wei_global + k_block_data_begin, c_block_data_begin, 0, 0),
wei_kcyx_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0),
p_wei_block); p_wei_block);
__syncthreads(); __syncthreads();
...@@ -196,8 +194,7 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -196,8 +194,7 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
#if 1 #if 1
threadwise_direct_convolution_2( threadwise_direct_convolution_2(
in_nchw_thread_block_desc, in_nchw_thread_block_desc,
p_in_block + p_in_block + in_nchw_block_desc.Get1dIndex(n_thread_data_begin,
in_nchw_block_desc.Get1dIndex(n_thread_data_begin,
c_thread_data, c_thread_data,
hi_thread_data_begin, hi_thread_data_begin,
wi_thread_data_begin), wi_thread_data_begin),
...@@ -209,8 +206,7 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -209,8 +206,7 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
#elif 0 #elif 0
threadwise_direct_convolution_3( threadwise_direct_convolution_3(
in_nchw_thread_block_desc, in_nchw_thread_block_desc,
p_in_block + p_in_block + in_nchw_block_desc.Get1dIndex(n_thread_data_begin,
in_nchw_block_desc.Get1dIndex(n_thread_data_begin,
c_thread_data, c_thread_data,
hi_thread_data_begin, hi_thread_data_begin,
wi_thread_data_begin), wi_thread_data_begin),
...@@ -228,8 +224,7 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -228,8 +224,7 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
out_nkhw_thread_desc, out_nkhw_thread_desc,
p_out_thread, p_out_thread,
out_nkhw_global_desc, out_nkhw_global_desc,
p_out_global + p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin, k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin), wo_block_data_begin + wo_thread_data_begin),
......
...@@ -198,9 +198,8 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -198,9 +198,8 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
p_in_vec_block); p_in_vec_block);
// copy weight tensor to LDS // copy weight tensor to LDS
blockwise_wei_copy.Run( blockwise_wei_copy.Run(p_wei_vec_global + wei_kcyx_vec_global_desc.Get1dIndex(
p_wei_vec_global + k_block_data_begin, c_block_data_begin, 0, 0),
wei_kcyx_vec_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0),
p_wei_vec_block); p_wei_vec_block);
__syncthreads(); __syncthreads();
...@@ -211,8 +210,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -211,8 +210,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
#if 1 #if 1
threadwise_direct_convolution_2( threadwise_direct_convolution_2(
in_nchw_vec_thread_block_desc, in_nchw_vec_thread_block_desc,
p_in_vec_block + p_in_vec_block + in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
c_thread_data, c_thread_data,
hi_thread_data_begin, hi_thread_data_begin,
wi_thread_data_begin), wi_thread_data_begin),
...@@ -224,8 +222,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -224,8 +222,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
#elif 0 #elif 0
threadwise_direct_convolution_3( threadwise_direct_convolution_3(
in_nchw_vec_thread_block_desc, in_nchw_vec_thread_block_desc,
p_in_vec_block + p_in_vec_block + in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
c_thread_data, c_thread_data,
hi_thread_data_begin, hi_thread_data_begin,
wi_thread_data_begin), wi_thread_data_begin),
...@@ -243,8 +240,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -243,8 +240,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
out_nkhw_thread_desc, out_nkhw_thread_desc,
p_out_thread, p_out_thread,
out_nkhw_global_desc, out_nkhw_global_desc,
p_out_global + p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin, k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin), wo_block_data_begin + wo_thread_data_begin),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment