"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "855ed8bcef8fc17655ccc28f70ce34a6b3b58d65"
Commit a25f992d authored by Chao Liu's avatar Chao Liu
Browse files

overhauling fwd-v4r4

parent 849243b8
...@@ -11,13 +11,13 @@ namespace ck { ...@@ -11,13 +11,13 @@ namespace ck {
// A and B are visable to the whole block, C is distributed among each thread // A and B are visable to the whole block, C is distributed among each thread
// Assume: // Assume:
// 1. A: // 1. A:
// 1. ABlockDesc is known at compile-time // 1. AKMBlockDesc is known at compile-time
// 2. ABlockBuffer is DynamicBuffer // 2. ABlockBuffer is DynamicBuffer
// 2. B: // 2. B:
// 1. ABlockDesc is known at compile-time // 1. AKMBlockDesc is known at compile-time
// 2. BBlockBuffer is DynamicBuffer // 2. BBlockBuffer is DynamicBuffer
// 3. C: // 3. C:
// 1. CThreadDesc is known at compile-time // 1. CM0M1N0N1ThreadDesc is known at compile-time
// 2. CThreadBuffer is StaticBuffer // 2. CThreadBuffer is StaticBuffer
// Also assume: // Also assume:
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization) // M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
...@@ -25,21 +25,21 @@ template <index_t BlockSize, ...@@ -25,21 +25,21 @@ template <index_t BlockSize,
typename FloatA, typename FloatA,
typename FloatB, typename FloatB,
typename FloatC, typename FloatC,
typename ABlockDesc, typename AKMBlockDesc,
typename BBlockDesc, typename BKNBlockDesc,
typename CThreadDesc, typename CM0M1N0N1ThreadDesc,
index_t M1PerThread, index_t M1PerThreadM11,
index_t N1PerThread, index_t N1PerThreadN11,
index_t KPerThread, index_t KPerThread,
index_t M1N1ThreadClusterM10, index_t M1N1ThreadClusterM100,
index_t M1N1ThreadClusterN10, index_t M1N1ThreadClusterN100,
index_t M1N1ThreadClusterM11, index_t M1N1ThreadClusterM101,
index_t M1N1ThreadClusterN11, index_t M1N1ThreadClusterN101,
index_t AThreadCopyScalarPerVector_M1, index_t AThreadCopyScalarPerVector_M11,
index_t BThreadCopyScalarPerVector_N1, index_t BThreadCopyScalarPerVector_N11,
typename std::enable_if<ABlockDesc::IsKnownAtCompileTime() && typename std::enable_if<AKMBlockDesc::IsKnownAtCompileTime() &&
BBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime() &&
CThreadDesc::IsKnownAtCompileTime(), CM0M1N0N1ThreadDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2 struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
{ {
...@@ -60,36 +60,38 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -60,36 +60,38 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
b_thread_copy_{ b_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])} make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])}
{ {
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime() && static_assert(AKMBlockDesc::IsKnownAtCompileTime() &&
CThreadDesc::IsKnownAtCompileTime(), BKNBlockDesc::IsKnownAtCompileTime() &&
CM0M1N0N1ThreadDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
static_assert(BlockSize == M1N1ThreadClusterM11 * M1N1ThreadClusterM10 * static_assert(BlockSize == M1N1ThreadClusterM101 * M1N1ThreadClusterM100 *
M1N1ThreadClusterN11 * M1N1ThreadClusterN10, M1N1ThreadClusterN101 * M1N1ThreadClusterN100,
"wrong! blocksize and cluster size not consistent"); "wrong! blocksize and cluster size not consistent");
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), static_assert(AKMBlockDesc{}.GetLength(I0) == BKNBlockDesc{}.GetLength(I0),
"wrong! K dimension not consistent"); "wrong! K dimension not consistent");
// TODO: remove this restriction // TODO: remove this restriction
static_assert(ABlockDesc{}.GetLength(I1) == 2 && BBlockDesc{}.GetLength(I1) == 2 && static_assert(AKMBlockDesc{}.GetLength(I1) == 2 && BKNBlockDesc{}.GetLength(I1) == 2 &&
CThreadDesc{}.GetLength(I0) == 2 && CThreadDesc{}.GetLength(I2) == 2, CM0M1N0N1ThreadDesc{}.GetLength(I0) == 2 &&
CM0M1N0N1ThreadDesc{}.GetLength(I2) == 2,
"wrong"); "wrong");
} }
__device__ static CIndex CalculateCThreadOriginDataIndex(index_t thread_id) __device__ static CIndex CalculateCThreadOriginDataIndex(index_t thread_id)
{ {
constexpr index_t M0 = ABlockDesc{}.GetLength(I1); constexpr index_t M0 = AKMBlockDesc{}.GetLength(I1);
constexpr index_t N0 = BBlockDesc{}.GetLength(I1); constexpr index_t N0 = BKNBlockDesc{}.GetLength(I1);
constexpr index_t M1 = ABlockDesc{}.GetLength(I2); constexpr index_t M1 = AKMBlockDesc{}.GetLength(I2);
constexpr index_t N1 = BBlockDesc{}.GetLength(I2); constexpr index_t N1 = BKNBlockDesc{}.GetLength(I2);
// 4-d data space into 4-d thread space // 4-d data space into 4-d thread space
constexpr auto adaptor0 = make_single_stage_tensor_adaptor( constexpr auto adaptor0 = make_single_stage_tensor_adaptor(
make_tuple(make_vectorize_transform(M0, 1), make_tuple(make_vectorize_transform(M0, 1),
make_vectorize_transform(M1PerThread, M1 / M1PerThread), make_vectorize_transform(M1PerThreadM11, M1 / M1PerThreadM11),
make_vectorize_transform(N0, 1), make_vectorize_transform(N0, 1),
make_vectorize_transform(N1PerThread, N1 / N1PerThread)), make_vectorize_transform(N1PerThreadN11, N1 / N1PerThreadN11)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
...@@ -97,18 +99,18 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -97,18 +99,18 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
constexpr auto adaptor1 = make_single_stage_tensor_adaptor( constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
make_tuple( make_tuple(
make_freeze_transform(make_multi_index(0)), make_freeze_transform(make_multi_index(0)),
make_unmerge_transform(make_tuple(M1N1ThreadClusterM10, M1N1ThreadClusterM11)), make_unmerge_transform(make_tuple(M1N1ThreadClusterM100, M1N1ThreadClusterM101)),
make_freeze_transform(make_multi_index(0)), make_freeze_transform(make_multi_index(0)),
make_unmerge_transform(make_tuple(M1N1ThreadClusterN10, M1N1ThreadClusterN11))), make_unmerge_transform(make_tuple(M1N1ThreadClusterN100, M1N1ThreadClusterN101))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0, 1>{}, Sequence<>{}, Sequence<2, 3>{})); make_tuple(Sequence<>{}, Sequence<0, 1>{}, Sequence<>{}, Sequence<2, 3>{}));
// 4-d thread space to 1-d thread space // 4-d thread space to 1-d thread space
constexpr auto adaptor2 = make_single_stage_tensor_adaptor( constexpr auto adaptor2 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M1N1ThreadClusterM10, make_tuple(make_merge_transform(make_tuple(M1N1ThreadClusterM100,
M1N1ThreadClusterN10, M1N1ThreadClusterN100,
M1N1ThreadClusterM11, M1N1ThreadClusterM101,
M1N1ThreadClusterN11))), M1N1ThreadClusterN101))),
make_tuple(Sequence<0, 2, 1, 3>{}), make_tuple(Sequence<0, 2, 1, 3>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
...@@ -133,15 +135,15 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -133,15 +135,15 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
FloatC, FloatC,
decltype(a_thread_desc_), decltype(a_thread_desc_),
decltype(b_thread_desc_), decltype(b_thread_desc_),
CThreadDesc, CM0M1N0N1ThreadDesc,
Sequence<KPerThread>, Sequence<KPerThread>,
Sequence<1, M1PerThread>, Sequence<1, M1PerThreadM11>,
Sequence<1, N1PerThread>>{}; Sequence<1, N1PerThreadN11>>{};
constexpr index_t K = ABlockDesc{}.GetLength(I0); constexpr index_t K = AKMBlockDesc{}.GetLength(I0);
// read A_sub_0 // read A_sub_0
a_thread_copy_.Run(ABlockDesc{}, a_thread_copy_.Run(AKMBlockDesc{},
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
...@@ -149,7 +151,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -149,7 +151,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
a_thread_buf); a_thread_buf);
// read B_sub_0 // read B_sub_0
b_thread_copy_.Run(BBlockDesc{}, b_thread_copy_.Run(BKNBlockDesc{},
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
...@@ -157,7 +159,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -157,7 +159,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
b_thread_buf); b_thread_buf);
// read B_sub_1 // read B_sub_1
b_thread_copy_.Run(BBlockDesc{}, b_thread_copy_.Run(BKNBlockDesc{},
make_tuple(I0, I1, I0), make_tuple(I0, I1, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
...@@ -165,7 +167,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -165,7 +167,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
b_thread_buf); b_thread_buf);
// read A_sub_1 // read A_sub_1
a_thread_copy_.Run(ABlockDesc{}, a_thread_copy_.Run(AKMBlockDesc{},
make_tuple(I0, I1, I0), make_tuple(I0, I1, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
...@@ -191,7 +193,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -191,7 +193,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
// loop over rest of k // loop over rest of k
static_for<KPerThread, K, KPerThread>{}([&](auto k) { static_for<KPerThread, K, KPerThread>{}([&](auto k) {
// read A_sub_0 // read A_sub_0
a_thread_copy_.Run(ABlockDesc{}, a_thread_copy_.Run(AKMBlockDesc{},
make_tuple(k, I0, I0), make_tuple(k, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
...@@ -207,7 +209,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -207,7 +209,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
make_tuple(I1, I0, I0, I0)); make_tuple(I1, I0, I0, I0));
// read B_sub_0 // read B_sub_0
b_thread_copy_.Run(BBlockDesc{}, b_thread_copy_.Run(BKNBlockDesc{},
make_tuple(k, I0, I0), make_tuple(k, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
...@@ -223,7 +225,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -223,7 +225,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
make_tuple(I1, I0, I1, I0)); make_tuple(I1, I0, I1, I0));
// read B_sub_1 // read B_sub_1
b_thread_copy_.Run(BBlockDesc{}, b_thread_copy_.Run(BKNBlockDesc{},
make_tuple(k, I1, I0), make_tuple(k, I1, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
...@@ -231,7 +233,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -231,7 +233,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
b_thread_buf); b_thread_buf);
// read A_sub_1 // read A_sub_1
a_thread_copy_.Run(ABlockDesc{}, a_thread_copy_.Run(AKMBlockDesc{},
make_tuple(k, I1, I0), make_tuple(k, I1, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
...@@ -273,37 +275,37 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -273,37 +275,37 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
} }
private: private:
static constexpr index_t M0_ = ABlockDesc{}.GetLength(I1); static constexpr index_t M0_ = AKMBlockDesc{}.GetLength(I1);
static constexpr index_t N0_ = BBlockDesc{}.GetLength(I1); static constexpr index_t N0_ = BKNBlockDesc{}.GetLength(I1);
// A[K, M0, M1] // A[K, M0, M1]
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThread>{}, Number<M0_>{}, Number<M1PerThread>{})); make_tuple(Number<KPerThread>{}, Number<M0_>{}, Number<M1PerThreadM11>{}));
// B[K, N0, N1] // B[K, N0, N1]
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThread>{}, Number<N0_>{}, Number<N1PerThread>{})); make_tuple(Number<KPerThread>{}, Number<N0_>{}, Number<N1PerThreadN11>{}));
using AThreadCopy = using AThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA, ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
FloatA, FloatA,
ABlockDesc, AKMBlockDesc,
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<KPerThread, 1, M1PerThread>, Sequence<KPerThread, 1, M1PerThreadM11>,
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
2, 2,
AThreadCopyScalarPerVector_M1, AThreadCopyScalarPerVector_M11,
1>; 1>;
using BThreadCopy = using BThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatB, ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
FloatB, FloatB,
BBlockDesc, BKNBlockDesc,
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<KPerThread, 1, N1PerThread>, Sequence<KPerThread, 1, N1PerThreadN11>,
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
2, 2,
BThreadCopyScalarPerVector_N1, BThreadCopyScalarPerVector_N11,
1>; 1>;
CIndex c_thread_origin_data_idx_; CIndex c_thread_origin_data_idx_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment