Commit 15232a0d authored by Jing Zhang's avatar Jing Zhang
Browse files

debug matA vector load

parent 7db237ba
......@@ -413,7 +413,7 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
constexpr auto a_g_k_m_kpack_block_desc = make_native_tensor_descriptor_aligned(
Sequence<1, KPerBlock, MPerBlock, KPack>{}, Number<max_align>{});
auto a_blockwise_copy = BlockwiseGenericTensorSliceCopy_v5<
auto a_blockwise_copy = BlockwiseGenericTensorSliceCopy_v4<
BlockSize,
decltype(a_g_k_m_kpack_global_desc),
decltype(a_g_k_m_kpack_block_desc),
......@@ -509,11 +509,15 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
for(index_t k_block_data_begin = 0; k_block_data_begin < K - KPerBlock;
k_block_data_begin += KPerBlock)
{
ABFloat p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
// load next data from device mem
a_blockwise_copy.MoveSrcSliceWindow(blockwise_a_copy_src_step, True);
b_blockwise_copy.MoveSrcSliceWindow(blockwise_b_copy_src_step, True);
a_blockwise_copy.RunLoadThreadBuffer(p_a_global);
// a_blockwise_copy.RunLoadThreadBuffer(p_a_global);
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global);
block_sync_lds();
......@@ -531,7 +535,9 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
block_sync_lds();
// store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_block);
// a_blockwise_copy.RunStoreThreadBuffer(p_a_block);
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block);
b_blockwise_copy.RunStoreThreadBuffer(p_b_block);
}
......
......@@ -95,8 +95,41 @@ struct ThreadwiseGenericTensorSliceCopy_v5
*reinterpret_cast<SrcData*>(&p_dst[dst_offset]) = src_data;
}
template <typename SrcData, index_t SrcDataPerAccess>
struct vector_data_load;
template <>
struct vector_data_load<float, 1>
{
template <typename SrcCoord>
__device__ static auto run(const float* p_src, const SrcCoord src_coord_begin)
{
return load_data<float>(p_src, src_coord_begin.GetOffset());
}
};
template <>
struct vector_data_load<float, 2>
{
template <typename SrcCoord>
__device__ static auto run(const float* p_src, const SrcCoord src_coord_begin)
{
return load_data<float2_t>(p_src, src_coord_begin.GetOffset());
}
};
template <>
struct vector_data_load<float, 4>
{
template <typename SrcCoord>
__device__ static auto run(const float* p_src, const SrcCoord src_coord_begin)
{
return load_data<float4_t>(p_src, src_coord_begin.GetOffset());
}
};
template <index_t SrcDataPerAccess, index_t SrcDataRange, typename SrcData, typename SrcCoord>
__device__ static auto vector_data_load(const SrcData* p_src, const SrcCoord src_coord_begin)
__device__ static auto buffer_vector_load(const SrcData* p_src, const SrcCoord src_coord_begin)
{
auto src_offset = src_coord_begin.GetOffset();
return amd_buffer_load<SrcData, SrcDataPerAccess>(p_src, src_offset, true, SrcDataRange);
......@@ -162,7 +195,8 @@ struct ThreadwiseGenericTensorSliceCopy_v5
const auto src_coord = mSrcSliceOrigin + to_multi_index(long_vector_data_begin_id);
auto src_buff =
vector_data_load<SrcDataPerRead, SrcDesc::GetElementSpace()>(p_src, src_coord);
vector_data_load<SrcData, src_data_per_access>::run(p_src, src_coord);
// buffer_vector_load<SrcDataPerRead, SrcDesc::GetElementSpace()>(p_src, src_coord);
// store data from the long-vector buffer to dst
constexpr auto buff_off =
......
......@@ -109,8 +109,8 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
using GemmABlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [GemmG, GemmM, GemmK, GemmKPack]
using GemmABlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [GemmG, GemmK, GemmM, GemmKPack]
constexpr index_t GemmABlockCopySrcDataPerRead_GemmKPack = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmKPack = 1;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmKPack = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmKPack = 4;
// B matrix Copy
constexpr index_t GemmBBlockCopyClusterLengths_GemmK = 4;
......
......@@ -29,8 +29,8 @@ int main(int argc, char* argv[])
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment