Commit 0486af77 authored by Jing Zhang's avatar Jing Zhang
Browse files

add initial in vector type; move buffer to gridwise gemm

parent 79137e1a
...@@ -78,28 +78,30 @@ struct BlockwiseGenericTensorSliceCopy_v5 ...@@ -78,28 +78,30 @@ struct BlockwiseGenericTensorSliceCopy_v5
return ThreadBufferDesc::GetElementSpace(); return ThreadBufferDesc::GetElementSpace();
} }
template <typename BlockSrcData> template <typename BlockSrcData, typename ThreadBuffData>
__device__ void RunLoadThreadBuffer(const BlockSrcData* p_block_src) __device__ void RunLoadThreadBuffer(const BlockSrcData* p_block_src,
ThreadBuffData& thread_buff)
{ {
if(BlockSize == mThreadClusterDesc.GetElementSize() or if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize()) get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{ {
mThreadwiseCopy.Load(p_block_src); mThreadwiseCopy.Load(p_block_src, thread_buff);
} }
} }
template <typename BlockDstData> template <typename ThreadBuffData, typename BlockDstData>
__device__ void RunStoreThreadBuffer(BlockDstData* p_block_dst) __device__ void RunStoreThreadBuffer(ThreadBuffData thread_buff, BlockDstData* p_block_dst)
{ {
if(BlockSize == mThreadClusterDesc.GetElementSize() or if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize()) get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{ {
mThreadwiseCopy.Store(p_block_dst); mThreadwiseCopy.Store(thread_buff, p_block_dst);
} }
} }
template <typename BlockSrcData, typename BlockDstData> template <typename BlockSrcData, typename BlockDstData, typename ThreadBuffData>
__device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst) __device__ void
Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst, ThreadBuffData& thread_buff)
{ {
static_assert(ThreadBufferAddressSpace == AddressSpace::Vgpr, static_assert(ThreadBufferAddressSpace == AddressSpace::Vgpr,
"wrong! This function use vgpr as its thread " "wrong! This function use vgpr as its thread "
...@@ -110,8 +112,8 @@ struct BlockwiseGenericTensorSliceCopy_v5 ...@@ -110,8 +112,8 @@ struct BlockwiseGenericTensorSliceCopy_v5
if(BlockSize == mThreadClusterDesc.GetElementSize() or if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize()) get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{ {
RunLoadThreadBuffer(p_block_src); RunLoadThreadBuffer(p_block_src, thread_buff);
RunStoreThreadBuffer(p_block_dst); RunStoreThreadBuffer(thread_buff, p_block_dst);
} }
} }
......
...@@ -496,10 +496,18 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2 ...@@ -496,10 +496,18 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
constexpr index_t c_thread_size = MPerBlock * NPerBlock / BlockSize; constexpr index_t c_thread_size = MPerBlock * NPerBlock / BlockSize;
auto c_thread_vec = GetRegBuffer<AccFloat, c_thread_size>(); auto c_thread_vec = GetRegBuffer<AccFloat, c_thread_size>();
using ThreadBufferTypeA =
decltype(GetRegBuffer<ABFloat, a_blockwise_copy.GetThreadBufferSize()>());
using ThreadBufferTypeB =
decltype(GetRegBuffer<ABFloat, b_blockwise_copy.GetThreadBufferSize()>());
ThreadBufferTypeA thread_buff_a;
ThreadBufferTypeB thread_buff_b;
// preload data into LDS // preload data into LDS
{ {
a_blockwise_copy.Run(p_a_global, p_a_block); a_blockwise_copy.Run(p_a_global, p_a_block, thread_buff_a);
b_blockwise_copy.Run(p_b_global, p_b_block); b_blockwise_copy.Run(p_b_global, p_b_block, thread_buff_b);
} }
constexpr auto blockwise_a_copy_src_step = Sequence<0, KPerBlock, 0, 0>{}; constexpr auto blockwise_a_copy_src_step = Sequence<0, KPerBlock, 0, 0>{};
...@@ -509,16 +517,12 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2 ...@@ -509,16 +517,12 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
for(index_t k_block_data_begin = 0; k_block_data_begin < K - KPerBlock; for(index_t k_block_data_begin = 0; k_block_data_begin < K - KPerBlock;
k_block_data_begin += KPerBlock) k_block_data_begin += KPerBlock)
{ {
// ABFloat p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
// load next data from device mem // load next data from device mem
a_blockwise_copy.MoveSrcSliceWindow(blockwise_a_copy_src_step, True); a_blockwise_copy.MoveSrcSliceWindow(blockwise_a_copy_src_step, True);
b_blockwise_copy.MoveSrcSliceWindow(blockwise_b_copy_src_step, True); b_blockwise_copy.MoveSrcSliceWindow(blockwise_b_copy_src_step, True);
a_blockwise_copy.RunLoadThreadBuffer(p_a_global); a_blockwise_copy.RunLoadThreadBuffer(p_a_global, thread_buff_a);
// a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer); b_blockwise_copy.RunLoadThreadBuffer(p_b_global, thread_buff_b);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global);
block_sync_lds(); block_sync_lds();
...@@ -535,10 +539,8 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2 ...@@ -535,10 +539,8 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
block_sync_lds(); block_sync_lds();
// store next data to LDS // store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_block); a_blockwise_copy.RunStoreThreadBuffer(thread_buff_a, p_a_block);
// a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block); b_blockwise_copy.RunStoreThreadBuffer(thread_buff_b, p_b_block);
b_blockwise_copy.RunStoreThreadBuffer(p_b_block);
} }
// tail // tail
......
...@@ -32,12 +32,6 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -32,12 +32,6 @@ struct ThreadwiseGenericTensorSliceCopy_v5
{ {
using ThreadBufferDesc = decltype(make_native_tensor_descriptor_packed(SliceLengths{})); using ThreadBufferDesc = decltype(make_native_tensor_descriptor_packed(SliceLengths{}));
static constexpr index_t ThreadBufferSize = ThreadBufferDesc::GetElementSpace();
using ThreadBufferType = decltype(GetRegBuffer<float, ThreadBufferSize>());
ThreadBufferType thread_buff;
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
...@@ -167,8 +161,8 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -167,8 +161,8 @@ struct ThreadwiseGenericTensorSliceCopy_v5
} }
}; };
template <typename SrcData> template <typename SrcData, typename DstData>
__device__ void Load(const SrcData* p_src) __device__ void Load(const SrcData* p_src, DstData& thread_buff)
{ {
constexpr auto vector_access_dim = Number<SrcVectorReadDim>{}; constexpr auto vector_access_dim = Number<SrcVectorReadDim>{};
...@@ -192,7 +186,6 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -192,7 +186,6 @@ struct ThreadwiseGenericTensorSliceCopy_v5
auto src_buff = buffer_vector_load<SrcDataPerRead, SrcDesc::GetElementSpace()>( auto src_buff = buffer_vector_load<SrcDataPerRead, SrcDesc::GetElementSpace()>(
p_src, src_coord); p_src, src_coord);
// vector_data_load<SrcData, src_data_per_access>::run(p_src, src_coord);
// store data from the long-vector buffer to dst // store data from the long-vector buffer to dst
constexpr auto buff_off = constexpr auto buff_off =
...@@ -203,8 +196,8 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -203,8 +196,8 @@ struct ThreadwiseGenericTensorSliceCopy_v5
}); });
} }
template <typename DstData> template <typename SrcData, typename DstData>
__device__ void Store(DstData* p_dst) __device__ void Store(SrcData thread_buff, DstData* p_dst)
{ {
constexpr auto vector_access_dim = Number<DstVectorWriteDim>{}; constexpr auto vector_access_dim = Number<DstVectorWriteDim>{};
...@@ -236,38 +229,6 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -236,38 +229,6 @@ struct ThreadwiseGenericTensorSliceCopy_v5
}); });
} }
template <typename SrcData, typename DstData>
__device__ void Store(SrcData src, DstData* p_dst)
{
constexpr auto vector_access_dim = Number<DstVectorWriteDim>{};
constexpr auto dst_data_per_access = Number<DstDataPerWrite>{};
static_assert(DstDataPerWrite == 1 || DstDataPerWrite == 2 || DstDataPerWrite == 4, "");
constexpr auto long_vector_size = dst_data_per_access;
constexpr auto long_vector_access_lengths = SliceLengths::Modify(
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
static_ford<decltype(long_vector_access_lengths), DstDimAccessOrder>{}(
[&](auto long_vector_access_id) {
constexpr auto long_vector_data_begin_id = long_vector_access_id.Modify(
Number<vector_access_dim>{},
Number<long_vector_size * long_vector_access_id[vector_access_dim]>{});
constexpr auto buff_off =
ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id)) /
long_vector_size;
auto src_buff = src.GetVector(Number<DstDataPerWrite>{})[Number<buff_off>{}];
const auto dst_coord = mDstSliceOrigin + to_multi_index(long_vector_data_begin_id);
vector_data_store<DstData, DstDataPerWrite>::run(p_dst, src_buff, dst_coord);
});
}
template <typename T, bool PositiveDirection> template <typename T, bool PositiveDirection>
__device__ void MoveSrcSliceWindow(const T& step_sizes_, __device__ void MoveSrcSliceWindow(const T& step_sizes_,
integral_constant<bool, PositiveDirection>) integral_constant<bool, PositiveDirection>)
......
...@@ -28,7 +28,7 @@ union float_vec2_t ...@@ -28,7 +28,7 @@ union float_vec2_t
{ {
StaticallyIndexedArray<float, 2> s1; StaticallyIndexedArray<float, 2> s1;
StaticallyIndexedArray<float2_t, 1> s2; StaticallyIndexedArray<float2_t, 1> s2;
__host__ __device__ constexpr float_vec2_t() {} __host__ __device__ constexpr float_vec2_t() { s2(Number<0>{}) = 0; }
}; };
union float_vec4_t union float_vec4_t
...@@ -66,7 +66,7 @@ union float_vec8_t ...@@ -66,7 +66,7 @@ union float_vec8_t
StaticallyIndexedArray<float2_t, 4> s2; StaticallyIndexedArray<float2_t, 4> s2;
StaticallyIndexedArray<float4_t, 2> s4; StaticallyIndexedArray<float4_t, 2> s4;
StaticallyIndexedArray<float8_t, 1> s8; StaticallyIndexedArray<float8_t, 1> s8;
__host__ __device__ constexpr float_vec8_t() {} __host__ __device__ constexpr float_vec8_t() { s8(Number<0>{}) = 0; }
template <index_t vs> template <index_t vs>
__host__ __device__ auto& GetVector(Number<vs>); __host__ __device__ auto& GetVector(Number<vs>);
...@@ -103,7 +103,7 @@ union float_vec16_t ...@@ -103,7 +103,7 @@ union float_vec16_t
StaticallyIndexedArray<float4_t, 4> s4; StaticallyIndexedArray<float4_t, 4> s4;
StaticallyIndexedArray<float8_t, 2> s8; StaticallyIndexedArray<float8_t, 2> s8;
StaticallyIndexedArray<float16_t, 1> s16; StaticallyIndexedArray<float16_t, 1> s16;
__host__ __device__ constexpr float_vec16_t() {} __host__ __device__ constexpr float_vec16_t() { s16(Number<0>{}) = 0; }
template <index_t vs> template <index_t vs>
__host__ __device__ auto& GetVector(Number<vs>); __host__ __device__ auto& GetVector(Number<vs>);
...@@ -147,7 +147,7 @@ union float_vec32_t ...@@ -147,7 +147,7 @@ union float_vec32_t
StaticallyIndexedArray<float_vec8_t, 4> s8; StaticallyIndexedArray<float_vec8_t, 4> s8;
StaticallyIndexedArray<float_vec16_t, 2> s16; StaticallyIndexedArray<float_vec16_t, 2> s16;
StaticallyIndexedArray<float32_t, 1> s32; StaticallyIndexedArray<float32_t, 1> s32;
__host__ __device__ constexpr float_vec32_t() {} __host__ __device__ constexpr float_vec32_t() { s32(Number<0>{}) = 0; }
template <index_t vs> template <index_t vs>
__host__ __device__ auto& GetVector(Number<vs>); __host__ __device__ auto& GetVector(Number<vs>);
...@@ -189,7 +189,7 @@ union float_vec64_t ...@@ -189,7 +189,7 @@ union float_vec64_t
StaticallyIndexedArray<float_vec32_t, 2> s32; StaticallyIndexedArray<float_vec32_t, 2> s32;
StaticallyIndexedArray<float32_t, 2> v32; StaticallyIndexedArray<float32_t, 2> v32;
StaticallyIndexedArray<float64_t, 1> s64; StaticallyIndexedArray<float64_t, 1> s64;
__host__ __device__ constexpr float_vec64_t() {} __host__ __device__ constexpr float_vec64_t() { s64(Number<0>{}) = 0; }
template <index_t vs> template <index_t vs>
__host__ __device__ auto& GetVector(Number<vs>); __host__ __device__ auto& GetVector(Number<vs>);
...@@ -214,8 +214,8 @@ union float_vec128_t ...@@ -214,8 +214,8 @@ union float_vec128_t
StaticallyIndexedArray<float_vec32_t, 4> s32; StaticallyIndexedArray<float_vec32_t, 4> s32;
StaticallyIndexedArray<float_vec64_t, 2> s64; StaticallyIndexedArray<float_vec64_t, 2> s64;
StaticallyIndexedArray<float128_t, 1> s128; StaticallyIndexedArray<float128_t, 1> s128;
float n[128]; // float n[128];
__host__ __device__ constexpr float_vec128_t() {} __host__ __device__ constexpr float_vec128_t() { s128(Number<0>{}) = 0; }
template <index_t vs> template <index_t vs>
__host__ __device__ auto& GetVector(Number<vs>); __host__ __device__ auto& GetVector(Number<vs>);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment