Commit f9b92b1e authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent ff4f8ba8
#pragma once
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
......@@ -34,6 +35,10 @@ template <typename ThreadGroup,
bool ThreadTransferDstResetCoordinateAfterRun>
struct ThreadGroupTensorSliceTransfer_v7
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
......@@ -106,8 +111,7 @@ struct ThreadGroupTensorSliceTransfer_v7
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.Run(
tie(src0_desc, src1_desc, src2_desc),
threadwise_transfer_.Run(tie(src0_desc, src1_desc, src2_desc),
tie(src0_buf, src1_buf, src2_buf),
tie(dst_desc),
tie(dst_buf));
......@@ -119,7 +123,8 @@ struct ThreadGroupTensorSliceTransfer_v7
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step);
threadwise_transfer_.MoveSrcSliceWindow(
tie(src0_desc, Src1Desc{}, Src2Desc{}), step, I0);
}
}
......@@ -128,7 +133,8 @@ struct ThreadGroupTensorSliceTransfer_v7
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step);
threadwise_transfer_.MoveSrcSliceWindow(
tie(Src0Desc{}, src1_desc, Src2Desc{}), step, I1);
}
}
......@@ -137,7 +143,8 @@ struct ThreadGroupTensorSliceTransfer_v7
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrc2SliceWindow(src2_desc, step);
threadwise_transfer_.MoveSrcSliceWindow(
tie(Src0Desc{}, Src1Desc{}, src2_desc), step, I2);
}
}
......@@ -146,7 +153,7 @@ struct ThreadGroupTensorSliceTransfer_v7
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
threadwise_transfer_.MoveDstSliceWindow(tie(dst_desc), step, I0);
}
}
......@@ -154,8 +161,7 @@ struct ThreadGroupTensorSliceTransfer_v7
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v7<
using ThreadwiseTransfer = ThreadwiseTensorSliceTransfer_v7<
Tuple<remove_cvref_t<Src0Data>, remove_cvref_t<Src1Data>, remove_cvref_t<Src2Data>>,
Tuple<remove_cvref_t<DstData>>,
Tuple<remove_reference_t<Src0Desc>&,
......
......@@ -7,7 +7,6 @@
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r3.hpp"
#include "thread_group_tensor_slice_transfer_v7.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
......@@ -124,7 +123,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<decltype(DsDataType{}.At(i))>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return static_cast<const DDataType*>(nullptr);
},
......@@ -543,8 +542,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
ck::tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global
#if 1
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v6r3<
#if 0
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock, // ThreadGroup
CDEElementwiseOperation, // ElementwiseOperation,
EGlobalMemoryDataOperation, // DstInMemOp,
......@@ -590,8 +589,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename Src0Data,
remove_cvref_t<decltype(DsDataType{}[I0])>, // typename Src1Data,
remove_cvref_t<decltype(DsDataType{}[I1])>, // typename Src2Data,
remove_cvref_t<tuple_element_t<0, DsDataType>>, // typename Src1Data,
remove_cvref_t<tuple_element_t<1, DsDataType>>, // typename Src2Data,
FloatE, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I0]),
......
......@@ -109,16 +109,18 @@ struct ThreadwiseTensorSliceTransfer_v7
__device__ void Run(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
const DstDescs& dst_descs,
const DstBuffers& dst_bufs)
DstBuffers dst_bufs)
{
auto generate_vectors = [&](auto data_types) {
constexpr index_t num = data_types.Size();
return generate_tuple([&](auto i) {
return generate_tuple(
[&](auto i) {
using DataType = remove_cvref_t<decltype(data_types[i])>;
return vector_type_maker_t<DataType, ScalarPerVector>{};
}, Number<num>{});
},
Number<num>{});
};
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
......@@ -130,7 +132,7 @@ struct ThreadwiseTensorSliceTransfer_v7
// copy data from src_bufs into src_vectors
static_for<0, nSrc, 1>{}([&](auto i) {
using src_vector_t = remove_cvref_t<typename decltype(src_vectors[i])::type>;
using src_vector_t = typename remove_cvref_t<decltype(src_vectors[i])>::type;
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i],
......@@ -149,7 +151,7 @@ struct ThreadwiseTensorSliceTransfer_v7
using DstData0 = remove_cvref_t<decltype(DstDatas{}[I0])>;
element_op_(dst_vectors[I0].template AsType<DstData0>()(i),
element_op_(dst_vectors(I0).template AsType<DstData0>()(i),
src_vectors[I0].template AsType<SrcData0>()[i],
src_vectors[I1].template AsType<SrcData1>()[i],
src_vectors[I2].template AsType<SrcData2>()[i]);
......@@ -157,7 +159,7 @@ struct ThreadwiseTensorSliceTransfer_v7
// copy data from buf_vectors into dst_bufs
static_for<0, nDst, 1>{}([&](auto i) {
using dst_vector_t = typename remove_cv_t<decltype(dst_vectors[i])>::type;
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
......@@ -230,39 +232,53 @@ struct ThreadwiseTensorSliceTransfer_v7
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template <index_t ISrc>
__device__ void MoveSrcSliceWindow(const SrcDescs& src_descs,
const Index& src_slice_origin_step_idx)
const Index& src_slice_origin_step_idx,
Number<ISrc> iSrc)
{
static_for<0, nSrc, 1>{}([&](auto i) {
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRunFlags::At(i)
const auto adjusted_step_idx = SrcResetCoordinateAfterRunFlags::At(iSrc)
? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_descs[i], adjusted_step_idx);
const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], adjusted_step_idx);
move_tensor_coordinate(src_descs[i], src_coords_(i), adjusted_step);
});
move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
template <index_t IDst>
__device__ void MoveDstSliceWindow(const DstDescs& dst_descs,
const Index& dst_slice_origin_step_idx)
const Index& dst_slice_origin_step_idx,
Number<IDst> iDst)
{
static_for<0, nDst, 1>{}([&](auto i) {
// if dst coord was not reset by Run(), then need to adjust the step here
const auto adjusted_step_idx =
DstResetCoordinateAfterRunFlags::At(i)
const auto adjusted_step_idx = DstResetCoordinateAfterRunFlags::At(iDst)
? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_descs[i], adjusted_step_idx);
const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx);
move_tensor_coordinate(dst_descs[i], dst_coords_(i), adjusted_step);
});
move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step);
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveAllSrcSliceWindow(const SrcDescs& src_descs,
const Index& src_slice_origin_step_idx)
{
static_for<0, nSrc, 1>{}(
[&](auto i) { MoveSrcSliceWindow(src_descs, src_slice_origin_step_idx, i); });
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveAllDstSliceWindow(const DstDescs& dst_descs,
const Index& dst_slice_origin_step_idx)
{
static_for<0, nDst, 1>{}(
[&](auto i) { MoveDstSliceWindow(dst_descs, dst_slice_origin_step_idx, i); });
}
private:
......
#pragma once
#include "statically_indexed_array.hpp"
namespace ck {
......
#ifndef CK_TUPLE_HPP
#define CK_TUPLE_HPP
#pragma once
#include "integral_constant.hpp"
#include "sequence.hpp"
......@@ -25,9 +24,9 @@ struct TupleElementKeyData
__host__ __device__ constexpr TupleElementKeyData() : mData{} {}
#endif
template <
typename T,
typename enable_if<!is_same<remove_cvref_t<T>, TupleElementKeyData>::value, bool>::type = false>
template <typename T,
typename enable_if<!is_same<remove_cvref_t<T>, TupleElementKeyData>::value,
bool>::type = false>
__host__ __device__ constexpr TupleElementKeyData(T&& v) : mData(std::forward<T>(v))
{
}
......@@ -36,7 +35,8 @@ struct TupleElementKeyData
};
template <typename Key, typename Data>
__host__ __device__ constexpr const Data& get_tuple_element_data(const TupleElementKeyData<Key, Data>& x)
__host__ __device__ constexpr const Data&
get_tuple_element_data(const TupleElementKeyData<Key, Data>& x)
{
return static_cast<const Data&>(x.mData);
}
......@@ -179,13 +179,13 @@ struct Tuple<>
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
};
template<index_t I, typename TTuple>
template <index_t I, typename TTuple>
struct tuple_element
{
using type = decltype(TTuple{}.At(Number<I>{}));
};
template<index_t I, typename TTuple>
template <index_t I, typename TTuple>
using tuple_element_t = typename tuple_element<I, TTuple>::type;
template <typename... Xs>
......@@ -202,4 +202,3 @@ constexpr Tuple<Args&...> tie(Args&... args) noexcept
}
} // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment