Commit f9b92b1e authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent ff4f8ba8
#pragma once #pragma once
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
...@@ -34,6 +35,10 @@ template <typename ThreadGroup, ...@@ -34,6 +35,10 @@ template <typename ThreadGroup,
bool ThreadTransferDstResetCoordinateAfterRun> bool ThreadTransferDstResetCoordinateAfterRun>
struct ThreadGroupTensorSliceTransfer_v7 struct ThreadGroupTensorSliceTransfer_v7
{ {
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
...@@ -106,8 +111,7 @@ struct ThreadGroupTensorSliceTransfer_v7 ...@@ -106,8 +111,7 @@ struct ThreadGroupTensorSliceTransfer_v7
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.Run( threadwise_transfer_.Run(tie(src0_desc, src1_desc, src2_desc),
tie(src0_desc, src1_desc, src2_desc),
tie(src0_buf, src1_buf, src2_buf), tie(src0_buf, src1_buf, src2_buf),
tie(dst_desc), tie(dst_desc),
tie(dst_buf)); tie(dst_buf));
...@@ -119,7 +123,8 @@ struct ThreadGroupTensorSliceTransfer_v7 ...@@ -119,7 +123,8 @@ struct ThreadGroupTensorSliceTransfer_v7
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step); threadwise_transfer_.MoveSrcSliceWindow(
tie(src0_desc, Src1Desc{}, Src2Desc{}), step, I0);
} }
} }
...@@ -128,7 +133,8 @@ struct ThreadGroupTensorSliceTransfer_v7 ...@@ -128,7 +133,8 @@ struct ThreadGroupTensorSliceTransfer_v7
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step); threadwise_transfer_.MoveSrcSliceWindow(
tie(Src0Desc{}, src1_desc, Src2Desc{}), step, I1);
} }
} }
...@@ -137,7 +143,8 @@ struct ThreadGroupTensorSliceTransfer_v7 ...@@ -137,7 +143,8 @@ struct ThreadGroupTensorSliceTransfer_v7
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrc2SliceWindow(src2_desc, step); threadwise_transfer_.MoveSrcSliceWindow(
tie(Src0Desc{}, Src1Desc{}, src2_desc), step, I2);
} }
} }
...@@ -146,7 +153,7 @@ struct ThreadGroupTensorSliceTransfer_v7 ...@@ -146,7 +153,7 @@ struct ThreadGroupTensorSliceTransfer_v7
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); threadwise_transfer_.MoveDstSliceWindow(tie(dst_desc), step, I0);
} }
} }
...@@ -154,8 +161,7 @@ struct ThreadGroupTensorSliceTransfer_v7 ...@@ -154,8 +161,7 @@ struct ThreadGroupTensorSliceTransfer_v7
static constexpr auto thread_cluster_desc_ = static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer = using ThreadwiseTransfer = ThreadwiseTensorSliceTransfer_v7<
ThreadwiseTensorSliceTransfer_v7<
Tuple<remove_cvref_t<Src0Data>, remove_cvref_t<Src1Data>, remove_cvref_t<Src2Data>>, Tuple<remove_cvref_t<Src0Data>, remove_cvref_t<Src1Data>, remove_cvref_t<Src2Data>>,
Tuple<remove_cvref_t<DstData>>, Tuple<remove_cvref_t<DstData>>,
Tuple<remove_reference_t<Src0Desc>&, Tuple<remove_reference_t<Src0Desc>&,
......
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp" #include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r3.hpp"
#include "thread_group_tensor_slice_transfer_v7.hpp" #include "thread_group_tensor_slice_transfer_v7.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp" #include "gridwise_gemm_pipeline_v1.hpp"
...@@ -124,7 +123,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -124,7 +123,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
using DDataType = remove_cvref_t<decltype(DsDataType{}.At(i))>; using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return static_cast<const DDataType*>(nullptr); return static_cast<const DDataType*>(nullptr);
}, },
...@@ -543,8 +542,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -543,8 +542,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
ck::tensor_operation::element_wise::PassThrough{}}; ck::tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global // shuffle: blockwise copy C from LDS to global
#if 1 #if 0
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v6r3< auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock, // ThreadGroup ThisThreadBlock, // ThreadGroup
CDEElementwiseOperation, // ElementwiseOperation, CDEElementwiseOperation, // ElementwiseOperation,
EGlobalMemoryDataOperation, // DstInMemOp, EGlobalMemoryDataOperation, // DstInMemOp,
...@@ -590,8 +589,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -590,8 +589,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename Src0Data, FloatCShuffle, // typename Src0Data,
remove_cvref_t<decltype(DsDataType{}[I0])>, // typename Src1Data, remove_cvref_t<tuple_element_t<0, DsDataType>>, // typename Src1Data,
remove_cvref_t<decltype(DsDataType{}[I1])>, // typename Src2Data, remove_cvref_t<tuple_element_t<1, DsDataType>>, // typename Src2Data,
FloatE, // typename DstData, FloatE, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I0]), decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I0]),
......
...@@ -109,16 +109,18 @@ struct ThreadwiseTensorSliceTransfer_v7 ...@@ -109,16 +109,18 @@ struct ThreadwiseTensorSliceTransfer_v7
__device__ void Run(const SrcDescs& src_descs, __device__ void Run(const SrcDescs& src_descs,
const SrcBuffers& src_bufs, const SrcBuffers& src_bufs,
const DstDescs& dst_descs, const DstDescs& dst_descs,
const DstBuffers& dst_bufs) DstBuffers dst_bufs)
{ {
auto generate_vectors = [&](auto data_types) { auto generate_vectors = [&](auto data_types) {
constexpr index_t num = data_types.Size(); constexpr index_t num = data_types.Size();
return generate_tuple([&](auto i) { return generate_tuple(
[&](auto i) {
using DataType = remove_cvref_t<decltype(data_types[i])>; using DataType = remove_cvref_t<decltype(data_types[i])>;
return vector_type_maker_t<DataType, ScalarPerVector>{}; return vector_type_maker_t<DataType, ScalarPerVector>{};
}, Number<num>{}); },
Number<num>{});
}; };
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
...@@ -130,7 +132,7 @@ struct ThreadwiseTensorSliceTransfer_v7 ...@@ -130,7 +132,7 @@ struct ThreadwiseTensorSliceTransfer_v7
// copy data from src_bufs into src_vectors // copy data from src_bufs into src_vectors
static_for<0, nSrc, 1>{}([&](auto i) { static_for<0, nSrc, 1>{}([&](auto i) {
using src_vector_t = remove_cvref_t<typename decltype(src_vectors[i])::type>; using src_vector_t = typename remove_cvref_t<decltype(src_vectors[i])>::type;
const bool is_src_valid = const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i], coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i],
...@@ -149,7 +151,7 @@ struct ThreadwiseTensorSliceTransfer_v7 ...@@ -149,7 +151,7 @@ struct ThreadwiseTensorSliceTransfer_v7
using DstData0 = remove_cvref_t<decltype(DstDatas{}[I0])>; using DstData0 = remove_cvref_t<decltype(DstDatas{}[I0])>;
element_op_(dst_vectors[I0].template AsType<DstData0>()(i), element_op_(dst_vectors(I0).template AsType<DstData0>()(i),
src_vectors[I0].template AsType<SrcData0>()[i], src_vectors[I0].template AsType<SrcData0>()[i],
src_vectors[I1].template AsType<SrcData1>()[i], src_vectors[I1].template AsType<SrcData1>()[i],
src_vectors[I2].template AsType<SrcData2>()[i]); src_vectors[I2].template AsType<SrcData2>()[i]);
...@@ -157,7 +159,7 @@ struct ThreadwiseTensorSliceTransfer_v7 ...@@ -157,7 +159,7 @@ struct ThreadwiseTensorSliceTransfer_v7
// copy data from buf_vectors into dst_bufs // copy data from buf_vectors into dst_bufs
static_for<0, nDst, 1>{}([&](auto i) { static_for<0, nDst, 1>{}([&](auto i) {
using dst_vector_t = typename remove_cv_t<decltype(dst_vectors[i])>::type; using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
const bool is_dst_valid = const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
...@@ -230,39 +232,53 @@ struct ThreadwiseTensorSliceTransfer_v7 ...@@ -230,39 +232,53 @@ struct ThreadwiseTensorSliceTransfer_v7
} }
// src_slice_origin_step_idx need to be known at compile-time, for performance reason // src_slice_origin_step_idx need to be known at compile-time, for performance reason
template <index_t ISrc>
__device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs,
const Index& src_slice_origin_step_idx) const Index& src_slice_origin_step_idx,
Number<ISrc> iSrc)
{ {
static_for<0, nSrc, 1>{}([&](auto i) {
// if src coord was not reset by RunRead(), then need to adjust the step here // if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx = const auto adjusted_step_idx = SrcResetCoordinateAfterRunFlags::At(iSrc)
SrcResetCoordinateAfterRunFlags::At(i)
? src_slice_origin_step_idx ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetCoordinateResetStep(); : src_slice_origin_step_idx + GetCoordinateResetStep();
// is it OK to construct a new step every time? // is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_descs[i], adjusted_step_idx); const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], adjusted_step_idx);
move_tensor_coordinate(src_descs[i], src_coords_(i), adjusted_step); move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step);
});
} }
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
template <index_t IDst>
__device__ void MoveDstSliceWindow(const DstDescs& dst_descs, __device__ void MoveDstSliceWindow(const DstDescs& dst_descs,
const Index& dst_slice_origin_step_idx) const Index& dst_slice_origin_step_idx,
Number<IDst> iDst)
{ {
static_for<0, nDst, 1>{}([&](auto i) {
// if dst coord was not reset by Run(), then need to adjust the step here // if dst coord was not reset by Run(), then need to adjust the step here
const auto adjusted_step_idx = const auto adjusted_step_idx = DstResetCoordinateAfterRunFlags::At(iDst)
DstResetCoordinateAfterRunFlags::At(i)
? dst_slice_origin_step_idx ? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetCoordinateResetStep(); : dst_slice_origin_step_idx + GetCoordinateResetStep();
// is it OK to construct a new step every time? // is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_descs[i], adjusted_step_idx); const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx);
move_tensor_coordinate(dst_descs[i], dst_coords_(i), adjusted_step); move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step);
}); }
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveAllSrcSliceWindow(const SrcDescs& src_descs,
const Index& src_slice_origin_step_idx)
{
static_for<0, nSrc, 1>{}(
[&](auto i) { MoveSrcSliceWindow(src_descs, src_slice_origin_step_idx, i); });
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveAllDstSliceWindow(const DstDescs& dst_descs,
const Index& dst_slice_origin_step_idx)
{
static_for<0, nDst, 1>{}(
[&](auto i) { MoveDstSliceWindow(dst_descs, dst_slice_origin_step_idx, i); });
} }
private: private:
......
#pragma once #pragma once
#include "statically_indexed_array.hpp" #include "statically_indexed_array.hpp"
namespace ck { namespace ck {
......
#ifndef CK_TUPLE_HPP #pragma once
#define CK_TUPLE_HPP
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "sequence.hpp" #include "sequence.hpp"
...@@ -25,9 +24,9 @@ struct TupleElementKeyData ...@@ -25,9 +24,9 @@ struct TupleElementKeyData
__host__ __device__ constexpr TupleElementKeyData() : mData{} {} __host__ __device__ constexpr TupleElementKeyData() : mData{} {}
#endif #endif
template < template <typename T,
typename T, typename enable_if<!is_same<remove_cvref_t<T>, TupleElementKeyData>::value,
typename enable_if<!is_same<remove_cvref_t<T>, TupleElementKeyData>::value, bool>::type = false> bool>::type = false>
__host__ __device__ constexpr TupleElementKeyData(T&& v) : mData(std::forward<T>(v)) __host__ __device__ constexpr TupleElementKeyData(T&& v) : mData(std::forward<T>(v))
{ {
} }
...@@ -36,7 +35,8 @@ struct TupleElementKeyData ...@@ -36,7 +35,8 @@ struct TupleElementKeyData
}; };
template <typename Key, typename Data> template <typename Key, typename Data>
__host__ __device__ constexpr const Data& get_tuple_element_data(const TupleElementKeyData<Key, Data>& x) __host__ __device__ constexpr const Data&
get_tuple_element_data(const TupleElementKeyData<Key, Data>& x)
{ {
return static_cast<const Data&>(x.mData); return static_cast<const Data&>(x.mData);
} }
...@@ -179,13 +179,13 @@ struct Tuple<> ...@@ -179,13 +179,13 @@ struct Tuple<>
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; } __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
}; };
template<index_t I, typename TTuple> template <index_t I, typename TTuple>
struct tuple_element struct tuple_element
{ {
using type = decltype(TTuple{}.At(Number<I>{})); using type = decltype(TTuple{}.At(Number<I>{}));
}; };
template<index_t I, typename TTuple> template <index_t I, typename TTuple>
using tuple_element_t = typename tuple_element<I, TTuple>::type; using tuple_element_t = typename tuple_element<I, TTuple>::type;
template <typename... Xs> template <typename... Xs>
...@@ -202,4 +202,3 @@ constexpr Tuple<Args&...> tie(Args&... args) noexcept ...@@ -202,4 +202,3 @@ constexpr Tuple<Args&...> tie(Args&... args) noexcept
} }
} // namespace ck } // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment