Commit 7cf350d6 authored by Jing Zhang's avatar Jing Zhang
Browse files

global store test

parent 2049ab57
...@@ -617,7 +617,9 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2 ...@@ -617,7 +617,9 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
m_thread_data_on_global % (M2 * M1) / M2, m_thread_data_on_global % (M2 * M1) / M2,
m_thread_data_on_global % M2, m_thread_data_on_global % M2,
n_thread_data_on_global)) n_thread_data_on_global))
.Store(c_thread_vec.GetVector(Number<BlkSize>{})[Number<blk_id>{}], p_c_global); .GlobalStore(c_thread_vec, p_c_global);
//.GlobalStore(c_thread_vec.GetVector(Number<BlkSize>{})[Number<blk_id>{}],
// p_c_global);
}); });
} }
} }
......
...@@ -185,6 +185,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -185,6 +185,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
const auto src_coord = mSrcSliceOrigin + to_multi_index(long_vector_data_begin_id); const auto src_coord = mSrcSliceOrigin + to_multi_index(long_vector_data_begin_id);
auto src_buff = GetRegBuffer<SrcData, SrcDataPerRead>(); auto src_buff = GetRegBuffer<SrcData, SrcDataPerRead>();
src_buff.GetVector(Number<SrcDataPerRead>{})(Number<0>{}) = src_buff.GetVector(Number<SrcDataPerRead>{})(Number<0>{}) =
buffer_vector_load<SrcDataPerRead, SrcDesc::GetElementSpace()>(p_src, buffer_vector_load<SrcDataPerRead, SrcDesc::GetElementSpace()>(p_src,
src_coord); src_coord);
...@@ -243,6 +244,39 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -243,6 +244,39 @@ struct ThreadwiseGenericTensorSliceCopy_v5
}); });
} }
template <typename SrcData, typename DstData>
__device__ void GlobalStore(SrcData thread_buff, DstData* p_dst)
{
constexpr auto vector_access_dim = Number<DstVectorWriteDim>{};
constexpr auto dst_data_per_access = Number<DstDataPerWrite>{};
static_assert(DstDataPerWrite == 1 || DstDataPerWrite == 2 || DstDataPerWrite == 4, "");
constexpr auto long_vector_size = dst_data_per_access;
constexpr auto long_vector_access_lengths = SliceLengths::Modify(
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
static_ford<decltype(long_vector_access_lengths), DstDimAccessOrder>{}(
[&](auto long_vector_access_id) {
constexpr auto long_vector_data_begin_id = long_vector_access_id.Modify(
Number<vector_access_dim>{},
Number<long_vector_size * long_vector_access_id[vector_access_dim]>{});
constexpr auto buff_off =
ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id)) /
long_vector_size;
auto src_buff =
thread_buff.GetVector(Number<DstDataPerWrite>{})[Number<buff_off>{}];
const auto dst_coord = mDstSliceOrigin + to_multi_index(long_vector_data_begin_id);
vector_data_store<DstData, DstDataPerWrite>::run(p_dst, src_buff, dst_coord);
});
}
template <typename T, bool PositiveDirection> template <typename T, bool PositiveDirection>
__device__ void MoveSrcSliceWindow(const T& step_sizes_, __device__ void MoveSrcSliceWindow(const T& step_sizes_,
integral_constant<bool, PositiveDirection>) integral_constant<bool, PositiveDirection>)
......
...@@ -216,10 +216,10 @@ union float_vec32_t ...@@ -216,10 +216,10 @@ union float_vec32_t
union float_vec64_t union float_vec64_t
{ {
StaticallyIndexedArray<float, 64> s1; StaticallyIndexedArray<float, 64> s1;
StaticallyIndexedArray<float_vec16_t, 4> s16;
StaticallyIndexedArray<float_vec32_t, 2> s32; StaticallyIndexedArray<float_vec32_t, 2> s32;
StaticallyIndexedArray<float32_t, 2> v32; StaticallyIndexedArray<float32_t, 2> v32;
StaticallyIndexedArray<float64_t, 1> s64; StaticallyIndexedArray<float64_t, 1> s64;
// float n[64];
__host__ __device__ constexpr float_vec64_t() { s64(Number<0>{}) = 0; } __host__ __device__ constexpr float_vec64_t() { s64(Number<0>{}) = 0; }
template <index_t vs> template <index_t vs>
...@@ -231,6 +231,12 @@ union float_vec64_t ...@@ -231,6 +231,12 @@ union float_vec64_t
return s1; return s1;
} }
template <>
__host__ __device__ auto& GetVector(Number<16>)
{
return s16;
}
template <> template <>
__host__ __device__ auto& GetVector(Number<32>) __host__ __device__ auto& GetVector(Number<32>)
{ {
...@@ -245,7 +251,6 @@ union float_vec128_t ...@@ -245,7 +251,6 @@ union float_vec128_t
StaticallyIndexedArray<float_vec32_t, 4> s32; StaticallyIndexedArray<float_vec32_t, 4> s32;
StaticallyIndexedArray<float_vec64_t, 2> s64; StaticallyIndexedArray<float_vec64_t, 2> s64;
StaticallyIndexedArray<float128_t, 1> s128; StaticallyIndexedArray<float128_t, 1> s128;
// float n[128];
__host__ __device__ constexpr float_vec128_t() { s128(Number<0>{}) = 0; } __host__ __device__ constexpr float_vec128_t() { s128(Number<0>{}) = 0; }
template <index_t vs> template <index_t vs>
...@@ -327,6 +332,7 @@ constexpr auto GetRegBuffer<float, 128>() ...@@ -327,6 +332,7 @@ constexpr auto GetRegBuffer<float, 128>()
return float_vec128_t{}; return float_vec128_t{};
} }
#if 1
struct c_vec32_4_t struct c_vec32_4_t
{ {
union VecType union VecType
...@@ -473,6 +479,7 @@ struct c_vec4_1_t ...@@ -473,6 +479,7 @@ struct c_vec4_1_t
return c; return c;
} }
}; };
#endif
template <class T, index_t N> template <class T, index_t N>
struct vector_type struct vector_type
......
...@@ -64,11 +64,11 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw( ...@@ -64,11 +64,11 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
make_native_tensor_descriptor_packed(Sequence<N, K, Ho, Wo>{}); make_native_tensor_descriptor_packed(Sequence<N, K, Ho, Wo>{});
// read params: tunning parameters // read params: tunning parameters
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 32;
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 32;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64; constexpr index_t GemmMPerWave = 16;
constexpr index_t GemmNPerWave = 64; constexpr index_t GemmNPerWave = 16;
constexpr index_t GemmKPack = 4; constexpr index_t GemmKPack = 4;
// read params: dependent parameters // read params: dependent parameters
...@@ -83,7 +83,7 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw( ...@@ -83,7 +83,7 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
// A matrix copy // A matrix copy
constexpr index_t GemmABlockCopyClusterLengths_GemmK = 4; constexpr index_t GemmABlockCopyClusterLengths_GemmK = 4;
constexpr index_t GemmABlockCopyClusterLengths_GemmM = 64; constexpr index_t GemmABlockCopyClusterLengths_GemmM = 32;
constexpr index_t GemmABlockCopyClusterLengths_GemmKPack = 1; constexpr index_t GemmABlockCopyClusterLengths_GemmKPack = 1;
constexpr index_t GemmABlockCopyThreadSliceLengths_GemmK = constexpr index_t GemmABlockCopyThreadSliceLengths_GemmK =
...@@ -114,7 +114,7 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw( ...@@ -114,7 +114,7 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
// B matrix Copy // B matrix Copy
constexpr index_t GemmBBlockCopyClusterLengths_GemmK = 2; constexpr index_t GemmBBlockCopyClusterLengths_GemmK = 2;
constexpr index_t GemmBBlockCopyClusterLengths_GemmN = 32; constexpr index_t GemmBBlockCopyClusterLengths_GemmN = 8;
constexpr index_t GemmBBlockCopyClusterLengths_GemmKPack = 4; constexpr index_t GemmBBlockCopyClusterLengths_GemmKPack = 4;
constexpr index_t GemmBBlockCopyThreadSliceLengths_GemmK = constexpr index_t GemmBBlockCopyThreadSliceLengths_GemmK =
......
...@@ -25,10 +25,10 @@ int main(int argc, char* argv[]) ...@@ -25,10 +25,10 @@ int main(int argc, char* argv[])
// 1x1, 56x56 // 1x1, 56x56
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1024; constexpr index_t C = 128;
constexpr index_t HI = 14; constexpr index_t HI = 14;
constexpr index_t WI = 14; constexpr index_t WI = 14;
constexpr index_t K = 1024; constexpr index_t K = 128;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment