"src/include/blockwise_gemm.cuh" did not exist on "dc60d16962771f360178c30285683d8fa2ea38c1"
Commit 37b82b7e authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 1f2cfceb
...@@ -84,6 +84,12 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -84,6 +84,12 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2); constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3); constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread"); static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
constexpr index_t N0 = N / (N1 * N2); constexpr index_t N0 = N / (N1 * N2);
...@@ -92,6 +98,14 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -92,6 +98,14 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
constexpr index_t E = C * Y * X; constexpr index_t E = C * Y * X;
// sanity-check for vectorized memory load
static_assert(ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1,
"wrong! global vector load of input tensor is wrong");
static_assert((X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// divide block work by [K, B] // divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0, static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0,
"wrong! cannot divide work evenly among block"); "wrong! cannot divide work evenly among block");
...@@ -111,15 +125,15 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -111,15 +125,15 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
// input tensor // input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo] // tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
constexpr auto in_n0_n1_n2_h_w_global_desc = constexpr auto in_n0_n1_n2_h_w_global_desc =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Ho>{}, Number<ConvStrides::Get(I0)>{}) in_n_c_h_w_global_desc.StridedSlice(I2, Number<Ho>{}, Number<ConvStrideH>{})
.StridedSlice(I3, Number<Wo>{}, Number<ConvStrides::Get(I1)>{}) .StridedSlice(I3, Number<Wo>{}, Number<ConvStrideW>{})
.Fold(I0, Number<N1>{}, Number<N2>{}) .Fold(I0, Number<N1>{}, Number<N2>{})
.Extract(Sequence<0, 1, 2, 4, 5>{}); .Extract(Sequence<0, 1, 2, 4, 5>{});
// batch descritpor for device memory // batch descritpor for device memory
constexpr auto in_c_y_x_global_desc = constexpr auto in_c_y_x_global_desc =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilations::Get(I0)>{}) in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
.StridedSlice(I3, Number<X>{}, Number<ConvDilations::Get(I1)>{}) .StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
.Extract(Sequence<1, 2, 3>{}); .Extract(Sequence<1, 2, 3>{});
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy // merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
......
...@@ -37,7 +37,7 @@ struct ConstantMergedTensorDescriptor ...@@ -37,7 +37,7 @@ struct ConstantMergedTensorDescriptor
return OriginalTensorDesc{}; return OriginalTensorDesc{};
} }
__host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; } __host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; }
template <index_t IDim> template <index_t IDim>
__host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number<IDim>) __host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number<IDim>)
...@@ -52,7 +52,7 @@ struct ConstantMergedTensorDescriptor ...@@ -52,7 +52,7 @@ struct ConstantMergedTensorDescriptor
} }
template <index_t IDim> template <index_t IDim>
__host__ __device__ static constexpr index_t GetLength(Number<IDim>) __host__ __device__ static constexpr auto GetLength(Number<IDim>)
{ {
constexpr auto original_dims_partial = std::get<IDim>(mOriginalDimMergeSeqs); constexpr auto original_dims_partial = std::get<IDim>(mOriginalDimMergeSeqs);
...@@ -60,7 +60,7 @@ struct ConstantMergedTensorDescriptor ...@@ -60,7 +60,7 @@ struct ConstantMergedTensorDescriptor
} }
template <index_t IDim> template <index_t IDim>
__host__ __device__ static constexpr index_t GetStride(Number<IDim>) __host__ __device__ static constexpr auto GetStride(Number<IDim>)
{ {
static_assert(!ContainMultipleOriginalDimensions(Number<IDim>{}), static_assert(!ContainMultipleOriginalDimensions(Number<IDim>{}),
"wrong! stride of a merged dimension is undefined"); "wrong! stride of a merged dimension is undefined");
...@@ -75,7 +75,7 @@ struct ConstantMergedTensorDescriptor ...@@ -75,7 +75,7 @@ struct ConstantMergedTensorDescriptor
return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs{}).GetElementSize()...>{}; return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs{}).GetElementSize()...>{};
} }
__host__ __device__ static constexpr index_t GetElementSize() __host__ __device__ static constexpr auto GetElementSize()
{ {
return OriginalTensorDesc::GetElementSize(); return OriginalTensorDesc::GetElementSize();
} }
......
...@@ -43,22 +43,22 @@ struct ConstantTensorDescriptor ...@@ -43,22 +43,22 @@ struct ConstantTensorDescriptor
return Sequence<IDim>{}; return Sequence<IDim>{};
} }
__host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; } __host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetLengths() { return Lengths{}; } __host__ __device__ static constexpr auto GetLengths() { return Lengths{}; }
__host__ __device__ static constexpr auto GetStrides() { return Strides{}; } __host__ __device__ static constexpr auto GetStrides() { return Strides{}; }
template <index_t I> template <class IDim>
__host__ __device__ static constexpr index_t GetLength(Number<I>) __host__ __device__ static constexpr auto GetLength(IDim)
{ {
return Lengths::Get(Number<I>{}); return Lengths::Get(IDim{});
} }
template <index_t I> template <class IDim>
__host__ __device__ static constexpr index_t GetStride(Number<I>) __host__ __device__ static constexpr auto GetStride(IDim)
{ {
return Strides::Get(Number<I>{}); return Strides::Get(IDim{});
} }
struct lambda_AreDimensionsContinuous struct lambda_AreDimensionsContinuous
...@@ -102,17 +102,18 @@ struct ConstantTensorDescriptor ...@@ -102,17 +102,18 @@ struct ConstantTensorDescriptor
return false; return false;
} }
__host__ __device__ static constexpr index_t GetElementSize() __host__ __device__ static constexpr auto GetElementSize()
{ {
return accumulate_on_sequence(Lengths{}, math::multiplies<index_t>{}, Number<1>{}); return Number<accumulate_on_sequence(
Lengths{}, math::multiplies<index_t>{}, Number<1>{})>{};
} }
__host__ __device__ static constexpr index_t GetElementSpace() __host__ __device__ static constexpr auto GetElementSpace()
{ {
constexpr index_t element_space_unaligned = accumulate_on_sequence( constexpr index_t element_space_unaligned = accumulate_on_sequence(
(GetLengths() - Number<1>{}) * GetStrides(), math::plus<index_t>{}, Number<1>{}); (GetLengths() - Number<1>{}) * GetStrides(), math::plus<index_t>{}, Number<1>{});
return element_space_unaligned; return Number<element_space_unaligned>{};
} }
// emulate constexpr lambda // emulate constexpr lambda
...@@ -156,13 +157,14 @@ struct ConstantTensorDescriptor ...@@ -156,13 +157,14 @@ struct ConstantTensorDescriptor
} }
template <index_t... Is> template <index_t... Is>
__host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Sequence<Is...>) __host__ __device__ static constexpr auto GetOffsetFromMultiIndex(Sequence<Is...>)
{ {
static_assert(sizeof...(Is) == nDim, "wrong! Dimension not consistent"); static_assert(sizeof...(Is) == nDim, "wrong! Dimension not consistent");
constexpr auto multi_id = Sequence<Is...>{}; constexpr auto multi_id = Sequence<Is...>{};
return accumulate_on_sequence(multi_id * GetStrides(), math::plus<index_t>{}, Number<0>{}); return Number<accumulate_on_sequence(
multi_id * GetStrides(), math::plus<index_t>{}, Number<0>{})>{};
} }
// emulate constexpr lambda // emulate constexpr lambda
......
...@@ -83,9 +83,7 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -83,9 +83,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
// divide work // divide work
constexpr auto data_per_cluster_per_dims = SubLengths{} * DataClusterLengths{}; constexpr auto data_per_cluster_per_dims = SubLengths{} * DataClusterLengths{};
static_for<0, nDim, 1>{}([&](auto IDim_) { static_for<0, nDim, 1>{}([&](auto IDim) {
constexpr auto IDim = decltype(IDim_){};
static_assert(SliceLengths::Get(IDim) % SubLengths::Get(IDim) == 0, static_assert(SliceLengths::Get(IDim) % SubLengths::Get(IDim) == 0,
"wrong! cannot evenly divide sliced tensor into sub-tensor"); "wrong! cannot evenly divide sliced tensor into sub-tensor");
...@@ -95,9 +93,7 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -95,9 +93,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
// for now, only support SubLengths == 1 on a merged dimension that constains // for now, only support SubLengths == 1 on a merged dimension that constains
// multiple original dimensions // multiple original dimensions
static_for<0, nDim, 1>{}([&](auto IDim_) { static_for<0, nDim, 1>{}([&](auto IDim) {
constexpr auto IDim = decltype(IDim_){};
static_assert(SubLengths::Get(IDim) == 1 || static_assert(SubLengths::Get(IDim) == 1 ||
(!SrcDesc::ContainMultipleOriginalDimensions(IDim) && (!SrcDesc::ContainMultipleOriginalDimensions(IDim) &&
!DstDesc::ContainMultipleOriginalDimensions(IDim)), !DstDesc::ContainMultipleOriginalDimensions(IDim)),
...@@ -121,8 +117,7 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -121,8 +117,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
dst_block_data_multi_id_begin + thread_data_multi_id_begin); dst_block_data_multi_id_begin + thread_data_multi_id_begin);
// partial offset on each dimension // partial offset on each dimension
static_for<0, nDim, 1>{}([&](auto IDim_) { static_for<0, nDim, 1>{}([&](auto IDim) {
constexpr auto IDim = decltype(IDim_){};
constexpr index_t idim = IDim; constexpr index_t idim = IDim;
constexpr auto src_partial_original_dims = constexpr auto src_partial_original_dims =
...@@ -135,8 +130,7 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -135,8 +130,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
extract_array(mThreadSrcOriginalMultiId, src_partial_original_dims)); extract_array(mThreadSrcOriginalMultiId, src_partial_original_dims));
}); });
static_for<0, nDim, 1>{}([&](auto IDim_) { static_for<0, nDim, 1>{}([&](auto IDim) {
constexpr auto IDim = decltype(IDim_){};
constexpr index_t idim = IDim; constexpr index_t idim = IDim;
constexpr auto dst_partial_original_dims = constexpr auto dst_partial_original_dims =
...@@ -208,6 +202,13 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -208,6 +202,13 @@ struct BlockwiseGenericTensorSliceCopy_v1
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin); thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
#endif #endif
// By position the origin of the per-thread window at the point, where multi-index
// of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy
// is assuming each thread is copy a noraml (not merged) tensor.
// User need to guarantee this is true.
// By setting SubLengths = 1 at the merged dimension, this is always true;
// If in the future, you want to enable SubLengths > 1 at the merged dimension,
// special care in implementation is needed
threadwise_generic_tensor_slice_copy_v1(SrcDesc{}, threadwise_generic_tensor_slice_copy_v1(SrcDesc{},
p_src + src_offset + mThreadSrcOffset, p_src + src_offset + mThreadSrcOffset,
make_zero_array<index_t, nDim>(), make_zero_array<index_t, nDim>(),
...@@ -259,6 +260,13 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -259,6 +260,13 @@ struct BlockwiseGenericTensorSliceCopy_v1
const index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(dst_data_multi_id_begin); const index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(dst_data_multi_id_begin);
#endif #endif
// By position the origin of the per-thread window at the point, where multi-index
// of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy
// is assuming each thread is copy a noraml (not merged) tensor.
// User need to guarantee this is true.
// By setting SubLengths = 1 at the merged dimension, this is always true;
// If in the future, you want to enable SubLengths > 1 at the merged dimension,
// special care in implementation is needed
threadwise_generic_tensor_slice_copy_v1(thread_tensor_desc, threadwise_generic_tensor_slice_copy_v1(thread_tensor_desc,
p_clipboard + clipboard_offset, p_clipboard + clipboard_offset,
make_zero_array<index_t, nDim>(), make_zero_array<index_t, nDim>(),
...@@ -293,7 +301,6 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -293,7 +301,6 @@ struct BlockwiseGenericTensorSliceCopy_v1
Number<IDim_>, Number<StepSize>, integral_constant<bool, PositiveDirection> direction) Number<IDim_>, Number<StepSize>, integral_constant<bool, PositiveDirection> direction)
{ {
constexpr auto IDim = Number<IDim_>{}; constexpr auto IDim = Number<IDim_>{};
constexpr index_t idim = IDim;
static_if<SrcDesc::ContainMultipleOriginalDimensions(IDim)>{}([&](auto) { static_if<SrcDesc::ContainMultipleOriginalDimensions(IDim)>{}([&](auto) {
// logic for a merged dimension, also works for non-merged dimension, but its logic may // logic for a merged dimension, also works for non-merged dimension, but its logic may
...@@ -316,22 +323,21 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -316,22 +323,21 @@ struct BlockwiseGenericTensorSliceCopy_v1
old_src_partial_original_multi_id, StepSize, direction); old_src_partial_original_multi_id, StepSize, direction);
// update "mThreadSrcOriginalMultiId" // update "mThreadSrcOriginalMultiId"
static_for<0, decltype(src_partial_original_dims)::GetSize(), 1>{}([&](auto I_) { static_for<0, decltype(src_partial_original_dims)::GetSize(), 1>{}([&](auto I) {
constexpr auto I = decltype(I_){}; constexpr auto IDimOriginal = src_partial_original_dims[I];
constexpr index_t idim_original = src_partial_original_dims.Get(I);
mThreadSrcOriginalMultiId(idim_original) = new_src_partial_original_multi_id[I]; mThreadSrcOriginalMultiId(IDimOriginal) = new_src_partial_original_multi_id[I];
}); });
// calculate new partial offset on this merged dimension // calculate new partial offset on this merged dimension
const index_t old_src_partial_offset = mThreadSrcPartialOffsets[idim]; const index_t old_src_partial_offset = mThreadSrcPartialOffsets[IDim];
const index_t new_src_partial_offset = const index_t new_src_partial_offset =
src_partial_original_desc.GetOffsetFromMultiIndex( src_partial_original_desc.GetOffsetFromMultiIndex(
new_src_partial_original_multi_id); new_src_partial_original_multi_id);
// update "mThreadSrcPartialOffsets" // update "mThreadSrcPartialOffsets"
mThreadSrcPartialOffsets(idim) = new_src_partial_offset; mThreadSrcPartialOffsets(IDim) = new_src_partial_offset;
// update "mThreadSrcOffset", do "+" before "-" to avoid underflow // update "mThreadSrcOffset", do "+" before "-" to avoid underflow
mThreadSrcOffset = (mThreadSrcOffset + new_src_partial_offset) - old_src_partial_offset; mThreadSrcOffset = (mThreadSrcOffset + new_src_partial_offset) - old_src_partial_offset;
...@@ -346,20 +352,20 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -346,20 +352,20 @@ struct BlockwiseGenericTensorSliceCopy_v1
// of the boundary of the tensor being sliced. Otherwise, there might be hazard like // of the boundary of the tensor being sliced. Otherwise, there might be hazard like
// unsigned integer underflow. That is NO runtime sanity check to prevent the hazard // unsigned integer underflow. That is NO runtime sanity check to prevent the hazard
constexpr index_t idim_original = SrcDesc::GetContainedOriginalDimensions(IDim).Front(); constexpr auto IDimOriginal = SrcDesc::GetContainedOriginalDimensions(IDim).Front();
static_if<PositiveDirection>{}([&](auto fwd) { static_if<PositiveDirection>{}([&](auto fwd) {
mThreadSrcOffset += StepSize * fwd(SrcDesc{}).GetStride(IDim); mThreadSrcOffset += StepSize * fwd(SrcDesc{}).GetStride(IDim);
mThreadSrcOriginalMultiId(idim_original) += StepSize; mThreadSrcOriginalMultiId(IDimOriginal) += StepSize;
mThreadSrcPartialOffsets(idim) += StepSize * fwd(SrcDesc{}).GetStride(IDim); mThreadSrcPartialOffsets(IDim) += StepSize * fwd(SrcDesc{}).GetStride(IDim);
}).Else([&](auto fwd) { }).Else([&](auto fwd) {
mThreadSrcOffset -= StepSize * fwd(SrcDesc{}).GetStride(IDim); mThreadSrcOffset -= StepSize * fwd(SrcDesc{}).GetStride(IDim);
mThreadSrcOriginalMultiId(idim_original) -= StepSize; mThreadSrcOriginalMultiId(IDimOriginal) -= StepSize;
mThreadSrcPartialOffsets(idim) -= StepSize * fwd(SrcDesc{}).GetStride(IDim); mThreadSrcPartialOffsets(IDim) -= StepSize * fwd(SrcDesc{}).GetStride(IDim);
}); });
}); });
} }
......
...@@ -16,31 +16,32 @@ struct Sequence ...@@ -16,31 +16,32 @@ struct Sequence
static constexpr index_t mSize = sizeof...(Is); static constexpr index_t mSize = sizeof...(Is);
__host__ __device__ static constexpr index_t GetSize() { return mSize; } __host__ __device__ static constexpr auto GetSize() { return Number<mSize>{}; }
template <index_t I> __host__ __device__ static constexpr index_t GetImpl(index_t I)
__host__ __device__ static constexpr index_t Get(Number<I>)
{ {
static_assert(I < mSize, "wrong! I too large");
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0 // the last dummy element is to prevent compiler complain about empty array, when mSize = 0
const index_t mData[mSize + 1] = {Is..., 0}; const index_t mData[mSize + 1] = {Is..., 0};
return mData[I]; return mData[I];
} }
template <index_t I> template <index_t I>
__host__ __device__ constexpr auto operator[](Number<I>) const __host__ __device__ static constexpr auto Get(Number<I>)
{ {
return Number<Get(Number<I>{})>{}; static_assert(I < mSize, "wrong! I too large");
return Number<GetImpl(Number<I>{})>{};
} }
// make sure I is constepxr template <index_t I>
__host__ __device__ constexpr index_t operator[](index_t I) const __host__ __device__ constexpr auto operator[](Number<I>) const
{ {
const index_t mData[mSize + 1] = {Is..., 0}; return Get(Number<I>{});
return mData[I];
} }
// make sure I is constepxr if you want a constexpr return type
__host__ __device__ constexpr index_t operator[](index_t I) const { return GetImpl(I); }
template <index_t... IRs> template <index_t... IRs>
__host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/) __host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
{ {
...@@ -54,16 +55,16 @@ struct Sequence ...@@ -54,16 +55,16 @@ struct Sequence
__host__ __device__ static constexpr auto Reverse(); __host__ __device__ static constexpr auto Reverse();
__host__ __device__ static constexpr index_t Front() __host__ __device__ static constexpr auto Front()
{ {
const index_t mData[mSize + 1] = {Is..., 0}; static_assert(mSize > 0, "wrong!");
return mData[0]; return Get(Number<0>{});
} }
__host__ __device__ static constexpr index_t Back() __host__ __device__ static constexpr auto Back()
{ {
const index_t mData[mSize + 1] = {Is..., 0}; static_assert(mSize > 0, "wrong!");
return mData[mSize - 1]; return Get(Number<mSize - 1>{});
} }
__host__ __device__ static constexpr auto PopFront(); __host__ __device__ static constexpr auto PopFront();
......
...@@ -13,30 +13,64 @@ struct integral_constant ...@@ -13,30 +13,64 @@ struct integral_constant
__host__ __device__ constexpr value_type operator()() const noexcept { return value; } __host__ __device__ constexpr value_type operator()() const noexcept { return value; }
}; };
template <class T, T X, T Y> template <class X, class Y>
__host__ __device__ constexpr auto operator+(integral_constant<T, X>, integral_constant<T, Y>) struct is_same : public integral_constant<bool, false>
{ {
return integral_constant<T, X + Y>{}; };
}
template <class T, T X, T Y> template <class X>
__host__ __device__ constexpr auto operator*(integral_constant<T, X>, integral_constant<T, Y>) struct is_same<X, X> : public integral_constant<bool, true>
{ {
return integral_constant<T, X * Y>{}; };
}
template <index_t N> template <index_t N>
using Number = integral_constant<index_t, N>; using Number = integral_constant<index_t, N>;
template <class X, class Y> template <index_t X, index_t Y>
struct is_same : public integral_constant<bool, false> __host__ __device__ constexpr auto operator+(Number<X>, Number<Y>)
{ {
}; return Number<X + Y>{};
}
template <class X> template <index_t X, index_t Y>
struct is_same<X, X> : public integral_constant<bool, true> __host__ __device__ constexpr auto operator-(Number<X>, Number<Y>)
{ {
}; static_assert(Y <= X, "wrong!");
return Number<X - Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator*(Number<X>, Number<Y>)
{
return Number<X * Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator/(Number<X>, Number<Y>)
{
static_assert(Y > 0, "wrong!");
return Number<X / Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator%(Number<X>, Number<Y>)
{
static_assert(Y > 0, "wrong!");
return Number<X % Y>{};
}
#if 0
static constexpr Number<0> 0_c;
static constexpr Number<1> 1_c;
static constexpr Number<2> 2_c;
static constexpr Number<3> 3_c;
static constexpr Number<4> 4_c;
static constexpr Number<5> 5_c;
static constexpr Number<6> 6_c;
static constexpr Number<7> 7_c;
static constexpr Number<8> 8_c;
static constexpr Number<9> 9_c;
#endif
} // namespace ck } // namespace ck
#endif #endif
...@@ -42,20 +42,16 @@ struct integer_divide_ceiler ...@@ -42,20 +42,16 @@ struct integer_divide_ceiler
} }
}; };
template <class T> template <class X, class Y>
__host__ __device__ constexpr T integer_divide_ceil(T a, T b) __host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
{ {
static_assert(is_same<T, index_t>{} || is_same<T, int>{}, "wrong type"); return (x + y - 1) / y;
return (a + b - 1) / b;
} }
template <class T> template <class X, class Y>
__host__ __device__ constexpr T integer_least_multiple(T a, T b) __host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
{ {
static_assert(is_same<T, index_t>{} || is_same<T, int>{}, "wrong type"); return y * integer_divide_ceil(x, y);
return b * integer_divide_ceil(a, b);
} }
template <class T> template <class T>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment