Commit 1d6022b1 authored by Jing Zhang's avatar Jing Zhang
Browse files

vector type output

parent 9a54fbd8
......@@ -295,13 +295,11 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
constexpr index_t BlkSize = blockwise_gemm.GetBlkSize();
constexpr index_t NumBlks = blockwise_gemm.GetNumBlks();
// force unrolling the output loop to get ride of scratches
#pragma unroll
for(index_t i = 0; i < NumBlks; ++i)
{
// force unrolling the output loop to get ride of scratches
static_for<0, NumBlks, 1>{}([&](auto blk_id) {
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = blockwise_gemm.GetBeginOfThreadMatrixC(i);
const auto c_thread_mtx_on_block = blockwise_gemm.GetBeginOfThreadMatrixC(blk_id);
const index_t m_thread_data_on_global =
m_block_data_on_global + c_thread_mtx_on_block.row;
......@@ -309,24 +307,26 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
const index_t n_thread_data_on_global =
n_block_data_on_global + c_thread_mtx_on_block.col;
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_g_m0_m1_m2_n_thread_desc),
decltype(c_g_m0_m1_m2_n_global_desc),
CThreadCopySliceLengths,
arithmetic_sequence_gen<0, 5, 1>::type,
4,
1,
1,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryOp>(
ThreadwiseGenericTensorSliceCopy_v5<decltype(c_g_m0_m1_m2_n_thread_desc),
decltype(c_g_m0_m1_m2_n_global_desc),
CThreadCopySliceLengths,
arithmetic_sequence_gen<0, 5, 1>::type,
arithmetic_sequence_gen<0, 5, 1>::type,
4,
4,
1,
1,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryOp>(
make_multi_index(0, 0, 0, 0, 0),
make_multi_index(g_block_data_on_global,
m_thread_data_on_global / (M2 * M1),
m_thread_data_on_global % (M2 * M1) / M2,
m_thread_data_on_global % M2,
n_thread_data_on_global))
.Run(c_thread_vec.n + i * BlkSize, p_c_global);
}
.Run(c_thread_vec.At(Number<16>{})[Number<blk_id>{}], p_c_global);
});
}
}
};
......
......@@ -236,6 +236,38 @@ struct ThreadwiseGenericTensorSliceCopy_v5
});
}
template <typename SrcData, typename DstData>
__device__ void Run(SrcData src, DstData* p_dst)
{
constexpr auto vector_access_dim = Number<DstVectorWriteDim>{};
constexpr auto dst_data_per_access = Number<DstDataPerWrite>{};
static_assert(DstDataPerWrite == 1 || DstDataPerWrite == 2 || DstDataPerWrite == 4, "");
constexpr auto long_vector_size = dst_data_per_access;
constexpr auto long_vector_access_lengths = SliceLengths::Modify(
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
static_ford<decltype(long_vector_access_lengths), DstDimAccessOrder>{}(
[&](auto long_vector_access_id) {
constexpr auto long_vector_data_begin_id = long_vector_access_id.Modify(
Number<vector_access_dim>{},
Number<long_vector_size * long_vector_access_id[vector_access_dim]>{});
constexpr auto buff_off =
ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id)) /
long_vector_size;
auto src_buff = src.At(Number<DstDataPerWrite>{})[Number<buff_off>{}];
const auto dst_coord = mDstSliceOrigin + to_multi_index(long_vector_data_begin_id);
vector_data_store<DstData, DstDataPerWrite>::run(p_dst, src_buff, dst_coord);
});
}
template <typename T, bool PositiveDirection>
__device__ void MoveSrcSliceWindow(const T& step_sizes_,
integral_constant<bool, PositiveDirection>)
......
......@@ -188,7 +188,6 @@ union float_vec64_t
StaticallyIndexedArray<float, 64> s1;
StaticallyIndexedArray<float32_t, 2> s32;
StaticallyIndexedArray<float64_t, 1> s64;
float n[64];
__host__ __device__ constexpr float_vec64_t() {}
template <index_t vs>
......@@ -210,10 +209,10 @@ union float_vec64_t
union float_vec128_t
{
StaticallyIndexedArray<float, 64> s1;
StaticallyIndexedArray<float_vec16_t, 8> s16;
StaticallyIndexedArray<float32_t, 4> s32;
StaticallyIndexedArray<float_vec64_t, 2> s64;
StaticallyIndexedArray<float128_t, 1> s128;
float n[128];
__host__ __device__ constexpr float_vec128_t() {}
template <index_t vs>
......@@ -225,6 +224,12 @@ union float_vec128_t
return s1;
}
template <>
__host__ __device__ auto& At(Number<16>)
{
return s16;
}
template <>
__host__ __device__ auto& At(Number<32>)
{
......@@ -238,8 +243,6 @@ union float_vec128_t
}
};
template <typename T, index_t BufferSize>
constexpr auto GetRegBuffer();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment