Commit 2488d0bf authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 97ec23bf
......@@ -541,21 +541,33 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global
// FIXME: arbitrary # of D tensors
const auto c_ds_descs = tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
ds_grid_desc_mblock_mperblock_nblock_nperblock[I0],
ds_grid_desc_mblock_mperblock_nblock_nperblock[I1]);
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// tuple of starting index of C/Ds blockwise copy
const auto idx_c_ds_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple(
[&](auto) {
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
},
Number<NumDTensor>{}));
// blockwise copy C/D/E between LDS and global
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock, // ThreadGroup
ThisThreadBlock,
Tuple<FloatCShuffle,
remove_cvref_t<tuple_element_t<0, DsDataType>>,
remove_cvref_t<tuple_element_t<1, DsDataType>>>,
Tuple<FloatE>, // typename DstData,
decltype(c_ds_descs),
Tuple<FloatE>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation, // ElementwiseOperation,
CDEElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
// support arbitray type
Sequence<1,
......@@ -566,13 +578,14 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
Sequence<true, false, false>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_descs,
make_tuple(make_multi_index(0, 0, 0, 0),
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
cde_element_op};
......@@ -619,7 +632,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
// each block copy its data from LDS to global
cde_block_copy_lds_and_global.Run(
c_ds_descs,
c_ds_desc_refs,
tie(c_shuffle_block_buf, ds_grid_buf[I0], ds_grid_buf[I1]),
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e_grid_buf));
......@@ -630,9 +643,9 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
sfc_cde_block.GetForwardStep(access_id);
// move on Ds
static_for<0, DsDataType::Size(), 1>{}([&](auto i) {
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_descs, i + I1, cde_lds_and_global_step);
c_ds_desc_refs, i + I1, cde_lds_and_global_step);
});
// move on E
......
......@@ -7,6 +7,10 @@
namespace ck {
// Assume:
// 1. src_descs and dst_descs are not known at compile-time
// 2. SrcBuffers and DstBuffers are DynamicBuffer
// 3. src_slice_origins and dst_slice_origins are not known at compile-time,
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions:
// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
......@@ -14,11 +18,6 @@ namespace ck {
// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
// tensor coordinate instead
// 3. Don't use a pointer to VGPR buffer, use vector instead
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
template <typename SrcDatas,
typename DstDatas,
typename SrcDescs,
......@@ -34,8 +33,6 @@ template <typename SrcDatas,
struct ThreadwiseTensorSliceTransfer_v7
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t nDim = SliceLengths::Size();
......
#ifndef CK_SEQUENCE_HPP
#define CK_SEQUENCE_HPP
#pragma once
#include "integral_constant.hpp"
#include "type.hpp"
......@@ -882,5 +881,10 @@ __host__ __device__ constexpr bool sequence_all_of(Seq, F f)
return flag;
}
template <typename Sx, typename Sy>
using sequence_merge_t = typename sequence_merge<Sx, Sy>::type;
template <index_t NSize, index_t I>
using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type;
} // namespace ck
#endif
#ifndef CK_TUPLE_HELPER_HPP
#define CK_TUPLE_HELPER_HPP
#pragma once
#include "functional4.hpp"
#include "tuple.hpp"
......@@ -20,6 +19,17 @@ __host__ __device__ constexpr auto generate_tie(F&& f, Number<N>)
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
template <typename... X, typename... Y>
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>& tx,
const Tuple<Y&...>& ty)
{
return unpack2(
[&](auto&&... zs) { return Tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
tx,
ty);
}
namespace detail {
template <typename F, typename X, index_t... Is>
......@@ -66,4 +76,3 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y,
}
} // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment