Commit 0486af77 authored by Jing Zhang's avatar Jing Zhang
Browse files

add initial in vector type; move buffer to gridwise gemm

parent 79137e1a
......@@ -78,28 +78,30 @@ struct BlockwiseGenericTensorSliceCopy_v5
return ThreadBufferDesc::GetElementSpace();
}
template <typename BlockSrcData>
__device__ void RunLoadThreadBuffer(const BlockSrcData* p_block_src)
template <typename BlockSrcData, typename ThreadBuffData>
__device__ void RunLoadThreadBuffer(const BlockSrcData* p_block_src,
ThreadBuffData& thread_buff)
{
if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
mThreadwiseCopy.Load(p_block_src);
mThreadwiseCopy.Load(p_block_src, thread_buff);
}
}
template <typename BlockDstData>
__device__ void RunStoreThreadBuffer(BlockDstData* p_block_dst)
template <typename ThreadBuffData, typename BlockDstData>
__device__ void RunStoreThreadBuffer(ThreadBuffData thread_buff, BlockDstData* p_block_dst)
{
if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
mThreadwiseCopy.Store(p_block_dst);
mThreadwiseCopy.Store(thread_buff, p_block_dst);
}
}
template <typename BlockSrcData, typename BlockDstData>
__device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst)
template <typename BlockSrcData, typename BlockDstData, typename ThreadBuffData>
__device__ void
Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst, ThreadBuffData& thread_buff)
{
static_assert(ThreadBufferAddressSpace == AddressSpace::Vgpr,
"wrong! This function use vgpr as its thread "
......@@ -110,8 +112,8 @@ struct BlockwiseGenericTensorSliceCopy_v5
if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
RunLoadThreadBuffer(p_block_src);
RunStoreThreadBuffer(p_block_dst);
RunLoadThreadBuffer(p_block_src, thread_buff);
RunStoreThreadBuffer(thread_buff, p_block_dst);
}
}
......
......@@ -496,10 +496,18 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
constexpr index_t c_thread_size = MPerBlock * NPerBlock / BlockSize;
auto c_thread_vec = GetRegBuffer<AccFloat, c_thread_size>();
using ThreadBufferTypeA =
decltype(GetRegBuffer<ABFloat, a_blockwise_copy.GetThreadBufferSize()>());
using ThreadBufferTypeB =
decltype(GetRegBuffer<ABFloat, b_blockwise_copy.GetThreadBufferSize()>());
ThreadBufferTypeA thread_buff_a;
ThreadBufferTypeB thread_buff_b;
// preload data into LDS
{
a_blockwise_copy.Run(p_a_global, p_a_block);
b_blockwise_copy.Run(p_b_global, p_b_block);
a_blockwise_copy.Run(p_a_global, p_a_block, thread_buff_a);
b_blockwise_copy.Run(p_b_global, p_b_block, thread_buff_b);
}
constexpr auto blockwise_a_copy_src_step = Sequence<0, KPerBlock, 0, 0>{};
......@@ -509,16 +517,12 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
for(index_t k_block_data_begin = 0; k_block_data_begin < K - KPerBlock;
k_block_data_begin += KPerBlock)
{
// ABFloat p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
// load next data from device mem
a_blockwise_copy.MoveSrcSliceWindow(blockwise_a_copy_src_step, True);
b_blockwise_copy.MoveSrcSliceWindow(blockwise_b_copy_src_step, True);
a_blockwise_copy.RunLoadThreadBuffer(p_a_global);
// a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global);
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, thread_buff_a);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, thread_buff_b);
block_sync_lds();
......@@ -535,10 +539,8 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
block_sync_lds();
// store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_block);
// a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block);
b_blockwise_copy.RunStoreThreadBuffer(p_b_block);
a_blockwise_copy.RunStoreThreadBuffer(thread_buff_a, p_a_block);
b_blockwise_copy.RunStoreThreadBuffer(thread_buff_b, p_b_block);
}
// tail
......
......@@ -32,12 +32,6 @@ struct ThreadwiseGenericTensorSliceCopy_v5
{
using ThreadBufferDesc = decltype(make_native_tensor_descriptor_packed(SliceLengths{}));
static constexpr index_t ThreadBufferSize = ThreadBufferDesc::GetElementSpace();
using ThreadBufferType = decltype(GetRegBuffer<float, ThreadBufferSize>());
ThreadBufferType thread_buff;
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
......@@ -167,8 +161,8 @@ struct ThreadwiseGenericTensorSliceCopy_v5
}
};
template <typename SrcData>
__device__ void Load(const SrcData* p_src)
template <typename SrcData, typename DstData>
__device__ void Load(const SrcData* p_src, DstData& thread_buff)
{
constexpr auto vector_access_dim = Number<SrcVectorReadDim>{};
......@@ -192,7 +186,6 @@ struct ThreadwiseGenericTensorSliceCopy_v5
auto src_buff = buffer_vector_load<SrcDataPerRead, SrcDesc::GetElementSpace()>(
p_src, src_coord);
// vector_data_load<SrcData, src_data_per_access>::run(p_src, src_coord);
// store data from the long-vector buffer to dst
constexpr auto buff_off =
......@@ -203,8 +196,8 @@ struct ThreadwiseGenericTensorSliceCopy_v5
});
}
template <typename DstData>
__device__ void Store(DstData* p_dst)
template <typename SrcData, typename DstData>
__device__ void Store(SrcData thread_buff, DstData* p_dst)
{
constexpr auto vector_access_dim = Number<DstVectorWriteDim>{};
......@@ -236,38 +229,6 @@ struct ThreadwiseGenericTensorSliceCopy_v5
});
}
template <typename SrcData, typename DstData>
__device__ void Store(SrcData src, DstData* p_dst)
{
constexpr auto vector_access_dim = Number<DstVectorWriteDim>{};
constexpr auto dst_data_per_access = Number<DstDataPerWrite>{};
static_assert(DstDataPerWrite == 1 || DstDataPerWrite == 2 || DstDataPerWrite == 4, "");
constexpr auto long_vector_size = dst_data_per_access;
constexpr auto long_vector_access_lengths = SliceLengths::Modify(
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
static_ford<decltype(long_vector_access_lengths), DstDimAccessOrder>{}(
[&](auto long_vector_access_id) {
constexpr auto long_vector_data_begin_id = long_vector_access_id.Modify(
Number<vector_access_dim>{},
Number<long_vector_size * long_vector_access_id[vector_access_dim]>{});
constexpr auto buff_off =
ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id)) /
long_vector_size;
auto src_buff = src.GetVector(Number<DstDataPerWrite>{})[Number<buff_off>{}];
const auto dst_coord = mDstSliceOrigin + to_multi_index(long_vector_data_begin_id);
vector_data_store<DstData, DstDataPerWrite>::run(p_dst, src_buff, dst_coord);
});
}
template <typename T, bool PositiveDirection>
__device__ void MoveSrcSliceWindow(const T& step_sizes_,
integral_constant<bool, PositiveDirection>)
......
......@@ -28,7 +28,7 @@ union float_vec2_t
{
StaticallyIndexedArray<float, 2> s1;
StaticallyIndexedArray<float2_t, 1> s2;
__host__ __device__ constexpr float_vec2_t() {}
__host__ __device__ constexpr float_vec2_t() { s2(Number<0>{}) = 0; }
};
union float_vec4_t
......@@ -66,7 +66,7 @@ union float_vec8_t
StaticallyIndexedArray<float2_t, 4> s2;
StaticallyIndexedArray<float4_t, 2> s4;
StaticallyIndexedArray<float8_t, 1> s8;
__host__ __device__ constexpr float_vec8_t() {}
__host__ __device__ constexpr float_vec8_t() { s8(Number<0>{}) = 0; }
template <index_t vs>
__host__ __device__ auto& GetVector(Number<vs>);
......@@ -103,7 +103,7 @@ union float_vec16_t
StaticallyIndexedArray<float4_t, 4> s4;
StaticallyIndexedArray<float8_t, 2> s8;
StaticallyIndexedArray<float16_t, 1> s16;
__host__ __device__ constexpr float_vec16_t() {}
__host__ __device__ constexpr float_vec16_t() { s16(Number<0>{}) = 0; }
template <index_t vs>
__host__ __device__ auto& GetVector(Number<vs>);
......@@ -147,7 +147,7 @@ union float_vec32_t
StaticallyIndexedArray<float_vec8_t, 4> s8;
StaticallyIndexedArray<float_vec16_t, 2> s16;
StaticallyIndexedArray<float32_t, 1> s32;
__host__ __device__ constexpr float_vec32_t() {}
__host__ __device__ constexpr float_vec32_t() { s32(Number<0>{}) = 0; }
template <index_t vs>
__host__ __device__ auto& GetVector(Number<vs>);
......@@ -189,7 +189,7 @@ union float_vec64_t
StaticallyIndexedArray<float_vec32_t, 2> s32;
StaticallyIndexedArray<float32_t, 2> v32;
StaticallyIndexedArray<float64_t, 1> s64;
__host__ __device__ constexpr float_vec64_t() {}
__host__ __device__ constexpr float_vec64_t() { s64(Number<0>{}) = 0; }
template <index_t vs>
__host__ __device__ auto& GetVector(Number<vs>);
......@@ -214,8 +214,8 @@ union float_vec128_t
StaticallyIndexedArray<float_vec32_t, 4> s32;
StaticallyIndexedArray<float_vec64_t, 2> s64;
StaticallyIndexedArray<float128_t, 1> s128;
float n[128];
__host__ __device__ constexpr float_vec128_t() {}
// float n[128];
__host__ __device__ constexpr float_vec128_t() { s128(Number<0>{}) = 0; }
template <index_t vs>
__host__ __device__ auto& GetVector(Number<vs>);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment