Commit 1cc683a3 authored by Chao Liu's avatar Chao Liu
Browse files

adding implicit gemm v3

parent 8a4b5978
...@@ -47,8 +47,8 @@ void device_convolution_direct_v2_nchw_kcyx_nkhw(InDesc, ...@@ -47,8 +47,8 @@ void device_convolution_direct_v2_nchw_kcyx_nkhw(InDesc,
constexpr index_t HoPerThread = 2; constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopyDataPerRead = 2; constexpr index_t InBlockCopyDataPerRead = 1;
constexpr index_t WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 1;
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#endif #endif
......
...@@ -92,7 +92,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc, ...@@ -92,7 +92,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 2; constexpr index_t OutThreadCopyDataPerWrite_W = 2;
#elif 1 #elif 0
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 32 // for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 32
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
......
...@@ -85,6 +85,9 @@ void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc, ...@@ -85,6 +85,9 @@ void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc,
constexpr index_t InBlockCopySrcDataPerRead_B = 1; constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_C_K = Sequence<1, 4>;
using WeiBlockCopyClusterLengths_C_K = Sequence<8, 32>;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4; constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
#endif #endif
...@@ -123,8 +126,11 @@ void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc, ...@@ -123,8 +126,11 @@ void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc,
InBlockCopyClusterLengths_N1_N2_C_B, InBlockCopyClusterLengths_N1_N2_C_B,
InBlockCopySrcDataPerRead_B, InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2, InBlockCopyDstDataPerWrite_N2,
WeiBlockCopySubLengths_C_K,
WeiBlockCopyClusterLengths_C_K,
WeiBlockCopyDataPerAccess_K>{}; WeiBlockCopyDataPerAccess_K>{};
#if 1
float time = launch_kernel(run_gridwise_convolution<decltype(gridwise_conv), T>, float time = launch_kernel(run_gridwise_convolution<decltype(gridwise_conv), T>,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
...@@ -138,6 +144,7 @@ void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc, ...@@ -138,6 +144,7 @@ void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time); (std::size_t(1000) * 1000 * 1000) / time);
usleep(std::min(time * 1000, float(10000))); usleep(std::min(time * 1000, float(10000)));
#endif
} }
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data()); out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
......
...@@ -411,7 +411,18 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result) ...@@ -411,7 +411,18 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
#if 1 #if 0
constexpr index_t N = 8;
constexpr index_t C = 8;
constexpr index_t HI = 3;
constexpr index_t WI = 18;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 1
// 3x3, 34x34 // 3x3, 34x34
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 256; constexpr index_t C = 256;
...@@ -635,11 +646,13 @@ int main(int argc, char* argv[]) ...@@ -635,11 +646,13 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
#if 1
if(Y == 3 && X == 3) if(Y == 3 && X == 3)
{ {
host_winograd_3x3_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads); host_winograd_3x3_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads);
} }
else else
#endif
{ {
host_direct_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads); host_direct_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads);
} }
......
...@@ -57,7 +57,7 @@ __host__ __device__ constexpr auto make_zero_array() ...@@ -57,7 +57,7 @@ __host__ __device__ constexpr auto make_zero_array()
} }
template <class TData, index_t NSize, index_t... IRs> template <class TData, index_t NSize, index_t... IRs>
__host__ __device__ auto reorder_array_given_new2old(const Array<TData, NSize>& old_array, __host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData, NSize>& old_array,
Sequence<IRs...> new2old) Sequence<IRs...> new2old)
{ {
Array<TData, NSize> new_array; Array<TData, NSize> new_array;
...@@ -73,7 +73,7 @@ __host__ __device__ auto reorder_array_given_new2old(const Array<TData, NSize>& ...@@ -73,7 +73,7 @@ __host__ __device__ auto reorder_array_given_new2old(const Array<TData, NSize>&
} }
template <class TData, index_t NSize, index_t... IRs> template <class TData, index_t NSize, index_t... IRs>
__host__ __device__ auto reorder_array_given_old2new(const Array<TData, NSize>& old_array, __host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData, NSize>& old_array,
Sequence<IRs...> old2new) Sequence<IRs...> old2new)
{ {
Array<TData, NSize> new_array; Array<TData, NSize> new_array;
...@@ -89,7 +89,7 @@ __host__ __device__ auto reorder_array_given_old2new(const Array<TData, NSize>& ...@@ -89,7 +89,7 @@ __host__ __device__ auto reorder_array_given_old2new(const Array<TData, NSize>&
} }
template <class TData, index_t NSize, class ExtractSeq> template <class TData, index_t NSize, class ExtractSeq>
__host__ __device__ auto extract_array(const Array<TData, NSize>& old_array, ExtractSeq) __host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_array, ExtractSeq)
{ {
Array<TData, ExtractSeq::GetSize()> new_array; Array<TData, ExtractSeq::GetSize()> new_array;
...@@ -112,6 +112,7 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData, ...@@ -112,6 +112,7 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData,
static_for<0, NSize, 1>{}([&](auto I) { static_for<0, NSize, 1>{}([&](auto I) {
constexpr index_t i = I.Get(); constexpr index_t i = I.Get();
result[i] = a[i] + b[i]; result[i] = a[i] + b[i];
}); });
...@@ -129,7 +130,7 @@ __host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is. ...@@ -129,7 +130,7 @@ __host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is.
static_for<0, NSize, 1>{}([&](auto I) { static_for<0, NSize, 1>{}([&](auto I) {
constexpr index_t i = I.Get(); constexpr index_t i = I.Get();
result[i] = a[i] + b.Get(I); result[i] = a[i] * b.Get(I);
}); });
return result; return result;
......
...@@ -26,6 +26,11 @@ struct ConstantMergedTensorDescriptor ...@@ -26,6 +26,11 @@ struct ConstantMergedTensorDescriptor
// TODO: check there is no duplication in OriginalDimMergeSeqs // TODO: check there is no duplication in OriginalDimMergeSeqs
} }
__host__ __device__ static constexpr auto GetOriginalTensorDescriptor()
{
return OriginalTensorDesc{};
}
__host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; } __host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; }
__host__ __device__ static constexpr index_t GetNumOfOriginalDimension() __host__ __device__ static constexpr index_t GetNumOfOriginalDimension()
...@@ -120,3 +125,9 @@ __host__ __device__ constexpr auto make_ConstantMergedTensorDescriptor(OriginalT ...@@ -120,3 +125,9 @@ __host__ __device__ constexpr auto make_ConstantMergedTensorDescriptor(OriginalT
{ {
return ConstantMergedTensorDescriptor<OriginalTensorDesc, OriginalDimMergeSeqs...>{}; return ConstantMergedTensorDescriptor<OriginalTensorDesc, OriginalDimMergeSeqs...>{};
} }
template <class TDesc>
__host__ __device__ void print_ConstantMergedTensorDescriptor(TDesc, const char* s)
{
print_ConstantTensorDescriptor(TDesc::GetOriginalTensorDescriptor(), s);
}
...@@ -396,31 +396,35 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor_default_rank_al ...@@ -396,31 +396,35 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor_default_rank_al
template <class TDesc> template <class TDesc>
__host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s) __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
{ {
constexpr auto desc = TDesc{}; constexpr index_t ndim = TDesc::GetNumOfDimension();
constexpr index_t ndim = desc.GetNumOfDimension();
static_assert(ndim >= 2 && ndim <= 10, "wrong!"); static_assert(ndim >= 2 && ndim <= 10, "wrong!");
if(ndim == 2) static_if<ndim == 2>{}([&](auto fwd) {
{
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
printf("%s dim %u, lengths {%u %u}, strides {%u %u}\n", constexpr auto desc = fwd(TDesc{});
printf("%s dim %u, lengths {%u %u}, strides {%u %u}, ranks {%u %u}\n",
s, s,
desc.GetNumOfDimension(), desc.GetNumOfDimension(),
desc.GetLength(I0), desc.GetLength(I0),
desc.GetLength(I1), desc.GetLength(I1),
desc.GetStride(I0), desc.GetStride(I0),
desc.GetStride(I1)); desc.GetStride(I1),
} desc.GetMemoryRank(I0),
else if(ndim == 3) desc.GetMemoryRank(I1));
{ });
static_if<ndim == 3>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
printf("%s dim %u, lengths {%u %u %u}, strides {%u %u %u}\n", constexpr auto desc = fwd(TDesc{});
printf("%s dim %u, lengths {%u %u %u}, strides {%u %u %u}, ranks {%u %u %u}\n",
s, s,
desc.GetNumOfDimension(), desc.GetNumOfDimension(),
desc.GetLength(I0), desc.GetLength(I0),
...@@ -428,16 +432,21 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s) ...@@ -428,16 +432,21 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
desc.GetLength(I2), desc.GetLength(I2),
desc.GetStride(I0), desc.GetStride(I0),
desc.GetStride(I1), desc.GetStride(I1),
desc.GetStride(I2)); desc.GetStride(I2),
} desc.GetMemoryRank(I0),
else if(ndim == 4) desc.GetMemoryRank(I1),
{ desc.GetMemoryRank(I2));
});
static_if<ndim == 4>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n", constexpr auto desc = fwd(TDesc{});
printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}, ranks {%u %u %u %u}\n",
s, s,
desc.GetNumOfDimension(), desc.GetNumOfDimension(),
desc.GetLength(I0), desc.GetLength(I0),
...@@ -447,17 +456,24 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s) ...@@ -447,17 +456,24 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
desc.GetStride(I0), desc.GetStride(I0),
desc.GetStride(I1), desc.GetStride(I1),
desc.GetStride(I2), desc.GetStride(I2),
desc.GetStride(I3)); desc.GetStride(I3),
} desc.GetMemoryRank(I0),
else if(ndim == 5) desc.GetMemoryRank(I1),
{ desc.GetMemoryRank(I2),
desc.GetMemoryRank(I3));
});
static_if<ndim == 5>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{}; constexpr auto I4 = Number<4>{};
printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}\n", constexpr auto desc = fwd(TDesc{});
printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}, ranks {%u %u %u %u "
"%u}\n",
s, s,
desc.GetNumOfDimension(), desc.GetNumOfDimension(),
desc.GetLength(I0), desc.GetLength(I0),
...@@ -469,10 +485,15 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s) ...@@ -469,10 +485,15 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
desc.GetStride(I1), desc.GetStride(I1),
desc.GetStride(I2), desc.GetStride(I2),
desc.GetStride(I3), desc.GetStride(I3),
desc.GetStride(I4)); desc.GetStride(I4),
} desc.GetMemoryRank(I0),
else if(ndim == 6) desc.GetMemoryRank(I1),
{ desc.GetMemoryRank(I2),
desc.GetMemoryRank(I3),
desc.GetMemoryRank(I4));
});
static_if<ndim == 6>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
...@@ -480,7 +501,10 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s) ...@@ -480,7 +501,10 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
constexpr auto I4 = Number<4>{}; constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{}; constexpr auto I5 = Number<5>{};
printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}\n", constexpr auto desc = fwd(TDesc{});
printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}, ranks {%u %u "
"%u %u %u %u}\n",
s, s,
desc.GetNumOfDimension(), desc.GetNumOfDimension(),
desc.GetLength(I0), desc.GetLength(I0),
...@@ -494,10 +518,16 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s) ...@@ -494,10 +518,16 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
desc.GetStride(I2), desc.GetStride(I2),
desc.GetStride(I3), desc.GetStride(I3),
desc.GetStride(I4), desc.GetStride(I4),
desc.GetStride(I5)); desc.GetStride(I5),
} desc.GetMemoryRank(I0),
else if(ndim == 7) desc.GetMemoryRank(I1),
{ desc.GetMemoryRank(I2),
desc.GetMemoryRank(I3),
desc.GetMemoryRank(I4),
desc.GetMemoryRank(I5));
});
static_if<ndim == 7>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
...@@ -506,7 +536,10 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s) ...@@ -506,7 +536,10 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
constexpr auto I5 = Number<5>{}; constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{}; constexpr auto I6 = Number<6>{};
printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}\n", constexpr auto desc = fwd(TDesc{});
printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}, ranks "
"{%u %u %u %u %u %u %u}\n",
s, s,
desc.GetNumOfDimension(), desc.GetNumOfDimension(),
desc.GetLength(I0), desc.GetLength(I0),
...@@ -522,10 +555,17 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s) ...@@ -522,10 +555,17 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
desc.GetStride(I3), desc.GetStride(I3),
desc.GetStride(I4), desc.GetStride(I4),
desc.GetStride(I5), desc.GetStride(I5),
desc.GetStride(I6)); desc.GetStride(I6),
} desc.GetMemoryRank(I0),
else if(ndim == 8) desc.GetMemoryRank(I1),
{ desc.GetMemoryRank(I2),
desc.GetMemoryRank(I3),
desc.GetMemoryRank(I4),
desc.GetMemoryRank(I5),
desc.GetMemoryRank(I6));
});
static_if<ndim == 8>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
...@@ -535,7 +575,10 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s) ...@@ -535,7 +575,10 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
constexpr auto I6 = Number<6>{}; constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{}; constexpr auto I7 = Number<7>{};
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}\n", constexpr auto desc = fwd(TDesc{});
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}, "
"ranks {%u %u %u %u %u %u %u %u}\n",
s, s,
desc.GetNumOfDimension(), desc.GetNumOfDimension(),
desc.GetLength(I0), desc.GetLength(I0),
...@@ -553,10 +596,18 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s) ...@@ -553,10 +596,18 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
desc.GetStride(I4), desc.GetStride(I4),
desc.GetStride(I5), desc.GetStride(I5),
desc.GetStride(I6), desc.GetStride(I6),
desc.GetStride(I7)); desc.GetStride(I7),
} desc.GetMemoryRank(I0),
else if(ndim == 9) desc.GetMemoryRank(I1),
{ desc.GetMemoryRank(I2),
desc.GetMemoryRank(I3),
desc.GetMemoryRank(I4),
desc.GetMemoryRank(I5),
desc.GetMemoryRank(I6),
desc.GetMemoryRank(I7));
});
static_if<ndim == 9>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
...@@ -567,8 +618,10 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s) ...@@ -567,8 +618,10 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
constexpr auto I7 = Number<7>{}; constexpr auto I7 = Number<7>{};
constexpr auto I8 = Number<8>{}; constexpr auto I8 = Number<8>{};
constexpr auto desc = fwd(TDesc{});
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u " printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u "
"%u}\n", "%u}, ranks {%u %u %u %u %u %u %u %u %u}\n",
s, s,
desc.GetNumOfDimension(), desc.GetNumOfDimension(),
desc.GetLength(I0), desc.GetLength(I0),
...@@ -588,10 +641,19 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s) ...@@ -588,10 +641,19 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
desc.GetStride(I5), desc.GetStride(I5),
desc.GetStride(I6), desc.GetStride(I6),
desc.GetStride(I7), desc.GetStride(I7),
desc.GetStride(I8)); desc.GetStride(I8),
} desc.GetMemoryRank(I0),
else if(ndim == 10) desc.GetMemoryRank(I1),
{ desc.GetMemoryRank(I2),
desc.GetMemoryRank(I3),
desc.GetMemoryRank(I4),
desc.GetMemoryRank(I5),
desc.GetMemoryRank(I6),
desc.GetMemoryRank(I7),
desc.GetMemoryRank(I8));
});
static_if<ndim == 10>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
...@@ -603,8 +665,10 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s) ...@@ -603,8 +665,10 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
constexpr auto I8 = Number<8>{}; constexpr auto I8 = Number<8>{};
constexpr auto I9 = Number<9>{}; constexpr auto I9 = Number<9>{};
constexpr auto desc = fwd(TDesc{});
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u " printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u "
"%u %u %u}\n", "%u %u %u}, ranks {%u %u %u %u %u %u %u %u %u %u}\n",
s, s,
desc.GetNumOfDimension(), desc.GetNumOfDimension(),
desc.GetLength(I0), desc.GetLength(I0),
...@@ -626,6 +690,16 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s) ...@@ -626,6 +690,16 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
desc.GetStride(I6), desc.GetStride(I6),
desc.GetStride(I7), desc.GetStride(I7),
desc.GetStride(I8), desc.GetStride(I8),
desc.GetStride(I9)); desc.GetStride(I9),
} desc.GetMemoryRank(I0),
desc.GetMemoryRank(I1),
desc.GetMemoryRank(I2),
desc.GetMemoryRank(I3),
desc.GetMemoryRank(I4),
desc.GetMemoryRank(I5),
desc.GetMemoryRank(I6),
desc.GetMemoryRank(I7),
desc.GetMemoryRank(I8),
desc.GetMemoryRank(I9));
});
} }
...@@ -263,7 +263,19 @@ struct sequence_map_inverse<Sequence<Is...>> ...@@ -263,7 +263,19 @@ struct sequence_map_inverse<Sequence<Is...>>
using SeqMapType = using SeqMapType =
typename sequence_map_inverse_impl<Sequence<Is...>, is_valid_map>::SeqMapType; typename sequence_map_inverse_impl<Sequence<Is...>, is_valid_map>::SeqMapType;
}; };
#endif
template <class Seq>
struct is_valid_sequence_map
{
static constexpr bool value =
#if 0 // sequence_sort is not implemented yet
is_same<typename arithmetic_sequence_gen<0, Seq::GetSize(), 1>::SeqType,
typename sequence_sort<Seq>::SortedSeqType>::value;
#else
true;
#endif #endif
};
template <index_t... Xs, index_t... Ys> template <index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>) __host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
......
...@@ -26,7 +26,18 @@ struct BlockwiseTensorSliceCopy_generic_v1 ...@@ -26,7 +26,18 @@ struct BlockwiseTensorSliceCopy_generic_v1
Array<index_t, nDim> dst_block_data_multi_id_begin) Array<index_t, nDim> dst_block_data_multi_id_begin)
{ {
// check NDim consistent // check NDim consistent
static_assert(SrcDesc::GetNumOfDimension() == DstDesc::GetNumOfDimension(), "wrong"); static_assert(nDim == SrcDesc::GetNumOfDimension() &&
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::GetSize() &&
nDim == SubLengths::GetSize() && nDim == DataClusterLengths::GetSize() &&
nDim == ThreadClusterArrangeOrder::GetSize() &&
nDim == SrcAccessOrder::GetSize() && nDim == DstAccessOrder::GetSize(),
"wrong");
// check
static_assert(is_valid_sequence_map<ThreadClusterArrangeOrder>::value &&
is_valid_sequence_map<SrcAccessOrder>::value &&
is_valid_sequence_map<DstAccessOrder>::value,
"wrong!");
// thread cluster // thread cluster
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_default_rank_packed( constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_default_rank_packed(
...@@ -73,8 +84,38 @@ struct BlockwiseTensorSliceCopy_generic_v1 ...@@ -73,8 +84,38 @@ struct BlockwiseTensorSliceCopy_generic_v1
mSrcMyThreadOffset = SrcDesc::GetOffsetFromMultiIndex(src_block_data_multi_id_begin + mSrcMyThreadOffset = SrcDesc::GetOffsetFromMultiIndex(src_block_data_multi_id_begin +
thread_data_multi_id_begin); thread_data_multi_id_begin);
mSrcMyThreadOffset = DstDesc::GetOffsetFromMultiIndex(dst_block_data_multi_id_begin + mDstMyThreadOffset = DstDesc::GetOffsetFromMultiIndex(dst_block_data_multi_id_begin +
thread_data_multi_id_begin); thread_data_multi_id_begin);
#if 0
{
printf("id %5u %5u: "
"src_block_data_multi_id_begin: %u %u %u %u, "
"thread_cluster_multi_id: %u %u %u %u, "
"data_cluster_multi_id: %u %u %u %u, "
"thread_data_multi_id_begin: %u %u %u %u, "
"mSrcMyThreadOffset %u, mDstMyThreadOffset %u \n",
get_block_1d_id(),
get_thread_local_1d_id(),
src_block_data_multi_id_begin[0],
src_block_data_multi_id_begin[1],
src_block_data_multi_id_begin[2],
src_block_data_multi_id_begin[3],
thread_cluster_multi_id[0],
thread_cluster_multi_id[1],
thread_cluster_multi_id[2],
thread_cluster_multi_id[3],
data_cluster_multi_id[0],
data_cluster_multi_id[1],
data_cluster_multi_id[2],
data_cluster_multi_id[3],
thread_data_multi_id_begin[0],
thread_data_multi_id_begin[1],
thread_data_multi_id_begin[2],
thread_data_multi_id_begin[3],
mSrcMyThreadOffset,
mDstMyThreadOffset);
}
#endif
} }
__device__ static constexpr index_t GetRegisterClipboardSize() __device__ static constexpr index_t GetRegisterClipboardSize()
......
...@@ -130,6 +130,7 @@ struct BlockwiseTensorSliceReorderCopy_v3 ...@@ -130,6 +130,7 @@ struct BlockwiseTensorSliceReorderCopy_v3
mSrcMyThreadOffset = mSrcMyThreadOffset =
src_desc.GetOffsetFromMultiIndex(src_data_multi_id + src_block_data_multi_id_begin); src_desc.GetOffsetFromMultiIndex(src_data_multi_id + src_block_data_multi_id_begin);
mDstMyThreadOffset = mDstMyThreadOffset =
dst_desc.GetOffsetFromMultiIndex(dst_data_multi_id + dst_block_data_multi_id_begin); dst_desc.GetOffsetFromMultiIndex(dst_data_multi_id + dst_block_data_multi_id_begin);
} }
......
...@@ -45,22 +45,22 @@ struct GridwiseConvolutionDirect_v2_nchw_kcyx_nkhw ...@@ -45,22 +45,22 @@ struct GridwiseConvolutionDirect_v2_nchw_kcyx_nkhw
constexpr index_t Y = wei_kcyx_global_desc.GetLength(I2); constexpr index_t Y = wei_kcyx_global_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_global_desc.GetLength(I3); constexpr index_t X = wei_kcyx_global_desc.GetLength(I3);
constexpr auto wei_ke_global_desc = make_ConstantTensorDescriptor( constexpr auto wei_ke_global_desc = make_ConstantTensorDescriptor_default_rank_packed(
Sequence<K, C * Y * X>{}); // 2d view of wei for blockwise copy Sequence<K, C * Y * X>{}); // 2d view of wei for blockwise copy
constexpr index_t HiPerBlock = HoPerBlock + Y - 1; constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
constexpr index_t WiPerBlock = WoPerBlock + X - 1; constexpr index_t WiPerBlock = WoPerBlock + X - 1;
constexpr auto in_nchw_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto in_nchw_block_desc = make_ConstantTensorDescriptor_default_rank_aligned(
Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{}, Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{},
Number<InBlockCopyDataPerRead>{}); Number<InBlockCopyDataPerRead>{});
constexpr auto wei_ke_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto wei_ke_block_desc = make_ConstantTensorDescriptor_default_rank_aligned(
Sequence<KPerBlock, CPerBlock * Y * X>{}, Sequence<KPerBlock, CPerBlock * Y * X>{},
Number<WeiBlockCopyDataPerRead>{}); // 2d view of wei for blockwise copy Number<WeiBlockCopyDataPerRead>{}); // 2d view of wei for blockwise copy
constexpr auto wei_kcyx_block_desc = constexpr auto wei_kcyx_block_desc = make_ConstantTensorDescriptor_default_rank(
make_ConstantTensorDescriptor(Sequence<KPerBlock, CPerBlock, Y, X>{}, Sequence<KPerBlock, CPerBlock, Y, X>{},
Sequence<wei_ke_block_desc.GetStride(I0), Y * X, X, 1>{}); Sequence<wei_ke_block_desc.GetStride(I0), Y * X, X, 1>{});
// shared mem // shared mem
...@@ -82,11 +82,11 @@ struct GridwiseConvolutionDirect_v2_nchw_kcyx_nkhw ...@@ -82,11 +82,11 @@ struct GridwiseConvolutionDirect_v2_nchw_kcyx_nkhw
constexpr index_t HiPerThread = HoPerThread + Y - 1; constexpr index_t HiPerThread = HoPerThread + Y - 1;
constexpr index_t WiPerThread = WoPerThread + X - 1; constexpr index_t WiPerThread = WoPerThread + X - 1;
constexpr auto in_nchw_thread_block_desc = make_ConstantTensorDescriptor( constexpr auto in_nchw_thread_block_desc = make_ConstantTensorDescriptor_default_rank(
Sequence<NPerThread, CPerThread, HiPerThread, WiPerThread>{}, Sequence<NPerThread, CPerThread, HiPerThread, WiPerThread>{},
in_nchw_block_desc.GetStrides()); in_nchw_block_desc.GetStrides());
constexpr auto wei_kcyx_thread_block_desc = make_ConstantTensorDescriptor( constexpr auto wei_kcyx_thread_block_desc = make_ConstantTensorDescriptor_default_rank(
Sequence<KPerThread, CPerThread, Y, X>{}, wei_kcyx_block_desc.GetStrides()); Sequence<KPerThread, CPerThread, Y, X>{}, wei_kcyx_block_desc.GetStrides());
constexpr auto out_nkhw_thread_desc = get_convolution_output_default_4d_tensor_descriptor( constexpr auto out_nkhw_thread_desc = get_convolution_output_default_4d_tensor_descriptor(
...@@ -170,7 +170,7 @@ struct GridwiseConvolutionDirect_v2_nchw_kcyx_nkhw ...@@ -170,7 +170,7 @@ struct GridwiseConvolutionDirect_v2_nchw_kcyx_nkhw
decltype(wei_ke_global_desc), decltype(wei_ke_global_desc),
decltype(wei_ke_block_desc), decltype(wei_ke_block_desc),
decltype(wei_ke_block_desc.GetLengths()), decltype(wei_ke_block_desc.GetLengths()),
WeiBlockCopyDataPerRead>{}; WeiBlockCopyDataPerRead>({0, 0}, {0, 0});
#endif #endif
// set threadwise output tensor to 0 // set threadwise output tensor to 0
......
...@@ -459,7 +459,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw ...@@ -459,7 +459,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
constexpr auto map_out_global2thread = Sequence<8, 9, 0, 1, 2, 3, 4, 5, 6, 7>{}; constexpr auto map_out_global2thread = Sequence<8, 9, 0, 1, 2, 3, 4, 5, 6, 7>{};
threadwise_tensor_slice_copy_reorder_given_dst2src_v2( #if 0
threadwise_tensor_slice_copy_reorder_given_dst2src_v3(
out_10d_thread_desc, out_10d_thread_desc,
p_out_thread, p_out_thread,
out_10d_global_desc, out_10d_global_desc,
...@@ -470,8 +471,24 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw ...@@ -470,8 +471,24 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin), wo_block_data_begin + wo_thread_data_begin),
out_10d_thread_desc.GetLengths(), out_10d_thread_desc.GetLengths(),
map_out_global2thread); map_out_global2thread,
// Number<OutThreadCopyDataPerWrite_W>{}); Number<OutThreadCopyDataPerWrite_W>{});
#else
threadwise_tensor_slice_copy_generic(
out_10d_thread_desc.ReorderGivenNew2Old(map_out_global2thread),
p_out_thread,
make_zero_array<index_t, 10>(),
out_10d_global_desc,
p_out_global +
out_n_k_h_w_global_desc.GetOffsetFromMultiIndex(
n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin),
make_zero_array<index_t, 10>(),
out_10d_thread_desc.GetLengths().ReorderGivenNew2Old(map_out_global2thread),
arithmetic_sequence_gen<0, 10, 1>::SeqType{});
#endif
}); });
} }
}; };
...@@ -32,6 +32,8 @@ template <index_t GridSize, ...@@ -32,6 +32,8 @@ template <index_t GridSize,
class InBlockCopyClusterLengths_N1_N2_C_B, class InBlockCopyClusterLengths_N1_N2_C_B,
index_t InBlockCopySrcDataPerRead_B, index_t InBlockCopySrcDataPerRead_B,
index_t InBlockCopyDstDataPerWrite_N2, index_t InBlockCopyDstDataPerWrite_N2,
class WeiBlockCopySubLengths_C_K,
class WeiBlockCopyClusterLengths_C_K,
index_t WeiBlockCopyDataPerAccess_K> index_t WeiBlockCopyDataPerAccess_K>
struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
{ {
...@@ -40,7 +42,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw ...@@ -40,7 +42,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
Float* const __restrict__ p_out_global) const Float* const __restrict__ p_out_global) const
{ {
// this is a mess // this is a mess
// TODO: more elegent way of specifying (or calculating) performance variables // TODO: fidn more elegent way of specifying (or calculating) performance parameters
static_assert(N2 == GemmNPerThreadSubC, "wrong!"); static_assert(N2 == GemmNPerThreadSubC, "wrong!");
static_assert((N1 * N2 * BPerBlock) % static_assert((N1 * N2 * BPerBlock) %
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
...@@ -132,7 +134,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw ...@@ -132,7 +134,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
Float, Float,
decltype(in_n1_n2_c_b_global_merged_desc), decltype(in_n1_n2_c_b_global_merged_desc),
decltype(in_n1_n2_c_b_block_desc), decltype(in_n1_n2_c_b_block_desc),
Sequence<N1, N2, CPerBlock, BPerBlock>, decltype(in_n1_n2_c_b_block_desc.GetLengths()),
InBlockCopySubLengths_N1_N2_C_B, InBlockCopySubLengths_N1_N2_C_B,
InBlockCopyClusterLengths_N1_N2_C_B, InBlockCopyClusterLengths_N1_N2_C_B,
Sequence<2, 0, 1, 3>, // thread_arrange_order [C, N1, N2, B] Sequence<2, 0, 1, 3>, // thread_arrange_order [C, N1, N2, B]
...@@ -153,15 +155,21 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw ...@@ -153,15 +155,21 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
// operator for blockwise copy of weight into LDS // operator for blockwise copy of weight into LDS
// slicing a tensor // slicing a tensor
// this copy operator already have tensor offset built-in // this copy operator already have blockwise offset built-in
const auto blockwise_wei_copy = const auto blockwise_wei_copy =
Blockwise2dTensorCopy3<BlockSize, BlockwiseTensorSliceCopy_generic_v1<BlockSize,
Float, Float,
decltype(wei_c_k_global_desc), decltype(wei_c_k_global_desc),
decltype(wei_c_k_block_desc), decltype(wei_c_k_block_desc),
decltype(wei_c_k_block_desc.GetLengths()), decltype(wei_c_k_block_desc.GetLengths()),
WeiBlockCopyDataPerAccess_K>({0, k_block_data_on_global}, WeiBlockCopySubLengths_C_K,
{0, 0}); WeiBlockCopyClusterLengths_C_K,
Sequence<0, 1>, // thread_arrange_order [C, K]
Sequence<0, 1>, // src_access_order [C, K]
Sequence<0, 1>, // dst_access_order [C, K]
WeiBlockCopyDataPerAccess_K,
WeiBlockCopyDataPerAccess_K>(
{0, k_block_data_on_global}, {0, 0});
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
...@@ -244,12 +252,16 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw ...@@ -244,12 +252,16 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
p_in_block_on_global += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1), p_in_block_on_global += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1),
p_wei_block_on_global += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0)) p_wei_block_on_global += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
{ {
#if 1 // debug
blockwise_in_copy.Run(p_in_block_on_global, p_in_block); blockwise_in_copy.Run(p_in_block_on_global, p_in_block);
blockwise_wei_copy.Run(p_wei_block_on_global, p_wei_block); blockwise_wei_copy.Run(p_wei_block_on_global, p_wei_block);
#endif
__syncthreads(); __syncthreads();
#if 1 // debug
blockwise_gemm.Run(p_wei_block, p_in_block, p_out_thread); blockwise_gemm.Run(p_wei_block, p_in_block, p_out_thread);
#endif
__syncthreads(); __syncthreads();
} }
...@@ -296,7 +308,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw ...@@ -296,7 +308,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
k_block_data_on_global + c_thread_mtx_on_block.row; k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global = const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col; b_block_data_on_global + c_thread_mtx_on_block.col / N2;
// output merged global tensor descriptor, for calculating origin of thread tensor // output merged global tensor descriptor, for calculating origin of thread tensor
// in global memory // in global memory
...@@ -320,7 +332,6 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw ...@@ -320,7 +332,6 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex( out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
k_thread_data_on_global, 0, 0, 0); // dst origin on merged global tensor k_thread_data_on_global, 0, 0, 0); // dst origin on merged global tensor
// copy
threadwise_tensor_slice_copy_generic( threadwise_tensor_slice_copy_generic(
out_k0_k1_k2_n1_b_n2_thread_mem_desc, // src thread tensor (in register) descriptor out_k0_k1_k2_n1_b_n2_thread_mem_desc, // src thread tensor (in register) descriptor
p_out_thread, // origin of src p_out_thread, // origin of src
...@@ -335,8 +346,33 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw ...@@ -335,8 +346,33 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
b_thread_data_on_global, b_thread_data_on_global,
0}, // starting point of slice w.r.t. origin of dst 0}, // starting point of slice w.r.t. origin of dst
out_k0_k1_k2_n1_b_n2_thread_mem_desc.GetLengths(), // slice lengths out_k0_k1_k2_n1_b_n2_thread_mem_desc.GetLengths(), // slice lengths
Sequence<2, 3, 4, 0, 5, 1>{} // order of dimension access Sequence<3, 5, 0, 1, 2, 4>{} // dimension access order [n1, n2, k0, k1, k2, b]
); );
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_n0_n1_n2_c_h_w_global_mem_desc,
"in_n0_n1_n2_c_h_w_global_mem_desc");
print_ConstantMergedTensorDescriptor(in_n1_n2_c_b_global_merged_desc,
"in_n1_n2_c_b_global_merged_desc");
print_ConstantTensorDescriptor(in_c_n1_b_n2_block_mem_desc,
"in_c_n1_b_n2_block_mem_desc");
print_ConstantTensorDescriptor(in_n1_n2_c_b_block_desc, "in_n1_n2_c_b_block_desc");
print_ConstantTensorDescriptor(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc,
"out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc");
print_ConstantMergedTensorDescriptor(out_k_n1_b_n2_global_merged_desc,
"out_k_n1_b_n2_global_merged_desc");
print_ConstantTensorDescriptor(out_k0_k1_k2_n1_b_n2_thread_mem_desc,
"out_k0_k1_k2_n1_b_n2_thread_mem_desc");
}
#endif
} }
} }
}; };
...@@ -80,8 +80,10 @@ __device__ void threadwise_direct_convolution_2(InDesc, ...@@ -80,8 +80,10 @@ __device__ void threadwise_direct_convolution_2(InDesc,
constexpr auto wei_desc = WeiDesc{}; constexpr auto wei_desc = WeiDesc{};
constexpr auto out_desc = OutDesc{}; constexpr auto out_desc = OutDesc{};
constexpr auto in_reg_desc = make_ConstantTensorDescriptor(in_desc.GetLengths()); constexpr auto in_reg_desc =
constexpr auto wei_reg_desc = make_ConstantTensorDescriptor(wei_desc.GetLengths()); make_ConstantTensorDescriptor_default_rank_packed(in_desc.GetLengths());
constexpr auto wei_reg_desc =
make_ConstantTensorDescriptor_default_rank_packed(wei_desc.GetLengths());
// register // register
TInWei p_in_reg[in_reg_desc.GetElementSpace()]; TInWei p_in_reg[in_reg_desc.GetElementSpace()];
......
...@@ -67,6 +67,22 @@ __device__ void threadwise_gemm(MatrixA, ...@@ -67,6 +67,22 @@ __device__ void threadwise_gemm(MatrixA,
integral_constant<bool, TransC>, integral_constant<bool, TransC>,
FloatC* __restrict__ p_c_thread) FloatC* __restrict__ p_c_thread)
{ {
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
printf("p_a_thread: %f %f %f %f\n",
p_a_thread[0],
p_a_thread[1],
p_a_thread[2],
p_a_thread[3]);
printf("p_b_thread: %f %f %f %f\n",
p_b_thread[0],
p_b_thread[1],
p_b_thread[2],
p_b_thread[3]);
}
#endif
if(TransA && (!TransB) && (!TransC)) if(TransA && (!TransB) && (!TransC))
{ {
constexpr auto a_mtx = MatrixA{}; constexpr auto a_mtx = MatrixA{};
......
...@@ -204,25 +204,46 @@ __device__ void threadwise_tensor_slice_copy_generic( ...@@ -204,25 +204,46 @@ __device__ void threadwise_tensor_slice_copy_generic(
SliceLengths, SliceLengths,
DimAccessOrder) DimAccessOrder)
{ {
static_assert(SrcDesc::GetNumOfDimension() == DstDesc::GetNumOfDimension(), constexpr index_t nDim = SrcDesc::GetNumOfDimension();
static_assert(nDim == SrcDesc::GetNumOfDimension() && nDim == DstDesc::GetNumOfDimension() &&
nDim == SliceLengths::GetSize() && nDim == DimAccessOrder::GetSize(),
"wrong! # of dimensions not the same"); "wrong! # of dimensions not the same");
constexpr auto src_desc = SrcDesc{}; static_assert(is_valid_sequence_map<DimAccessOrder>::value, "wrong! map is not valid");
constexpr auto dst_desc = DstDesc{};
constexpr auto slice_lengths_in_access_order = constexpr auto slice_lengths_in_access_order =
SliceLengths{}.ReorderGivenNew2Old(DimAccessOrder{}); SliceLengths::ReorderGivenNew2Old(DimAccessOrder{});
#if 1
ford<decltype(slice_lengths_in_access_order)>{}([&](auto data_multi_id_in_access_order) { ford<decltype(slice_lengths_in_access_order)>{}([&](auto data_multi_id_in_access_order) {
const auto data_multi_id = const auto data_multi_id =
reorder_array_given_old2new(data_multi_id_in_access_order, DimAccessOrder{}); reorder_array_given_old2new(data_multi_id_in_access_order, DimAccessOrder{});
const index_t src_index =
SrcDesc::GetOffsetFromMultiIndex(src_multi_id_begin + data_multi_id);
const index_t dst_index = const index_t dst_index =
dst_desc.GetOffsetFromMultiIndex(src_multi_id_begin + data_multi_id); DstDesc::GetOffsetFromMultiIndex(dst_multi_id_begin + data_multi_id);
p_dst[dst_index] = p_src[src_index];
});
#else
static_ford<decltype(slice_lengths_in_access_order)>{}(
[&](auto data_multi_id_in_access_order_) {
constexpr auto data_multi_id_in_access_order =
sequence2array(decltype(data_multi_id_in_access_order_){});
const auto data_multi_id =
reorder_array_given_old2new(data_multi_id_in_access_order, DimAccessOrder{});
const index_t src_index = const index_t src_index =
src_desc.GetOffsetFromMultiIndex(dst_multi_id_begin + data_multi_id); SrcDesc::GetOffsetFromMultiIndex(src_multi_id_begin + data_multi_id);
const index_t dst_index =
DstDesc::GetOffsetFromMultiIndex(dst_multi_id_begin + data_multi_id);
p_dst[dst_index] = p_src[src_index]; p_dst[dst_index] = p_src[src_index];
}); });
#endif
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment