"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "54e1251e05261ffb369fbc898d7478d24dd10e69"
Commit 2488d0bf authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 97ec23bf
...@@ -541,21 +541,33 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -541,21 +541,33 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
n_thread_data_on_block_idx[I2]), n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}}; ck::tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global // tuple of reference to C/Ds tensor descriptors
// FIXME: arbitrary # of D tensors const auto c_ds_desc_refs = concat_tuple_of_reference(
const auto c_ds_descs = tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
ds_grid_desc_mblock_mperblock_nblock_nperblock[I0], generate_tie(
ds_grid_desc_mblock_mperblock_nblock_nperblock[I1]); [&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// tuple of starting index of C/Ds blockwise copy
const auto idx_c_ds_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple(
[&](auto) {
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
},
Number<NumDTensor>{}));
// blockwise copy C/D/E between LDS and global
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock, // ThreadGroup ThisThreadBlock,
Tuple<FloatCShuffle, Tuple<FloatCShuffle,
remove_cvref_t<tuple_element_t<0, DsDataType>>, remove_cvref_t<tuple_element_t<0, DsDataType>>,
remove_cvref_t<tuple_element_t<1, DsDataType>>>, remove_cvref_t<tuple_element_t<1, DsDataType>>>,
Tuple<FloatE>, // typename DstData, Tuple<FloatE>,
decltype(c_ds_descs), decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation, // ElementwiseOperation, CDEElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
// support arbitray type // support arbitray type
Sequence<1, Sequence<1,
...@@ -566,13 +578,14 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -566,13 +578,14 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename DimAccessOrder, Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim, 3, // index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, CDEShuffleBlockTransferScalarPerVector_NPerBlock,
Sequence<true, false, false>, // ThreadTransferSrcResetCoordinateAfterRunFlags sequence_merge_t<
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags Sequence<true>,
{c_ds_descs, uniform_sequence_gen_t<NumDTensor,
make_tuple(make_multi_index(0, 0, 0, 0), false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), {c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
cde_element_op}; cde_element_op};
...@@ -619,7 +632,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -619,7 +632,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
// each block copy its data from LDS to global // each block copy its data from LDS to global
cde_block_copy_lds_and_global.Run( cde_block_copy_lds_and_global.Run(
c_ds_descs, c_ds_desc_refs,
tie(c_shuffle_block_buf, ds_grid_buf[I0], ds_grid_buf[I1]), tie(c_shuffle_block_buf, ds_grid_buf[I0], ds_grid_buf[I1]),
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e_grid_buf)); tie(e_grid_buf));
...@@ -630,9 +643,9 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -630,9 +643,9 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
sfc_cde_block.GetForwardStep(access_id); sfc_cde_block.GetForwardStep(access_id);
// move on Ds // move on Ds
static_for<0, DsDataType::Size(), 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_block_copy_lds_and_global.MoveSrcSliceWindow( cde_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_descs, i + I1, cde_lds_and_global_step); c_ds_desc_refs, i + I1, cde_lds_and_global_step);
}); });
// move on E // move on E
......
...@@ -7,6 +7,10 @@ ...@@ -7,6 +7,10 @@
namespace ck { namespace ck {
// Assume:
// 1. src_descs and dst_descs are not known at compile-time
// 2. SrcBuffers and DstBuffers are DynamicBuffer
// 3. src_slice_origins and dst_slice_origins are not known at compile-time,
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory // Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions: // and sometimes useless instructions:
// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument // 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
...@@ -14,11 +18,6 @@ namespace ck { ...@@ -14,11 +18,6 @@ namespace ck {
// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same // 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
// tensor coordinate instead // tensor coordinate instead
// 3. Don't use a pointer to VGPR buffer, use vector instead // 3. Don't use a pointer to VGPR buffer, use vector instead
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
template <typename SrcDatas, template <typename SrcDatas,
typename DstDatas, typename DstDatas,
typename SrcDescs, typename SrcDescs,
...@@ -34,8 +33,6 @@ template <typename SrcDatas, ...@@ -34,8 +33,6 @@ template <typename SrcDatas,
struct ThreadwiseTensorSliceTransfer_v7 struct ThreadwiseTensorSliceTransfer_v7
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
......
#ifndef CK_SEQUENCE_HPP #pragma once
#define CK_SEQUENCE_HPP
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "type.hpp" #include "type.hpp"
...@@ -882,5 +881,10 @@ __host__ __device__ constexpr bool sequence_all_of(Seq, F f) ...@@ -882,5 +881,10 @@ __host__ __device__ constexpr bool sequence_all_of(Seq, F f)
return flag; return flag;
} }
template <typename Sx, typename Sy>
using sequence_merge_t = typename sequence_merge<Sx, Sy>::type;
template <index_t NSize, index_t I>
using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type;
} // namespace ck } // namespace ck
#endif
#ifndef CK_TUPLE_HELPER_HPP #pragma once
#define CK_TUPLE_HELPER_HPP
#include "functional4.hpp" #include "functional4.hpp"
#include "tuple.hpp" #include "tuple.hpp"
...@@ -20,6 +19,17 @@ __host__ __device__ constexpr auto generate_tie(F&& f, Number<N>) ...@@ -20,6 +19,17 @@ __host__ __device__ constexpr auto generate_tie(F&& f, Number<N>)
typename arithmetic_sequence_gen<0, N, 1>::type{}); typename arithmetic_sequence_gen<0, N, 1>::type{});
} }
// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
template <typename... X, typename... Y>
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>& tx,
const Tuple<Y&...>& ty)
{
return unpack2(
[&](auto&&... zs) { return Tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
tx,
ty);
}
namespace detail { namespace detail {
template <typename F, typename X, index_t... Is> template <typename F, typename X, index_t... Is>
...@@ -66,4 +76,3 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, ...@@ -66,4 +76,3 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y,
} }
} // namespace ck } // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment