Commit fdaaaa50 authored by Chao Liu's avatar Chao Liu
Browse files

Merge branch 'direct_fp16'

parents 2c9b8c24 18a81e35
......@@ -37,7 +37,8 @@ __device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, Float* __re
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
template <class Float,
template <class SrcData,
class DstData,
class SrcDesc,
class DstDesc,
class SrcOpLengths,
......@@ -45,9 +46,9 @@ template <class Float,
class F>
__device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
SrcDesc,
const Float* __restrict__ p_src,
const SrcData* __restrict__ p_src,
DstDesc,
Float* __restrict__ p_dst,
DstData* __restrict__ p_dst,
SrcOpLengths,
DstFromSrcReorder,
F f)
......@@ -88,33 +89,38 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_d
}
}
template <class Float, class Desc>
__device__ void threadwise_4d_tensor_set_zero(Desc, Float* __restrict__ p)
template <class Data, class Desc>
__device__ void threadwise_4d_tensor_set_zero(Desc, Data* __restrict__ p)
{
auto f_set_zero = [](Float& v) { v = Float(0); };
auto f_set_zero = [](Data& v) { v = Data(0); };
threadwise_4d_tensor_pointwise_operation_unary<Float, Desc, decltype(f_set_zero)>(
threadwise_4d_tensor_pointwise_operation_unary<Data, Desc, decltype(f_set_zero)>(
Desc{}, p, f_set_zero);
}
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, class DstFromSrcReorder>
template <class SrcData,
class DstData,
class SrcDesc,
class DstDesc,
class SrcOpLengths,
class DstFromSrcReorder>
__device__ void
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
const Float* __restrict__ p_src,
const SrcData* __restrict__ p_src,
DstDesc,
Float* __restrict__ p_dst,
DstData* __restrict__ p_dst,
SrcOpLengths,
DstFromSrcReorder)
{
auto f_copy = [](const Float& src, Float& dst) { dst = src; };
auto f_copy = [](const SrcData& src, DstData& dst) { dst = static_cast<DstData>(src); };
threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
}
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
template <class SrcData, class DstData, class SrcDesc, class DstDesc, class SrcOpLengths>
__device__ void threadwise_4d_tensor_copy(
SrcDesc, const Float* __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths)
SrcDesc, const SrcData* __restrict__ p_src, DstDesc, DstData* __restrict__ p_dst, SrcOpLengths)
{
auto dst_from_src_reorder = Sequence<0, 1, 2, 3>{};
......
......@@ -2,13 +2,13 @@
#include "ConstantTensorDescriptor.hip.hpp"
// optimized for scenario if p_in, p_wei, p_out are in register
template <class Float, class InDesc, class WeiDesc, class OutDesc>
template <class TInWei, class TOut, class InDesc, class WeiDesc, class OutDesc>
__device__ void threadwise_direct_convolution_1(InDesc,
Float* const __restrict__ p_in,
TInWei* const __restrict__ p_in,
WeiDesc,
Float* const __restrict__ p_wei,
TInWei* const __restrict__ p_wei,
OutDesc,
Float* __restrict__ p_out)
TOut* __restrict__ p_out)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -51,25 +51,8 @@ __device__ void threadwise_direct_convolution_1(InDesc,
const unsigned out_index = out_desc.Get1dIndex(n, k, ho, wo);
p_out[out_index] += p_wei[wei_index] * p_in[in_index];
#if 0
// if(threadIdx.x == 0)
{
printf("threadwise_direct_convolution: \t"
"threadIdx.x %u\t"
"out_index %u, p_out[out_index] %f, \t"
"wei_index %u, p_wei[wei_index] %f, \t"
"in_index %u, p_in[in_index] %f\n",
threadIdx.x,
out_index,
p_out[out_index],
wei_index,
p_wei[wei_index],
in_index,
p_in[in_index]);
}
#endif
fused_multiply_accumulate(
p_out[out_index], p_wei[wei_index], p_in[in_index]);
}
}
}
......@@ -81,13 +64,13 @@ __device__ void threadwise_direct_convolution_1(InDesc,
// Optimized for scenario if p_in and p_wei are in LDS, p_out are in register
// Copy in and wei into register before doing convolution
template <class Float, class InDesc, class WeiDesc, class OutDesc>
template <class TInWei, class TOut, class InDesc, class WeiDesc, class OutDesc>
__device__ void threadwise_direct_convolution_2(InDesc,
Float* const __restrict__ p_in,
TInWei* const __restrict__ p_in,
WeiDesc,
Float* const __restrict__ p_wei,
TInWei* const __restrict__ p_wei,
OutDesc,
Float* __restrict__ p_out)
TOut* __restrict__ p_out)
{
constexpr auto in_desc = InDesc{};
constexpr auto wei_desc = WeiDesc{};
......@@ -97,8 +80,8 @@ __device__ void threadwise_direct_convolution_2(InDesc,
constexpr auto wei_reg_desc = make_ConstantTensorDescriptor(wei_desc.GetLengths());
// register
Float p_in_reg[in_reg_desc.GetElementSpace()];
Float p_wei_reg[wei_reg_desc.GetElementSpace()];
TInWei p_in_reg[in_reg_desc.GetElementSpace()];
TInWei p_wei_reg[wei_reg_desc.GetElementSpace()];
// copy input tensor into register
threadwise_4d_tensor_copy(in_desc, p_in, in_reg_desc, p_in_reg, in_reg_desc.GetLengths());
......@@ -114,13 +97,13 @@ __device__ void threadwise_direct_convolution_2(InDesc,
// optimized for scenario where p_in and p_wei are in LDS, p_out is in register
// break down a non-1x1 convolution into a sequence of 1x1 convolutions,
// load 1x1 weight into register, and do 1x1 convolution in register.
template <class Float, class InDesc, class WeiDesc, class OutDesc>
template <class Data, class InDesc, class WeiDesc, class OutDesc>
__device__ void threadwise_direct_convolution_3(InDesc,
Float* const __restrict__ p_in,
Data* const __restrict__ p_in,
WeiDesc,
Float* const __restrict__ p_wei,
Data* const __restrict__ p_wei,
OutDesc,
Float* __restrict__ p_out)
Data* __restrict__ p_out)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -139,8 +122,8 @@ __device__ void threadwise_direct_convolution_3(InDesc,
constexpr auto wei_reg_desc = make_ConstantTensorDescriptor(
Sequence<wei_desc.GetLength(I0), wei_desc.GetLength(I1), 1, 1>{});
Float p_in_reg[in_reg_desc.GetElementSpace()];
Float p_wei_reg[wei_reg_desc.GetElementSpace()];
Data p_in_reg[in_reg_desc.GetElementSpace()];
Data p_wei_reg[wei_reg_desc.GetElementSpace()];
constexpr unsigned in_w_new_read = 1;
......
......@@ -10,7 +10,7 @@ __device__ void threadwise_6d_tensor_copy(SrcDesc,
SrcOpLengths,
Number<DataPerRead>)
{
using vector_t = typename vector_type<Float, DataPerRead>::type;
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
static_assert(SrcDesc{}.GetDimension() == 6 && DstDesc{}.GetDimension() == 6 &&
SrcOpLengths::nDim == 6,
......@@ -80,7 +80,7 @@ __device__ void threadwise_8d_tensor_copy(SrcDesc,
SrcOpLengths,
Number<DataPerRead>)
{
using vector_t = typename vector_type<Float, DataPerRead>::type;
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
static_assert(SrcDesc{}.GetDimension() == 8 && DstDesc{}.GetDimension() == 8 &&
SrcOpLengths::nDim == 8,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment