"...composable_kernel.git" did not exist on "304802889728707c2a162322ce18686169e732ea"
Commit 159559c1 authored by Jing Zhang's avatar Jing Zhang
Browse files

tuning

parent 7b002f23
......@@ -141,12 +141,14 @@ struct BlockwiseGenericTensorSliceCopy_v5
private:
using ThreadBufferDesc = decltype(make_native_tensor_descriptor_packed(ThreadSliceLengths{}));
using ThreadBufferType = decltype(GetRegBuffer<float, GetThreadBufferSize()>());
using ThreadwiseCopy = ThreadwiseGenericTensorSliceCopy_v5<BlockSrcDesc,
BlockDstDesc,
ThreadSliceLengths,
SrcDimAccessOrder,
DstDimAccessOrder,
ThreadBufferType,
SrcVectoReadDim,
DstVectorWriteDim,
SrcDataPerRead,
......
......@@ -19,6 +19,7 @@ template <typename SrcDesc,
typename SliceLengths,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
typename BufferVectorType,
index_t SrcVectorReadDim,
index_t DstVectorWriteDim,
index_t SrcDataPerRead,
......@@ -60,7 +61,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
SliceLengths{}[DstVectorWriteDim] % math::lcm(SrcDataPerRead, DstDataPerWrite) == 0,
"wrong! cannot evenly divide");
static_assert(ThreadBufferSize == 4, "");
static_assert(ThreadBufferSize == 8 || ThreadBufferSize == 16, "");
}
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v5()
......@@ -252,7 +253,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
.Else([&](auto) { mDstSliceOrigin -= step_sizes; });
}
float_vec4_t thread_buff;
BufferVectorType thread_buff;
private:
SrcCoord mSrcSliceOrigin;
......
......@@ -36,22 +36,22 @@ union float_vec4_t
StaticallyIndexedArray<float4_t, 1> s4;
__host__ __device__ constexpr float_vec4_t() {}
template<index_t vs>
template <index_t vs>
__host__ __device__ auto& At();
template<>
template <>
__host__ __device__ auto& At<1>()
{
return s1;
}
template<>
template <>
__host__ __device__ auto& At<2>()
{
return s2;
}
template<>
template <>
__host__ __device__ auto& At<4>()
{
return s4;
......@@ -65,9 +65,98 @@ union float_vec8_t
StaticallyIndexedArray<float4_t, 2> s4;
StaticallyIndexedArray<float8_t, 1> s8;
__host__ __device__ constexpr float_vec8_t() {}
template <index_t vs>
__host__ __device__ auto& At();
template <>
__host__ __device__ auto& At<1>()
{
return s1;
}
template <>
__host__ __device__ auto& At<2>()
{
return s2;
}
template <>
__host__ __device__ auto& At<4>()
{
return s4;
}
template <>
__host__ __device__ auto& At<8>()
{
return s8;
}
};
union float_vec16_t
{
StaticallyIndexedArray<float, 16> s1;
StaticallyIndexedArray<float2_t, 8> s2;
StaticallyIndexedArray<float4_t, 4> s4;
StaticallyIndexedArray<float8_t, 2> s8;
StaticallyIndexedArray<float16_t, 1> s16;
__host__ __device__ constexpr float_vec16_t() {}
template <index_t vs>
__host__ __device__ auto& At();
template <>
__host__ __device__ auto& At<1>()
{
return s1;
}
template <>
__host__ __device__ auto& At<2>()
{
return s2;
}
template <>
__host__ __device__ auto& At<4>()
{
return s4;
}
template <>
__host__ __device__ auto& At<8>()
{
return s8;
}
template <>
__host__ __device__ auto& At<16>()
{
return s16;
}
};
template <typename T, index_t BufferSize>
constexpr auto GetRegBuffer();
template <>
constexpr auto GetRegBuffer<float, 4>()
{
return float_vec4_t{};
}
template <>
constexpr auto GetRegBuffer<float, 8>()
{
return float_vec8_t{};
}
template <>
constexpr auto GetRegBuffer<float, 16>()
{
return float_vec16_t{};
}
struct c_vec32_4_t
{
......
......@@ -65,11 +65,11 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
// read params: tunning parameters
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 128;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 1;
constexpr index_t GemmKPack = 4;
// read params: dependent parameters
constexpr index_t BlockSize = 256;
......@@ -108,8 +108,8 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
using GemmABlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [GemmG, GemmM, GemmK, GemmKPack]
using GemmABlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [GemmG, GemmK, GemmM, GemmKPack]
constexpr index_t GemmABlockCopySrcDataPerRead_GemmKPack = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmKPack = 1;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmKPack = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmKPack = 4;
// B matrix Copy
constexpr index_t GemmBBlockCopyClusterLengths_GemmK = 4;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment