"composable_kernel/include/utility/math.hpp" did not exist on "e7b8705b913c1bb7d216255f1f233ea03c096f1e"
Commit fa6d8037 authored by Tejash Shah's avatar Tejash Shah
Browse files

Changed scalar type to vector type for threadwise gemm for fp16 and bfloat16 data types

parent 2185affb
......@@ -308,9 +308,23 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global,
p_wei_register_clipboard);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4 respectively)
static_if<std::is_same<Float, half>::value>{}([&](auto) {
using vector_t = typename vector_type<half, 4>::MemoryType;
vector_t* vec_wei_block_now = reinterpret_cast<vector_t*>(p_wei_block_now);
vector_t* vec_in_block_now = reinterpret_cast<vector_t*>(p_in_block_now);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(vec_wei_block_now, vec_in_block_now, p_out_thread);
}).Else([&](auto) {
using vector_t = typename vector_type<ushort, 2>::MemoryType;
vector_t* vec_wei_block_now = reinterpret_cast<vector_t*>(p_wei_block_now);
vector_t* vec_in_block_now = reinterpret_cast<vector_t*>(p_in_block_now);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(vec_wei_block_now, vec_in_block_now, p_out_thread);
});
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
p_in_block_next);
......@@ -336,7 +350,22 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
p_wei_register_clipboard);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4 respectively)
static_if<std::is_same<Float, half>::value>{}([&](auto) {
using vector_t = typename vector_type<half, 4>::MemoryType;
vector_t* vec_wei_block_now = reinterpret_cast<vector_t*>(p_wei_block_double);
vector_t* vec_in_block_now = reinterpret_cast<vector_t*>(p_in_block_double);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(vec_wei_block_now, vec_in_block_now, p_out_thread);
}).Else([&](auto) {
using vector_t = typename vector_type<ushort, 2>::MemoryType;
vector_t* vec_wei_block_now = reinterpret_cast<vector_t*>(p_wei_block_double);
vector_t* vec_in_block_now = reinterpret_cast<vector_t*>(p_in_block_double);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(vec_wei_block_now, vec_in_block_now, p_out_thread);
});
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
......@@ -348,9 +377,22 @@ struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_b
__syncthreads();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space,
p_out_thread);
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4 respectively)
static_if<std::is_same<Float, half>::value>{}([&](auto) {
using vector_t = typename vector_type<half, 4>::MemoryType;
vector_t* vec_wei_block_now = reinterpret_cast<vector_t*>(p_wei_block_double + wei_block_space);
vector_t* vec_in_block_now = reinterpret_cast<vector_t*>(p_in_block_double + in_block_space);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(vec_wei_block_now, vec_in_block_now, p_out_thread);
}).Else([&](auto) {
using vector_t = typename vector_type<ushort, 2>::MemoryType;
vector_t* vec_wei_block_now = reinterpret_cast<vector_t*>(p_wei_block_double + wei_block_space);
vector_t* vec_in_block_now = reinterpret_cast<vector_t*>(p_in_block_double + in_block_space);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(vec_wei_block_now, vec_in_block_now, p_out_thread);
});
}
// copy output: register to global memory
......
......@@ -57,37 +57,18 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
}
}).Else([&](auto) {
static_if<std::is_same<Float, half>::value>{}([&](auto) {
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4 respectively)
using vector_t = typename vector_type<Float, 4>::MemoryType;
for(index_t i = 0; i < NRow; ++i)
{
for(index_t j = 0; j < NCol; ++j)
{
const index_t src_index = src_mtx.GetOffsetFromMultiIndex(i, j);
const index_t dst_index = dst_mtx.GetOffsetFromMultiIndex(i, j);
*reinterpret_cast<vector_t*>(&p_dst[dst_index*4]) =
*reinterpret_cast<const vector_t*>(&p_src[src_index*4]);
}
}
}).Else([&](auto) {
using vector_t = typename vector_type<Float, 2>::MemoryType;
for(index_t i = 0; i < NRow; ++i)
// For half/bfloat16, Float type is half4/bfloat2 respectively.
for(index_t i = 0; i < NRow; ++i)
{
for(index_t j = 0; j < NCol; ++j)
{
for(index_t j = 0; j < NCol; ++j)
{
const index_t src_index = src_mtx.GetOffsetFromMultiIndex(i, j);
const index_t dst_index = dst_mtx.GetOffsetFromMultiIndex(i, j);
const index_t src_index = src_mtx.GetOffsetFromMultiIndex(i, j);
const index_t dst_index = dst_mtx.GetOffsetFromMultiIndex(i, j);
*reinterpret_cast<vector_t*>(&p_dst[dst_index*2]) =
*reinterpret_cast<const vector_t*>(&p_src[src_index*2]);
}
*reinterpret_cast<Float*>(&p_dst[dst_index]) =
*reinterpret_cast<const Float*>(&p_src[src_index]);
}
});
}
});
}
......@@ -129,32 +110,26 @@ __device__ void threadwise_gemm(MatrixA,
const index_t bindex = b_mtx.GetOffsetFromMultiIndex(k, j);
const index_t cindex = c_mtx.GetOffsetFromMultiIndex(i, j);
static_if<std::is_same<FloatA, float>::value>{}([&](auto) {
p_c_thread[cindex] += CVT_FLOAT2ACCUM(p_a_thread[aindex]) *
CVT_FLOAT2ACCUM(p_b_thread[bindex]);
}).Else([&](auto) {
static_if<std::is_same<FloatA, half>::value>{}([&](auto) {
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4
// respectively)
float acc = 0.0;
for(index_t v = 0; v < 4; ++v)
{
acc += CVT_FLOAT2ACCUM(p_a_thread[aindex*4 + v]) *
CVT_FLOAT2ACCUM(p_b_thread[bindex*4 + v]);
}
p_c_thread[cindex] = acc;
}).Else([&](auto) {
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4
// respectively)
float acc = 0.0;
for(index_t v = 0; v < 2; ++v)
{
acc += CVT_FLOAT2ACCUM(p_a_thread[aindex*2 + v]) *
CVT_FLOAT2ACCUM(p_b_thread[bindex*2 + v]);
}
p_c_thread[cindex] += acc;
});
});
//static_if<std::is_same<FloatA, float>::value>{}([&](auto) {
// p_c_thread[cindex] +=
// CVT_FLOAT2ACCUM(p_a_thread[aindex]) * CVT_FLOAT2ACCUM(p_b_thread[bindex]);
//}).Else([&](auto) {
static_if<std::is_same<FloatA, ck::vector_type<half, 4>::MemoryType>::value>{}([&](auto) {
const half* s0_half = reinterpret_cast<const half*>(&p_a_thread[aindex]);
const half* s1_half = reinterpret_cast<const half*>(&p_b_thread[bindex]);
p_c_thread[cindex] +=
CVT_FLOAT2ACCUM(s0_half[0]) * CVT_FLOAT2ACCUM(s1_half[0]) +
CVT_FLOAT2ACCUM(s0_half[1]) * CVT_FLOAT2ACCUM(s1_half[1]) +
CVT_FLOAT2ACCUM(s0_half[2]) * CVT_FLOAT2ACCUM(s1_half[2]) +
CVT_FLOAT2ACCUM(s0_half[3]) * CVT_FLOAT2ACCUM(s1_half[3]);
}).Else([&](auto) {
const ushort* s0_ushort = reinterpret_cast<const ushort*>(&p_a_thread[aindex]);
const ushort* s1_ushort = reinterpret_cast<const ushort*>(&p_b_thread[bindex]);
p_c_thread[cindex] +=
CVT_FLOAT2ACCUM(s0_ushort[0]) * CVT_FLOAT2ACCUM(s1_ushort[0]) +
CVT_FLOAT2ACCUM(s0_ushort[1]) * CVT_FLOAT2ACCUM(s1_ushort[1]);
});
// });
}
}
}
......
......@@ -112,7 +112,6 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
static_if<std::is_same<vector_src_t, vector_dest_t>::value>{}([&](auto) {
*reinterpret_cast<vector_dest_t*>(&p_dst[dst_index]) =
*reinterpret_cast<const vector_src_t*>(&p_src[src_index]);
//printf("%f ", static_cast<float>(p_dst[dst_index]));
}).Else([&](auto) {
for(unsigned int data_idx = 0; data_idx < DataPerAccess; ++data_idx)
{
......
......@@ -138,11 +138,21 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
wi < in_nchw.mDesc.GetLengths()[3])
{
v += double(in_nchw(n, c, hi, wi)) * double(wei_kcyx(k, c, y, x));
if(n == 0 && k == 0 && ho == 0 && wo == 0)
{
//std::cout << "cpu " << c << "," << hi << "," << wi << " * " <<
// << c << "," << y << "," << x << " = "
// << in_nchw(n,c,hi,wi) << " * " << wei_kcyx(k, c, y, x) << std::endl;
// printf(" cpu %d,%d,%d * %d,%d,%d = %f * %f\n",
// c, hi, wi, c, y, x, double(in_nchw(n,c,hi,wi)), double(wei_kcyx(k, c, y, x)));
}
}
}
}
}
out_nkhw(n, k, ho, wo) = v;
if(n == 0 && k == 0 && ho == 0 && wo == 0)
printf("cpu %d,%d,%d,%d = %f", n,k, ho,wo,v);
};
auto f_par = make_ParallelTensorFunctor(f,
......@@ -787,9 +797,8 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 1
#elif 0
// 1x1 filter, 7x7 image
// cudnn@V100 49%, ck@V100 50%, ck@P100 61%, ck@VII 52%
constexpr index_t N = 32;
constexpr index_t C = 128;
constexpr index_t HI = 28;
......@@ -803,6 +812,20 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 1
constexpr index_t N = 8;
constexpr index_t C = 64;
constexpr index_t HI = 4;
constexpr index_t WI = 4;
constexpr index_t K = 64;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#endif
auto lower_pads = Sequence<HPad, WPad>{};
......@@ -897,7 +920,7 @@ int main(int argc, char* argv[])
if(do_verification)
{
#if 1
#if 0
if(Y == 3 && X == 3 && ConvStrides{}[0] == 1 && ConvStrides{}[1] == 1 &&
ConvDilations{}[0] == 1 && ConvDilations{}[1] == 1)
{
......@@ -915,6 +938,7 @@ int main(int argc, char* argv[])
upper_pads);
}
check_error(out_nkhw_host, out_nkhw_device);
printf("gpu value %f", double(out_nkhw_device.mData[0]));
#if 0
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment