Commit 050a1a68 authored by Chao Liu's avatar Chao Liu
Browse files

adding int8 direct that reads pre-vectorized data

parent 18ffbd68
...@@ -3,17 +3,18 @@ ...@@ -3,17 +3,18 @@
#include "device.hpp" #include "device.hpp"
#include "gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp" #include "gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp"
template <class T, class InDesc, class WeiDesc, class OutDesc> template <class TInWei, class TOut, class InDesc, class WeiDesc, class OutDesc>
void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc, void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& in_nchw, const Tensor<TInWei>& in_nchw,
WeiDesc, WeiDesc,
const Tensor<T>& wei_kcyx, const Tensor<TInWei>& wei_kcyx,
OutDesc, OutDesc,
Tensor<T>& out_nkhw, Tensor<TOut>& out_nkhw,
unsigned nrepeat) unsigned nrepeat)
{ {
constexpr unsigned NVector = 1; constexpr unsigned NVector = 4;
using vector_t = vector_type<T, NVector>; using accum_t = int32_t;
using vector_t = vector_type<TInWei, NVector>;
using vector_mem_t = typename vector_t::MemoryType; using vector_mem_t = typename vector_t::MemoryType;
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -44,11 +45,16 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc, ...@@ -44,11 +45,16 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
Tensor<vector_mem_t> in_nchw_vec(make_TensorDescriptor(in_nchw_vec_desc)); Tensor<vector_mem_t> in_nchw_vec(make_TensorDescriptor(in_nchw_vec_desc));
auto f_vectorized_nchw = [&](auto n, auto c, auto h, auto w) { auto f_vectorized_nchw = [&](auto n, auto c, auto h, auto w) {
#if 1 #if 0
in_nchw_vec(n, c, h, w) = in_nchw(n, c, h, w); in_nchw_vec(n, c, h, w) = in_nchw(n, c, h, w);
#else #elif 0
in_nchw_vec(n, c, h, w) = in_nchw_vec(n, c, h, w) =
vector_t::Pack(in_nchw(n, 2 * c, h, w), in_nchw(n, 2 * c + 1, h, w)); vector_t::Pack(in_nchw(n, 2 * c, h, w), in_nchw(n, 2 * c + 1, h, w));
#elif 1
in_nchw_vec(n, c, h, w) = vector_t::Pack(in_nchw(n, 4 * c, h, w),
in_nchw(n, 4 * c + 1, h, w),
in_nchw(n, 4 * c + 2, h, w),
in_nchw(n, 4 * c + 3, h, w));
#endif #endif
}; };
...@@ -62,11 +68,16 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc, ...@@ -62,11 +68,16 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
Tensor<vector_mem_t> wei_kcyx_vec(make_TensorDescriptor(wei_kcyx_vec_desc)); Tensor<vector_mem_t> wei_kcyx_vec(make_TensorDescriptor(wei_kcyx_vec_desc));
auto f_vectorized_kcyx = [&](auto k, auto c, auto y, auto x) { auto f_vectorized_kcyx = [&](auto k, auto c, auto y, auto x) {
#if 1 #if 0
wei_kcyx_vec(k, c, y, x) = wei_kcyx(k, c, y, x); wei_kcyx_vec(k, c, y, x) = wei_kcyx(k, c, y, x);
#else #elif 0
wei_kcyx_vec(k, c, y, x) = wei_kcyx_vec(k, c, y, x) =
vector_t::Pack(wei_kcyx(k, 2 * c, y, x), wei_kcyx(k, 2 * c + 1, y, x)); vector_t::Pack(wei_kcyx(k, 2 * c, y, x), wei_kcyx(k, 2 * c + 1, y, x));
#elif 1
wei_kcyx_vec(k, c, y, x) = vector_t::Pack(wei_kcyx(k, 4 * c, y, x),
wei_kcyx(k, 4 * c + 1, y, x),
wei_kcyx(k, 4 * c + 2, y, x),
wei_kcyx(k, 4 * c + 3, y, x));
#endif #endif
}; };
...@@ -76,13 +87,13 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc, ...@@ -76,13 +87,13 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
// //
DeviceMem in_nchw_vec_device_buf(sizeof(vector_mem_t) * in_nchw_vec.mDesc.GetElementSpace()); DeviceMem in_nchw_vec_device_buf(sizeof(vector_mem_t) * in_nchw_vec.mDesc.GetElementSpace());
DeviceMem wei_kcyx_vec_device_buf(sizeof(vector_mem_t) * wei_kcyx_vec.mDesc.GetElementSpace()); DeviceMem wei_kcyx_vec_device_buf(sizeof(vector_mem_t) * wei_kcyx_vec.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(sizeof(T) * out_nkhw.mDesc.GetElementSpace()); DeviceMem out_nkhw_device_buf(sizeof(TOut) * out_nkhw.mDesc.GetElementSpace());
in_nchw_vec_device_buf.ToDevice(in_nchw_vec.mData.data()); in_nchw_vec_device_buf.ToDevice(in_nchw_vec.mData.data());
wei_kcyx_vec_device_buf.ToDevice(wei_kcyx_vec.mData.data()); wei_kcyx_vec_device_buf.ToDevice(wei_kcyx_vec.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1 #if 0
// 3x3, 34x34, 128 thread, fp32, vector = 1 // 3x3, 34x34, 128 thread, fp32, vector = 1
constexpr unsigned NPerBlock = 2; constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32; constexpr unsigned KPerBlock = 32;
...@@ -100,7 +111,7 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc, ...@@ -100,7 +111,7 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
constexpr unsigned WeiBlockCopyDataPerRead = 2; constexpr unsigned WeiBlockCopyDataPerRead = 2;
constexpr unsigned BlockSize = 128; constexpr unsigned BlockSize = 128;
#elif 1 #elif 0
// 3x3, 34x34, 128 thread, fp32, vector = 2 // 3x3, 34x34, 128 thread, fp32, vector = 2
constexpr unsigned NPerBlock = 2; constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32; constexpr unsigned KPerBlock = 32;
...@@ -117,9 +128,27 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc, ...@@ -117,9 +128,27 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
constexpr unsigned InBlockCopyDataPerRead = 2; constexpr unsigned InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 2; constexpr unsigned WeiBlockCopyDataPerRead = 2;
constexpr unsigned BlockSize = 128;
#elif 0
// 3x3, 34x34, 128 thread, int8, vector = 4
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32;
constexpr unsigned NPerThread = 2;
constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 2;
constexpr unsigned WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 2;
constexpr unsigned BlockSize = 128; constexpr unsigned BlockSize = 128;
#elif 1 #elif 1
// 3x3, 34x34, 128 thread, fp16 // 1x1, 32x32, 128 thread, int8, vector = 4
constexpr unsigned NPerBlock = 2; constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32; constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 4; constexpr unsigned CPerBlock = 4;
...@@ -128,12 +157,12 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc, ...@@ -128,12 +157,12 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
constexpr unsigned NPerThread = 2; constexpr unsigned NPerThread = 2;
constexpr unsigned KPerThread = 4; constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 2; constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 2; constexpr unsigned HoPerThread = 2;
constexpr unsigned WoPerThread = 2; constexpr unsigned WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2; constexpr unsigned InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr unsigned WeiBlockCopyDataPerRead = 2;
constexpr unsigned BlockSize = 128; constexpr unsigned BlockSize = 128;
#endif #endif
...@@ -146,7 +175,9 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc, ...@@ -146,7 +175,9 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
for(unsigned i = 0; i < nrepeat; ++i) for(unsigned i = 0; i < nrepeat; ++i)
{ {
float time = launch_kernel( float time = launch_kernel(
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw<T, gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw<TInWei,
TOut,
accum_t,
decltype(in_nchw_vec_desc), decltype(in_nchw_vec_desc),
decltype(wei_kcyx_vec_desc), decltype(wei_kcyx_vec_desc),
decltype(out_nkhw_desc), decltype(out_nkhw_desc),
...@@ -167,9 +198,9 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc, ...@@ -167,9 +198,9 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
GridSize>, GridSize>,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
static_cast<T*>(in_nchw_vec_device_buf.GetDeviceBuffer()), static_cast<TInWei*>(in_nchw_vec_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_vec_device_buf.GetDeviceBuffer()), static_cast<TInWei*>(wei_kcyx_vec_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())); static_cast<TInWei*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms\n", time); printf("Elapsed time : %f ms\n", time);
usleep(std::min(time * 1000, float(10000))); usleep(std::min(time * 1000, float(10000)));
......
...@@ -88,9 +88,12 @@ auto make_TensorDescriptor(TConstTensorDesc) ...@@ -88,9 +88,12 @@ auto make_TensorDescriptor(TConstTensorDesc)
return TensorDescriptor(lengths, strides); return TensorDescriptor(lengths, strides);
} }
template <class T, class LowerPads, class UpperPads> template <class TIn, class TWei, class TOut, class LowerPads, class UpperPads>
void host_direct_convolution( void host_direct_convolution(const Tensor<TIn>& in_nchw,
const Tensor<T>& in_nchw, const Tensor<T>& wei_kcyx, Tensor<T>& out, LowerPads, UpperPads) const Tensor<TWei>& wei_kcyx,
Tensor<TOut>& out_nkhw,
LowerPads,
UpperPads)
{ {
unsigned h_pad_low = LowerPads{}.Get(Number<0>{}); unsigned h_pad_low = LowerPads{}.Get(Number<0>{});
unsigned w_pad_low = LowerPads{}.Get(Number<1>{}); unsigned w_pad_low = LowerPads{}.Get(Number<1>{});
...@@ -116,21 +119,24 @@ void host_direct_convolution( ...@@ -116,21 +119,24 @@ void host_direct_convolution(
} }
} }
} }
out(n, k, ho, wo) = v; out_nkhw(n, k, ho, wo) = v;
}; };
auto f_par = make_ParallelTensorFunctor(f, auto f_par = make_ParallelTensorFunctor(f,
out.mDesc.GetLengths()[0], out_nkhw.mDesc.GetLengths()[0],
out.mDesc.GetLengths()[1], out_nkhw.mDesc.GetLengths()[1],
out.mDesc.GetLengths()[2], out_nkhw.mDesc.GetLengths()[2],
out.mDesc.GetLengths()[3]); out_nkhw.mDesc.GetLengths()[3]);
f_par(std::thread::hardware_concurrency()); f_par(std::thread::hardware_concurrency());
} }
template <class T, class LowerPads, class UpperPads> template <class TIn, class TWei, class TOut, class LowerPads, class UpperPads>
void host_winograd_3x3_convolution( void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
const Tensor<T>& in_nchw, const Tensor<T>& wei_kcyx, Tensor<T>& out, LowerPads, UpperPads) const Tensor<TWei>& wei_kcyx,
Tensor<TOut>& out_nkhw,
LowerPads,
UpperPads)
{ {
constexpr std::size_t HoPerTile = 2; constexpr std::size_t HoPerTile = 2;
constexpr std::size_t WoPerTile = 2; constexpr std::size_t WoPerTile = 2;
...@@ -144,8 +150,8 @@ void host_winograd_3x3_convolution( ...@@ -144,8 +150,8 @@ void host_winograd_3x3_convolution(
std::size_t Y = wei_kcyx.mDesc.GetLengths()[2]; std::size_t Y = wei_kcyx.mDesc.GetLengths()[2];
std::size_t X = wei_kcyx.mDesc.GetLengths()[3]; std::size_t X = wei_kcyx.mDesc.GetLengths()[3];
std::size_t HO = out.mDesc.GetLengths()[2]; std::size_t HO = out_nkhw.mDesc.GetLengths()[2];
std::size_t WO = out.mDesc.GetLengths()[3]; std::size_t WO = out_nkhw.mDesc.GetLengths()[3];
unsigned h_pad_low = LowerPads{}.Get(Number<0>{}); unsigned h_pad_low = LowerPads{}.Get(Number<0>{});
unsigned w_pad_low = LowerPads{}.Get(Number<1>{}); unsigned w_pad_low = LowerPads{}.Get(Number<1>{});
...@@ -180,7 +186,7 @@ void host_winograd_3x3_convolution( ...@@ -180,7 +186,7 @@ void host_winograd_3x3_convolution(
} }
else else
{ {
in_hold(n, c, htile, wtile, j, i) = T(0); in_hold(n, c, htile, wtile, j, i) = TIn(0);
} }
} }
} }
...@@ -347,8 +353,8 @@ void host_winograd_3x3_convolution( ...@@ -347,8 +353,8 @@ void host_winograd_3x3_convolution(
std::size_t ho = HoPerTile * htile + j; std::size_t ho = HoPerTile * htile + j;
for(int i = 0; i < WoPerTile; ++i) for(int i = 0; i < WoPerTile; ++i)
{ {
std::size_t wo = WoPerTile * wtile + i; std::size_t wo = WoPerTile * wtile + i;
out(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i); out_nkhw(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i);
} }
} }
}; };
...@@ -403,7 +409,7 @@ int main(int argc, char* argv[]) ...@@ -403,7 +409,7 @@ int main(int argc, char* argv[])
constexpr unsigned HPad = 0; constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0; constexpr unsigned WPad = 0;
#elif 1 #elif 0
// 3x3, 34x34 // 3x3, 34x34
constexpr unsigned N = 64; constexpr unsigned N = 64;
constexpr unsigned C = 256; constexpr unsigned C = 256;
...@@ -502,7 +508,7 @@ int main(int argc, char* argv[]) ...@@ -502,7 +508,7 @@ int main(int argc, char* argv[])
constexpr unsigned HPad = 1; constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1; constexpr unsigned WPad = 1;
#elif 1 #elif 0
// 1x1 filter, 28x28 image // 1x1 filter, 28x28 image
constexpr unsigned N = 16; constexpr unsigned N = 16;
constexpr unsigned C = 256; constexpr unsigned C = 256;
...@@ -562,6 +568,18 @@ int main(int argc, char* argv[]) ...@@ -562,6 +568,18 @@ int main(int argc, char* argv[])
constexpr unsigned HPad = 2; constexpr unsigned HPad = 2;
constexpr unsigned WPad = 2; constexpr unsigned WPad = 2;
#elif 1
// 1x1 filter, 32x32 image
constexpr unsigned N = 64;
constexpr unsigned C = 256;
constexpr unsigned HI = 32;
constexpr unsigned WI = 32;
constexpr unsigned K = 512;
constexpr unsigned Y = 1;
constexpr unsigned X = 1;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
#endif #endif
auto lower_pads = Sequence<HPad, WPad>{}; auto lower_pads = Sequence<HPad, WPad>{};
...@@ -576,11 +594,12 @@ int main(int argc, char* argv[]) ...@@ -576,11 +594,12 @@ int main(int argc, char* argv[])
ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: "); ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
ostream_ConstantTensorDescriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: "); ostream_ConstantTensorDescriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: ");
using Float = float; using in_data_t = char;
Tensor<Float> in_nchw(make_TensorDescriptor(in_nchw_desc)); using out_data_t = int32_t;
Tensor<Float> wei_kcyx(make_TensorDescriptor(wei_kcyx_desc)); Tensor<in_data_t> in_nchw(make_TensorDescriptor(in_nchw_desc));
Tensor<Float> out_nkhw_host(make_TensorDescriptor(out_nkhw_desc)); Tensor<in_data_t> wei_kcyx(make_TensorDescriptor(wei_kcyx_desc));
Tensor<Float> out_nkhw_device(make_TensorDescriptor(out_nkhw_desc)); Tensor<out_data_t> out_nkhw_host(make_TensorDescriptor(out_nkhw_desc));
Tensor<out_data_t> out_nkhw_device(make_TensorDescriptor(out_nkhw_desc));
std::size_t num_thread = std::thread::hardware_concurrency(); std::size_t num_thread = std::thread::hardware_concurrency();
......
...@@ -10,16 +10,6 @@ namespace CUDA { ...@@ -10,16 +10,6 @@ namespace CUDA {
using half = CUDA::half; using half = CUDA::half;
using half2 = CUDA::half2; using half2 = CUDA::half2;
struct half4
{
half data[4];
};
struct half8
{
half data[8];
};
template <class T, unsigned N> template <class T, unsigned N>
struct vector_type struct vector_type
{ {
...@@ -119,39 +109,141 @@ struct vector_type<half2, 4> ...@@ -119,39 +109,141 @@ struct vector_type<half2, 4>
using MemoryType = float4; using MemoryType = float4;
}; };
template <class TDst, class TSrc0, class TSrc1, class TSrc2> template <>
__device__ void fused_multiply_add(TDst& d, TSrc0 s0, TSrc1 s1, TSrc2 s2) struct vector_type<char, 1>
{ {
using MemoryType = char;
__host__ __device__ static MemoryType Pack(char s) { return s; }
};
template <>
struct vector_type<char, 2>
{
using MemoryType = char2;
__host__ __device__ static MemoryType Pack(char s0, char s1)
{
union
{
MemoryType vector;
char scalar[2];
} data;
data.scalar[0] = s0;
data.scalar[1] = s1;
return data.vector;
}
};
template <>
struct vector_type<char, 4>
{
using MemoryType = char4;
__host__ __device__ static MemoryType Pack(char s0, char s1, char s2, char s3)
{
union
{
MemoryType vector;
char scalar[4];
} data;
data.scalar[0] = s0;
data.scalar[1] = s1;
data.scalar[2] = s2;
data.scalar[3] = s3;
return data.vector;
}
};
template <>
struct vector_type<char, 8>
{
using MemoryType = int64_t;
};
template <>
struct vector_type<char2, 2>
{
using MemoryType = char4;
};
template <>
struct vector_type<char2, 4>
{
using MemoryType = int64_t;
};
template <>
struct vector_type<char4, 2>
{
using MemoryType = int64_t;
};
template <class TDst, class TSrc0, class TSrc1>
__device__ void fused_multiply_accumulate(TDst& d, const TSrc0& s0, const TSrc1& s1)
{
// static_assert(false, "should not call into base");
printf("should not call into base"); printf("should not call into base");
assert(false); assert(false);
} }
template <> template <>
__device__ void fused_multiply_add(float& d, float s0, float s1, float s2) __device__ void fused_multiply_accumulate(float& d, const float& s0, const float& s1)
{
d += s0 * s1;
}
template <>
__device__ void fused_multiply_accumulate(float& d, const float2& s0, const float2& s1)
{
d += s0.x * s1.x;
d += s0.y * s1.y;
}
template <>
__device__ void fused_multiply_accumulate(float& d, const float4& s0, const float4& s1)
{ {
d = s0 * s1 + s2; d += s0.x * s1.x;
d += s0.y * s1.y;
d += s0.z * s1.z;
d += s0.w * s1.w;
} }
template <> template <>
__device__ void fused_multiply_add(float& d, float2 s0, float2 s1, float s2) __device__ void fused_multiply_accumulate(half& d, const half& s0, const half& s1)
{ {
d = s0.x * s1.x + s0.y * s1.y + s2; d += s0 * s1;
} }
template <> template <>
__device__ void fused_multiply_add(float& d, float4 s0, float4 s1, float s2) __device__ void fused_multiply_accumulate(half& d, const half2& s0, const half2& s1)
{ {
d = s0.x * s1.x + s0.y * s1.y + s0.z * s1.z + s0.w * s1.w + s2; d += s0.x * s1.x;
d += s0.y * s1.y;
} }
#if 0
template <> template <>
__device__ void fused_multiply_add(half& d, half s0, half s1, half s2) __device__ void fused_multiply_accumulate(float& d, const half2& s0, const half2& s1)
{ {
d = s0 * s1 + s2; d += s0.x * s1.x + s0.y * s1.y;
} }
#endif
template <> template <>
__device__ void fused_multiply_add(half& d, half2 s0, half2 s1, half s2) __device__ void fused_multiply_accumulate(char& d, const char& s0, const char& s1)
{ {
d = s0.x * s1.x + s0.y * s1.y + s2; d += s0 * s1;
} }
\ No newline at end of file
template <>
__device__ void fused_multiply_accumulate(int32_t& d, const char4& s0, const char4& s1)
{
#if DEVICE_BACKEND_CUDA
d = __dp4a(s0, s1, d);
#else
d += s0.x * s1.x + s0.y * s1.y + s0.z * s1.z + s0.w * s1.w;
#endif
}
...@@ -7,7 +7,9 @@ ...@@ -7,7 +7,9 @@
#include "threadwise_4d_tensor_op.hip.hpp" #include "threadwise_4d_tensor_op.hip.hpp"
#include "threadwise_direct_convolution.hip.hpp" #include "threadwise_direct_convolution.hip.hpp"
template <class Float, template <class TInWei,
class TOut,
class TAccum,
class InGlobalDesc, class InGlobalDesc,
class WeiGlobalDesc, class WeiGlobalDesc,
class OutGlobalDesc, class OutGlobalDesc,
...@@ -27,14 +29,16 @@ template <class Float, ...@@ -27,14 +29,16 @@ template <class Float,
unsigned BlockSize, unsigned BlockSize,
unsigned GridSize> unsigned GridSize>
__global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
const typename vector_type<Float, const typename vector_type<TInWei,
ScalarPerVector>::MemoryType* const __restrict__ p_in_vec_global, ScalarPerVector>::MemoryType* const __restrict__ p_in_vec_global,
const typename vector_type<Float, const typename vector_type<TInWei,
ScalarPerVector>::MemoryType* const __restrict__ p_wei_vec_global, ScalarPerVector>::MemoryType* const __restrict__ p_wei_vec_global,
Float* const __restrict__ p_out_global) TOut* const __restrict__ p_out_global)
{ {
using scalar_t = Float; using in_scalar_t = TInWei;
using vector_mem_t = typename vector_type<scalar_t, ScalarPerVector>::MemoryType; using in_vector_mem_t = typename vector_type<in_scalar_t, ScalarPerVector>::MemoryType;
using out_scalar_t = TOut;
using accum_t = TAccum;
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -79,9 +83,9 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -79,9 +83,9 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
? InBlockCopyDataPerRead ? InBlockCopyDataPerRead
: WeiBlockCopyDataPerRead; : WeiBlockCopyDataPerRead;
__shared__ vector_mem_t __shared__ in_vector_mem_t
p_in_vec_block[max_align * ((in_block_size + max_align - 1) / max_align)]; p_in_vec_block[max_align * ((in_block_size + max_align - 1) / max_align)];
__shared__ vector_mem_t __shared__ in_vector_mem_t
p_wei_vec_block[max_align * ((wei_block_size + max_align - 1) / max_align)]; p_wei_vec_block[max_align * ((wei_block_size + max_align - 1) / max_align)];
// threadwise tensors // threadwise tensors
...@@ -99,7 +103,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -99,7 +103,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
in_nchw_vec_thread_block_desc, wei_kcyx_vec_thread_block_desc); in_nchw_vec_thread_block_desc, wei_kcyx_vec_thread_block_desc);
// register // register
scalar_t p_out_thread[out_nkhw_thread_desc.GetElementSpace()]; out_scalar_t p_out_thread[out_nkhw_thread_desc.GetElementSpace()];
// divide block work // divide block work
constexpr unsigned NBlockWork = constexpr unsigned NBlockWork =
...@@ -155,7 +159,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -155,7 +159,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
constexpr auto blockwise_in_copy = constexpr auto blockwise_in_copy =
Blockwise4dTensorCopy1<BlockSize, Blockwise4dTensorCopy1<BlockSize,
vector_mem_t, in_vector_mem_t,
decltype(in_nchw_vec_global_desc), decltype(in_nchw_vec_global_desc),
decltype(in_nchw_vec_block_desc), decltype(in_nchw_vec_block_desc),
decltype(in_nchw_vec_block_desc.GetLengths()), decltype(in_nchw_vec_block_desc.GetLengths()),
...@@ -164,7 +168,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -164,7 +168,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
#if 0 #if 0
constexpr auto blockwise_wei_copy = constexpr auto blockwise_wei_copy =
Blockwise4dTensorCopy1<BlockSize, Blockwise4dTensorCopy1<BlockSize,
vector_mem_t, in_vector_mem_t,
decltype(wei_kcyx_vec_global_desc), decltype(wei_kcyx_vec_global_desc),
decltype(wei_kcyx_vec_block_desc), decltype(wei_kcyx_vec_block_desc),
decltype(wei_kcyx_vec_block_desc.GetLengths()), decltype(wei_kcyx_vec_block_desc.GetLengths()),
...@@ -172,15 +176,17 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -172,15 +176,17 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
#elif 1 #elif 1
const auto blockwise_wei_copy = const auto blockwise_wei_copy =
Blockwise2dTensorCopy3<BlockSize, Blockwise2dTensorCopy3<BlockSize,
vector_mem_t, in_vector_mem_t,
decltype(wei_ke_vec_global_desc), decltype(wei_ke_vec_global_desc),
decltype(wei_ke_vec_block_desc), decltype(wei_ke_vec_block_desc),
decltype(wei_ke_vec_block_desc.GetLengths()), decltype(wei_ke_vec_block_desc.GetLengths()),
WeiBlockCopyDataPerRead>{}; WeiBlockCopyDataPerRead>{};
#endif #endif
#if 1 // debug
// set threadwise output tensor to 0 // set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_nkhw_thread_desc, p_out_thread); threadwise_4d_tensor_set_zero(out_nkhw_thread_desc, p_out_thread);
#endif
for(unsigned c_block_data_begin = 0; c_block_data_begin < C; for(unsigned c_block_data_begin = 0; c_block_data_begin < C;
c_block_data_begin += CPerBlock, __syncthreads()) c_block_data_begin += CPerBlock, __syncthreads())
......
...@@ -37,7 +37,8 @@ __device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, Float* __re ...@@ -37,7 +37,8 @@ __device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, Float* __re
// TODO: in order to optimize mem access for different mem type, // TODO: in order to optimize mem access for different mem type,
// need to write specialized version // need to write specialized version
template <class Float, template <class SrcData,
class DstData,
class SrcDesc, class SrcDesc,
class DstDesc, class DstDesc,
class SrcOpLengths, class SrcOpLengths,
...@@ -45,9 +46,9 @@ template <class Float, ...@@ -45,9 +46,9 @@ template <class Float,
class F> class F>
__device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src( __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
SrcDesc, SrcDesc,
const Float* __restrict__ p_src, const SrcData* __restrict__ p_src,
DstDesc, DstDesc,
Float* __restrict__ p_dst, DstData* __restrict__ p_dst,
SrcOpLengths, SrcOpLengths,
DstFromSrcReorder, DstFromSrcReorder,
F f) F f)
...@@ -88,33 +89,38 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_d ...@@ -88,33 +89,38 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_d
} }
} }
template <class Float, class Desc> template <class Data, class Desc>
__device__ void threadwise_4d_tensor_set_zero(Desc, Float* __restrict__ p) __device__ void threadwise_4d_tensor_set_zero(Desc, Data* __restrict__ p)
{ {
auto f_set_zero = [](Float& v) { v = Float(0); }; auto f_set_zero = [](Data& v) { v = Data(0); };
threadwise_4d_tensor_pointwise_operation_unary<Float, Desc, decltype(f_set_zero)>( threadwise_4d_tensor_pointwise_operation_unary<Data, Desc, decltype(f_set_zero)>(
Desc{}, p, f_set_zero); Desc{}, p, f_set_zero);
} }
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, class DstFromSrcReorder> template <class SrcData,
class DstData,
class SrcDesc,
class DstDesc,
class SrcOpLengths,
class DstFromSrcReorder>
__device__ void __device__ void
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc, threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
const Float* __restrict__ p_src, const SrcData* __restrict__ p_src,
DstDesc, DstDesc,
Float* __restrict__ p_dst, DstData* __restrict__ p_dst,
SrcOpLengths, SrcOpLengths,
DstFromSrcReorder) DstFromSrcReorder)
{ {
auto f_copy = [](const Float& src, Float& dst) { dst = src; }; auto f_copy = [](const SrcData& src, DstData& dst) { dst = static_cast<DstData>(src); };
threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src( threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy); SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
} }
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths> template <class SrcData, class DstData, class SrcDesc, class DstDesc, class SrcOpLengths>
__device__ void threadwise_4d_tensor_copy( __device__ void threadwise_4d_tensor_copy(
SrcDesc, const Float* __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths) SrcDesc, const SrcData* __restrict__ p_src, DstDesc, DstData* __restrict__ p_dst, SrcOpLengths)
{ {
auto dst_from_src_reorder = Sequence<0, 1, 2, 3>{}; auto dst_from_src_reorder = Sequence<0, 1, 2, 3>{};
......
...@@ -51,10 +51,8 @@ __device__ void threadwise_direct_convolution_1(InDesc, ...@@ -51,10 +51,8 @@ __device__ void threadwise_direct_convolution_1(InDesc,
const unsigned out_index = out_desc.Get1dIndex(n, k, ho, wo); const unsigned out_index = out_desc.Get1dIndex(n, k, ho, wo);
fused_multiply_add(p_out[out_index], fused_multiply_accumulate(
p_wei[wei_index], p_out[out_index], p_wei[wei_index], p_in[in_index]);
p_in[in_index],
p_out[out_index]);
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment