Commit 1d6022b1 authored by Jing Zhang's avatar Jing Zhang
Browse files

vector type output

parent 9a54fbd8
...@@ -295,13 +295,11 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2 ...@@ -295,13 +295,11 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
constexpr index_t BlkSize = blockwise_gemm.GetBlkSize(); constexpr index_t BlkSize = blockwise_gemm.GetBlkSize();
constexpr index_t NumBlks = blockwise_gemm.GetNumBlks(); constexpr index_t NumBlks = blockwise_gemm.GetNumBlks();
// force unrolling the output loop to get ride of scratches // force unrolling the output loop to get ride of scratches
#pragma unroll static_for<0, NumBlks, 1>{}([&](auto blk_id) {
for(index_t i = 0; i < NumBlks; ++i)
{
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = blockwise_gemm.GetBeginOfThreadMatrixC(i); const auto c_thread_mtx_on_block = blockwise_gemm.GetBeginOfThreadMatrixC(blk_id);
const index_t m_thread_data_on_global = const index_t m_thread_data_on_global =
m_block_data_on_global + c_thread_mtx_on_block.row; m_block_data_on_global + c_thread_mtx_on_block.row;
...@@ -309,24 +307,26 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2 ...@@ -309,24 +307,26 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
const index_t n_thread_data_on_global = const index_t n_thread_data_on_global =
n_block_data_on_global + c_thread_mtx_on_block.col; n_block_data_on_global + c_thread_mtx_on_block.col;
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_g_m0_m1_m2_n_thread_desc), ThreadwiseGenericTensorSliceCopy_v5<decltype(c_g_m0_m1_m2_n_thread_desc),
decltype(c_g_m0_m1_m2_n_global_desc), decltype(c_g_m0_m1_m2_n_global_desc),
CThreadCopySliceLengths, CThreadCopySliceLengths,
arithmetic_sequence_gen<0, 5, 1>::type, arithmetic_sequence_gen<0, 5, 1>::type,
4, arithmetic_sequence_gen<0, 5, 1>::type,
1, 4,
1, 4,
AddressSpace::Vgpr, 1,
AddressSpace::Global, 1,
CGlobalMemoryOp>( AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryOp>(
make_multi_index(0, 0, 0, 0, 0), make_multi_index(0, 0, 0, 0, 0),
make_multi_index(g_block_data_on_global, make_multi_index(g_block_data_on_global,
m_thread_data_on_global / (M2 * M1), m_thread_data_on_global / (M2 * M1),
m_thread_data_on_global % (M2 * M1) / M2, m_thread_data_on_global % (M2 * M1) / M2,
m_thread_data_on_global % M2, m_thread_data_on_global % M2,
n_thread_data_on_global)) n_thread_data_on_global))
.Run(c_thread_vec.n + i * BlkSize, p_c_global); .Run(c_thread_vec.At(Number<16>{})[Number<blk_id>{}], p_c_global);
} });
} }
} }
}; };
......
...@@ -236,6 +236,38 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -236,6 +236,38 @@ struct ThreadwiseGenericTensorSliceCopy_v5
}); });
} }
template <typename SrcData, typename DstData>
__device__ void Run(SrcData src, DstData* p_dst)
{
constexpr auto vector_access_dim = Number<DstVectorWriteDim>{};
constexpr auto dst_data_per_access = Number<DstDataPerWrite>{};
static_assert(DstDataPerWrite == 1 || DstDataPerWrite == 2 || DstDataPerWrite == 4, "");
constexpr auto long_vector_size = dst_data_per_access;
constexpr auto long_vector_access_lengths = SliceLengths::Modify(
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
static_ford<decltype(long_vector_access_lengths), DstDimAccessOrder>{}(
[&](auto long_vector_access_id) {
constexpr auto long_vector_data_begin_id = long_vector_access_id.Modify(
Number<vector_access_dim>{},
Number<long_vector_size * long_vector_access_id[vector_access_dim]>{});
constexpr auto buff_off =
ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id)) /
long_vector_size;
auto src_buff = src.At(Number<DstDataPerWrite>{})[Number<buff_off>{}];
const auto dst_coord = mDstSliceOrigin + to_multi_index(long_vector_data_begin_id);
vector_data_store<DstData, DstDataPerWrite>::run(p_dst, src_buff, dst_coord);
});
}
template <typename T, bool PositiveDirection> template <typename T, bool PositiveDirection>
__device__ void MoveSrcSliceWindow(const T& step_sizes_, __device__ void MoveSrcSliceWindow(const T& step_sizes_,
integral_constant<bool, PositiveDirection>) integral_constant<bool, PositiveDirection>)
......
...@@ -188,7 +188,6 @@ union float_vec64_t ...@@ -188,7 +188,6 @@ union float_vec64_t
StaticallyIndexedArray<float, 64> s1; StaticallyIndexedArray<float, 64> s1;
StaticallyIndexedArray<float32_t, 2> s32; StaticallyIndexedArray<float32_t, 2> s32;
StaticallyIndexedArray<float64_t, 1> s64; StaticallyIndexedArray<float64_t, 1> s64;
float n[64];
__host__ __device__ constexpr float_vec64_t() {} __host__ __device__ constexpr float_vec64_t() {}
template <index_t vs> template <index_t vs>
...@@ -210,10 +209,10 @@ union float_vec64_t ...@@ -210,10 +209,10 @@ union float_vec64_t
union float_vec128_t union float_vec128_t
{ {
StaticallyIndexedArray<float, 64> s1; StaticallyIndexedArray<float, 64> s1;
StaticallyIndexedArray<float_vec16_t, 8> s16;
StaticallyIndexedArray<float32_t, 4> s32; StaticallyIndexedArray<float32_t, 4> s32;
StaticallyIndexedArray<float_vec64_t, 2> s64; StaticallyIndexedArray<float_vec64_t, 2> s64;
StaticallyIndexedArray<float128_t, 1> s128; StaticallyIndexedArray<float128_t, 1> s128;
float n[128];
__host__ __device__ constexpr float_vec128_t() {} __host__ __device__ constexpr float_vec128_t() {}
template <index_t vs> template <index_t vs>
...@@ -225,6 +224,12 @@ union float_vec128_t ...@@ -225,6 +224,12 @@ union float_vec128_t
return s1; return s1;
} }
template <>
__host__ __device__ auto& At(Number<16>)
{
return s16;
}
template <> template <>
__host__ __device__ auto& At(Number<32>) __host__ __device__ auto& At(Number<32>)
{ {
...@@ -238,8 +243,6 @@ union float_vec128_t ...@@ -238,8 +243,6 @@ union float_vec128_t
} }
}; };
template <typename T, index_t BufferSize> template <typename T, index_t BufferSize>
constexpr auto GetRegBuffer(); constexpr auto GetRegBuffer();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment