Commit 02d72160 authored by Chao Liu's avatar Chao Liu
Browse files

adding int8 direct that reads pre-vectorized data

parent 050a1a68
...@@ -133,14 +133,14 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc, ...@@ -133,14 +133,14 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
// 3x3, 34x34, 128 thread, int8, vector = 4 // 3x3, 34x34, 128 thread, int8, vector = 4
constexpr unsigned NPerBlock = 2; constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32; constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 4; constexpr unsigned CPerBlock = 8;
constexpr unsigned HoPerBlock = 2; constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 32; constexpr unsigned WoPerBlock = 32;
constexpr unsigned NPerThread = 2; constexpr unsigned NPerThread = 1;
constexpr unsigned KPerThread = 4; constexpr unsigned KPerThread = 8;
constexpr unsigned CPerThread = 1; constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 2; constexpr unsigned HoPerThread = 4;
constexpr unsigned WoPerThread = 2; constexpr unsigned WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2; constexpr unsigned InBlockCopyDataPerRead = 2;
...@@ -149,16 +149,16 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc, ...@@ -149,16 +149,16 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
constexpr unsigned BlockSize = 128; constexpr unsigned BlockSize = 128;
#elif 1 #elif 1
// 1x1, 32x32, 128 thread, int8, vector = 4 // 1x1, 32x32, 128 thread, int8, vector = 4
constexpr unsigned NPerBlock = 2; constexpr unsigned NPerBlock = 1;
constexpr unsigned KPerBlock = 32; constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4; constexpr unsigned CPerBlock = 16;
constexpr unsigned HoPerBlock = 2; constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 32; constexpr unsigned WoPerBlock = 32;
constexpr unsigned NPerThread = 2; constexpr unsigned NPerThread = 1;
constexpr unsigned KPerThread = 4; constexpr unsigned KPerThread = 8;
constexpr unsigned CPerThread = 1; constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 2; constexpr unsigned HoPerThread = 4;
constexpr unsigned WoPerThread = 2; constexpr unsigned WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2; constexpr unsigned InBlockCopyDataPerRead = 2;
......
...@@ -120,7 +120,7 @@ struct vector_type<char, 1> ...@@ -120,7 +120,7 @@ struct vector_type<char, 1>
template <> template <>
struct vector_type<char, 2> struct vector_type<char, 2>
{ {
using MemoryType = char2; using MemoryType = int16_t;
__host__ __device__ static MemoryType Pack(char s0, char s1) __host__ __device__ static MemoryType Pack(char s0, char s1)
{ {
...@@ -139,7 +139,7 @@ struct vector_type<char, 2> ...@@ -139,7 +139,7 @@ struct vector_type<char, 2>
template <> template <>
struct vector_type<char, 4> struct vector_type<char, 4>
{ {
using MemoryType = char4; using MemoryType = int32_t;
__host__ __device__ static MemoryType Pack(char s0, char s1, char s2, char s3) __host__ __device__ static MemoryType Pack(char s0, char s1, char s2, char s3)
{ {
...@@ -163,6 +163,13 @@ struct vector_type<char, 8> ...@@ -163,6 +163,13 @@ struct vector_type<char, 8>
using MemoryType = int64_t; using MemoryType = int64_t;
}; };
template <>
struct vector_type<int32_t, 2>
{
using MemoryType = int64_t;
};
#if 0
template <> template <>
struct vector_type<char2, 2> struct vector_type<char2, 2>
{ {
...@@ -176,33 +183,29 @@ struct vector_type<char2, 4> ...@@ -176,33 +183,29 @@ struct vector_type<char2, 4>
}; };
template <> template <>
struct vector_type<char4, 2> struct vector_type<char4, 1>
{ {
using MemoryType = int64_t; using MemoryType = int;
}; };
template <class TDst, class TSrc0, class TSrc1> template <>
__device__ void fused_multiply_accumulate(TDst& d, const TSrc0& s0, const TSrc1& s1) struct vector_type<char4, 2>
{ {
// static_assert(false, "should not call into base"); using MemoryType = int64_t;
printf("should not call into base"); };
assert(false); #endif
}
template <>
__device__ void fused_multiply_accumulate(float& d, const float& s0, const float& s1) __device__ void fused_multiply_accumulate(float& d, const float& s0, const float& s1)
{ {
d += s0 * s1; d += s0 * s1;
} }
template <>
__device__ void fused_multiply_accumulate(float& d, const float2& s0, const float2& s1) __device__ void fused_multiply_accumulate(float& d, const float2& s0, const float2& s1)
{ {
d += s0.x * s1.x; d += s0.x * s1.x;
d += s0.y * s1.y; d += s0.y * s1.y;
} }
template <>
__device__ void fused_multiply_accumulate(float& d, const float4& s0, const float4& s1) __device__ void fused_multiply_accumulate(float& d, const float4& s0, const float4& s1)
{ {
d += s0.x * s1.x; d += s0.x * s1.x;
...@@ -211,13 +214,8 @@ __device__ void fused_multiply_accumulate(float& d, const float4& s0, const floa ...@@ -211,13 +214,8 @@ __device__ void fused_multiply_accumulate(float& d, const float4& s0, const floa
d += s0.w * s1.w; d += s0.w * s1.w;
} }
template <> __device__ void fused_multiply_accumulate(half& d, const half& s0, const half& s1) { d += s0 * s1; }
__device__ void fused_multiply_accumulate(half& d, const half& s0, const half& s1)
{
d += s0 * s1;
}
template <>
__device__ void fused_multiply_accumulate(half& d, const half2& s0, const half2& s1) __device__ void fused_multiply_accumulate(half& d, const half2& s0, const half2& s1)
{ {
d += s0.x * s1.x; d += s0.x * s1.x;
...@@ -225,25 +223,25 @@ __device__ void fused_multiply_accumulate(half& d, const half2& s0, const half2& ...@@ -225,25 +223,25 @@ __device__ void fused_multiply_accumulate(half& d, const half2& s0, const half2&
} }
#if 0 #if 0
template <>
__device__ void fused_multiply_accumulate(float& d, const half2& s0, const half2& s1) __device__ void fused_multiply_accumulate(float& d, const half2& s0, const half2& s1)
{ {
d += s0.x * s1.x + s0.y * s1.y; d += s0.x * s1.x + s0.y * s1.y;
} }
#endif #endif
template <> __device__ void fused_multiply_accumulate(char& d, const char& s0, const char& s1) { d += s0 * s1; }
__device__ void fused_multiply_accumulate(char& d, const char& s0, const char& s1)
{
d += s0 * s1;
}
template <> // TODO:: this interface is misleading, int32 is actually int8x4
__device__ void fused_multiply_accumulate(int32_t& d, const char4& s0, const char4& s1) // need to make a better interface
__device__ void fused_multiply_accumulate(int32_t& d, const int32_t& s0, const int32_t& s1)
{ {
#if DEVICE_BACKEND_CUDA #if DEVICE_BACKEND_CUDA
#if 1 // debug
d = __dp4a(s0, s1, d); d = __dp4a(s0, s1, d);
#else #elif 1
d += s0.x * s1.x + s0.y * s1.y + s0.z * s1.z + s0.w * s1.w; asm volatile("dp4a.s32.s32 %0, %1, %2, %3;" : "=r"(d) : "r"(s0), "r"(s1), "r"(d));
#elif 0 // this is wrong! just for debugging
d += (*reinterpret_cast<const int32_t*>(&s0)) * (*reinterpret_cast<const int32_t*>(&s1));
#endif
#endif #endif
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment