Commit 1ca98e75 authored by aska-0096's avatar aska-0096
Browse files

tempsave

parent 9a99c841
...@@ -399,27 +399,27 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -399,27 +399,27 @@ struct ThreadwiseTensorSliceTransfer_v2
// 1. DstDesc is known at compile-time // 1. DstDesc is known at compile-time
// 2. DstBuffer is StaticBuffer // 2. DstBuffer is StaticBuffer
// 3. dst_slice_origin_idx is known at compile-time // 3. dst_slice_origin_idx is known at compile-time
template <typename SrcData, template <typename SrcDatas,
typename DstData, typename DstDatas,
typename SrcDesc, typename SrcDescs,
typename DstDesc, typename DstDescs,
typename SliceLengths, typename SliceLengths,
typename DimAccessOrder, typename DimAccessOrder,
index_t SrcVectorDim, index_t SrcVectorDim,
index_t SrcScalarPerVector, index_t SrcScalarPerVectors,
index_t SrcScalarStrideInVector, index_t SrcScalarStrideInVectors,
bool SrcResetCoordinateAfterRun, bool SrcResetCoordinateAfterRun,
bool InvalidElementAsNaN = false, bool InvalidElementAsNaN = false,
typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false> typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v2r1 struct ThreadwiseTensorSliceTransfer_v2r1
{ {
static_assert((InvalidElementAsNaN && !std::is_integral<DstData>::value) || static_assert((InvalidElementAsNaN && !std::is_integral<DstDatas>::value) ||
(!InvalidElementAsNaN), (!InvalidElementAsNaN),
"Filling invalid element as NaN is only for floating point types"); "Filling invalid element as NaN is only for floating point types");
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
static constexpr index_t nSrc = SrcDescs::Size(); static constexpr index_t nSrc = SrcDescs::Size();
static constexpr index_t nSrc = SrcDescs::Size(); static constexpr index_t nDst = DstDescs::Size();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
...@@ -437,37 +437,36 @@ struct ThreadwiseTensorSliceTransfer_v2r1 ...@@ -437,37 +437,36 @@ struct ThreadwiseTensorSliceTransfer_v2r1
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDesc& src_desc, __device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDescs& src_descs,
const Index& src_slice_origin_idx) const Indexs& src_slice_origin_idxs)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx))
{ {
static_assert(DstDesc::IsKnownAtCompileTime(), static_assert(DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time"); "wrong! SrcDesc need to known at compile-time");
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0, static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! Not divisible"); "wrong! Not divisible");
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) src_coords_(generate_tuple([&](auto i) { return make_tensor_coordinate(src_desc[i], src_slice_origin_idx[i]); },
{ nSrc);)
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
} }
template <typename SrcBuffer, typename DstBuffer, typename DstSliceOriginIdx> template <typename SrcBuffers, typename DstBuffers, typename DstSliceOriginIdxs>
__device__ void Run(const SrcDesc& src_desc, __device__ void Run(const SrcDescs& src_descs,
const SrcBuffer& src_buf, const SrcBuffers& src_bufs,
const DstDesc&, const DstDescs&,
const DstSliceOriginIdx&, const DstSliceOriginIdxs&,
DstBuffer& dst_buf) DstBuffers& dst_bufs)
{ {
static_assert(DstDesc::IsKnownAtCompileTime(), static_for<0, nDst, 1>{}([&](auto i) {
static_assert(remove_cvref_t<tuple_element_t<i.value, DstDescs>>::IsKnownAtCompileTime(),
"wrong! DstDesc need to known at compile-time"); "wrong! DstDesc need to known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<DstSliceOriginIdx>>::value, static_assert(is_known_at_compile_time<remove_cvref_t<tuple_element_t<i.value, DstSliceOriginIdxs>>>::value,
"wrong! DstSliceOrigin need to known at compile-time"); "wrong! DstSliceOrigin need to known at compile-time");
static_assert( static_assert(
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value && is_same<remove_cvref_t<typename tuple_element_t<i.value, DstBuffer>::type>, remove_cvref_t<tuple_element_t<i.value, DstDatas>>>::value &&
"wrong! inconsistent type"); "wrong! inconsistent type");
});
// DstDesc and dst_slice_origin_idx are known at compile-time // DstDesc and dst_slice_origin_idx are known at compile-time
constexpr auto dst_desc = remove_cvref_t<DstDesc>{}; constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment