Commit fc148cef authored by Chao Liu's avatar Chao Liu
Browse files

added back pipelined 2x2 to blockwise gemm

parent 0374f8de
...@@ -11,21 +11,21 @@ namespace ck { ...@@ -11,21 +11,21 @@ namespace ck {
// A and B are visable to the whole block, C is distributed among each thread // A and B are visable to the whole block, C is distributed among each thread
// Assume: // Assume:
// 1. A: // 1. A:
// 1. BlockMatrixA is known at compile-time // 1. ABlockDesc is known at compile-time
// 2. ABlockBuffer is DynamicBuffer // 2. ABlockBuffer is DynamicBuffer
// 2. B: // 2. B:
// 1. BlockMatrixA is known at compile-time // 1. ABlockDesc is known at compile-time
// 2. BBlockBuffer is DynamicBuffer // 2. BBlockBuffer is DynamicBuffer
// 3. C: // 3. C:
// 1. ThreadMatrixC is known at compile-time // 1. CThreadDesc is known at compile-time
// 2. CThreadBuffer is StaticBuffer // 2. CThreadBuffer is StaticBuffer
template <index_t BlockSize, template <index_t BlockSize,
typename FloatA, typename FloatA,
typename FloatB, typename FloatB,
typename FloatC, typename FloatC,
typename BlockMatrixA, typename ABlockDesc,
typename BlockMatrixB, typename BBlockDesc,
typename ThreadMatrixC, typename CThreadDesc,
index_t MPerThreadSubC, index_t MPerThreadSubC,
index_t NPerThreadSubC, index_t NPerThreadSubC,
index_t KPerThreadLoop, index_t KPerThreadLoop,
...@@ -35,9 +35,9 @@ template <index_t BlockSize, ...@@ -35,9 +35,9 @@ template <index_t BlockSize,
index_t NLevel1ThreadCluster, index_t NLevel1ThreadCluster,
index_t ThreadGemmADataPerRead_M, index_t ThreadGemmADataPerRead_M,
index_t ThreadGemmBDataPerRead_N, index_t ThreadGemmBDataPerRead_N,
typename std::enable_if<BlockMatrixA::IsKnownAtCompileTime() && typename std::enable_if<ABlockDesc::IsKnownAtCompileTime() &&
BlockMatrixB::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime() &&
ThreadMatrixC::IsKnownAtCompileTime(), CThreadDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
{ {
...@@ -49,13 +49,12 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -49,13 +49,12 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
public: public:
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v1r1() __device__ BlockwiseGemm_km_kn_m0m1n0n1_v1r1()
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())}, : c_thread_begin_mtx_idx_{GetBeginOfCThreadDesc(get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.row)}, a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.row)},
b_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.col)} b_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.col)}
{ {
static_assert(BlockMatrixA::IsKnownAtCompileTime() && static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime() &&
BlockMatrixB::IsKnownAtCompileTime() && CThreadDesc::IsKnownAtCompileTime(),
ThreadMatrixC::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -66,27 +65,27 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -66,27 +65,27 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n"); static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0), static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
"wrong! K dimension not consistent"); "wrong! K dimension not consistent");
constexpr index_t M = BlockMatrixA{}.GetLength(I1); // A is transposed constexpr index_t M = ABlockDesc{}.GetLength(I1); // A is transposed
constexpr index_t N = BlockMatrixB{}.GetLength(I1); constexpr index_t N = BBlockDesc{}.GetLength(I1);
static_assert(M % (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster) == 0 && static_assert(M % (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster) == 0 &&
N % (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster) == 0, N % (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster) == 0,
"wrong! Cannot evenly divide work among"); "wrong! Cannot evenly divide work among");
static_assert(ThreadMatrixC{}.GetLength(I0) == GetThreadMatrixCLengths()[I0] && static_assert(CThreadDesc{}.GetLength(I0) == GetCThreadDescLengths()[I0] &&
ThreadMatrixC{}.GetLength(I1) == GetThreadMatrixCLengths()[I1], CThreadDesc{}.GetLength(I1) == GetCThreadDescLengths()[I1],
"wrong! ThreadMatrixC lengths is wrong"); "wrong! CThreadDesc lengths is wrong");
} }
__device__ static constexpr auto GetThreadMatrixCLengths() __device__ static constexpr auto GetCThreadDescLengths()
{ {
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr index_t M = BlockMatrixA{}.GetLength(I1); // A is transposed constexpr index_t M = ABlockDesc{}.GetLength(I1); // A is transposed
constexpr index_t N = BlockMatrixB{}.GetLength(I1); constexpr index_t N = BBlockDesc{}.GetLength(I1);
constexpr index_t MRepeat = constexpr index_t MRepeat =
M / (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster); M / (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster);
...@@ -96,7 +95,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -96,7 +95,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
return Sequence<MRepeat * MPerThreadSubC, NRepeat * NPerThreadSubC>{}; return Sequence<MRepeat * MPerThreadSubC, NRepeat * NPerThreadSubC>{};
} }
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) __device__ static MatrixIndex GetBeginOfCThreadDesc(index_t thread_id)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -130,9 +129,9 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -130,9 +129,9 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto a_block_mtx = BlockMatrixA{}; constexpr auto a_block_mtx = ABlockDesc{};
constexpr auto b_block_mtx = BlockMatrixB{}; constexpr auto b_block_mtx = BBlockDesc{};
constexpr auto c_thread_mtx_desc = ThreadMatrixC{}; constexpr auto c_thread_mtx_desc = CThreadDesc{};
constexpr auto K = a_block_mtx.GetLength(I0); constexpr auto K = a_block_mtx.GetLength(I0);
...@@ -174,7 +173,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -174,7 +173,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
decltype(c_thread_sub_mtx)>{}; decltype(c_thread_sub_mtx)>{};
// read A_sub_0 // read A_sub_0
a_thread_copy_.Run(BlockMatrixA{}, a_thread_copy_.Run(ABlockDesc{},
make_tuple(I0, I0), make_tuple(I0, I0),
a_block_buf, a_block_buf,
a_thread_mtx_desc_, a_thread_mtx_desc_,
...@@ -182,7 +181,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -182,7 +181,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
a_thread_buf); a_thread_buf);
// read B_sub_0 // read B_sub_0
b_thread_copy_.Run(BlockMatrixB{}, b_thread_copy_.Run(BBlockDesc{},
make_tuple(I0, I0), make_tuple(I0, I0),
b_block_buf, b_block_buf,
b_thread_mtx_desc_, b_thread_mtx_desc_,
...@@ -190,7 +189,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -190,7 +189,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
b_thread_buf); b_thread_buf);
// read B_sub_1 // read B_sub_1
b_thread_copy_.Run(BlockMatrixB{}, b_thread_copy_.Run(BBlockDesc{},
make_tuple(I0, Number<NPerLevel1Cluster>{}), make_tuple(I0, Number<NPerLevel1Cluster>{}),
b_block_buf, b_block_buf,
b_thread_mtx_desc_, b_thread_mtx_desc_,
...@@ -198,7 +197,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -198,7 +197,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
b_thread_buf); b_thread_buf);
// read A_sub_1 // read A_sub_1
a_thread_copy_.Run(BlockMatrixA{}, a_thread_copy_.Run(ABlockDesc{},
make_tuple(I0, Number<MPerLevel1Cluster>{}), make_tuple(I0, Number<MPerLevel1Cluster>{}),
a_block_buf, a_block_buf,
a_thread_mtx_desc_, a_thread_mtx_desc_,
...@@ -224,7 +223,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -224,7 +223,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// loop over rest of k // loop over rest of k
static_for<KPerThreadLoop, K, KPerThreadLoop>{}([&](auto k) { static_for<KPerThreadLoop, K, KPerThreadLoop>{}([&](auto k) {
// read A_sub_0 // read A_sub_0
a_thread_copy_.Run(BlockMatrixA{}, a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I0), make_tuple(k, I0),
a_block_buf, a_block_buf,
a_thread_mtx_desc_, a_thread_mtx_desc_,
...@@ -240,7 +239,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -240,7 +239,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
make_tuple(Number<MPerThreadSubC>{}, I0)); make_tuple(Number<MPerThreadSubC>{}, I0));
// read B_sub_0 // read B_sub_0
b_thread_copy_.Run(BlockMatrixB{}, b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I0), make_tuple(k, I0),
b_block_buf, b_block_buf,
b_thread_mtx_desc_, b_thread_mtx_desc_,
...@@ -256,7 +255,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -256,7 +255,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{})); make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}));
// read B_sub_1 // read B_sub_1
b_thread_copy_.Run(BlockMatrixB{}, b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, Number<NPerLevel1Cluster>{}), make_tuple(k, Number<NPerLevel1Cluster>{}),
b_block_buf, b_block_buf,
b_thread_mtx_desc_, b_thread_mtx_desc_,
...@@ -264,7 +263,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -264,7 +263,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
b_thread_buf); b_thread_buf);
// read A_sub_1 // read A_sub_1
a_thread_copy_.Run(BlockMatrixA{}, a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, Number<MPerLevel1Cluster>{}), make_tuple(k, Number<MPerLevel1Cluster>{}),
a_block_buf, a_block_buf,
a_thread_mtx_desc_, a_thread_mtx_desc_,
...@@ -314,8 +313,8 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -314,8 +313,8 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr index_t MPerThread = ThreadMatrixC{}.GetLength(I0); constexpr index_t MPerThread = CThreadDesc{}.GetLength(I0);
constexpr index_t NPerThread = ThreadMatrixC{}.GetLength(I1); constexpr index_t NPerThread = CThreadDesc{}.GetLength(I1);
constexpr index_t MRepeat = MPerThread / MPerThreadSubC; constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC; constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
...@@ -342,15 +341,15 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -342,15 +341,15 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
Sequence<0, 1, 2, 3>{}); Sequence<0, 1, 2, 3>{});
static constexpr auto a_thread_mtx_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto a_thread_mtx_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, ThreadMatrixC{}.GetLength(Number<0>{}))); make_tuple(Number<KPerThreadLoop>{}, CThreadDesc{}.GetLength(Number<0>{})));
static constexpr auto b_thread_mtx_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto b_thread_mtx_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, ThreadMatrixC{}.GetLength(Number<1>{}))); make_tuple(Number<KPerThreadLoop>{}, CThreadDesc{}.GetLength(Number<1>{})));
using AThreadCopy = using AThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA, ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
FloatA, FloatA,
BlockMatrixA, ABlockDesc,
decltype(a_thread_mtx_desc_), decltype(a_thread_mtx_desc_),
Sequence<KPerThreadLoop, MPerThreadSubC>, Sequence<KPerThreadLoop, MPerThreadSubC>,
Sequence<0, 1>, Sequence<0, 1>,
...@@ -363,7 +362,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -363,7 +362,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
using BThreadCopy = using BThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatB, ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
FloatB, FloatB,
BlockMatrixB, BBlockDesc,
decltype(b_thread_mtx_desc_), decltype(b_thread_mtx_desc_),
Sequence<KPerThreadLoop, NPerThreadSubC>, Sequence<KPerThreadLoop, NPerThreadSubC>,
Sequence<0, 1>, Sequence<0, 1>,
...@@ -411,7 +410,7 @@ template <index_t BlockSize, ...@@ -411,7 +410,7 @@ template <index_t BlockSize,
BBlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime() &&
CThreadDesc::IsKnownAtCompileTime(), CThreadDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1 struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1r1
{ {
using AIndex = MultiIndex<3>; using AIndex = MultiIndex<3>;
using BIndex = MultiIndex<3>; using BIndex = MultiIndex<3>;
...@@ -423,7 +422,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1 ...@@ -423,7 +422,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
public: public:
__device__ BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1() __device__ BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1r1()
: c_thread_origin_data_idx_{CalculateCThreadOriginDataIndex(get_thread_local_1d_id())}, : c_thread_origin_data_idx_{CalculateCThreadOriginDataIndex(get_thread_local_1d_id())},
a_thread_copy_{ a_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])}, make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])},
...@@ -479,7 +478,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1 ...@@ -479,7 +478,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
FloatC, FloatC,
decltype(a_thread_desc_), decltype(a_thread_desc_),
decltype(b_thread_desc_), decltype(b_thread_desc_),
CThreadDesc>{}; CThreadDesc,
Sequence<KPerThreadLoop>,
Sequence<M0_, M1PerThread>,
Sequence<N0_, N1PerThread>>{};
constexpr index_t K = ABlockDesc{}.GetLength(I0); constexpr index_t K = ABlockDesc{}.GetLength(I0);
...@@ -553,5 +555,295 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1 ...@@ -553,5 +555,295 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
BThreadCopy b_thread_copy_; BThreadCopy b_thread_copy_;
}; };
// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
// A and B are visable to the whole block, C is distributed among each thread
// Assume:
// 1. A:
// 1. ABlockDesc is known at compile-time
// 2. ABlockBuffer is DynamicBuffer
// 2. B:
// 1. ABlockDesc is known at compile-time
// 2. BBlockBuffer is DynamicBuffer
// 3. C:
// 1. CThreadDesc is known at compile-time
// 2. CThreadBuffer is StaticBuffer
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename ABlockDesc,
typename BBlockDesc,
typename CThreadDesc,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThreadLoop,
index_t MLevel0ThreadCluster,
index_t NLevel0ThreadCluster,
index_t MLevel1ThreadCluster,
index_t NLevel1ThreadCluster,
index_t AThreadCopyScalarPerVector_M1,
index_t BThreadCopyScalarPerVector_N1,
typename std::enable_if<ABlockDesc::IsKnownAtCompileTime() &&
BBlockDesc::IsKnownAtCompileTime() &&
CThreadDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1r2
{
using AIndex = MultiIndex<3>;
using BIndex = MultiIndex<3>;
using CIndex = MultiIndex<4>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
public:
__device__ BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1r2()
: c_thread_origin_data_idx_{CalculateCThreadOriginDataIndex(get_thread_local_1d_id())},
a_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])},
b_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])}
{
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime() &&
CThreadDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(BlockSize == c_thread_cluster_desc_.GetElementSize(),
"wrong! wrong blocksize");
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
"wrong! K dimension not consistent");
// TODO: remove this restriction
static_assert(ABlockDesc{}.GetLength(I1) == 2 && BBlockDesc{}.GetLength(I1) == 2 &&
CThreadDesc{}.GetLength(I0) == 2 && CThreadDesc{}.GetLength(I2) == 2,
"wrong");
}
__device__ static CIndex CalculateCThreadOriginDataIndex(index_t thread_id)
{
const auto thread_cluster_idx =
c_thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id));
constexpr index_t MPerLevel0Cluster = M1PerThread * MLevel0ThreadCluster;
constexpr index_t NPerLevel0Cluster = N1PerThread * NLevel0ThreadCluster;
return make_multi_index(
0,
thread_cluster_idx[I0] * MPerLevel0Cluster + thread_cluster_idx[I2] * M1PerThread,
0,
thread_cluster_idx[I1] * NPerLevel0Cluster + thread_cluster_idx[I3] * N1PerThread);
}
__host__ __device__ static constexpr auto GetCThreadClusterDescriptor()
{
return make_cluster_descriptor_v2(Sequence<MLevel1ThreadCluster,
NLevel1ThreadCluster,
MLevel0ThreadCluster,
NLevel0ThreadCluster>{},
Sequence<0, 1, 2, 3>{});
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<FloatA>(a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<FloatB>(b_thread_desc_.GetElementSpaceSize());
constexpr auto threadwise_gemm =
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA,
FloatB,
FloatC,
decltype(a_thread_desc_),
decltype(b_thread_desc_),
CThreadDesc,
Sequence<KPerThreadLoop>,
Sequence<1, M1PerThread>,
Sequence<1, N1PerThread>>{};
constexpr index_t K = ABlockDesc{}.GetLength(I0);
// read A_sub_0
a_thread_copy_.Run(ABlockDesc{},
make_tuple(I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0),
a_thread_buf);
// read B_sub_0
b_thread_copy_.Run(BBlockDesc{},
make_tuple(I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0),
b_thread_buf);
// read B_sub_1
b_thread_copy_.Run(BBlockDesc{},
make_tuple(I0, I1, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I1, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(ABlockDesc{},
make_tuple(I0, I1, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I1, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0),
c_thread_buf,
make_tuple(I0, I0, I1, I0));
// loop over rest of k
static_for<KPerThreadLoop, K, KPerThreadLoop>{}([&](auto k) {
// read A_sub_0
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0),
a_thread_buf);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I1, I0),
b_thread_buf,
make_tuple(I0, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I0, I0));
// read B_sub_0
b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0),
b_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I1, I0),
b_thread_buf,
make_tuple(I0, I1, I0),
c_thread_buf,
make_tuple(I1, I0, I1, I0));
// read B_sub_1
b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I1, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I1, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I1, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I1, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0),
c_thread_buf,
make_tuple(I0, I0, I1, I0));
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I1, I0),
b_thread_buf,
make_tuple(I0, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I0, I0));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I1, I0),
b_thread_buf,
make_tuple(I0, I1, I0),
c_thread_buf,
make_tuple(I1, I0, I1, I0));
}
private:
static constexpr auto c_thread_cluster_desc_ = GetCThreadClusterDescriptor();
static constexpr index_t M0_ = ABlockDesc{}.GetLength(I1);
static constexpr index_t N0_ = BBlockDesc{}.GetLength(I1);
// A[K, M0, M1]
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<M0_>{}, Number<M1PerThread>{}));
// B[K, N0, N1]
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<N0_>{}, Number<N1PerThread>{}));
using AThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
FloatA,
ABlockDesc,
decltype(a_thread_desc_),
Sequence<KPerThreadLoop, 1, M1PerThread>,
Sequence<0, 1, 2>,
2,
AThreadCopyScalarPerVector_M1,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>;
using BThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
FloatB,
BBlockDesc,
decltype(b_thread_desc_),
Sequence<KPerThreadLoop, 1, N1PerThread>,
Sequence<0, 1, 2>,
2,
BThreadCopyScalarPerVector_N1,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>;
CIndex c_thread_origin_data_idx_;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
} // namespace ck } // namespace ck
#endif #endif
...@@ -721,22 +721,22 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -721,22 +721,22 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{})); Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1<BlockSize, BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1r2<BlockSize,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_k_m0_m1_block_desc), decltype(a_k_m0_m1_block_desc),
decltype(b_k_n0_n1_block_desc), decltype(b_k_n0_n1_block_desc),
decltype(c_m0_m1_n0_n1_thread_desc), decltype(c_m0_m1_n0_n1_thread_desc),
MPerThread, MPerThread,
NPerThread, NPerThread,
KPerThread, KPerThread,
MLevel0Cluster, MLevel0Cluster,
NLevel0Cluster, NLevel0Cluster,
MLevel1Cluster, MLevel1Cluster,
NLevel1Cluster, NLevel1Cluster,
MPerThread, MPerThread,
NPerThread>{}; NPerThread>{};
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
......
...@@ -151,11 +151,27 @@ template <typename FloatA, ...@@ -151,11 +151,27 @@ template <typename FloatA,
typename ADesc, typename ADesc,
typename BDesc, typename BDesc,
typename CDesc, typename CDesc,
typename KLengths,
typename MLengths,
typename NLengths,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(), CDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1 struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
{ {
__device__ constexpr ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1()
{
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
// TODO: sanity-check: compare ADesc, BDesc, CDesc Size with KLenghts, MLengths and NLengths
// TODO remove this restriction
static_assert(KLengths::Size() == 1 && MLengths::Size() == 2 && NLengths::Size() == 2,
"wrong!");
}
template <typename ABuffer, template <typename ABuffer,
typename AOriginIdx, typename AOriginIdx,
typename BBuffer, typename BBuffer,
...@@ -169,10 +185,6 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1 ...@@ -169,10 +185,6 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
CBuffer& c_buf, CBuffer& c_buf,
COriginIdx) COriginIdx)
{ {
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert( static_assert(
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value && is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value && is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value &&
...@@ -192,11 +204,11 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1 ...@@ -192,11 +204,11 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto K = ADesc{}.GetLength(I0); constexpr auto K = KLengths{}[I0];
constexpr auto M0 = CDesc{}.GetLength(I0); constexpr auto M0 = MLengths{}[I0];
constexpr auto M1 = CDesc{}.GetLength(I1); constexpr auto M1 = MLengths{}[I1];
constexpr auto N0 = CDesc{}.GetLength(I2); constexpr auto N0 = NLengths{}[I0];
constexpr auto N1 = CDesc{}.GetLength(I3); constexpr auto N1 = NLengths{}[I1];
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment