"include/vscode:/vscode.git/clone" did not exist on "dba65b1c71197d63688c75fb0290142b7a0f30e4"
Commit 7972ab17 authored by Jing Zhang's avatar Jing Zhang
Browse files

add vector store by 4

parent 5494423f
......@@ -80,9 +80,9 @@ struct ThreadwiseGenericTensorSliceCopy_v5
}
template <typename DstData, typename SrcData>
__device__ static DstData load_data(const SrcData* p_src, index_t src_offset)
__device__ static void load_data(DstData& dst, const SrcData* p_src, index_t src_offset)
{
return *reinterpret_cast<const DstData*>(&p_src[src_offset]);
dst = *reinterpret_cast<const DstData*>(&p_src[src_offset]);
}
template <typename DstData, typename SrcData>
......@@ -98,10 +98,10 @@ struct ThreadwiseGenericTensorSliceCopy_v5
struct vector_data_load<float, 1>
{
template <typename SrcCoord>
__device__ static float run(const float* p_src, const SrcCoord src_coord_begin)
__device__ static auto run(const float* p_src, const SrcCoord src_coord_begin)
{
float r;
r = load_data<float>(p_src, src_coord_begin.GetOffset());
load_data(r, p_src, src_coord_begin.GetOffset());
return r;
}
};
......@@ -110,10 +110,10 @@ struct ThreadwiseGenericTensorSliceCopy_v5
struct vector_data_load<float, 2>
{
template <typename SrcCoord>
__device__ static float2_t run(const float* p_src, const SrcCoord src_coord_begin)
__device__ static auto run(const float* p_src, const SrcCoord src_coord_begin)
{
float2_t r;
r = load_data<float2_t>(p_src, src_coord_begin.GetOffset());
load_data(r, p_src, src_coord_begin.GetOffset());
return r;
}
};
......@@ -122,10 +122,10 @@ struct ThreadwiseGenericTensorSliceCopy_v5
struct vector_data_load<float, 4>
{
template <typename SrcCoord>
__device__ static float4_t run(const float* p_src, const SrcCoord src_coord_begin)
__device__ static auto run(const float* p_src, const SrcCoord src_coord_begin)
{
float4_t r;
r = load_data<float4_t>(p_src, src_coord_begin.GetOffset());
load_data(r, p_src, src_coord_begin.GetOffset());
return r;
}
};
......@@ -155,6 +155,17 @@ struct ThreadwiseGenericTensorSliceCopy_v5
}
};
template <>
struct vector_data_store<float, 4>
{
template <typename DstCoord>
__device__ static void
run(float* p_dst, const float4_t src_data, const DstCoord dst_coord_begin)
{
store_data(src_data, p_dst, dst_coord_begin.GetOffset());
}
};
template <typename SrcData>
__device__ void Load(const SrcData* p_src)
{
......@@ -196,7 +207,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
constexpr auto dst_data_per_access = Number<DstDataPerWrite>{};
static_assert(DstDataPerWrite == 1 || DstDataPerWrite == 2, "");
static_assert(DstDataPerWrite == 1 || DstDataPerWrite == 2 || DstDataPerWrite == 4, "");
constexpr auto long_vector_size = dst_data_per_access;
......
......@@ -184,7 +184,7 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 2;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#elif 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment