Commit d85b89bb authored by Tejash Shah's avatar Tejash Shah
Browse files

Used half4 datatype for blockwise gemm in place of half datatype

parent 2185affb
......@@ -6,7 +6,7 @@
#include "threadwise_gemm.hpp"
#ifndef CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
#define CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM 1
#define CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM 0
#endif
namespace ck {
......@@ -136,9 +136,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr auto b_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
......@@ -153,6 +150,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
std::is_same<FloatB, float>::value>{}([&](auto) {
using Float4 = vector_type<float, 4>::MemoryType;
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
Float4* reg_a = reinterpret_cast<Float4*>(p_a_thread);
Float4* reg_b = reinterpret_cast<Float4*>(p_b_thread);
Float4* reg_c = reinterpret_cast<Float4*>(p_c_thread);
......@@ -183,33 +183,39 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
}).Else([&](auto) { // If A and B datatype is bfloat16/float16
using Half4x4 = vector_type<vector_type<half, 4>, 4>;
FloatA p_a_thread[a_thread_mtx.GetElementSpace() * 4];
FloatB p_b_thread[b_thread_mtx.GetElementSpace() * 4];
using Half4x4 = vector_type<vector_type<half, 4>::MemoryType, 4>::MemoryType;
using Float4 = vector_type<float, 4>::MemoryType;
Half4x4* reg_a = reinterpret_cast<Half4x4*>(p_a_thread);
Half4x4* reg_b = reinterpret_cast<Half4x4*>(p_b_thread);
Float4* reg_c = reinterpret_cast<Float4*>(p_c_thread);
reg_a[0] = *reinterpret_cast<const Half4x4*>(&p_a_block[mMyThreadOffsetA]);
reg_b[0] = *reinterpret_cast<const Half4x4*>(&p_b_block[mMyThreadOffsetB]);
reg_b[1] =
*reinterpret_cast<const Half4x4*>(&p_b_block[mMyThreadOffsetB + NPerLevel1Cluster]);
reg_a[1] =
*reinterpret_cast<const Half4x4*>(&p_a_block[mMyThreadOffsetA + MPerLevel1Cluster]);
reg_a[0] = *reinterpret_cast<const Half4x4*>(&p_a_block[mMyThreadOffsetA * 4]);
reg_b[0] = *reinterpret_cast<const Half4x4*>(&p_b_block[mMyThreadOffsetB * 4]);
reg_b[1] = *reinterpret_cast<const Half4x4*>(
&p_b_block[(mMyThreadOffsetB + NPerLevel1Cluster) * 4]);
reg_a[1] = *reinterpret_cast<const Half4x4*>(
&p_a_block[(mMyThreadOffsetA + MPerLevel1Cluster) * 4]);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
#pragma unroll
for(index_t k = 1; k < K; ++k)
{
reg_a[0] = *reinterpret_cast<const Half4x4*>(&p_a_block[mMyThreadOffsetA + k * M]);
reg_a[0] =
*reinterpret_cast<const Half4x4*>(&p_a_block[(mMyThreadOffsetA + k * M) * 4]);
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
reg_b[0] = *reinterpret_cast<const Half4x4*>(&p_b_block[mMyThreadOffsetB + k * N]);
reg_b[0] =
*reinterpret_cast<const Half4x4*>(&p_b_block[(mMyThreadOffsetB + k * N) * 4]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
reg_b[1] = *reinterpret_cast<const Half4x4*>(
&p_b_block[mMyThreadOffsetB + k * N + NPerLevel1Cluster]);
&p_b_block[(mMyThreadOffsetB + k * N + NPerLevel1Cluster) * 4]);
reg_a[1] = *reinterpret_cast<const Half4x4*>(
&p_a_block[mMyThreadOffsetA + k * M + MPerLevel1Cluster]);
&p_a_block[(mMyThreadOffsetA + k * M + MPerLevel1Cluster) * 4]);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
}
......@@ -447,6 +453,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
FloatC* __restrict__ p_c_thread) const
{
// The assembly path doesn't support bfloat16 using asm instructions
#if CK_USE_AMD_INLINE_ASM && CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
static_if<std::is_same<FloatA, ushort>::value && std::is_same<FloatB, ushort>::value>{}(
[&](auto) { Run_source(p_a_block, p_b_block, p_c_thread); })
......@@ -454,7 +462,25 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
Run_amd_asm(p_a_block, p_b_block, p_c_thread);
});
#else
static_if<std::is_same<FloatA, half>::value &&
std::is_same<FloatB, half>::value>{}([&](auto) {
// Vectorize the pointer to match with how half/bfloat16 datatypes are
// processed in gemm operation. Half type packs 4 half values while
// bfloat16 packs 2 bfloat16 values. Since gemm's matrix A and B
// 2D indexes are computed with a single value in mind (e.g. float),
// to retain the same 2D indexes for half/bfloat16, we recast datatype
// from a single half to 4 packed half/2 packed bfloat16 respectively.
const vector_type<half, 4>::MemoryType* p_a_block_vec =
reinterpret_cast<const vector_type<half, 4>::MemoryType*>(p_a_block);
const vector_type<half, 4>::MemoryType* p_b_block_vec =
reinterpret_cast<const vector_type<half, 4>::MemoryType*>(p_b_block);
Run_source(p_a_block_vec, p_b_block_vec, p_c_thread);
}).Else([&](auto) { // If A and B datatype is bfloat16/float16
Run_source(p_a_block, p_b_block, p_c_thread);
});
#endif
}
};
......
......@@ -56,11 +56,7 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
}
}
}).Else([&](auto) {
static_if<std::is_same<Float, half>::value>{}([&](auto) {
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4 respectively)
using vector_t = typename vector_type<Float, 4>::MemoryType;
}).Else([&](auto) { // fp16/bfp16
for(index_t i = 0; i < NRow; ++i)
{
for(index_t j = 0; j < NCol; ++j)
......@@ -68,26 +64,10 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
const index_t src_index = src_mtx.GetOffsetFromMultiIndex(i, j);
const index_t dst_index = dst_mtx.GetOffsetFromMultiIndex(i, j);
*reinterpret_cast<vector_t*>(&p_dst[dst_index*4]) =
*reinterpret_cast<const vector_t*>(&p_src[src_index*4]);
*reinterpret_cast<Float*>(&p_dst[dst_index]) =
*reinterpret_cast<const Float*>(&p_src[src_index]);
}
}
}).Else([&](auto) {
using vector_t = typename vector_type<Float, 2>::MemoryType;
for(index_t i = 0; i < NRow; ++i)
{
for(index_t j = 0; j < NCol; ++j)
{
const index_t src_index = src_mtx.GetOffsetFromMultiIndex(i, j);
const index_t dst_index = dst_mtx.GetOffsetFromMultiIndex(i, j);
*reinterpret_cast<vector_t*>(&p_dst[dst_index*2]) =
*reinterpret_cast<const vector_t*>(&p_src[src_index*2]);
}
}
});
});
}
......@@ -129,32 +109,35 @@ __device__ void threadwise_gemm(MatrixA,
const index_t bindex = b_mtx.GetOffsetFromMultiIndex(k, j);
const index_t cindex = c_mtx.GetOffsetFromMultiIndex(i, j);
static_if<std::is_same<FloatA, float>::value>{}([&](auto) {
p_c_thread[cindex] += CVT_FLOAT2ACCUM(p_a_thread[aindex]) *
CVT_FLOAT2ACCUM(p_b_thread[bindex]);
}).Else([&](auto) {
static_if<std::is_same<FloatA, half>::value>{}([&](auto) {
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4
// respectively)
#if MIOPEN_USE_FP32 == 1
p_c_thread[cindex] +=
CVT_FLOAT2ACCUM(p_a_thread[aindex]) * CVT_FLOAT2ACCUM(p_b_thread[bindex]);
#elif MIOPEN_USE_FP16 == 1
const half* p_a_thread_half =
reinterpret_cast<const half*>(&p_a_thread[aindex]);
const half* p_b_thread_half =
reinterpret_cast<const half*>(&p_b_thread[bindex]);
float acc = 0.0;
for(index_t v = 0; v < 4; ++v)
{
acc += CVT_FLOAT2ACCUM(p_a_thread[aindex*4 + v]) *
CVT_FLOAT2ACCUM(p_b_thread[bindex*4 + v]);
acc += CVT_FLOAT2ACCUM(p_a_thread_half[v]) *
CVT_FLOAT2ACCUM(p_b_thread_half[v]);
}
p_c_thread[cindex] = acc;
}).Else([&](auto) {
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4
// respectively)
p_c_thread[cindex] += acc;
#elif MIOPEN_USE_BF16 == 1
const ushort* p_a_thread_ushort =
reinterpret_cast<const ushort*>(&p_a_thread[aindex]);
const ushort* p_b_thread_ushort =
reinterpret_cast<const ushort*>(&p_b_thread[bindex]);
float acc = 0.0;
for(index_t v = 0; v < 2; ++v)
{
acc += CVT_FLOAT2ACCUM(p_a_thread[aindex*2 + v]) *
CVT_FLOAT2ACCUM(p_b_thread[bindex*2 + v]);
acc += CVT_FLOAT2ACCUM(p_a_thread_ushort[v]) *
CVT_FLOAT2ACCUM(p_b_thread_ushort[v]);
}
p_c_thread[cindex] += acc;
});
});
#else
#endif
}
}
}
......
......@@ -112,7 +112,6 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
static_if<std::is_same<vector_src_t, vector_dest_t>::value>{}([&](auto) {
*reinterpret_cast<vector_dest_t*>(&p_dst[dst_index]) =
*reinterpret_cast<const vector_src_t*>(&p_src[src_index]);
//printf("%f ", static_cast<float>(p_dst[dst_index]));
}).Else([&](auto) {
for(unsigned int data_idx = 0; data_idx < DataPerAccess; ++data_idx)
{
......
......@@ -147,8 +147,9 @@ __device__ void outerProduct1x4(const half2* a, const half2* b, float* c)
"3"(c[3])); // 3rd Src Acc registers for 2 half2 registers
}
__device__ void outerProduct1x4Half(const vector_type<half, 4>& a,
const vector_type<vector_type<half, 4>, 4>& b,
__device__ void
outerProduct1x4Half(const vector_type<half, 4>::MemoryType& a,
const vector_type<vector_type<half, 4>::MemoryType, 4>::MemoryType& b,
vector_type<float, 4>::MemoryType& c)
{
outerProduct1x4(reinterpret_cast<const half2*>(&a),
......@@ -156,14 +157,16 @@ __device__ void outerProduct1x4Half(const vector_type<half, 4>& a,
reinterpret_cast<float*>(&c));
}
__device__ void outerProduct4x4(const vector_type<vector_type<half, 4>, 4>& a,
const vector_type<vector_type<half, 4>, 4>& b,
__device__ void
outerProduct4x4(const vector_type<vector_type<half, 4>::MemoryType, 4>::MemoryType& a,
const vector_type<vector_type<half, 4>::MemoryType, 4>::MemoryType& b,
vector_type<float, 4>::MemoryType& c0,
vector_type<float, 4>::MemoryType& c1,
vector_type<float, 4>::MemoryType& c2,
vector_type<float, 4>::MemoryType& c3)
{
const vector_type<half, 4>* reg_a = reinterpret_cast<const vector_type<half, 4>*>(&a);
const vector_type<half, 4>::MemoryType* reg_a =
reinterpret_cast<const vector_type<half, 4>::MemoryType*>(&a);
outerProduct1x4Half(reg_a[0], b, c0);
outerProduct1x4Half(reg_a[1], b, c1);
outerProduct1x4Half(reg_a[2], b, c2);
......
......@@ -15,6 +15,19 @@ namespace ck {
// instruction
typedef float float2_t __attribute__((ext_vector_type(2)));
typedef float float4_t __attribute__((ext_vector_type(4)));
typedef half2 half2_t;
typedef struct
{
half2 vector[2];
} half4_t;
typedef struct
{
ushort vector[2];
} ushort2_t;
typedef struct
{
ushort2_t vector[2];
} ushort4_t;
using index_t = uint32_t;
......
......@@ -19,6 +19,19 @@ namespace ck {
// instruction,
using float2_t = float2;
using float4_t = float4;
typedef half2 half2_t;
typedef struct
{
half2 vector[2];
} half4_t;
typedef struct
{
ushort vector[2];
} ushort2_t;
typedef struct
{
ushort2_t vector[2];
} ushort4_t;
using index_t = uint32_t;
......
......@@ -10,7 +10,11 @@ namespace ck {
template <class T, index_t N>
struct vector_type
{
typedef struct
{
T vector[N];
} MemoryType;
MemoryType mData;
};
template <>
......@@ -33,9 +37,7 @@ struct vector_type<float, 2>
{
using MemoryType = float2_t;
__host__ __device__ static constexpr index_t GetSize() { return 2; }
union Data
union DataType
{
MemoryType vector;
float scalar[2];
......@@ -48,6 +50,13 @@ struct vector_type<float, 2>
*(reinterpret_cast<float*>(&v) + I) = s;
}
__host__ __device__ static MemoryType Pack(float s0, float s1)
{
DataType data;
data.scalar[0] = s0;
data.scalar[1] = s1;
return data.vector;
}
};
template <>
......@@ -83,9 +92,9 @@ struct vector_type<half, 1>
template <>
struct vector_type<half, 2>
{
using MemoryType = half2;
using MemoryType = half2_t;
union Data
union DataType
{
MemoryType vector;
half scalar[2];
......@@ -100,17 +109,25 @@ struct vector_type<half, 2>
*(reinterpret_cast<half*>(&v) + I) = s;
}
__host__ __device__ static MemoryType Pack(half s0, half s1)
{
DataType data;
data.scalar[0] = s0;
data.scalar[1] = s1;
return data.vector;
}
};
template <>
struct vector_type<half, 4>
{
typedef struct MemoryType
{
half2 vector[2];
} MemoryType;
using MemoryType = half4_t;
__host__ __device__ static constexpr index_t GetSize() { return 4; }
union DataType
{
MemoryType vector;
half scalar[4];
};
template <index_t I>
__host__ __device__ static void SetScalar(MemoryType& v, half s, Number<I>)
......@@ -118,6 +135,16 @@ struct vector_type<half, 4>
static_assert(I < 4, "wrong");
*(reinterpret_cast<half*>(&v) + I) = s;
}
__host__ __device__ static MemoryType Pack(half s0, half s1, half s2, half s3)
{
DataType data;
data.scalar[0] = s0;
data.scalar[1] = s1;
data.scalar[2] = s2;
data.scalar[3] = s3;
return data.vector;
}
};
template <>
......@@ -138,12 +165,12 @@ struct vector_type<ushort, 1>
template <>
struct vector_type<ushort, 2>
{
using MemoryType = ushort2;
using MemoryType = ushort2_t;
union Data
union DataType
{
MemoryType vector;
half scalar[2];
ushort scalar[2];
};
__host__ __device__ static constexpr index_t GetSize() { return 2; }
......@@ -155,17 +182,25 @@ struct vector_type<ushort, 2>
*(reinterpret_cast<ushort*>(&v) + I) = s;
}
__host__ __device__ static MemoryType Pack(ushort s0, ushort s1)
{
DataType data;
data.scalar[0] = s0;
data.scalar[1] = s1;
return data.vector;
}
};
template <>
struct vector_type<ushort, 4>
{
typedef struct MemoryType
{
ushort2 vector[2];
} MemoryType;
using MemoryType = ushort4_t;
__host__ __device__ static constexpr index_t GetSize() { return 4; }
union DataType
{
MemoryType vector;
ushort scalar[4];
};
template <index_t I>
__host__ __device__ static void SetScalar(MemoryType& v, ushort s, Number<I>)
......@@ -173,6 +208,16 @@ struct vector_type<ushort, 4>
static_assert(I < 4, "wrong");
*(reinterpret_cast<ushort*>(&v) + I) = s;
}
__host__ __device__ static MemoryType Pack(ushort s0, ushort s1, ushort s2, ushort s3)
{
DataType data;
data.scalar[0] = s0;
data.scalar[1] = s1;
data.scalar[2] = s2;
data.scalar[3] = s3;
return data.vector;
}
};
} // namespace ck
......
......@@ -801,6 +801,22 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 1x1 filter, 7x7 image
// cudnn@V100 49%, ck@V100 50%, ck@P100 61%, ck@VII 52%
constexpr index_t N = 8;
constexpr index_t C = 64;
constexpr index_t HI = 4;
constexpr index_t WI = 4;
constexpr index_t K = 64;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#endif
......@@ -897,7 +913,7 @@ int main(int argc, char* argv[])
if(do_verification)
{
#if 1
#if 0
if(Y == 3 && X == 3 && ConvStrides{}[0] == 1 && ConvStrides{}[1] == 1 &&
ConvDilations{}[0] == 1 && ConvDilations{}[1] == 1)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment