Commit acdafc67 authored by Jing Zhang's avatar Jing Zhang
Browse files

add other vector_sizes

parent d5c856bc
......@@ -101,7 +101,31 @@ struct ThreadwiseGenericTensorSliceCopy_v5
__device__ static float run(const float* p_src, const SrcCoord src_coord_begin)
{
float r;
r = load_data<float, float>(p_src, src_coord_begin.GetOffset());
r = load_data<float>(p_src, src_coord_begin.GetOffset());
return r;
}
};
template <>
struct vector_data_load<float, 2>
{
template <typename SrcCoord>
__device__ static float2_t run(const float* p_src, const SrcCoord src_coord_begin)
{
float2_t r;
r = load_data<float2_t>(p_src, src_coord_begin.GetOffset());
return r;
}
};
template <>
struct vector_data_load<float, 4>
{
template <typename SrcCoord>
__device__ static float4_t run(const float* p_src, const SrcCoord src_coord_begin)
{
float4_t r;
r = load_data<float4_t>(p_src, src_coord_begin.GetOffset());
return r;
}
};
......@@ -116,7 +140,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
__device__ static void
run(float* p_dst, const float src_data, const DstCoord dst_coord_begin)
{
store_data<float, float>(src_data, p_dst, dst_coord_begin.GetOffset());
store_data<float>(src_data, p_dst, dst_coord_begin.GetOffset());
}
};
......@@ -127,7 +151,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
constexpr auto src_data_per_access = Number<SrcDataPerRead>{};
static_assert(SrcDataPerRead == 1, "");
static_assert(SrcDataPerRead == 1 || SrcDataPerRead == 2 || SrcDataPerRead == 4, "");
constexpr auto long_vector_size = src_data_per_access;
......@@ -147,7 +171,8 @@ struct ThreadwiseGenericTensorSliceCopy_v5
// store data from the long-vector buffer to dst
constexpr auto buff_off =
ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id));
ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id)) /
long_vector_size;
thread_buff.template At<SrcDataPerRead>()(Number<buff_off>{}) = src_buff;
});
......@@ -174,7 +199,8 @@ struct ThreadwiseGenericTensorSliceCopy_v5
Number<long_vector_size * long_vector_access_id[vector_access_dim]>{});
constexpr auto buff_off =
ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id));
ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id)) /
long_vector_size;
auto src_buff = thread_buff.template At<DstDataPerWrite>()[Number<buff_off>{}];
......
......@@ -25,15 +25,16 @@ typedef ushort ushort8_t __attribute__((ext_vector_type(8)));
union float_vec2_t
{
StaticallyIndexedArray<float, 2> s1;
float2_t s2;
__host__ __device__ constexpr float_vec2_t() {s2 = {0, 0};}
StaticallyIndexedArray<float2_t, 1> s2;
__host__ __device__ constexpr float_vec2_t() {}
};
union float_vec4_t
{
StaticallyIndexedArray<float, 4> s1;
float4_t s4;
__host__ __device__ constexpr float_vec4_t() {s4 = {0, 0, 0, 0};}
StaticallyIndexedArray<float2_t, 2> s2;
StaticallyIndexedArray<float4_t, 1> s4;
__host__ __device__ constexpr float_vec4_t() {}
template<index_t vs>
__host__ __device__ auto& At();
......@@ -44,6 +45,12 @@ union float_vec4_t
return s1;
}
template<>
__host__ __device__ auto& At<2>()
{
return s2;
}
template<>
__host__ __device__ auto& At<4>()
{
......@@ -54,9 +61,9 @@ union float_vec4_t
union float_vec8_t
{
StaticallyIndexedArray<float, 8> s1;
StaticallyIndexedArray<float_vec2_t, 4> s2;
StaticallyIndexedArray<float_vec4_t, 2> s4;
float8_t s8;
StaticallyIndexedArray<float2_t, 4> s2;
StaticallyIndexedArray<float4_t, 2> s4;
StaticallyIndexedArray<float8_t, 1> s8;
__host__ __device__ constexpr float_vec8_t() {}
};
......
......@@ -183,7 +183,7 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment