Commit 10fdada7 authored by Jing Zhang's avatar Jing Zhang
Browse files

rename e0_e1

parent 95228cd7
......@@ -12,14 +12,16 @@ template <index_t BlockSize,
typename BlockMatrixA,
typename BlockMatrixB,
typename ThreadMatrixC,
index_t KPerThread,
index_t HPerThread,
index_t WPerThread,
index_t EPerThreadLoop,
index_t ThreadGemmADataPerRead_K,
index_t ThreadGemmBDataPerRead_W>
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
struct MatrixIndex
{
index_t k;
......@@ -27,6 +29,10 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
index_t w;
};
static constexpr auto KPerThread = ThreadMatrixC{}.GetLength(I0);
static constexpr auto HPerThread = ThreadMatrixC{}.GetLength(I2);
static constexpr auto WPerThread = ThreadMatrixC{}.GetLength(I3);
// HACK: fix this @Jing Zhang
static constexpr index_t KPerThreadSubC = 4;
......@@ -39,16 +45,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<KPerThreadSubC>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
BlockMatrixA,
decltype(a_thread_mtx_),
Sequence<EPerThreadLoop, KPerThreadSubC>,
Sequence<0, 1>,
1,
ThreadGemmADataPerRead_K,
1>;
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.k * KPerThread)}
......@@ -58,11 +54,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
ThreadMatrixC::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0),
"wrong! K dimension not consistent\n");
......@@ -88,11 +79,11 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
{
constexpr index_t H = BlockMatrixB{}.GetLength(Number<2>{});
constexpr index_t W = BlockMatrixB{}.GetLength(Number<3>{});
constexpr index_t HPerBlock = BlockMatrixB{}.GetLength(Number<2>{});
constexpr index_t WPerBlock = BlockMatrixB{}.GetLength(Number<3>{});
constexpr auto num_w_threads = W / WPerThread;
constexpr auto num_h_threads = H / HPerThread;
constexpr auto num_w_threads = WPerBlock / WPerThread;
constexpr auto num_h_threads = HPerBlock / HPerThread;
constexpr auto num_hw_threads = num_w_threads * num_h_threads;
index_t k_thread_id = thread_id / num_hw_threads;
......@@ -115,8 +106,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
is_same<remove_cvref_t<typename CThreadBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{};
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto EPerBlock = a_block_mtx.GetLength(I0);
......@@ -166,8 +155,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
}
template <typename ABlockSliceMoveStepIdx>
__device__ void MoveASliceWindow(const BlockMatrixA&,
const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx)
__device__ void MoveABlockSliceWindow(const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx)
{
a_thread_copy_.MoveSrcSliceWindow(BlockMatrixA{}, a_block_slice_move_step_idx);
}
......@@ -175,6 +163,16 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
private:
MatrixIndex c_thread_begin_mtx_idx_;
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
BlockMatrixA,
decltype(a_thread_mtx_),
Sequence<EPerThreadLoop, KPerThreadSubC>,
Sequence<0, 1>,
1,
ThreadGemmADataPerRead_K,
1>;
AThreadCopy a_thread_copy_;
};
......
......@@ -26,14 +26,6 @@ template <typename FloatA,
struct ThreadwiseGemmDlops_km_kn_mn_v3
{
__device__ ThreadwiseGemmDlops_km_kn_mn_v3()
{
static_assert(AThreadDesc_E_K::IsKnownAtCompileTime() &&
BThreadDesc_E_N_Ho_Wo::IsKnownAtCompileTime() &&
CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
}
template <typename ABuffer,
typename AOriginIdx,
typename BBuffer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment