"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "5b42b7393e1c2b8bad12ee060722bd5d2ba92150"
Commit 4b21c0fd authored by Chao Liu's avatar Chao Liu
Browse files

overhauling fwd-v4r4

parent a25f992d
...@@ -14,7 +14,7 @@ namespace ck { ...@@ -14,7 +14,7 @@ namespace ck {
// 1. AKMBlockDesc is known at compile-time // 1. AKMBlockDesc is known at compile-time
// 2. ABlockBuffer is DynamicBuffer // 2. ABlockBuffer is DynamicBuffer
// 2. B: // 2. B:
// 1. AKMBlockDesc is known at compile-time // 1. BKNBlockDesc is known at compile-time
// 2. BBlockBuffer is DynamicBuffer // 2. BBlockBuffer is DynamicBuffer
// 3. C: // 3. C:
// 1. CM0M1N0N1ThreadDesc is known at compile-time // 1. CM0M1N0N1ThreadDesc is known at compile-time
...@@ -27,7 +27,6 @@ template <index_t BlockSize, ...@@ -27,7 +27,6 @@ template <index_t BlockSize,
typename FloatC, typename FloatC,
typename AKMBlockDesc, typename AKMBlockDesc,
typename BKNBlockDesc, typename BKNBlockDesc,
typename CM0M1N0N1ThreadDesc,
index_t M1PerThreadM11, index_t M1PerThreadM11,
index_t N1PerThreadN11, index_t N1PerThreadN11,
index_t KPerThread, index_t KPerThread,
...@@ -38,10 +37,9 @@ template <index_t BlockSize, ...@@ -38,10 +37,9 @@ template <index_t BlockSize,
index_t AThreadCopyScalarPerVector_M11, index_t AThreadCopyScalarPerVector_M11,
index_t BThreadCopyScalarPerVector_N11, index_t BThreadCopyScalarPerVector_N11,
typename std::enable_if<AKMBlockDesc::IsKnownAtCompileTime() && typename std::enable_if<AKMBlockDesc::IsKnownAtCompileTime() &&
BKNBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(),
CM0M1N0N1ThreadDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2 struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
{ {
using AIndex = MultiIndex<3>; using AIndex = MultiIndex<3>;
using BIndex = MultiIndex<3>; using BIndex = MultiIndex<3>;
...@@ -52,40 +50,76 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -52,40 +50,76 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr index_t K = AKMBlockDesc{}.GetLength(I0);
static constexpr index_t M = AKMBlockDesc{}.GetLength(I1);
static constexpr index_t N = BKNBlockDesc{}.GetLength(I1);
static constexpr index_t M1 = M1N1ThreadClusterM100 * M1N1ThreadClusterM101 * M1PerThreadM11;
static constexpr index_t N1 = M1N1ThreadClusterN100 * M1N1ThreadClusterN101 * N1PerThreadN11;
static constexpr index_t M0 = M / M1;
static constexpr index_t N0 = N / N1;
__host__ __device__ static constexpr auto
MakeAKM0M1BlockDescriptor(const AKMBlockDesc& a_k_m_block_desc)
{
const auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor(
AKMBlockDesc{},
make_tuple(make_pass_through_transform(Number<K>{}),
make_unmerge_transform(make_tuple(Number<M0>{}, Number<M1>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
return a_k_m0_m1_block_desc;
}
__host__ __device__ static constexpr auto
MakeBKN0N1BlockDescriptor(const BKNBlockDesc& n_k_n_block_desc)
{
const auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor(
BKNBlockDesc{},
make_tuple(make_pass_through_transform(Number<K>{}),
make_unmerge_transform(make_tuple(Number<N0>{}, Number<N1>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
return b_k_n0_n1_block_desc;
}
__host__ __device__ static constexpr auto GetCM0M1N0N1ThreadTensorLengths()
{
return Sequence<M0, M1PerThreadM11, N0, N1PerThreadN11>{};
}
static constexpr auto a_k_m0_m1_block_desc_ = MakeAKM0M1BlockDescriptor(AKMBlockDesc{});
static constexpr auto b_k_n0_n1_block_desc_ = MakeBKN0N1BlockDescriptor(BKNBlockDesc{});
public: public:
__device__ BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2() __device__ BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2()
: c_thread_origin_data_idx_{CalculateCThreadOriginDataIndex(get_thread_local_1d_id())}, : c_thread_origin_data_idx_{CalculateCThreadOriginDataIndex(get_thread_local_1d_id())},
a_thread_copy_{ a_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])}, make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])},
b_thread_copy_{ b_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])} make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])}
{ {
static_assert(AKMBlockDesc::IsKnownAtCompileTime() && static_assert(AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(),
BKNBlockDesc::IsKnownAtCompileTime() &&
CM0M1N0N1ThreadDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
static_assert(BlockSize == M1N1ThreadClusterM101 * M1N1ThreadClusterM100 * static_assert(BlockSize == M1N1ThreadClusterM101 * M1N1ThreadClusterM100 *
M1N1ThreadClusterN101 * M1N1ThreadClusterN100, M1N1ThreadClusterN101 * M1N1ThreadClusterN100,
"wrong! blocksize and cluster size not consistent"); "wrong! blocksize and cluster size not consistent");
static_assert(M % M1 == 0 && N % N1 == 0, "wrong!");
static_assert(AKMBlockDesc{}.GetLength(I0) == BKNBlockDesc{}.GetLength(I0), static_assert(AKMBlockDesc{}.GetLength(I0) == BKNBlockDesc{}.GetLength(I0),
"wrong! K dimension not consistent"); "wrong! K dimension not consistent");
// TODO: remove this restriction // TODO: remove this restriction
static_assert(AKMBlockDesc{}.GetLength(I1) == 2 && BKNBlockDesc{}.GetLength(I1) == 2 && static_assert(M0 == 2 && N0 == 2, "wrong");
CM0M1N0N1ThreadDesc{}.GetLength(I0) == 2 &&
CM0M1N0N1ThreadDesc{}.GetLength(I2) == 2,
"wrong");
} }
__device__ static CIndex CalculateCThreadOriginDataIndex(index_t thread_id) __device__ static CIndex CalculateCThreadOriginDataIndex(index_t thread_id)
{ {
constexpr index_t M0 = AKMBlockDesc{}.GetLength(I1);
constexpr index_t N0 = BKNBlockDesc{}.GetLength(I1);
constexpr index_t M1 = AKMBlockDesc{}.GetLength(I2);
constexpr index_t N1 = BKNBlockDesc{}.GetLength(I2);
// 4-d data space into 4-d thread space // 4-d data space into 4-d thread space
constexpr auto adaptor0 = make_single_stage_tensor_adaptor( constexpr auto adaptor0 = make_single_stage_tensor_adaptor(
make_tuple(make_vectorize_transform(M0, 1), make_tuple(make_vectorize_transform(M0, 1),
...@@ -119,58 +153,68 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -119,58 +153,68 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
return cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); return cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
} }
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename CM0M1N0N1ThreadDesc,
__device__ void Run(const ABlockBuffer& a_block_buf, typename ABlockBuffer,
typename BBlockBuffer,
typename CThreadBuffer>
__device__ void Run(const CM0M1N0N1ThreadDesc& c_m0_m1_n0_n1_thread_desc,
const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
{ {
auto a_thread_buf = static_assert(CM0M1N0N1ThreadDesc::IsKnownAtCompileTime(),
make_static_buffer<AddressSpace::Vgpr, FloatA>(a_thread_desc_.GetElementSpaceSize()); "wrong! Desc should be known at compile-time");
auto b_thread_buf =
make_static_buffer<AddressSpace::Vgpr, FloatB>(b_thread_desc_.GetElementSpaceSize()); // TODO: remove this restriction
static_assert(M0 == 2 && N0 == 2 && CM0M1N0N1ThreadDesc{}.GetLength(I0) == M0 &&
CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0,
"wrong");
auto a_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatA>(
a_k_m0_m1_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatB>(
b_k_n0_n1_thread_desc_.GetElementSpaceSize());
constexpr auto threadwise_gemm = constexpr auto threadwise_gemm =
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA, ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA,
FloatB, FloatB,
FloatC, FloatC,
decltype(a_thread_desc_), decltype(a_k_m0_m1_thread_desc_),
decltype(b_thread_desc_), decltype(b_k_n0_n1_thread_desc_),
CM0M1N0N1ThreadDesc, CM0M1N0N1ThreadDesc,
Sequence<KPerThread>, Sequence<KPerThread>,
Sequence<1, M1PerThreadM11>, Sequence<1, M1PerThreadM11>,
Sequence<1, N1PerThreadN11>>{}; Sequence<1, N1PerThreadN11>>{};
constexpr index_t K = AKMBlockDesc{}.GetLength(I0);
// read A_sub_0 // read A_sub_0
a_thread_copy_.Run(AKMBlockDesc{}, a_thread_copy_.Run(a_k_m0_m1_block_desc_,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_k_m0_m1_thread_desc_,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
a_thread_buf); a_thread_buf);
// read B_sub_0 // read B_sub_0
b_thread_copy_.Run(BKNBlockDesc{}, b_thread_copy_.Run(b_k_n0_n1_block_desc_,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_k_n0_n1_thread_desc_,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
b_thread_buf); b_thread_buf);
// read B_sub_1 // read B_sub_1
b_thread_copy_.Run(BKNBlockDesc{}, b_thread_copy_.Run(b_k_n0_n1_block_desc_,
make_tuple(I0, I1, I0), make_tuple(I0, I1, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_k_n0_n1_thread_desc_,
make_tuple(I0, I1, I0), make_tuple(I0, I1, I0),
b_thread_buf); b_thread_buf);
// read A_sub_1 // read A_sub_1
a_thread_copy_.Run(AKMBlockDesc{}, a_thread_copy_.Run(a_k_m0_m1_block_desc_,
make_tuple(I0, I1, I0), make_tuple(I0, I1, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_k_m0_m1_thread_desc_,
make_tuple(I0, I1, I0), make_tuple(I0, I1, I0),
a_thread_buf); a_thread_buf);
...@@ -193,10 +237,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -193,10 +237,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
// loop over rest of k // loop over rest of k
static_for<KPerThread, K, KPerThread>{}([&](auto k) { static_for<KPerThread, K, KPerThread>{}([&](auto k) {
// read A_sub_0 // read A_sub_0
a_thread_copy_.Run(AKMBlockDesc{}, a_thread_copy_.Run(a_k_m0_m1_block_desc_,
make_tuple(k, I0, I0), make_tuple(k, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_k_m0_m1_thread_desc_,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
a_thread_buf); a_thread_buf);
...@@ -209,10 +253,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -209,10 +253,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
make_tuple(I1, I0, I0, I0)); make_tuple(I1, I0, I0, I0));
// read B_sub_0 // read B_sub_0
b_thread_copy_.Run(BKNBlockDesc{}, b_thread_copy_.Run(b_k_n0_n1_block_desc_,
make_tuple(k, I0, I0), make_tuple(k, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_k_n0_n1_thread_desc_,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
b_thread_buf); b_thread_buf);
...@@ -225,18 +269,18 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -225,18 +269,18 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
make_tuple(I1, I0, I1, I0)); make_tuple(I1, I0, I1, I0));
// read B_sub_1 // read B_sub_1
b_thread_copy_.Run(BKNBlockDesc{}, b_thread_copy_.Run(b_k_n0_n1_block_desc_,
make_tuple(k, I1, I0), make_tuple(k, I1, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_k_n0_n1_thread_desc_,
make_tuple(I0, I1, I0), make_tuple(I0, I1, I0),
b_thread_buf); b_thread_buf);
// read A_sub_1 // read A_sub_1
a_thread_copy_.Run(AKMBlockDesc{}, a_thread_copy_.Run(a_k_m0_m1_block_desc_,
make_tuple(k, I1, I0), make_tuple(k, I1, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_k_m0_m1_thread_desc_,
make_tuple(I0, I1, I0), make_tuple(I0, I1, I0),
a_thread_buf); a_thread_buf);
...@@ -275,22 +319,19 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -275,22 +319,19 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
} }
private: private:
static constexpr index_t M0_ = AKMBlockDesc{}.GetLength(I1);
static constexpr index_t N0_ = BKNBlockDesc{}.GetLength(I1);
// A[K, M0, M1] // A[K, M0, M1]
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto a_k_m0_m1_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThread>{}, Number<M0_>{}, Number<M1PerThreadM11>{})); make_tuple(Number<KPerThread>{}, Number<M0>{}, Number<M1PerThreadM11>{}));
// B[K, N0, N1] // B[K, N0, N1]
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( static constexpr auto b_k_n0_n1_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThread>{}, Number<N0_>{}, Number<N1PerThreadN11>{})); make_tuple(Number<KPerThread>{}, Number<N0>{}, Number<N1PerThreadN11>{}));
using AThreadCopy = using AThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA, ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
FloatA, FloatA,
AKMBlockDesc, decltype(a_k_m0_m1_block_desc_),
decltype(a_thread_desc_), decltype(a_k_m0_m1_thread_desc_),
Sequence<KPerThread, 1, M1PerThreadM11>, Sequence<KPerThread, 1, M1PerThreadM11>,
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
2, 2,
...@@ -300,8 +341,8 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -300,8 +341,8 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
using BThreadCopy = using BThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatB, ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
FloatB, FloatB,
BKNBlockDesc, decltype(b_k_n0_n1_block_desc_),
decltype(b_thread_desc_), decltype(b_k_n0_n1_thread_desc_),
Sequence<KPerThread, 1, N1PerThreadN11>, Sequence<KPerThread, 1, N1PerThreadN11>,
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
2, 2,
......
...@@ -283,59 +283,27 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -283,59 +283,27 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
// b_mtx[KPerBlocl, NPerBlock] is in LDS // b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check
static_assert(
MPerBlock % (M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10) == 0 &&
NPerBlock % (N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10) == 0,
"wrong!");
constexpr index_t M0PerThread =
MPerBlock / (M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10);
constexpr index_t N0PerThread =
NPerBlock / (N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10);
constexpr auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor(
a_k_m_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(make_tuple(
Number<M0PerThread>{},
Number<M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
constexpr auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor(
b_k_n_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(make_tuple(
Number<N0PerThread>{},
Number<N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
constexpr auto c_m0_m1_n0_n1_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<M0PerThread>{},
Number<M1PerThread>{},
Number<N0PerThread>{},
Number<N1PerThread>{}));
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2<BlockSize, BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2<BlockSize,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_k_m0_m1_block_desc), decltype(a_k_m_block_desc),
decltype(b_k_n0_n1_block_desc), decltype(b_k_n_block_desc),
decltype( M1PerThread,
c_m0_m1_n0_n1_thread_desc), N1PerThread,
M1PerThread, KPerThread,
N1PerThread, M1N1ThreadClusterM10,
KPerThread, M1N1ThreadClusterN10,
M1N1ThreadClusterM10, M1N1ThreadClusterM11,
M1N1ThreadClusterN10, M1N1ThreadClusterN11,
M1N1ThreadClusterM11, M1PerThread,
M1N1ThreadClusterN11, N1PerThread>{};
M1PerThread, constexpr auto c_m0_m1_n0_n1_thread_tensor_lengths =
N1PerThread>{}; decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
constexpr auto c_m0_m1_n0_n1_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
sequence_to_tuple_of_number(c_m0_m1_n0_n1_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
...@@ -351,10 +319,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -351,10 +319,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>( auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize()); c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1< ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
FloatAcc, decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_thread_desc), decltype(c_m0_m1_n0_n1_thread_tensor_lengths)>{}
Sequence<M0PerThread, M1PerThread, N0PerThread, N1PerThread>>{}
.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); .Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
...@@ -415,7 +382,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -415,7 +382,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks); b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); blockwise_gemm.Run(
c_m0_m1_n0_n1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf); a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf);
...@@ -438,7 +406,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -438,7 +406,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks); b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf); blockwise_gemm.Run(
c_m0_m1_n0_n1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf); a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf);
...@@ -465,7 +434,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -465,7 +434,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
b_blockwise_copy.RunRead(b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks); b_blockwise_copy.RunRead(b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); blockwise_gemm.Run(
c_m0_m1_n0_n1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS // LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf); a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf);
...@@ -474,14 +444,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -474,14 +444,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
__syncthreads(); __syncthreads();
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf); blockwise_gemm.Run(
c_m0_m1_n0_n1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
} }
else // if has 1 iteration left else // if has 1 iteration left
{ {
__syncthreads(); __syncthreads();
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); blockwise_gemm.Run(
c_m0_m1_n0_n1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
} }
// output: register to global memory // output: register to global memory
...@@ -495,18 +467,17 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -495,18 +467,17 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
const auto c_thread_data_idx_on_block = const auto c_thread_data_idx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(get_thread_local_1d_id()); blockwise_gemm.CalculateCThreadOriginDataIndex(get_thread_local_1d_id());
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseDynamicTensorSliceTransfer_v1r3<FloatAcc,
FloatAcc, FloatC,
FloatC, decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_thread_desc), decltype(c_m0_m1_n0_n1_global_desc),
decltype(c_m0_m1_n0_n1_global_desc), decltype(c_m0_m1_n0_n1_thread_tensor_lengths),
Sequence<M0PerThread, M1PerThread, N0PerThread, N1PerThread>, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim,
CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector,
CThreadTransferDstScalarPerVector, CGlobalMemoryDataOperation,
CGlobalMemoryDataOperation, 1,
1, true>{
true>{
c_m0_m1_n0_n1_global_desc, c_m0_m1_n0_n1_global_desc,
make_multi_index(m_block_data_idx_on_global / M1 + c_thread_data_idx_on_block[I0], make_multi_index(m_block_data_idx_on_global / M1 + c_thread_data_idx_on_block[I0],
c_thread_data_idx_on_block[I1], c_thread_data_idx_on_block[I1],
......
...@@ -26,5 +26,11 @@ __host__ __device__ constexpr auto generate_sequence_v2(F&& f, Number<N>) ...@@ -26,5 +26,11 @@ __host__ __device__ constexpr auto generate_sequence_v2(F&& f, Number<N>)
typename arithmetic_sequence_gen<0, N, 1>::type{}); typename arithmetic_sequence_gen<0, N, 1>::type{});
} }
template <index_t... Is>
__host__ __device__ constexpr auto to_sequence(Tuple<Number<Is>...>)
{
return Sequence<Is...>{};
}
} // namespace ck } // namespace ck
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment