Commit e6a23d8b authored by Jing Zhang's avatar Jing Zhang
Browse files

add e2

parent a8169558
......@@ -7,20 +7,21 @@
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatA,
typename FloatB,
typename FloatC,
typename BlockMatrixA,
typename BlockMatrixB,
typename ThreadMatrixC,
typename ABlockDesc_E1_K_E2,
typename BBlockDesc_E1_N_Ho_Wo_E2,
typename CThreadDesc_K_N_Ho_Wo,
index_t EPerThreadLoop,
index_t ThreadGemmADataPerRead_K,
index_t ThreadGemmBDataPerRead_W>
index_t ThreadGemmADataPerRead_E2>
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
struct MatrixIndex
{
......@@ -29,36 +30,48 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
index_t w;
};
static constexpr index_t KPerThreadLoop = 4;
static constexpr auto E1 = ABlockDesc_E1_K_E2{}.GetLength(I0);
static constexpr auto K = ABlockDesc_E1_K_E2{}.GetLength(I1);
static constexpr auto E2 = ABlockDesc_E1_K_E2{}.GetLength(I2);
static constexpr auto KPerThread = ThreadMatrixC{}.GetLength(I0);
static constexpr auto HPerThread = ThreadMatrixC{}.GetLength(I2);
static constexpr auto WPerThread = ThreadMatrixC{}.GetLength(I3);
static constexpr auto H = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I2);
static constexpr auto W = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I3);
static constexpr auto KPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I0);
static constexpr auto HPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I2);
static constexpr auto WPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I3);
static constexpr index_t KPerThreadLoop = KPerThread;
static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadLoop>{}));
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadLoop>{}, Number<E2>{}));
static constexpr auto b_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<EPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
static constexpr auto b_thread_mtx_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<EPerThreadLoop>{},
Number<1>{},
Number<HPerThread>{},
Number<WPerThread>{},
Number<E2>{}));
static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<KPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.k * KPerThread)}
: c_thread_begin_mtx_idx_{GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.k * KPerThread, 0)}
{
static_assert(BlockMatrixA::IsKnownAtCompileTime() &&
BlockMatrixB::IsKnownAtCompileTime() &&
ThreadMatrixC::IsKnownAtCompileTime(),
static_assert(ABlockDesc_E1_K_E2::IsKnownAtCompileTime() &&
BBlockDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() &&
CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0),
"wrong! K dimension not consistent\n");
static_assert(
ABlockDesc_E1_K_E2{}.GetLength(I0) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I0) &&
ABlockDesc_E1_K_E2{}.GetLength(I2) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I4),
"wrong! E dimension not consistent\n");
constexpr index_t K = BlockMatrixA{}.GetLength(I1); // A is transposed
constexpr index_t H = BlockMatrixB{}.GetLength(I2);
constexpr index_t W = BlockMatrixB{}.GetLength(I3);
static_assert(E1 % EPerThreadLoop == 0, "");
static_assert(KPerThread % KPerThreadLoop == 0, "");
static_assert(K % KPerThread == 0 && H % HPerThread == 0 && W % WPerThread == 0,
"wrong! Cannot evenly divide work among\n");
......@@ -71,15 +84,15 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
"wrong! wrong blocksize\n");
}
__device__ static constexpr auto GetThreadMatrixCLengths()
__device__ static constexpr auto GetCThreadDesc_K_N_Ho_WoLengths()
{
return Sequence<KPerThread, 1, HPerThread, WPerThread>{};
}
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
__device__ static MatrixIndex GetBeginOfCThreadDesc_K_N_Ho_Wo(index_t thread_id)
{
constexpr index_t HPerBlock = BlockMatrixB{}.GetLength(Number<2>{});
constexpr index_t WPerBlock = BlockMatrixB{}.GetLength(Number<3>{});
constexpr index_t HPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I2);
constexpr index_t WPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I3);
constexpr auto num_w_threads = WPerBlock / WPerThread;
constexpr auto num_h_threads = HPerBlock / HPerThread;
......@@ -100,42 +113,37 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
CThreadBuffer& c_thread_buf) const
{
static_assert(
is_same<remove_cvref_t<typename ABlockBuffer::type>, remove_cvref_t<FloatAB>>::value &&
is_same<remove_cvref_t<typename BThreadBuffer::type>, remove_cvref_t<FloatAB>>::value &&
is_same<remove_cvref_t<typename ABlockBuffer::type>, remove_cvref_t<FloatA>>::value &&
is_same<remove_cvref_t<typename BThreadBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_same<remove_cvref_t<typename CThreadBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type");
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto EPerBlock = a_block_mtx.GetLength(I0);
static_assert(EPerBlock % EPerThreadLoop == 0, "");
static_assert(KPerThread % KPerThreadLoop == 0, "");
constexpr auto a_block_mtx = ABlockDesc_E1_K_E2{};
// thread A buffer for GEMM
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatAB, a_thread_mtx_.GetElementSpaceSize(), true>
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
a_thread_buf;
constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatAB,
FloatAB,
constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
FloatB,
FloatC,
decltype(a_thread_mtx_),
decltype(b_thread_mtx_),
decltype(c_thread_mtx_)>{};
static_for<0, EPerBlock, EPerThreadLoop>{}([&](auto e_begin) {
static_for<0, E1, EPerThreadLoop>{}([&](auto e_begin) {
static_for<0, KPerThread, KPerThreadLoop>{}([&](auto k_begin) {
a_thread_copy_.Run(a_block_mtx,
make_tuple(e_begin, k_begin),
make_tuple(e_begin, k_begin, I0),
a_block_buf,
a_thread_mtx_,
make_tuple(I0, I0),
make_tuple(I0, I0, I0),
a_thread_buf);
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0),
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(e_begin, I0, I0, I0),
make_tuple(e_begin, I0, I0, I0, I0),
c_thread_buf,
make_tuple(k_begin, I0, I0, I0));
});
......@@ -145,21 +153,22 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
template <typename ABlockSliceMoveStepIdx>
__device__ void MoveABlockSliceWindow(const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx)
{
a_thread_copy_.MoveSrcSliceWindow(BlockMatrixA{}, a_block_slice_move_step_idx);
a_thread_copy_.MoveSrcSliceWindow(ABlockDesc_E1_K_E2{}, a_block_slice_move_step_idx);
}
private:
MatrixIndex c_thread_begin_mtx_idx_;
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
BlockMatrixA,
using AThreadCopy =
ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatB,
ABlockDesc_E1_K_E2,
decltype(a_thread_mtx_),
Sequence<EPerThreadLoop, KPerThreadLoop>,
Sequence<0, 1>,
1,
ThreadGemmADataPerRead_K,
1>;
Sequence<EPerThreadLoop, KPerThreadLoop, E2>,
Sequence<0, 1, 2>,
2,
ThreadGemmADataPerRead_E2,
ThreadGemmADataPerRead_E2>;
AThreadCopy a_thread_copy_;
};
......
......@@ -9,16 +9,17 @@ namespace ck {
// C[M, N] += transpose(A[K, M]) * B[K, N]
// Element of matrix can be vectorized data
// Assume:
// 1. AThreadDesc_E_K, BThreadDesc_E_N_Ho_Wo, CThreadDesc_K_N_Ho_Wo are known at compile-time
// 1. AThreadDesc_E1_K_E2, BThreadDesc_E1_N_Ho_Wo_E2, CThreadDesc_K_N_Ho_Wo are known at
// compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
template <typename FloatA,
typename FloatB,
typename FloatC,
typename AThreadDesc_E_K,
typename BThreadDesc_E_N_Ho_Wo,
typename AThreadDesc_E1_K_E2,
typename BThreadDesc_E1_N_Ho_Wo_E2,
typename CThreadDesc_K_N_Ho_Wo,
typename enable_if<AThreadDesc_E_K::IsKnownAtCompileTime() &&
BThreadDesc_E_N_Ho_Wo::IsKnownAtCompileTime() &&
typename enable_if<AThreadDesc_E1_K_E2::IsKnownAtCompileTime() &&
BThreadDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() &&
CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseGemmDlops_km_kn_mn_v3
......@@ -38,8 +39,8 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
COriginIdx)
{
static_assert(AThreadDesc_E_K::IsKnownAtCompileTime() &&
BThreadDesc_E_N_Ho_Wo::IsKnownAtCompileTime() &&
static_assert(AThreadDesc_E1_K_E2::IsKnownAtCompileTime() &&
BThreadDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() &&
CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
......@@ -54,18 +55,19 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type");
constexpr index_t Vec = 2;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto E = AThreadDesc_E_K{}.GetLength(I0);
constexpr auto K = AThreadDesc_E_K{}.GetLength(I1);
constexpr auto E1 = AThreadDesc_E1_K_E2{}.GetLength(I0);
constexpr auto K = AThreadDesc_E1_K_E2{}.GetLength(I1);
constexpr auto E2 = AThreadDesc_E1_K_E2{}.GetLength(I2);
static_assert(E1 == 4 && E2 == 4, "");
constexpr auto H = BThreadDesc_E_N_Ho_Wo{}.GetLength(I2);
constexpr auto W = BThreadDesc_E_N_Ho_Wo{}.GetLength(I3);
constexpr auto H = BThreadDesc_E1_N_Ho_Wo_E2{}.GetLength(I2);
constexpr auto W = BThreadDesc_E1_N_Ho_Wo_E2{}.GetLength(I3);
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
......@@ -74,22 +76,23 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
static_for<0, K, 1>{}([&](auto k) {
static_for<0, H, 1>{}([&](auto h) {
static_for<0, W, 1>{}([&](auto w) {
static_for<0, E, Vec>{}([&](auto e) {
vector_type<FloatA, Vec> a_vec;
vector_type<FloatB, Vec> b_vec;
static_for<0, E1, 1>{}([&](auto e) {
vector_type<FloatA, E2> a_vec;
vector_type<FloatB, E2> b_vec;
static_for<0, Vec, 1>{}([&](auto v) {
constexpr index_t a_offset = AThreadDesc_E_K{}.CalculateOffset(
a_origin_idx + make_tuple(e + v, k));
constexpr index_t b_offset = BThreadDesc_E_N_Ho_Wo{}.CalculateOffset(
b_origin_idx + make_tuple(e + v, 0, h, w));
static_for<0, E2, 1>{}([&](auto v) {
constexpr index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset(
a_origin_idx + make_tuple(e, k, v));
constexpr index_t b_offset =
BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
b_origin_idx + make_tuple(e, 0, h, w, v));
a_vec.template AsType<FloatA>()(v) = a_buf[Number<a_offset>{}];
b_vec.template AsType<FloatB>()(v) = b_buf[Number<b_offset>{}];
});
using a_vector_t = typename vector_type<FloatA, Vec>::type;
using b_vector_t = typename vector_type<FloatB, Vec>::type;
using a_vector_t = typename vector_type<FloatA, E2>::type;
using b_vector_t = typename vector_type<FloatB, E2>::type;
constexpr index_t c_offset = CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
c_origin_idx + make_tuple(k, 0, h, w));
......
......@@ -102,26 +102,27 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
constexpr index_t BlockSize = 64;
constexpr index_t KPerBlock = 16;
constexpr index_t HoPerBlock = 32;
constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 8;
constexpr index_t E1 = 16;
constexpr index_t EPerBlock = 16;
constexpr index_t E1 = 4;
constexpr index_t E2 = 4;
constexpr index_t EPerBlock = 4;
constexpr index_t KPerThread = KPerBlock;
constexpr index_t KPerThread = 4;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = EPerBlock;
constexpr index_t EPerThread = 4;
using ABlockTransferThreadSliceLengths_E0_E1_K = Sequence<1, 4, 1>;
using ABlockTransferThreadClusterLengths_E0_E1_K = Sequence<1, 4, 16>;
using ABlockTransferThreadSliceLengths_E0_E1_K_E2 = Sequence<1, 1, 1, 4>;
using ABlockTransferThreadClusterLengths_E0_E1_K_E2 = Sequence<1, 4, 16, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_E = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = 1;
constexpr index_t ABlockTransferDstScalarPerVector_E2 = 1;
constexpr index_t BThreadTransferSrcScalarPerVector_E = 4;
constexpr index_t BThreadTransferSrcScalarPerVector_E2 = 1;
constexpr index_t CThreadTransferDstScalarPerVector_K = 4;
constexpr index_t CThreadTransferDstScalarPerVector_K = 1;
#endif
constexpr auto conv_driver =
......@@ -131,6 +132,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
TAcc,
TOut,
E1,
E2,
KPerBlock,
HoPerBlock,
WoPerBlock,
......@@ -139,11 +141,11 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
HoPerThread,
WoPerThread,
EPerThread,
ABlockTransferThreadSliceLengths_E0_E1_K,
ABlockTransferThreadClusterLengths_E0_E1_K,
ABlockTransferSrcScalarPerVector_E,
ABlockTransferDstScalarPerVector_K,
BThreadTransferSrcScalarPerVector_E,
ABlockTransferThreadSliceLengths_E0_E1_K_E2,
ABlockTransferThreadClusterLengths_E0_E1_K_E2,
ABlockTransferSrcScalarPerVector_E2,
ABlockTransferDstScalarPerVector_E2,
BThreadTransferSrcScalarPerVector_E2,
CThreadTransferDstScalarPerVector_K>{};
const auto ave_time =
......
......@@ -52,6 +52,7 @@ REPEAT=$6
#./host/driver_online/conv_fwd_driver_online $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1
./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1 16 16 3 3 1080 1920 1 1 1 1 1 1 1 1
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1 16 16 1 1 8 8 1 1 1 1 1 1 1 1
################################################ layout algo verify init log repeat M___ N___ K___
#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment