Commit 7cf350d6 authored by Jing Zhang's avatar Jing Zhang
Browse files

global store test

parent 2049ab57
......@@ -617,7 +617,9 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
m_thread_data_on_global % (M2 * M1) / M2,
m_thread_data_on_global % M2,
n_thread_data_on_global))
.Store(c_thread_vec.GetVector(Number<BlkSize>{})[Number<blk_id>{}], p_c_global);
.GlobalStore(c_thread_vec, p_c_global);
//.GlobalStore(c_thread_vec.GetVector(Number<BlkSize>{})[Number<blk_id>{}],
// p_c_global);
});
}
}
......
......@@ -185,6 +185,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
const auto src_coord = mSrcSliceOrigin + to_multi_index(long_vector_data_begin_id);
auto src_buff = GetRegBuffer<SrcData, SrcDataPerRead>();
src_buff.GetVector(Number<SrcDataPerRead>{})(Number<0>{}) =
buffer_vector_load<SrcDataPerRead, SrcDesc::GetElementSpace()>(p_src,
src_coord);
......@@ -243,6 +244,39 @@ struct ThreadwiseGenericTensorSliceCopy_v5
});
}
template <typename SrcData, typename DstData>
__device__ void GlobalStore(SrcData thread_buff, DstData* p_dst)
{
constexpr auto vector_access_dim = Number<DstVectorWriteDim>{};
constexpr auto dst_data_per_access = Number<DstDataPerWrite>{};
static_assert(DstDataPerWrite == 1 || DstDataPerWrite == 2 || DstDataPerWrite == 4, "");
constexpr auto long_vector_size = dst_data_per_access;
constexpr auto long_vector_access_lengths = SliceLengths::Modify(
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
static_ford<decltype(long_vector_access_lengths), DstDimAccessOrder>{}(
[&](auto long_vector_access_id) {
constexpr auto long_vector_data_begin_id = long_vector_access_id.Modify(
Number<vector_access_dim>{},
Number<long_vector_size * long_vector_access_id[vector_access_dim]>{});
constexpr auto buff_off =
ThreadBufferDesc::CalculateOffset(to_multi_index(long_vector_data_begin_id)) /
long_vector_size;
auto src_buff =
thread_buff.GetVector(Number<DstDataPerWrite>{})[Number<buff_off>{}];
const auto dst_coord = mDstSliceOrigin + to_multi_index(long_vector_data_begin_id);
vector_data_store<DstData, DstDataPerWrite>::run(p_dst, src_buff, dst_coord);
});
}
template <typename T, bool PositiveDirection>
__device__ void MoveSrcSliceWindow(const T& step_sizes_,
integral_constant<bool, PositiveDirection>)
......
......@@ -216,10 +216,10 @@ union float_vec32_t
union float_vec64_t
{
StaticallyIndexedArray<float, 64> s1;
StaticallyIndexedArray<float_vec16_t, 4> s16;
StaticallyIndexedArray<float_vec32_t, 2> s32;
StaticallyIndexedArray<float32_t, 2> v32;
StaticallyIndexedArray<float64_t, 1> s64;
// float n[64];
__host__ __device__ constexpr float_vec64_t() { s64(Number<0>{}) = 0; }
template <index_t vs>
......@@ -231,6 +231,12 @@ union float_vec64_t
return s1;
}
template <>
__host__ __device__ auto& GetVector(Number<16>)
{
return s16;
}
template <>
__host__ __device__ auto& GetVector(Number<32>)
{
......@@ -245,7 +251,6 @@ union float_vec128_t
StaticallyIndexedArray<float_vec32_t, 4> s32;
StaticallyIndexedArray<float_vec64_t, 2> s64;
StaticallyIndexedArray<float128_t, 1> s128;
// float n[128];
__host__ __device__ constexpr float_vec128_t() { s128(Number<0>{}) = 0; }
template <index_t vs>
......@@ -327,6 +332,7 @@ constexpr auto GetRegBuffer<float, 128>()
return float_vec128_t{};
}
#if 1
struct c_vec32_4_t
{
union VecType
......@@ -473,6 +479,7 @@ struct c_vec4_1_t
return c;
}
};
#endif
template <class T, index_t N>
struct vector_type
......
......@@ -64,11 +64,11 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
make_native_tensor_descriptor_packed(Sequence<N, K, Ho, Wo>{});
// read params: tunning parameters
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmMPerBlock = 32;
constexpr index_t GemmNPerBlock = 32;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmMPerWave = 16;
constexpr index_t GemmNPerWave = 16;
constexpr index_t GemmKPack = 4;
// read params: dependent parameters
......@@ -83,7 +83,7 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
// A matrix copy
constexpr index_t GemmABlockCopyClusterLengths_GemmK = 4;
constexpr index_t GemmABlockCopyClusterLengths_GemmM = 64;
constexpr index_t GemmABlockCopyClusterLengths_GemmM = 32;
constexpr index_t GemmABlockCopyClusterLengths_GemmKPack = 1;
constexpr index_t GemmABlockCopyThreadSliceLengths_GemmK =
......@@ -114,7 +114,7 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
// B matrix Copy
constexpr index_t GemmBBlockCopyClusterLengths_GemmK = 2;
constexpr index_t GemmBBlockCopyClusterLengths_GemmN = 32;
constexpr index_t GemmBBlockCopyClusterLengths_GemmN = 8;
constexpr index_t GemmBBlockCopyClusterLengths_GemmKPack = 4;
constexpr index_t GemmBBlockCopyThreadSliceLengths_GemmK =
......
......@@ -25,10 +25,10 @@ int main(int argc, char* argv[])
// 1x1, 56x56
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t C = 128;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 1024;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment