Commit 1ca98e75 authored by aska-0096's avatar aska-0096
Browse files

tempsave

parent 9a99c841
......@@ -399,27 +399,27 @@ struct ThreadwiseTensorSliceTransfer_v2
// 1. DstDesc is known at compile-time
// 2. DstBuffer is StaticBuffer
// 3. dst_slice_origin_idx is known at compile-time
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
template <typename SrcDatas,
typename DstDatas,
typename SrcDescs,
typename DstDescs,
typename SliceLengths,
typename DimAccessOrder,
index_t SrcVectorDim,
index_t SrcScalarPerVector,
index_t SrcScalarStrideInVector,
index_t SrcScalarPerVectors,
index_t SrcScalarStrideInVectors,
bool SrcResetCoordinateAfterRun,
bool InvalidElementAsNaN = false,
typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v2r1
{
static_assert((InvalidElementAsNaN && !std::is_integral<DstData>::value) ||
static_assert((InvalidElementAsNaN && !std::is_integral<DstDatas>::value) ||
(!InvalidElementAsNaN),
"Filling invalid element as NaN is only for floating point types");
static constexpr index_t nDim = SliceLengths::Size();
static constexpr index_t nSrc = SrcDescs::Size();
static constexpr index_t nSrc = SrcDescs::Size();
static constexpr index_t nDst = DstDescs::Size();
using Index = MultiIndex<nDim>;
......@@ -437,37 +437,36 @@ struct ThreadwiseTensorSliceTransfer_v2r1
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDesc& src_desc,
const Index& src_slice_origin_idx)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx))
__device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDescs& src_descs,
const Indexs& src_slice_origin_idxs)
{
static_assert(DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! Not divisible");
src_coords_(generate_tuple([&](auto i) { return make_tensor_coordinate(src_desc[i], src_slice_origin_idx[i]); },
nSrc);)
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
template <typename SrcBuffers, typename DstBuffers, typename DstSliceOriginIdxs>
__device__ void Run(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
const DstDescs&,
const DstSliceOriginIdxs&,
DstBuffers& dst_bufs)
{
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
}
template <typename SrcBuffer, typename DstBuffer, typename DstSliceOriginIdx>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc&,
const DstSliceOriginIdx&,
DstBuffer& dst_buf)
{
static_assert(DstDesc::IsKnownAtCompileTime(),
static_for<0, nDst, 1>{}([&](auto i) {
static_assert(remove_cvref_t<tuple_element_t<i.value, DstDescs>>::IsKnownAtCompileTime(),
"wrong! DstDesc need to known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<DstSliceOriginIdx>>::value,
static_assert(is_known_at_compile_time<remove_cvref_t<tuple_element_t<i.value, DstSliceOriginIdxs>>>::value,
"wrong! DstSliceOrigin need to known at compile-time");
static_assert(
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value &&
"wrong! inconsistent type");
static_assert(
is_same<remove_cvref_t<typename tuple_element_t<i.value, DstBuffer>::type>, remove_cvref_t<tuple_element_t<i.value, DstDatas>>>::value &&
"wrong! inconsistent type");
});
// DstDesc and dst_slice_origin_idx are known at compile-time
constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment