Commit fa6d8037 authored by Tejash Shah's avatar Tejash Shah
Browse files

Changed scalar type to vector type for threadwise gemm for fp16 and bfloat16 data types

parent 2185affb
......@@ -308,8 +308,22 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global,
p_wei_register_clipboard);
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4 respectively)
static_if<std::is_same<Float, half>::value>{}([&](auto) {
using vector_t = typename vector_type<half, 4>::MemoryType;
vector_t* vec_wei_block_now = reinterpret_cast<vector_t*>(p_wei_block_now);
vector_t* vec_in_block_now = reinterpret_cast<vector_t*>(p_in_block_now);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
blockwise_gemm.Run(vec_wei_block_now, vec_in_block_now, p_out_thread);
}).Else([&](auto) {
using vector_t = typename vector_type<ushort, 2>::MemoryType;
vector_t* vec_wei_block_now = reinterpret_cast<vector_t*>(p_wei_block_now);
vector_t* vec_in_block_now = reinterpret_cast<vector_t*>(p_in_block_now);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(vec_wei_block_now, vec_in_block_now, p_out_thread);
});
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
......@@ -336,7 +350,22 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
p_wei_register_clipboard);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4 respectively)
static_if<std::is_same<Float, half>::value>{}([&](auto) {
using vector_t = typename vector_type<half, 4>::MemoryType;
vector_t* vec_wei_block_now = reinterpret_cast<vector_t*>(p_wei_block_double);
vector_t* vec_in_block_now = reinterpret_cast<vector_t*>(p_in_block_double);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(vec_wei_block_now, vec_in_block_now, p_out_thread);
}).Else([&](auto) {
using vector_t = typename vector_type<ushort, 2>::MemoryType;
vector_t* vec_wei_block_now = reinterpret_cast<vector_t*>(p_wei_block_double);
vector_t* vec_in_block_now = reinterpret_cast<vector_t*>(p_in_block_double);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(vec_wei_block_now, vec_in_block_now, p_out_thread);
});
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
......@@ -348,9 +377,22 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
__syncthreads();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space,
p_out_thread);
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4 respectively)
static_if<std::is_same<Float, half>::value>{}([&](auto) {
using vector_t = typename vector_type<half, 4>::MemoryType;
vector_t* vec_wei_block_now = reinterpret_cast<vector_t*>(p_wei_block_double + wei_block_space);
vector_t* vec_in_block_now = reinterpret_cast<vector_t*>(p_in_block_double + in_block_space);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(vec_wei_block_now, vec_in_block_now, p_out_thread);
}).Else([&](auto) {
using vector_t = typename vector_type<ushort, 2>::MemoryType;
vector_t* vec_wei_block_now = reinterpret_cast<vector_t*>(p_wei_block_double + wei_block_space);
vector_t* vec_in_block_now = reinterpret_cast<vector_t*>(p_in_block_double + in_block_space);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(vec_wei_block_now, vec_in_block_now, p_out_thread);
});
}
// copy output: register to global memory
......
......@@ -57,25 +57,7 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
}
}).Else([&](auto) {
static_if<std::is_same<Float, half>::value>{}([&](auto) {
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4 respectively)
using vector_t = typename vector_type<Float, 4>::MemoryType;
for(index_t i = 0; i < NRow; ++i)
{
for(index_t j = 0; j < NCol; ++j)
{
const index_t src_index = src_mtx.GetOffsetFromMultiIndex(i, j);
const index_t dst_index = dst_mtx.GetOffsetFromMultiIndex(i, j);
*reinterpret_cast<vector_t*>(&p_dst[dst_index*4]) =
*reinterpret_cast<const vector_t*>(&p_src[src_index*4]);
}
}
}).Else([&](auto) {
using vector_t = typename vector_type<Float, 2>::MemoryType;
// For half/bfloat16, Float type is half4/bfloat2 respectively.
for(index_t i = 0; i < NRow; ++i)
{
for(index_t j = 0; j < NCol; ++j)
......@@ -83,12 +65,11 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
const index_t src_index = src_mtx.GetOffsetFromMultiIndex(i, j);
const index_t dst_index = dst_mtx.GetOffsetFromMultiIndex(i, j);
*reinterpret_cast<vector_t*>(&p_dst[dst_index*2]) =
*reinterpret_cast<const vector_t*>(&p_src[src_index*2]);
*reinterpret_cast<Float*>(&p_dst[dst_index]) =
*reinterpret_cast<const Float*>(&p_src[src_index]);
}
}
});
});
}
template <class MatrixA,
......@@ -129,32 +110,26 @@ __device__ void threadwise_gemm(MatrixA,
const index_t bindex = b_mtx.GetOffsetFromMultiIndex(k, j);
const index_t cindex = c_mtx.GetOffsetFromMultiIndex(i, j);
static_if<std::is_same<FloatA, float>::value>{}([&](auto) {
p_c_thread[cindex] += CVT_FLOAT2ACCUM(p_a_thread[aindex]) *
CVT_FLOAT2ACCUM(p_b_thread[bindex]);
}).Else([&](auto) {
static_if<std::is_same<FloatA, half>::value>{}([&](auto) {
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4
// respectively)
float acc = 0.0;
for(index_t v = 0; v < 4; ++v)
{
acc += CVT_FLOAT2ACCUM(p_a_thread[aindex*4 + v]) *
CVT_FLOAT2ACCUM(p_b_thread[bindex*4 + v]);
}
p_c_thread[cindex] = acc;
//static_if<std::is_same<FloatA, float>::value>{}([&](auto) {
// p_c_thread[cindex] +=
// CVT_FLOAT2ACCUM(p_a_thread[aindex]) * CVT_FLOAT2ACCUM(p_b_thread[bindex]);
//}).Else([&](auto) {
static_if<std::is_same<FloatA, ck::vector_type<half, 4>::MemoryType>::value>{}([&](auto) {
const half* s0_half = reinterpret_cast<const half*>(&p_a_thread[aindex]);
const half* s1_half = reinterpret_cast<const half*>(&p_b_thread[bindex]);
p_c_thread[cindex] +=
CVT_FLOAT2ACCUM(s0_half[0]) * CVT_FLOAT2ACCUM(s1_half[0]) +
CVT_FLOAT2ACCUM(s0_half[1]) * CVT_FLOAT2ACCUM(s1_half[1]) +
CVT_FLOAT2ACCUM(s0_half[2]) * CVT_FLOAT2ACCUM(s1_half[2]) +
CVT_FLOAT2ACCUM(s0_half[3]) * CVT_FLOAT2ACCUM(s1_half[3]);
}).Else([&](auto) {
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4
// respectively)
float acc = 0.0;
for(index_t v = 0; v < 2; ++v)
{
acc += CVT_FLOAT2ACCUM(p_a_thread[aindex*2 + v]) *
CVT_FLOAT2ACCUM(p_b_thread[bindex*2 + v]);
}
p_c_thread[cindex] += acc;
});
const ushort* s0_ushort = reinterpret_cast<const ushort*>(&p_a_thread[aindex]);
const ushort* s1_ushort = reinterpret_cast<const ushort*>(&p_b_thread[bindex]);
p_c_thread[cindex] +=
CVT_FLOAT2ACCUM(s0_ushort[0]) * CVT_FLOAT2ACCUM(s1_ushort[0]) +
CVT_FLOAT2ACCUM(s0_ushort[1]) * CVT_FLOAT2ACCUM(s1_ushort[1]);
});
// });
}
}
}
......
......@@ -112,7 +112,6 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
static_if<std::is_same<vector_src_t, vector_dest_t>::value>{}([&](auto) {
*reinterpret_cast<vector_dest_t*>(&p_dst[dst_index]) =
*reinterpret_cast<const vector_src_t*>(&p_src[src_index]);
//printf("%f ", static_cast<float>(p_dst[dst_index]));
}).Else([&](auto) {
for(unsigned int data_idx = 0; data_idx < DataPerAccess; ++data_idx)
{
......
......@@ -138,11 +138,21 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
wi < in_nchw.mDesc.GetLengths()[3])
{
v += double(in_nchw(n, c, hi, wi)) * double(wei_kcyx(k, c, y, x));
if(n == 0 && k == 0 && ho == 0 && wo == 0)
{
//std::cout << "cpu " << c << "," << hi << "," << wi << " * " <<
// << c << "," << y << "," << x << " = "
// << in_nchw(n,c,hi,wi) << " * " << wei_kcyx(k, c, y, x) << std::endl;
// printf(" cpu %d,%d,%d * %d,%d,%d = %f * %f\n",
// c, hi, wi, c, y, x, double(in_nchw(n,c,hi,wi)), double(wei_kcyx(k, c, y, x)));
}
}
}
}
}
out_nkhw(n, k, ho, wo) = v;
if(n == 0 && k == 0 && ho == 0 && wo == 0)
printf("cpu %d,%d,%d,%d = %f", n,k, ho,wo,v);
};
auto f_par = make_ParallelTensorFunctor(f,
......@@ -787,9 +797,8 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 1
#elif 0
// 1x1 filter, 7x7 image
// cudnn@V100 49%, ck@V100 50%, ck@P100 61%, ck@VII 52%
constexpr index_t N = 32;
constexpr index_t C = 128;
constexpr index_t HI = 28;
......@@ -801,6 +810,20 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 1
constexpr index_t N = 8;
constexpr index_t C = 64;
constexpr index_t HI = 4;
constexpr index_t WI = 4;
constexpr index_t K = 64;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#endif
......@@ -897,7 +920,7 @@ int main(int argc, char* argv[])
if(do_verification)
{
#if 1
#if 0
if(Y == 3 && X == 3 && ConvStrides{}[0] == 1 && ConvStrides{}[1] == 1 &&
ConvDilations{}[0] == 1 && ConvDilations{}[1] == 1)
{
......@@ -915,6 +938,7 @@ int main(int argc, char* argv[])
upper_pads);
}
check_error(out_nkhw_host, out_nkhw_device);
printf("gpu value %f", double(out_nkhw_device.mData[0]));
#if 0
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment