Commit 498e71b0 authored by Chao Liu's avatar Chao Liu
Browse files

try using more constexpr

parent 917d7a2b
...@@ -140,7 +140,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -140,7 +140,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t OutThreadCopyDataPerWrite_N = 2;
#elif 1 #elif 0
// for 3x3, 34x34, v1r3, Pascal // for 3x3, 34x34, v1r3, Pascal
// for 3x3, 28x28, v1r3, Pascal // for 3x3, 28x28, v1r3, Pascal
// for 3x3, 14x14, v1r3, Pascal // for 3x3, 14x14, v1r3, Pascal
......
...@@ -443,7 +443,7 @@ int main(int argc, char* argv[]) ...@@ -443,7 +443,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 1 #elif 0
// 3x3 filter, 28x28 image // 3x3 filter, 28x28 image
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 256; constexpr index_t C = 256;
...@@ -455,7 +455,7 @@ int main(int argc, char* argv[]) ...@@ -455,7 +455,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 1
// 1x1 filter, 28x28 image // 1x1 filter, 28x28 image
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 512; constexpr index_t C = 512;
......
...@@ -18,11 +18,24 @@ struct Array ...@@ -18,11 +18,24 @@ struct Array
__host__ __device__ constexpr index_t GetSize() const { return NSize; } __host__ __device__ constexpr index_t GetSize() const { return NSize; }
__host__ __device__ const TData& operator[](index_t i) const { return mData[i]; } __host__ __device__ constexpr TData operator[](index_t i) const { return mData[i]; }
__host__ __device__ TData& operator[](index_t i) { return mData[i]; } __host__ __device__ TData& operator[](index_t i) { return mData[i]; }
__host__ __device__ auto PushBack(TData x) const template <index_t I>
__host__ __device__ constexpr TData Get(Number<I>) const
{
return mData[I];
}
template <index_t I>
__host__ __device__ constexpr bool Set(Number<I>, TData x)
{
mData[I] = x;
return true; // for constexpr
}
__host__ __device__ constexpr auto PushBack(TData x) const
{ {
Array<TData, NSize + 1> new_array; Array<TData, NSize + 1> new_array;
......
...@@ -74,7 +74,8 @@ struct ConstantMergedTensorDescriptor ...@@ -74,7 +74,8 @@ struct ConstantMergedTensorDescriptor
return OriginalTensorDesc::GetElementSize(); return OriginalTensorDesc::GetElementSize();
} }
__host__ __device__ static auto #if 0
__host__ __device__ static constexpr auto
GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id) GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
{ {
Array<index_t, nOriginalDim> original_multi_id; Array<index_t, nOriginalDim> original_multi_id;
...@@ -98,21 +99,111 @@ struct ConstantMergedTensorDescriptor ...@@ -98,21 +99,111 @@ struct ConstantMergedTensorDescriptor
return original_multi_id; return original_multi_id;
} }
#else
template <class OriginalDimsPartial>
struct GetOriginalMultiIndexFromMultiIndex_impl1
{
const Array<index_t, OriginalDimsPartial::GetSize()>& original_multi_id_partial_ref;
Array<index_t, nOriginalDim>& original_multi_id_ref;
__host__ __device__ constexpr GetOriginalMultiIndexFromMultiIndex_impl1(
const Array<index_t, OriginalDimsPartial::GetSize()>& original_multi_id_partial,
Array<index_t, nOriginalDim>& original_multi_id)
: original_multi_id_partial_ref(original_multi_id_partial),
original_multi_id_ref(original_multi_id)
{
}
template <index_t I>
constexpr __host__ __device__ bool operator()(Number<I>) const
{
constexpr index_t idim_original = OriginalDimsPartial::Get(Number<I>{});
index_t itmp = original_multi_id_partial_ref.Get(Number<I>{});
original_multi_id_ref.Set(Number<idim_original>{}, itmp);
return true;
}
};
struct GetOriginalMultiIndexFromMultiIndex_impl0
{
const Array<index_t, nDim>& multi_id_ref;
Array<index_t, nOriginalDim>& original_multi_id_ref;
__host__ __device__ constexpr GetOriginalMultiIndexFromMultiIndex_impl0(
const Array<index_t, nDim>& multi_id, Array<index_t, nOriginalDim>& original_multi_id)
: multi_id_ref(multi_id), original_multi_id_ref(original_multi_id)
{
}
template <index_t IDim>
constexpr __host__ __device__ bool operator()(Number<IDim>) const
{
constexpr auto original_dims_partial =
std::get<IDim>(std::tuple<OriginalDimMergeSeqs...>{});
// get partial original-multi-id corresponding to this merged dimension
const auto original_multi_id_partial =
OriginalTensorDesc::Extract(original_dims_partial)
.GetMultiIndexFrom1dIndex(multi_id_ref[IDim]);
static_for<0, original_dims_partial.GetSize(), 1>{}(
GetOriginalMultiIndexFromMultiIndex_impl1<decltype(original_dims_partial)>(
original_multi_id_partial, original_multi_id_ref));
return true;
}
};
__host__ __device__ static constexpr auto
GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
{
Array<index_t, nOriginalDim> original_multi_id;
static_for<0, nDim, 1>{}(
GetOriginalMultiIndexFromMultiIndex_impl0(multi_id, original_multi_id));
return original_multi_id;
}
template <index_t... Is>
__host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Sequence<Is...>)
{
constexpr auto multi_id = sequence2array(Sequence<Is...>{});
constexpr auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);
return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
}
#endif
#if 0
// return type is Sequence<...>
template <index_t... Is>
__host__ __device__ static constexpr auto GetOriginalMultiIndexFromMultiIndex(Sequence<Is...>)
{
// not implemented
return Sequence<>{};
}
#endif
__host__ __device__ static index_t GetOffsetFromMultiIndex(Array<index_t, nDim> multi_id) __host__ __device__ static constexpr index_t
GetOffsetFromMultiIndex(Array<index_t, nDim> multi_id)
{ {
const auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id); auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);
return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id); return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
} }
template <class... Is> template <class... Is>
__host__ __device__ static index_t GetOffsetFromMultiIndex(Is... is) __host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Is... is)
{ {
return GetOffsetFromMultiIndex(Array<index_t, nDim>{is...}); return GetOffsetFromMultiIndex(Array<index_t, nDim>{is...});
} }
__host__ __device__ static Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id) __host__ __device__ static constexpr Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
{ {
constexpr auto dummy_desc = make_ConstantTensorDescriptor_packed(GetLengths()); constexpr auto dummy_desc = make_ConstantTensorDescriptor_packed(GetLengths());
......
...@@ -4,7 +4,8 @@ ...@@ -4,7 +4,8 @@
template <class Lengths> template <class Lengths>
__host__ __device__ constexpr auto calculate_tensor_strides_packed(Lengths) __host__ __device__ constexpr auto calculate_tensor_strides_packed(Lengths)
{ {
return reverse_inclusive_scan_sequence(Lengths{}.PopFront(), mod_conv::multiplies<index_t>{}) return reverse_inclusive_scan_sequence(
Lengths{}.PopFront(), mod_conv::multiplies<index_t>{}, Number<1>{})
.PushBack(Number<1>{}); .PushBack(Number<1>{});
} }
...@@ -91,8 +92,10 @@ struct ConstantTensorDescriptor ...@@ -91,8 +92,10 @@ struct ConstantTensorDescriptor
return align.Get() * ((element_space_unaligned + align.Get() - 1) / align.Get()); return align.Get() * ((element_space_unaligned + align.Get() - 1) / align.Get());
} }
#if 0
template <index_t NSize> template <index_t NSize>
__host__ __device__ static index_t GetOffsetFromMultiIndex(Array<index_t, NSize> multi_id) __host__ __device__ static constexpr index_t
GetOffsetFromMultiIndex(Array<index_t, NSize> multi_id)
{ {
static_assert(NSize == nDim, "wrong! Dimension not consistent"); static_assert(NSize == nDim, "wrong! Dimension not consistent");
...@@ -105,9 +108,43 @@ struct ConstantTensorDescriptor ...@@ -105,9 +108,43 @@ struct ConstantTensorDescriptor
return offset; return offset;
} }
#else
template <index_t NSize>
struct GetOffsetFromMultiIndex_impl
{
Array<index_t, NSize>& multi_id_ref;
index_t& offset_ref;
__host__ __device__ constexpr GetOffsetFromMultiIndex_impl(Array<index_t, NSize>& multi_id,
index_t& offset)
: multi_id_ref(multi_id), offset_ref(offset)
{
}
template <index_t IDim>
__host__ __device__ constexpr bool operator()(Number<IDim>) const
{
offset_ref += multi_id_ref.Get(Number<IDim>{}) * Type::GetStride(Number<IDim>{});
return true;
}
};
template <index_t NSize>
__host__ __device__ static constexpr index_t
GetOffsetFromMultiIndex(Array<index_t, NSize> multi_id)
{
static_assert(NSize == nDim, "wrong! Dimension not consistent");
index_t offset = 0;
static_for<0, nDim, 1>{}(GetOffsetFromMultiIndex_impl<NSize>(multi_id, offset));
return offset;
}
#endif
template <class... Is> template <class... Is>
__host__ __device__ static index_t GetOffsetFromMultiIndex(Is... is) __host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Is... is)
{ {
return GetOffsetFromMultiIndex(Array<index_t, sizeof...(Is)>{is...}); return GetOffsetFromMultiIndex(Array<index_t, sizeof...(Is)>{is...});
} }
...@@ -123,7 +160,8 @@ struct ConstantTensorDescriptor ...@@ -123,7 +160,8 @@ struct ConstantTensorDescriptor
multi_id * GetStrides(), mod_conv::plus<index_t>{}, Number<0>{}); multi_id * GetStrides(), mod_conv::plus<index_t>{}, Number<0>{});
} }
__host__ __device__ static Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id) #if 0
__host__ __device__ static constexpr Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
{ {
Array<index_t, nDim> multi_id; Array<index_t, nDim> multi_id;
...@@ -141,8 +179,58 @@ struct ConstantTensorDescriptor ...@@ -141,8 +179,58 @@ struct ConstantTensorDescriptor
return multi_id; return multi_id;
} }
#else
struct GetMultiIndexFrom1dIndex_impl
{
using DummyStrides = decltype(calculate_tensor_strides_packed(GetLengths()));
index_t& id_ref;
Array<index_t, nDim>& multi_id_ref;
__host__ __device__ constexpr GetMultiIndexFrom1dIndex_impl(index_t& id,
Array<index_t, nDim>& multi_id)
: id_ref(id), multi_id_ref(multi_id)
{
}
template <index_t IDim>
__host__ __device__ constexpr bool operator()(Number<IDim>) const
{
constexpr index_t stride = DummyStrides::Get(Number<IDim>{});
multi_id_ref.Set(Number<IDim>{}, id_ref / stride);
id_ref -= multi_id_ref.Get(Number<IDim>{}) * stride;
return true;
}
};
__host__ __device__ static constexpr Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
{
Array<index_t, nDim> multi_id;
constexpr auto dummy_strides = calculate_tensor_strides_packed(GetLengths());
// calculate index in each of the dimensions in the order of their dimension
static_for<0, nDim - 1, 1>{}(GetMultiIndexFrom1dIndex_impl(id, multi_id));
index_t itmp = id / dummy_strides.Get(Number<nDim - 1>{});
multi_id.Set(Number<nDim - 1>{}, itmp);
return multi_id;
}
#endif
#if 0
// return type is Sequence<...>
template<index_t Id>
__host__ __device__ static constexpr auto GetMultiIndexFrom1dIndex(Number<Id>)
{
return inclusive_scan_sequence(f_impl, GetStrides(), Number<Id>{});
}
#endif
__host__ __device__ static auto __host__ __device__ static constexpr auto
GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id) GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
{ {
return multi_id; return multi_id;
...@@ -278,8 +366,8 @@ struct ConstantTensorDescriptor ...@@ -278,8 +366,8 @@ struct ConstantTensorDescriptor
// folded strides // folded strides
constexpr auto fold_strides = constexpr auto fold_strides =
Number<unfold_stride>{} * Number<unfold_stride>{} *
reverse_inclusive_scan_sequence(fold_intervals.PushBack(Number<1>{}), reverse_inclusive_scan_sequence(
mod_conv::multiplies<index_t>{}); fold_intervals.PushBack(Number<1>{}), mod_conv::multiplies<index_t>{}, Number<1>{});
// left and right // left and right
constexpr auto left = typename arithmetic_sequence_gen<0, IDim, 1>::SeqType{}; constexpr auto left = typename arithmetic_sequence_gen<0, IDim, 1>::SeqType{};
......
...@@ -139,31 +139,49 @@ struct arithmetic_sequence_gen ...@@ -139,31 +139,49 @@ struct arithmetic_sequence_gen
typename arithmetic_sequence_gen_impl<IBegin, IEnd - IBegin, Increment>::SeqType; typename arithmetic_sequence_gen_impl<IBegin, IEnd - IBegin, Increment>::SeqType;
}; };
template <class, class> // reverse scan with init
template <class, class, index_t>
struct sequence_reverse_inclusive_scan; struct sequence_reverse_inclusive_scan;
template <index_t I, index_t... Is, class Reduce> template <index_t I, index_t... Is, class Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce> struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
{ {
using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce>::SeqType; using old_scan =
typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::SeqType;
static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front()); static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());
using SeqType = typename sequence_merge<Sequence<new_reduce>, old_scan>::SeqType; using SeqType = typename sequence_merge<Sequence<new_reduce>, old_scan>::SeqType;
}; };
template <index_t I, class Reduce> template <index_t I, class Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce> struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce, Init>
{ {
using SeqType = Sequence<I>; using SeqType = Sequence<Reduce{}(I, Init)>;
}; };
template <class Reduce> template <class Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<>, Reduce> struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
{ {
using SeqType = Sequence<>; using SeqType = Sequence<>;
}; };
#if 0
// reverse scan with token
template <class, class, index_t>
struct sequence_reverse_inclusive_token_scan;
template <index_t I, index_t... Is, class F, index_t Token>
struct sequence_reverse_inclusive_token_scan<Sequence<I, Is...>, F, Token>
{
using old_scan = typename sequence_reverse_inclusive_token_scan<Sequence<Is...>, F, Token>::SeqType;
static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());
using SeqType = typename sequence_merge<Sequence<new_reduce>, old_scan>::SeqType;
};
#endif
template <class, class> template <class, class>
struct sequence_extract; struct sequence_extract;
...@@ -434,16 +452,16 @@ transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>) ...@@ -434,16 +452,16 @@ transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
return Sequence<f(Xs, Ys, Zs)...>{}; return Sequence<f(Xs, Ys, Zs)...>{};
} }
template <class Seq, class Reduce> template <class Seq, class Reduce, index_t Init>
__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce) __host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number<Init>)
{ {
return typename sequence_reverse_inclusive_scan<Seq, Reduce>::SeqType{}; return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::SeqType{};
} }
template <class Seq, class Reduce> template <class Seq, class Reduce, index_t Init>
__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce) __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<Init>)
{ {
return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}).Reverse(); return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
} }
template <class Seq> template <class Seq>
......
...@@ -203,6 +203,7 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -203,6 +203,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths); make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths);
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) { static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
#if 0
constexpr auto repeat_multi_id = sequence2array(decltype(repeat_multi_id_){}); constexpr auto repeat_multi_id = sequence2array(decltype(repeat_multi_id_){});
const auto src_thread_data_multi_id_begin = const auto src_thread_data_multi_id_begin =
...@@ -216,6 +217,19 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -216,6 +217,19 @@ struct BlockwiseGenericTensorSliceCopy_v1
const index_t clipboard_offset = thread_tensor_desc.GetOffsetFromMultiIndex( const index_t clipboard_offset = thread_tensor_desc.GetOffsetFromMultiIndex(
clipboard_data_multi_id_begin); // cannot not constexpr, why? clipboard_data_multi_id_begin); // cannot not constexpr, why?
#else
constexpr auto src_thread_data_multi_id_begin =
repeat_multi_id_ * data_per_cluster_per_dims;
constexpr auto clipboard_data_multi_id_begin =
repeat_multi_id_ * thread_sub_tensor_lengths;
constexpr index_t src_offset =
SrcDesc::GetOffsetFromMultiIndex(src_thread_data_multi_id_begin);
constexpr index_t clipboard_offset =
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
#endif
threadwise_generic_tensor_slice_copy_v1(SrcDesc{}, threadwise_generic_tensor_slice_copy_v1(SrcDesc{},
p_src + src_offset + mThreadSrcOffset, p_src + src_offset + mThreadSrcOffset,
......
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
struct forwarder struct forwarder
{ {
template <typename T> template <typename T>
__host__ __device__ constexpr T operator()(T&& x) const __host__ __device__ constexpr T&& operator()(T&& x) const
{ {
return std::forward<T>(x); return static_cast<T&&>(x);
} }
}; };
...@@ -76,7 +76,7 @@ template <index_t Iter, index_t Remaining, index_t Increment> ...@@ -76,7 +76,7 @@ template <index_t Iter, index_t Remaining, index_t Increment>
struct static_for_impl struct static_for_impl
{ {
template <class F> template <class F>
__host__ __device__ void operator()(F f) const constexpr __host__ __device__ void operator()(F f) const
{ {
static_assert(Remaining % Increment == 0, "wrong! Remaining % Increment != 0"); static_assert(Remaining % Increment == 0, "wrong! Remaining % Increment != 0");
static_assert(Increment <= Remaining, "will go out-of-range"); static_assert(Increment <= Remaining, "will go out-of-range");
...@@ -90,7 +90,7 @@ template <index_t Iter, index_t Increment> ...@@ -90,7 +90,7 @@ template <index_t Iter, index_t Increment>
struct static_for_impl<Iter, 0, Increment> struct static_for_impl<Iter, 0, Increment>
{ {
template <class F> template <class F>
__host__ __device__ void operator()(F) const constexpr __host__ __device__ void operator()(F) const
{ {
// no work left, just return // no work left, just return
return; return;
...@@ -102,13 +102,19 @@ template <index_t NBegin, index_t NEnd, index_t Increment> ...@@ -102,13 +102,19 @@ template <index_t NBegin, index_t NEnd, index_t Increment>
struct static_for struct static_for
{ {
template <class F> template <class F>
__host__ __device__ void operator()(F f) const constexpr __host__ __device__ void operator()(F f) const
{ {
static_assert(NBegin <= NEnd, "wrongs! should have NBegin <= NEnd");
static_assert((NEnd - NBegin) % Increment == 0, static_assert((NEnd - NBegin) % Increment == 0,
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0"); "Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
#if 0
static_if<(NBegin < NEnd)>{}( static_if<(NBegin < NEnd)>{}(
[&](auto fwd) { static_for_impl<NBegin, NEnd - NBegin, fwd(Increment)>{}(f); }); [&](auto fwd) { static_for_impl<NBegin, NEnd - NBegin, fwd(Increment)>{}(f); });
#else
static_for_impl<NBegin, NEnd - NBegin, Increment>{}(f);
#endif
} }
}; };
......
...@@ -155,7 +155,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn ...@@ -155,7 +155,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
decltype(wei_c_k_global_desc), decltype(wei_c_k_global_desc),
decltype(wei_c_k_block_desc), decltype(wei_c_k_block_desc),
decltype(wei_c_k_block_desc.GetLengths()), decltype(wei_c_k_block_desc.GetLengths()),
WeiBlockCopyDataPerRead_K>{}; WeiBlockCopyDataPerRead_K>({0, 0}, {0, 0});
// a series of blockwise batched GEMM // a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix // C_matrix += transpose(A_matrix) * B_matrix
...@@ -235,8 +235,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn ...@@ -235,8 +235,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
} }
#endif #endif
// set threadwise output tensor to 0 // set threadwise output to 0
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread); threadwise_matrix_set_zero(c_k_wn_thread_mtx_desc, p_out_thread);
for(index_t y = 0; y < Y; ++y) for(index_t y = 0; y < Y; ++y)
{ {
......
...@@ -246,7 +246,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -246,7 +246,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// choose GEMM implementation here // choose GEMM implementation here
const auto run_blockwise_gemm = [&](auto... Xs) { const auto run_blockwise_gemm = [&](auto... Xs) {
#if 1 #if 0
return blockwise_gemm.Run(Xs...); return blockwise_gemm.Run(Xs...);
#else #else
return blockwise_gemm.Run_asm(Xs...); return blockwise_gemm.Run_asm(Xs...);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment