Commit 10fdada7 authored by Jing Zhang's avatar Jing Zhang
Browse files

rename e0_e1

parent 95228cd7
...@@ -12,14 +12,16 @@ template <index_t BlockSize, ...@@ -12,14 +12,16 @@ template <index_t BlockSize,
typename BlockMatrixA, typename BlockMatrixA,
typename BlockMatrixB, typename BlockMatrixB,
typename ThreadMatrixC, typename ThreadMatrixC,
index_t KPerThread,
index_t HPerThread,
index_t WPerThread,
index_t EPerThreadLoop, index_t EPerThreadLoop,
index_t ThreadGemmADataPerRead_K, index_t ThreadGemmADataPerRead_K,
index_t ThreadGemmBDataPerRead_W> index_t ThreadGemmBDataPerRead_W>
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
{ {
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
struct MatrixIndex struct MatrixIndex
{ {
index_t k; index_t k;
...@@ -27,6 +29,10 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -27,6 +29,10 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
index_t w; index_t w;
}; };
static constexpr auto KPerThread = ThreadMatrixC{}.GetLength(I0);
static constexpr auto HPerThread = ThreadMatrixC{}.GetLength(I2);
static constexpr auto WPerThread = ThreadMatrixC{}.GetLength(I3);
// HACK: fix this @Jing Zhang // HACK: fix this @Jing Zhang
static constexpr index_t KPerThreadSubC = 4; static constexpr index_t KPerThreadSubC = 4;
...@@ -39,16 +45,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -39,16 +45,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple( static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<KPerThreadSubC>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{})); Number<KPerThreadSubC>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
BlockMatrixA,
decltype(a_thread_mtx_),
Sequence<EPerThreadLoop, KPerThreadSubC>,
Sequence<0, 1>,
1,
ThreadGemmADataPerRead_K,
1>;
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3() __device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())}, : c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.k * KPerThread)} a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.k * KPerThread)}
...@@ -58,11 +54,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -58,11 +54,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
ThreadMatrixC::IsKnownAtCompileTime(), ThreadMatrixC::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0), static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0),
"wrong! K dimension not consistent\n"); "wrong! K dimension not consistent\n");
...@@ -88,11 +79,11 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -88,11 +79,11 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) __device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
{ {
constexpr index_t H = BlockMatrixB{}.GetLength(Number<2>{}); constexpr index_t HPerBlock = BlockMatrixB{}.GetLength(Number<2>{});
constexpr index_t W = BlockMatrixB{}.GetLength(Number<3>{}); constexpr index_t WPerBlock = BlockMatrixB{}.GetLength(Number<3>{});
constexpr auto num_w_threads = W / WPerThread; constexpr auto num_w_threads = WPerBlock / WPerThread;
constexpr auto num_h_threads = H / HPerThread; constexpr auto num_h_threads = HPerBlock / HPerThread;
constexpr auto num_hw_threads = num_w_threads * num_h_threads; constexpr auto num_hw_threads = num_w_threads * num_h_threads;
index_t k_thread_id = thread_id / num_hw_threads; index_t k_thread_id = thread_id / num_hw_threads;
...@@ -115,8 +106,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -115,8 +106,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
is_same<remove_cvref_t<typename CThreadBuffer::type>, remove_cvref_t<FloatC>>::value && is_same<remove_cvref_t<typename CThreadBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type"); "wrong! inconsistent type");
constexpr auto I0 = Number<0>{};
constexpr auto a_block_mtx = BlockMatrixA{}; constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto EPerBlock = a_block_mtx.GetLength(I0); constexpr auto EPerBlock = a_block_mtx.GetLength(I0);
...@@ -166,8 +155,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -166,8 +155,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
} }
template <typename ABlockSliceMoveStepIdx> template <typename ABlockSliceMoveStepIdx>
__device__ void MoveASliceWindow(const BlockMatrixA&, __device__ void MoveABlockSliceWindow(const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx)
const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx)
{ {
a_thread_copy_.MoveSrcSliceWindow(BlockMatrixA{}, a_block_slice_move_step_idx); a_thread_copy_.MoveSrcSliceWindow(BlockMatrixA{}, a_block_slice_move_step_idx);
} }
...@@ -175,6 +163,16 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -175,6 +163,16 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
private: private:
MatrixIndex c_thread_begin_mtx_idx_; MatrixIndex c_thread_begin_mtx_idx_;
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
BlockMatrixA,
decltype(a_thread_mtx_),
Sequence<EPerThreadLoop, KPerThreadSubC>,
Sequence<0, 1>,
1,
ThreadGemmADataPerRead_K,
1>;
AThreadCopy a_thread_copy_; AThreadCopy a_thread_copy_;
}; };
......
...@@ -26,14 +26,6 @@ template <typename FloatA, ...@@ -26,14 +26,6 @@ template <typename FloatA,
struct ThreadwiseGemmDlops_km_kn_mn_v3 struct ThreadwiseGemmDlops_km_kn_mn_v3
{ {
__device__ ThreadwiseGemmDlops_km_kn_mn_v3()
{
static_assert(AThreadDesc_E_K::IsKnownAtCompileTime() &&
BThreadDesc_E_N_Ho_Wo::IsKnownAtCompileTime() &&
CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
}
template <typename ABuffer, template <typename ABuffer,
typename AOriginIdx, typename AOriginIdx,
typename BBuffer, typename BBuffer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment