"tests/nn/git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "1ae7778447e0489cbe086fe24ea764105ffa9eb8"
Commit a6b95c39 authored by Chao Liu's avatar Chao Liu
Browse files

rework sequence

parent df73287b
#pragma once #pragma once
#include "common.hip.hpp" #include "common.hip.hpp"
template <class PreviousStrides, class RemainLengths>
__host__ __device__ constexpr auto calculate_default_strides_impl(PreviousStrides, RemainLengths)
{
constexpr index_t previous_stride = PreviousStrides{}.Front();
constexpr index_t current_length = RemainLengths{}.Back();
constexpr index_t current_stride = current_length * previous_stride;
return calculate_default_strides_impl(PreviousStrides{}.PushFront(Number<current_stride>{}),
RemainLengths{}.PopBack());
}
template <class PreviousStrides, index_t L0, index_t L1>
__host__ __device__ constexpr auto calculate_default_strides_impl(PreviousStrides, Sequence<L0, L1>)
{
constexpr index_t previous_stride = PreviousStrides{}.Front();
constexpr index_t current_stride = L1 * previous_stride;
return PreviousStrides{}.PushFront(Number<current_stride>{});
}
template <class Lengths> template <class Lengths>
__host__ __device__ constexpr auto calculate_default_strides(Lengths) __host__ __device__ constexpr auto calculate_default_strides(Lengths)
{ {
return calculate_default_strides_impl(Sequence<1>{}, Lengths{}); return reverse_inclusive_scan_sequence(Lengths{}.PopFront().PushBack(Number<1>{}),
std::multiplies<index_t>{});
} }
// this is ugly, only for 2d // this is ugly, only for 2d
...@@ -58,6 +39,7 @@ template <class Lengths, class Strides> ...@@ -58,6 +39,7 @@ template <class Lengths, class Strides>
struct ConstantTensorDescriptor struct ConstantTensorDescriptor
{ {
using Type = ConstantTensorDescriptor; using Type = ConstantTensorDescriptor;
static constexpr index_t nDim = Lengths::GetSize(); static constexpr index_t nDim = Lengths::GetSize();
__host__ __device__ constexpr ConstantTensorDescriptor() __host__ __device__ constexpr ConstantTensorDescriptor()
...@@ -193,7 +175,8 @@ struct ConstantTensorDescriptor ...@@ -193,7 +175,8 @@ struct ConstantTensorDescriptor
// folded strides // folded strides
constexpr auto fold_strides = constexpr auto fold_strides =
Number<unfold_stride>{} * Number<unfold_stride>{} *
reverse_scan_sequence(fold_intervals.PushBack(Number<1>{}), std::multiplies<index_t>{}); reverse_inclusive_scan_sequence(fold_intervals.PushBack(Number<1>{}),
std::multiplies<index_t>{});
// left and right // left and right
constexpr auto left = make_increasing_sequence(Number<0>{}, Number<IDim>{}, Number<1>{}); constexpr auto left = make_increasing_sequence(Number<0>{}, Number<IDim>{}, Number<1>{});
......
...@@ -9,7 +9,8 @@ struct Sequence ...@@ -9,7 +9,8 @@ struct Sequence
static constexpr index_t mSize = sizeof...(Is); static constexpr index_t mSize = sizeof...(Is);
const index_t mData[mSize] = {Is...}; const index_t mData[mSize + 1] = {
Is..., 0}; // the last element is dummy, to prevent compiler complain on empty Sequence
__host__ __device__ static constexpr index_t GetSize() { return mSize; } __host__ __device__ static constexpr index_t GetSize() { return mSize; }
...@@ -39,10 +40,7 @@ struct Sequence ...@@ -39,10 +40,7 @@ struct Sequence
assert(false); assert(false);
} }
__host__ __device__ constexpr auto Reverse() const __host__ __device__ constexpr auto Reverse() const;
{
// not implemented
}
__host__ __device__ constexpr index_t Front() const { return mData[0]; } __host__ __device__ constexpr index_t Front() const { return mData[0]; }
...@@ -73,13 +71,13 @@ struct Sequence ...@@ -73,13 +71,13 @@ struct Sequence
template <index_t... Ns> template <index_t... Ns>
__host__ __device__ constexpr auto Extract(Number<Ns>...) const __host__ __device__ constexpr auto Extract(Number<Ns>...) const
{ {
return Sequence<Get(Number<Ns>{})...>{}; return Sequence<Type{}.Get(Number<Ns>{})...>{};
} }
template <index_t... Ns> template <index_t... Ns>
__host__ __device__ constexpr auto Extract(Sequence<Ns...>) const __host__ __device__ constexpr auto Extract(Sequence<Ns...>) const
{ {
return Sequence<Get(Number<Ns>{})...>{}; return Sequence<Type{}.Get(Number<Ns>{})...>{};
} }
}; };
...@@ -89,43 +87,109 @@ struct sequence_merge; ...@@ -89,43 +87,109 @@ struct sequence_merge;
template <index_t... Xs, index_t... Ys> template <index_t... Xs, index_t... Ys>
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>> struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
{ {
using Type = Sequence<Xs..., Ys...>; using SeqType = Sequence<Xs..., Ys...>;
}; };
template <index_t IBegin, index_t NSize, index_t Increment> template <index_t IBegin, index_t NSize, index_t Increment>
struct increasing_sequence_gen struct increasing_sequence_gen_impl
{ {
static constexpr index_t NSizeLeft = NSize / 2; static constexpr index_t NSizeLeft = NSize / 2;
using Type = using SeqType = typename sequence_merge<
sequence_merge<typename increasing_sequence_gen<IBegin, NSizeLeft, Increment>::Type, typename increasing_sequence_gen_impl<IBegin, NSizeLeft, Increment>::SeqType,
typename increasing_sequence_gen<IBegin + NSizeLeft * Increment, typename increasing_sequence_gen_impl<IBegin + NSizeLeft * Increment,
NSize - NSizeLeft, NSize - NSizeLeft,
Increment>::Type>; Increment>::SeqType>::SeqType;
}; };
template <index_t IBegin, index_t Increment> template <index_t IBegin, index_t Increment>
struct increasing_sequence_gen<IBegin, 1, Increment> struct increasing_sequence_gen_impl<IBegin, 1, Increment>
{ {
using Type = Sequence<IBegin>; using SeqType = Sequence<IBegin>;
}; };
template <index_t IBegin, index_t Increment> template <index_t IBegin, index_t Increment>
struct increasing_sequence_gen<IBegin, 0, Increment> struct increasing_sequence_gen_impl<IBegin, 0, Increment>
{ {
using Type = Sequence<>; using SeqType = Sequence<>;
};
template <index_t IBegin, index_t IEnd, index_t Increment>
struct increasing_sequence_gen
{
using SeqType =
typename increasing_sequence_gen_impl<IBegin, IEnd - IBegin, Increment>::SeqType;
}; };
template <index_t IBegin, index_t IEnd, index_t Increment> template <index_t IBegin, index_t IEnd, index_t Increment>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_increasing_sequence(Number<IBegin>, Number<IEnd>, Number<Increment>) make_increasing_sequence(Number<IBegin>, Number<IEnd>, Number<Increment>)
{ {
static_assert(IBegin <= IEnd && Increment > 0, "wrong!"); return typename increasing_sequence_gen<IBegin, IEnd, Increment>::SeqType{};
}
constexpr index_t NSize = (IEnd - IBegin) / Increment; template <class, class>
struct sequence_reverse_inclusive_scan;
return increasing_sequence_gen<IBegin, NSize, Increment>{}; template <index_t I, index_t... Is, class Reduce>
} struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce>
{
using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce>::SeqType;
static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());
using SeqType = typename sequence_merge<Sequence<new_reduce>, old_scan>::SeqType;
};
template <index_t I, class Reduce>
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce>
{
using SeqType = Sequence<I>;
};
template <class, class>
struct sequence_extract;
template <class Seq, index_t... Is>
struct sequence_extract<Seq, Sequence<Is...>>
{
using SeqType = Sequence<Seq{}.Get(Number<Is>{})...>;
};
template <class Seq, index_t I>
struct sequence_split
{
static constexpr index_t NSize = Seq{}.GetSize();
using range0 = typename increasing_sequence_gen<0, I, 1>::SeqType;
using range1 = typename increasing_sequence_gen<I, NSize, 1>::SeqType;
using SeqType0 = typename sequence_extract<Seq, range0>::SeqType;
using SeqType1 = typename sequence_extract<Seq, range1>::SeqType;
};
template <class Seq>
struct sequence_reverse
{
static constexpr index_t NSize = Seq{}.GetSize();
using seq_split = sequence_split<Seq, NSize / 2>;
using SeqType = typename sequence_merge<
typename sequence_reverse<typename seq_split::SeqType1>::SeqType,
typename sequence_reverse<typename seq_split::SeqType0>::SeqType>::SeqType;
};
template <index_t I>
struct sequence_reverse<Sequence<I>>
{
using SeqType = Sequence<I>;
};
template <index_t I0, index_t I1>
struct sequence_reverse<Sequence<I0, I1>>
{
using SeqType = Sequence<I1, I0>;
};
template <index_t... Xs, index_t... Ys> template <index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>) __host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
...@@ -179,9 +243,9 @@ __host__ __device__ constexpr auto operator+(Sequence<Xs...>, Number<Y>) ...@@ -179,9 +243,9 @@ __host__ __device__ constexpr auto operator+(Sequence<Xs...>, Number<Y>)
template <index_t... Xs, index_t Y> template <index_t... Xs, index_t Y>
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Number<Y>) __host__ __device__ constexpr auto operator-(Sequence<Xs...>, Number<Y>)
{ {
#if 0 // doesn't compile
constexpr auto seq_x = Sequence<Xs...>{}; constexpr auto seq_x = Sequence<Xs...>{};
#if 0 // doesn't compile
static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) { static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) {
constexpr auto I = decltype(Iter){}; constexpr auto I = decltype(Iter){};
static_assert(seq_x.Get(I) >= Y, "wrong! going to underflow"); static_assert(seq_x.Get(I) >= Y, "wrong! going to underflow");
...@@ -253,95 +317,12 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>) ...@@ -253,95 +317,12 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
return Sequence<Is...>{}; return Sequence<Is...>{};
} }
#if 0 template <class Seq>
// TODO: for some reason, compiler cannot instantiate this template __host__ __device__ constexpr auto sequence_pop_back(Seq)
template <index_t... Is, index_t I>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<Is..., I>)
{
static_assert(sizeof...(Is) > 0, "empty Sequence!");
return Sequence<Is...>{};
}
#else
// TODO: delete these very ugly mess
template <index_t I0, index_t I1>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1>)
{
return Sequence<I0>{};
}
template <index_t I0, index_t I1, index_t I2>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2>)
{
return Sequence<I0, I1>{};
}
template <index_t I0, index_t I1, index_t I2, index_t I3>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3>)
{
return Sequence<I0, I1, I2>{};
}
template <index_t I0, index_t I1, index_t I2, index_t I3, index_t I4>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4>)
{
return Sequence<I0, I1, I2, I3>{};
}
template <index_t I0, index_t I1, index_t I2, index_t I3, index_t I4, index_t I5>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5>)
{
return Sequence<I0, I1, I2, I3, I4>{};
}
template <index_t I0, index_t I1, index_t I2, index_t I3, index_t I4, index_t I5, index_t I6>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5, I6>)
{
return Sequence<I0, I1, I2, I3, I4, I5>{};
}
template <index_t I0,
index_t I1,
index_t I2,
index_t I3,
index_t I4,
index_t I5,
index_t I6,
index_t I7>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5, I6, I7>)
{
return Sequence<I0, I1, I2, I3, I4, I5, I6>{};
}
template <index_t I0,
index_t I1,
index_t I2,
index_t I3,
index_t I4,
index_t I5,
index_t I6,
index_t I7,
index_t I8>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5, I6, I7, I8>)
{
return Sequence<I0, I1, I2, I3, I4, I5, I6, I7>{};
}
template <index_t I0,
index_t I1,
index_t I2,
index_t I3,
index_t I4,
index_t I5,
index_t I6,
index_t I7,
index_t I8,
index_t I9>
__host__ __device__ constexpr auto
sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5, I6, I7, I8, I9>)
{ {
return Sequence<I0, I1, I2, I3, I4, I5, I6, I7, I8>{}; static_assert(Seq{}.GetSize() > 0, "empty Sequence!");
return sequence_pop_front(Seq{}.Reverse()).Reverse();
} }
#endif
template <class F, index_t... Xs> template <class F, index_t... Xs>
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>) __host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>)
...@@ -399,38 +380,20 @@ __host__ __device__ constexpr index_t ...@@ -399,38 +380,20 @@ __host__ __device__ constexpr index_t
return Reduce{}(a, I); return Reduce{}(a, I);
} }
template <index_t NRemain> template <index_t... Is>
struct scan_sequence_impl __host__ __device__ constexpr auto Sequence<Is...>::Reverse() const
{ {
template <class ScanedSeq, class RemainSeq, class Reduce> return typename sequence_reverse<Sequence<Is...>>::SeqType{};
__host__ __device__ constexpr auto operator()(ScanedSeq, RemainSeq, Reduce) const }
{
static_assert(RemainSeq{}.GetSize() == NRemain,
"wrong! RemainSeq and NRemain not consistent!");
constexpr index_t a = Reduce{}(ScanedSeq{}.Back(), RemainSeq{}.Front());
constexpr auto scaned_seq = ScanedSeq{}.PushBack(Number<a>{});
static_if<(NRemain > 1)>{}([&](auto fwd) {
return scan_sequence_impl<NRemain - 1>{}(
scaned_seq, RemainSeq{}.PopFront(), fwd(Reduce{}));
}).else_([&](auto fwd) { return fwd(scaned_seq); });
}
};
template <class Seq, class Reduce> template <class Seq, class Reduce>
__host__ __device__ constexpr auto scan_sequence(Seq, Reduce) __host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce)
{ {
constexpr auto scaned_seq = Sequence<Seq{}.front()>{}; return typename sequence_reverse_inclusive_scan<Seq, Reduce>::SeqType{};
constexpr auto remain_seq = Seq{}.PopFront();
constexpr index_t remain_size = Seq::GetSize() - 1;
return scan_sequence_impl<remain_size>{}(scaned_seq, remain_seq, Reduce{});
} }
template <class Seq, class Reduce> template <class Seq, class Reduce>
__host__ __device__ constexpr auto reverse_scan_sequence(Seq, Reduce) __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce)
{ {
return scan_seqeunce(Seq{}.Reverse(), Reduce{}).Reverse(); return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}).Reverse();
} }
...@@ -80,29 +80,26 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn ...@@ -80,29 +80,26 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
"wrong! cannot evenly divide work for workgroup "); "wrong! cannot evenly divide work for workgroup ");
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; constexpr index_t KBlockWork = mod_conv::integer_divide_ceil(K, KPerBlock);
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; constexpr index_t HBlockWork = mod_conv::integer_divide_ceil(Ho, HoPerBlock);
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; constexpr index_t WBlockWork = mod_conv::integer_divide_ceil(Wo, WoPerBlock);
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock; constexpr index_t NBlockWork = mod_conv::integer_divide_ceil(N, NPerBlock);
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); constexpr auto block_work_desc = make_ConstantTensorDescriptor(
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); Sequence<KBlockWork, HBlockWork, WBlockWork, NBlockWork>{});
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
itmp -= h_block_work_id * (WBlockWork * NBlockWork); const auto block_work_multi_id = block_work_desc.GetMultiIndex(get_block_1d_id());
const index_t w_block_work_id = itmp / NBlockWork;
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork; const index_t k_block_data_begin = block_work_multi_id[0] * KPerBlock;
const index_t ho_block_data_begin = block_work_multi_id[1] * HoPerBlock;
const index_t k_block_data_begin = k_block_work_id * KPerBlock; const index_t wo_block_data_begin = block_work_multi_id[2] * WoPerBlock;
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock; const index_t n_block_data_begin = block_work_multi_id[3] * NPerBlock;
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
const index_t hi_block_data_begin = ho_block_data_begin; const index_t hi_block_data_begin = ho_block_data_begin;
const index_t wi_block_data_begin = wo_block_data_begin; const index_t wi_block_data_begin = wo_block_data_begin;
// global tensor view // global tensor view
constexpr auto wei_c_k_global_desc = constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(I0, I3);
make_ConstantTensorDescriptor(Sequence<C, K>{}, Sequence<Y * X * K, 1>{});
// LDS tensor view // LDS tensor view
// be careful of alignment // be careful of alignment
...@@ -360,13 +357,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn ...@@ -360,13 +357,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto f_dummy) { // f_dummy do nothing but static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto fwd) {
// perfect forwarding. // fwd do nothing but perfect forwarding.
// Using this trick to // Using this trick to make this lambda a generic lambda, so it won't be compiled until
// make this lambda a generic lambda, so it won't be compiled until // being instantiated here
// instantiated
static_assert( static_assert(
(f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0), (fwd(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
"wrong!"); "wrong!");
// output is a 10d tensor // output is a 10d tensor
...@@ -374,37 +370,32 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn ...@@ -374,37 +370,32 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
constexpr index_t N1 = NPerBlock / N2; constexpr index_t N1 = NPerBlock / N2;
constexpr index_t W2 = constexpr index_t W2 =
(GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC); (GemmNLevel0Cluster * GemmNLevel1Cluster) / fwd(NPerBlock / GemmNPerThreadSubC);
constexpr index_t W1 = WoPerBlock / W2; constexpr index_t W1 = WoPerBlock / W2;
constexpr index_t K2 = GemmMPerThreadSubC; constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread; constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc = constexpr auto out_10d_global_desc = fwd(out_k_h_w_n_global_desc)
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2), .Fold(I3, Number<N1>{}, Number<N2>{})
K1, .Fold(I2, Number<W1>{}, Number<W2>{})
K2, .Fold(I0, Number<K1>{}, Number<K2>{});
Ho,
Wo / (W1 * W2), constexpr auto out_10d_thread_desc = fwd(out_k_h_w_n_thread_desc)
W1, .Fold(I3, Number<1>{}, Number<N2>{})
W2, .Fold(I2, Number<W1>{}, Number<1>{})
N / f_dummy(N1 * N2), .Fold(I0, Number<1>{}, Number<K2>{});
N1,
N2>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
#if 0 #if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{ {
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"out_k_h_w_n_thread_desc"); "a: out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); print_ConstantTensorDescriptor(out_10d_thread_desc, "a: out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
"out_k_h_w_n_global_desc"); "a: out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); print_ConstantTensorDescriptor(out_10d_global_desc, "a: out_10d_global_desc");
} }
#endif #endif
...@@ -419,8 +410,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn ...@@ -419,8 +410,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(), out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{}); Number<OutThreadCopyDataPerWrite_N>{});
}).else_([&](auto f_dummy) { }).else_([&](auto fwd) {
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0, GemmNPerThreadSubC % NPerThread == 0,
"wrong!"); "wrong!");
...@@ -429,32 +420,33 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn ...@@ -429,32 +420,33 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock; constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster; constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3); constexpr index_t W1 = WoPerBlock / fwd(W2 * W3);
constexpr index_t K2 = GemmMPerThreadSubC; constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread; constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor( constexpr auto out_10d_global_desc =
Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3, N / N1, N1>{}); fwd(out_k_h_w_n_global_desc)
.Fold(I3, Number<N1>{})
.Fold(I2, Number<W1>{}, Number<W2>{}, Number<W3>{})
.Fold(I0, Number<K1>{}, Number<K2>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( constexpr auto out_10d_thread_desc =
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{}); fwd(out_k_h_w_n_thread_desc)
.Fold(I3, Number<N1>{})
.Fold(I2, Number<W1>{}, Number<1>{}, Number<W3>{})
.Fold(I0, Number<1>{}, Number<K2>{});
#if 0 #if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{ {
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"out_k_h_w_n_thread_desc"); "b: out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); print_ConstantTensorDescriptor(out_10d_thread_desc, "b: out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
"out_k_h_w_n_global_desc"); "b: out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); print_ConstantTensorDescriptor(out_10d_global_desc, "b: out_10d_global_desc");
for(index_t i = 0; i < 64; ++i)
{
printf("out %f, ", p_out_thread[i]);
}
} }
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment