Commit a0584426 authored by Chao Liu's avatar Chao Liu
Browse files

refactoring ConstantTensorDescriptor

parent fd8de384
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <unistd.h> #include <unistd.h>
#include "device.hpp" #include "device.hpp"
#include "gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp" #include "gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp"
#include "gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp"
template <class T, class InDesc, class WeiDesc, class OutDesc> template <class T, class InDesc, class WeiDesc, class OutDesc>
void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc, void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
...@@ -57,7 +58,13 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc, ...@@ -57,7 +58,13 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
for(unsigned i = 0; i < nrepeat; ++i) for(unsigned i = 0; i < nrepeat; ++i)
{ {
float time = launch_kernel(gridwise_direct_convolution_2_nchw_kcyx_nkhw<T, float time = launch_kernel(
#if 0
gridwise_direct_convolution_2_nchw_kcyx_nkhw
#else
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#endif
<T,
InDesc, InDesc,
WeiDesc, WeiDesc,
OutDesc, OutDesc,
......
...@@ -211,7 +211,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -211,7 +211,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
for(unsigned i = 0; i < nrepeat; ++i) for(unsigned i = 0; i < nrepeat; ++i)
{ {
float time = launch_kernel( float time = launch_kernel(
#if 1 #if 0
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
#else #else
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
......
#pragma once
template <class TData, unsigned NSize>
struct Array
{
using Type = Array<TData, NSize>;
static constexpr unsigned nSize = NSize;
unsigned mData[nSize];
template <class... Xs>
__host__ __device__ Array(Xs... xs) : mData({static_cast<TData>(xs)...})
{
}
__host__ __device__ TData operator[](unsigned i) const { return mData[i]; }
};
...@@ -65,8 +65,8 @@ __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0 ...@@ -65,8 +65,8 @@ __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0
template <class Lengths, class Strides> template <class Lengths, class Strides>
struct ConstantTensorDescriptor struct ConstantTensorDescriptor
{ {
using Type = ConstantTensorDescriptor<Lengths, Strides>;
static constexpr unsigned nDim = Lengths::nDim; static constexpr unsigned nDim = Lengths::nDim;
using NDimConstant = Number<nDim>;
__host__ __device__ constexpr ConstantTensorDescriptor() __host__ __device__ constexpr ConstantTensorDescriptor()
{ {
...@@ -91,293 +91,70 @@ struct ConstantTensorDescriptor ...@@ -91,293 +91,70 @@ struct ConstantTensorDescriptor
return Strides{}.Get(Number<I>{}); return Strides{}.Get(Number<I>{});
} }
__host__ __device__ constexpr unsigned GetElementSize() const // c++14 doesn't support constexpr lambdas, has to use this trick instead
struct GetElementSize_f
{ {
static_assert(nDim >= 2 && nDim <= 8, "nDim"); template <class IDim>
__host__ __device__ constexpr unsigned operator()(IDim idim) const
if(nDim == 2)
{ {
constexpr auto I0 = Number<0>{}; return Type{}.GetLength(idim);
constexpr auto I1 = Number<1>{};
return GetLength(I0) * GetLength(I1);
} }
else if(nDim == 3) };
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
return GetLength(I0) * GetLength(I1) * GetLength(I2); __host__ __device__ constexpr unsigned GetElementSize() const
}
else if(nDim == 4)
{ {
constexpr auto I0 = Number<0>{}; // c++14 doesn't support constexpr lambdas, has to use this trick instead
constexpr auto I1 = Number<1>{}; struct multiply
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3);
}
else if(nDim == 5)
{ {
constexpr auto I0 = Number<0>{}; __host__ __device__ constexpr unsigned operator()(unsigned a, unsigned b) const
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3) * GetLength(I4);
}
else if(nDim == 6)
{ {
constexpr auto I0 = Number<0>{}; return a * b;
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3) * GetLength(I4) *
GetLength(I5);
} }
else if(nDim == 7) };
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3) * GetLength(I4) * return static_const_reduce_n<nDim>{}(GetElementSize_f{}, multiply{});
GetLength(I5) * GetLength(I6);
} }
else if(nDim == 8)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3) * GetLength(I4) * // c++14 doesn't support constexpr lambdas, has to use this trick instead
GetLength(I5) * GetLength(I6) * GetLength(I7); struct GetElementSpace_f
}
else
{ {
assert(false); template <class IDim>
} __host__ __device__ constexpr unsigned operator()(IDim idim) const
{
return (Type{}.GetLength(idim) - 1) * Type{}.GetStride(idim);
} }
};
template <class Align = Number<1>> template <class Align = Number<1>>
__host__ __device__ constexpr unsigned GetElementSpace(Align align = Align{}) const __host__ __device__ constexpr unsigned GetElementSpace(Align align = Align{}) const
{ {
static_assert(nDim >= 2 && nDim <= 8, "nDim"); // c++14 doesn't support constexpr lambdas, has to use this trick instead
struct add
constexpr unsigned align_size = align.Get();
if(nDim == 2)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
align_size;
}
else if(nDim == 3)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
(GetLength(I2) - 1) * GetStride(I2) + align_size;
}
else if(nDim == 4)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
(GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) +
align_size;
}
else if(nDim == 5)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
(GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) +
(GetLength(I4) - 1) * GetStride(I4) + align_size;
}
else if(nDim == 6)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
(GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) +
(GetLength(I4) - 1) * GetStride(I4) + (GetLength(I5) - 1) * GetStride(I5) +
align_size;
}
else if(nDim == 7)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
(GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) +
(GetLength(I4) - 1) * GetStride(I4) + (GetLength(I5) - 1) * GetStride(I5) +
(GetLength(I6) - 1) * GetStride(I6) + align_size;
}
else if(nDim == 8)
{ {
constexpr auto I0 = Number<0>{}; __host__ __device__ constexpr unsigned operator()(unsigned a, unsigned b) const
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
(GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) +
(GetLength(I4) - 1) * GetStride(I4) + (GetLength(I5) - 1) * GetStride(I5) +
(GetLength(I6) - 1) * GetStride(I6) + (GetLength(I7) - 1) * GetStride(I7) +
align_size;
}
}
// this is ugly, only for 2d
__host__ __device__ unsigned Get1dIndex(unsigned i0, unsigned i1) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
static_assert(nDim == 2, "nDim is not 2");
return i0 * GetStride(I0) + i1 * GetStride(I1);
}
// this is ugly, only for 3d
__host__ __device__ unsigned Get1dIndex(unsigned i0, unsigned i1, unsigned i2) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
static_assert(nDim == 3, "nDim is not 3");
return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2);
}
// this is ugly, only for 4d
__host__ __device__ unsigned
Get1dIndex(unsigned i0, unsigned i1, unsigned i2, unsigned i3) const
{ {
constexpr auto I0 = Number<0>{}; return a + b;
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
static_assert(nDim == 4, "nDim is not 4");
return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3);
} }
};
// this is ugly, only for 5d return static_const_reduce_n<nDim>{}(GetElementSpace_f{}, add{}) + align.Get();
__host__ __device__ unsigned
Get1dIndex(unsigned i0, unsigned i1, unsigned i2, unsigned i3, unsigned i4) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
static_assert(nDim == 5, "nDim is not 5");
return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3) +
i4 * GetStride(I4);
} }
// this is ugly, only for 6d template <class... Is>
__host__ __device__ unsigned __host__ __device__ unsigned Get1dIndex(Is... is) const
Get1dIndex(unsigned i0, unsigned i1, unsigned i2, unsigned i3, unsigned i4, unsigned i5) const
{ {
constexpr auto I0 = Number<0>{}; static_assert(sizeof...(Is) == nDim, "number of multi-index is wrong");
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
static_assert(nDim == 6, "nDim is not 6"); const auto multi_id = Array<unsigned, nDim>(is...);
return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3) +
i4 * GetStride(I4) + i5 * GetStride(I5);
}
// this is ugly, only for 7d unsigned id = 0;
__host__ __device__ unsigned Get1dIndex(unsigned i0,
unsigned i1,
unsigned i2,
unsigned i3,
unsigned i4,
unsigned i5,
unsigned i6) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
static_assert(nDim == 7, "nDim is not 7");
return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3) +
i4 * GetStride(I4) + i5 * GetStride(I5) + i6 * GetStride(I6);
}
// this is ugly, only for 8d static_loop_n<nDim>{}([&](auto IDim) {
__host__ __device__ unsigned Get1dIndex(unsigned i0, constexpr unsigned idim = IDim.Get();
unsigned i1, id += multi_id[idim] * GetStride(IDim);
unsigned i2, });
unsigned i3,
unsigned i4,
unsigned i5,
unsigned i6,
unsigned i7) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
static_assert(nDim == 8, "nDim is not 8"); return id;
return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3) +
i4 * GetStride(I4) + i5 * GetStride(I5) + i6 * GetStride(I6) + i7 * GetStride(I7);
} }
__host__ __device__ constexpr auto Condense() const __host__ __device__ constexpr auto Condense() const
...@@ -385,6 +162,12 @@ struct ConstantTensorDescriptor ...@@ -385,6 +162,12 @@ struct ConstantTensorDescriptor
constexpr auto default_strides = calculate_default_strides(Lengths{}); constexpr auto default_strides = calculate_default_strides(Lengths{});
return ConstantTensorDescriptor<Lengths, decltype(default_strides)>{}; return ConstantTensorDescriptor<Lengths, decltype(default_strides)>{};
} }
template <unsigned IDim, unsigned NVector>
__host__ __device__ constexpr auto Vectorize(Number<IDim>, Number<NVector>) const
{
assert(false); // not implemented
}
}; };
template <class Lengths> template <class Lengths>
......
#pragma once
#include "constant_integral.hip.hpp"
#include "functional.hip.hpp"
template <unsigned... Is>
struct Sequence
{
using Type = Sequence<Is...>;
static constexpr unsigned nDim = sizeof...(Is);
const unsigned mData[nDim] = {Is...};
template <unsigned I>
__host__ __device__ constexpr unsigned Get(Number<I>) const
{
return mData[I];
}
// this is ugly, only for nDIm = 4
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
__host__ __device__ constexpr auto ReorderByGetNewFromOld(Sequence<I0, I1, I2, I3>) const
{
static_assert(nDim == 4, "nDim != 4");
constexpr auto old_sequence = Type{};
constexpr unsigned NR0 = old_sequence.mData[I0];
constexpr unsigned NR1 = old_sequence.mData[I1];
constexpr unsigned NR2 = old_sequence.mData[I2];
constexpr unsigned NR3 = old_sequence.mData[I3];
return Sequence<NR0, NR1, NR2, NR3>{};
}
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
__host__ __device__ constexpr auto ReorderByPutOldToNew(Sequence<I0, I1, I2, I3>) const
{
// don't know how to implement this
printf("Sequence::ReorderByPutOldToNew not implemented");
assert(false);
}
template <unsigned I>
__host__ __device__ constexpr auto PushBack(Number<I>) const
{
return Sequence<Is..., I>{};
}
__host__ __device__ constexpr auto PopBack() const;
template <class F>
__host__ __device__ constexpr auto Transform(F f) const
{
return Sequence<f(Is)...>{};
}
};
template <unsigned... Is, unsigned I>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<Is..., I>)
{
static_assert(sizeof...(Is) >= 1, "empty Sequence!");
return Sequence<Is...>{};
}
template <class F, unsigned... Xs, unsigned... Ys>
__host__ __device__ constexpr auto sequence_sequence_op(Sequence<Xs...>, Sequence<Ys...>, F f)
{
static_assert(Sequence<Xs...>::nDim == Sequence<Ys...>::nDim, "Dim not the same");
return Sequence<f(Xs, Ys)...>{};
}
template <unsigned... Xs, unsigned... Ys>
__host__ __device__ constexpr auto sequence_sequence_add(Sequence<Xs...>, Sequence<Ys...>)
{
struct add
{
__host__ __device__ constexpr unsigned operator()(unsigned x, unsigned y) const
{
return x + y;
}
};
return sequence_sequence_op(Sequence<Xs...>{}, Sequence<Ys...>{}, add{});
}
template <unsigned... Is>
__host__ __device__ constexpr auto Sequence<Is...>::PopBack() const
{
return sequence_pop_back(Type{});
}
#pragma once #pragma once
#include "constant_integral.hip.hpp"
#include "Sequence.hip.hpp"
#include "Array.hip.hpp"
#include "functional.hip.hpp"
__device__ unsigned get_thread_local_1d_id() { return threadIdx.x; } __device__ unsigned get_thread_local_1d_id() { return threadIdx.x; }
...@@ -91,54 +95,6 @@ struct vector_type<half, 8> ...@@ -91,54 +95,6 @@ struct vector_type<half, 8>
}; };
#endif #endif
template <class T, T N>
struct integral_constant
{
static const T value = N;
__host__ __device__ constexpr T Get() const { return value; }
};
template <unsigned N>
using Number = integral_constant<unsigned, N>;
template <unsigned... Is>
struct Sequence
{
using Type = Sequence<Is...>;
static constexpr unsigned nDim = sizeof...(Is);
const unsigned mData[nDim] = {Is...};
template <unsigned I>
__host__ __device__ constexpr unsigned Get(Number<I>) const
{
return mData[I];
}
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
__host__ __device__ constexpr auto ReorderByGetNewFromOld(Sequence<I0, I1, I2, I3>) const
{
constexpr auto old_sequence = Type{};
constexpr unsigned NR0 = old_sequence.mData[I0];
constexpr unsigned NR1 = old_sequence.mData[I1];
constexpr unsigned NR2 = old_sequence.mData[I2];
constexpr unsigned NR3 = old_sequence.mData[I3];
return Sequence<NR0, NR1, NR2, NR3>{};
}
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
__host__ __device__ constexpr auto ReorderByPutOldToNew(Sequence<I0, I1, I2, I3>) const
{
// don't know how to implement this
printf("Sequence::ReorderByPutOldToNew not implemented");
assert(false);
}
};
template <typename T> template <typename T>
__host__ __device__ constexpr T max(T a, T b) __host__ __device__ constexpr T max(T a, T b)
{ {
......
#pragma once
template <class T, T N>
struct integral_constant
{
static const T value = N;
__host__ __device__ constexpr T Get() const { return value; }
};
template <unsigned N>
using Number = integral_constant<unsigned, N>;
#pragma once
#include "constant_integral.hip.hpp"
template <unsigned NLoop>
struct static_loop_n
{
template <class F>
__host__ __device__ void operator()(F f) const
{
static_assert(NLoop > 1, "out-of-range");
f(Number<NLoop - 1>{});
static_loop_n<NLoop - 1>{}(f);
}
};
template <>
struct static_loop_n<1>
{
template <class F>
__host__ __device__ void operator()(F f) const
{
f(Number<0>{});
}
};
template <unsigned NLoop>
struct static_const_reduce_n
{
template <class F, class Reduce>
__host__ __device__ constexpr auto operator()(F f, Reduce r) const
{
static_assert(NLoop > 1, "out-of-range");
constexpr auto a = f(Number<NLoop - 1>{});
auto b = static_const_reduce_n<NLoop - 1>{}(f, r); // cannot use constexpr here, weird
return r(a, b);
}
};
template <>
struct static_const_reduce_n<1>
{
template <class F, class Reduce>
__host__ __device__ constexpr auto operator()(F f, Reduce) const
{
return f(Number<0>{});
}
};
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_direct_convolution.hip.hpp"
#include "threadwise_4d_tensor_op.hip.hpp"
#include "threadwise_direct_convolution.hip.hpp"
template <class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
unsigned NPerBlock,
unsigned KPerBlock,
unsigned CPerBlock,
unsigned HoPerBlock,
unsigned WoPerBlock,
unsigned NPerThread,
unsigned KPerThread,
unsigned CPerThread,
unsigned HoPerThread,
unsigned WoPerThread,
unsigned BlockSize,
unsigned GridSize>
__global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_global_desc = InGlobalDesc{};
constexpr auto wei_global_desc = WeiGlobalDesc{};
constexpr auto out_global_desc = OutGlobalDesc{};
constexpr unsigned Y = wei_global_desc.GetLength(I2);
constexpr unsigned X = wei_global_desc.GetLength(I3);
constexpr unsigned HiPerBlock = HoPerBlock + Y - 1;
constexpr unsigned WiPerBlock = WoPerBlock + X - 1;
constexpr auto in_block_desc =
make_ConstantTensorDescriptor(Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{});
constexpr auto wei_block_desc =
make_ConstantTensorDescriptor(Sequence<KPerBlock, CPerBlock, Y, X>{});
// shared mem
constexpr unsigned in_block_size = in_block_desc.GetElementSpace();
constexpr unsigned wei_block_size = wei_block_desc.GetElementSpace();
__shared__ Float p_in_block[in_block_size];
__shared__ Float p_wei_block[wei_block_size];
// threadwise tensors
constexpr unsigned HiPerThread = HoPerThread + Y - 1;
constexpr unsigned WiPerThread = WoPerThread + X - 1;
constexpr auto in_thread_block_desc = make_ConstantTensorDescriptor(
Sequence<NPerThread, CPerThread, HiPerThread, WiPerThread>{}, in_block_desc.GetStrides());
constexpr auto wei_thread_block_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread, CPerThread, Y, X>{}, wei_block_desc.GetStrides());
constexpr auto out_thread_desc = get_convolution_output_default_4d_tensor_descriptor(
in_thread_block_desc, wei_thread_block_desc);
// register
Float p_out_thread[out_thread_desc.GetElementSpace()];
// divide block work
constexpr unsigned NBlockWork = (out_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
constexpr unsigned KBlockWork = (out_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
constexpr unsigned HBlockWork = (out_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock;
constexpr unsigned WBlockWork = (out_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;
const unsigned block_id = blockIdx.x;
unsigned itmp = block_id;
const unsigned n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork);
const unsigned k_block_work_id = itmp / (HBlockWork * WBlockWork);
itmp -= k_block_work_id * (HBlockWork * WBlockWork);
const unsigned h_block_work_id = itmp / WBlockWork;
const unsigned w_block_work_id = itmp - h_block_work_id * WBlockWork;
const unsigned n_block_data_begin = n_block_work_id * NPerBlock;
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock;
const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock;
const unsigned hi_block_data_begin = ho_block_data_begin; // minus padding
const unsigned wi_block_data_begin = wo_block_data_begin; // minus padding
// divide thread work
constexpr unsigned NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
constexpr unsigned KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread;
constexpr unsigned HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread;
constexpr unsigned WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread;
const unsigned thread_id = threadIdx.x;
itmp = thread_id;
const unsigned n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork);
itmp -= n_thread_work_id * (KThreadWork * HThreadWork * WThreadWork);
const unsigned k_thread_work_id = itmp / (HThreadWork * WThreadWork);
itmp -= k_thread_work_id * (HThreadWork * WThreadWork);
const unsigned h_thread_work_id = itmp / WThreadWork;
const unsigned w_thread_work_id = itmp - h_thread_work_id * WThreadWork;
const unsigned n_thread_data_begin = n_thread_work_id * NPerThread;
const unsigned k_thread_data_begin = k_thread_work_id * KPerThread;
const unsigned ho_thread_data_begin = h_thread_work_id * HoPerThread;
const unsigned wo_thread_data_begin = w_thread_work_id * WoPerThread;
const unsigned hi_thread_data_begin = ho_thread_data_begin;
const unsigned wi_thread_data_begin = wo_thread_data_begin;
constexpr auto blockwise_in_copy =
Blockwise4dTensorCopy1<BlockSize,
Float,
decltype(in_global_desc),
decltype(in_block_desc),
decltype(in_block_desc.GetLengths())>{};
constexpr auto blockwise_wei_copy =
Blockwise4dTensorCopy1<BlockSize,
Float,
decltype(wei_global_desc),
decltype(wei_block_desc),
decltype(wei_block_desc.GetLengths())>{};
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_thread_desc, p_out_thread);
for(unsigned c_block_data_begin = 0; c_block_data_begin < in_global_desc.GetLength(I1);
c_block_data_begin += CPerBlock, __syncthreads())
{
// copy input tensor to LDS
blockwise_in_copy.Run(p_in_global + in_global_desc.Get1dIndex(n_block_data_begin,
c_block_data_begin,
hi_block_data_begin,
wi_block_data_begin),
p_in_block);
// copy weight tensor to LDS
blockwise_wei_copy.Run(
p_wei_global + wei_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0),
p_wei_block);
__syncthreads();
for(unsigned c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread)
{
// threadwise convolution
#if 1
threadwise_direct_convolution_2(
in_thread_block_desc,
p_in_block + in_block_desc.Get1dIndex(n_thread_data_begin,
c_thread_data,
hi_thread_data_begin,
wi_thread_data_begin),
wei_thread_block_desc,
p_wei_block + wei_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0),
out_thread_desc,
p_out_thread);
#elif 0
threadwise_direct_convolution_3(
in_thread_block_desc,
p_in_block + in_block_desc.Get1dIndex(n_thread_data_begin,
c_thread_data,
hi_thread_data_begin,
wi_thread_data_begin),
wei_thread_block_desc,
p_wei_block + wei_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0),
out_thread_desc,
p_out_thread);
#endif
}
}
// copy output tensor from register to global mem
threadwise_4d_tensor_copy(
out_thread_desc,
p_out_thread,
out_global_desc,
p_out_global + out_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin),
out_thread_desc.GetLengths());
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment