"git@developer.sourcefind.cn:wqshmzh/ktransformers.git" did not exist on "d3ebdafd4b1a06dca822004407cb2e436951ce19"
Commit a6b95c39 authored by Chao Liu's avatar Chao Liu
Browse files

rework sequence

parent df73287b
#pragma once
#include "common.hip.hpp"
template <class PreviousStrides, class RemainLengths>
__host__ __device__ constexpr auto calculate_default_strides_impl(PreviousStrides, RemainLengths)
{
constexpr index_t previous_stride = PreviousStrides{}.Front();
constexpr index_t current_length = RemainLengths{}.Back();
constexpr index_t current_stride = current_length * previous_stride;
return calculate_default_strides_impl(PreviousStrides{}.PushFront(Number<current_stride>{}),
RemainLengths{}.PopBack());
}
template <class PreviousStrides, index_t L0, index_t L1>
__host__ __device__ constexpr auto calculate_default_strides_impl(PreviousStrides, Sequence<L0, L1>)
{
constexpr index_t previous_stride = PreviousStrides{}.Front();
constexpr index_t current_stride = L1 * previous_stride;
return PreviousStrides{}.PushFront(Number<current_stride>{});
}
template <class Lengths>
__host__ __device__ constexpr auto calculate_default_strides(Lengths)
{
return calculate_default_strides_impl(Sequence<1>{}, Lengths{});
return reverse_inclusive_scan_sequence(Lengths{}.PopFront().PushBack(Number<1>{}),
std::multiplies<index_t>{});
}
// this is ugly, only for 2d
......@@ -58,6 +39,7 @@ template <class Lengths, class Strides>
struct ConstantTensorDescriptor
{
using Type = ConstantTensorDescriptor;
static constexpr index_t nDim = Lengths::GetSize();
__host__ __device__ constexpr ConstantTensorDescriptor()
......@@ -193,7 +175,8 @@ struct ConstantTensorDescriptor
// folded strides
constexpr auto fold_strides =
Number<unfold_stride>{} *
reverse_scan_sequence(fold_intervals.PushBack(Number<1>{}), std::multiplies<index_t>{});
reverse_inclusive_scan_sequence(fold_intervals.PushBack(Number<1>{}),
std::multiplies<index_t>{});
// left and right
constexpr auto left = make_increasing_sequence(Number<0>{}, Number<IDim>{}, Number<1>{});
......
......@@ -9,7 +9,8 @@ struct Sequence
static constexpr index_t mSize = sizeof...(Is);
const index_t mData[mSize] = {Is...};
const index_t mData[mSize + 1] = {
Is..., 0}; // the last element is dummy, to prevent compiler complain on empty Sequence
__host__ __device__ static constexpr index_t GetSize() { return mSize; }
......@@ -39,10 +40,7 @@ struct Sequence
assert(false);
}
__host__ __device__ constexpr auto Reverse() const
{
// not implemented
}
__host__ __device__ constexpr auto Reverse() const;
__host__ __device__ constexpr index_t Front() const { return mData[0]; }
......@@ -73,13 +71,13 @@ struct Sequence
template <index_t... Ns>
__host__ __device__ constexpr auto Extract(Number<Ns>...) const
{
return Sequence<Get(Number<Ns>{})...>{};
return Sequence<Type{}.Get(Number<Ns>{})...>{};
}
template <index_t... Ns>
__host__ __device__ constexpr auto Extract(Sequence<Ns...>) const
{
return Sequence<Get(Number<Ns>{})...>{};
return Sequence<Type{}.Get(Number<Ns>{})...>{};
}
};
......@@ -89,43 +87,109 @@ struct sequence_merge;
template <index_t... Xs, index_t... Ys>
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
{
using Type = Sequence<Xs..., Ys...>;
using SeqType = Sequence<Xs..., Ys...>;
};
template <index_t IBegin, index_t NSize, index_t Increment>
struct increasing_sequence_gen
struct increasing_sequence_gen_impl
{
static constexpr index_t NSizeLeft = NSize / 2;
using Type =
sequence_merge<typename increasing_sequence_gen<IBegin, NSizeLeft, Increment>::Type,
typename increasing_sequence_gen<IBegin + NSizeLeft * Increment,
using SeqType = typename sequence_merge<
typename increasing_sequence_gen_impl<IBegin, NSizeLeft, Increment>::SeqType,
typename increasing_sequence_gen_impl<IBegin + NSizeLeft * Increment,
NSize - NSizeLeft,
Increment>::Type>;
Increment>::SeqType>::SeqType;
};
template <index_t IBegin, index_t Increment>
struct increasing_sequence_gen<IBegin, 1, Increment>
struct increasing_sequence_gen_impl<IBegin, 1, Increment>
{
using Type = Sequence<IBegin>;
using SeqType = Sequence<IBegin>;
};
template <index_t IBegin, index_t Increment>
struct increasing_sequence_gen<IBegin, 0, Increment>
struct increasing_sequence_gen_impl<IBegin, 0, Increment>
{
using Type = Sequence<>;
using SeqType = Sequence<>;
};
template <index_t IBegin, index_t IEnd, index_t Increment>
struct increasing_sequence_gen
{
using SeqType =
typename increasing_sequence_gen_impl<IBegin, IEnd - IBegin, Increment>::SeqType;
};
template <index_t IBegin, index_t IEnd, index_t Increment>
__host__ __device__ constexpr auto
make_increasing_sequence(Number<IBegin>, Number<IEnd>, Number<Increment>)
{
static_assert(IBegin <= IEnd && Increment > 0, "wrong!");
return typename increasing_sequence_gen<IBegin, IEnd, Increment>::SeqType{};
}
constexpr index_t NSize = (IEnd - IBegin) / Increment;
template <class, class>
struct sequence_reverse_inclusive_scan;
return increasing_sequence_gen<IBegin, NSize, Increment>{};
}
template <index_t I, index_t... Is, class Reduce>
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce>
{
using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce>::SeqType;
static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());
using SeqType = typename sequence_merge<Sequence<new_reduce>, old_scan>::SeqType;
};
template <index_t I, class Reduce>
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce>
{
using SeqType = Sequence<I>;
};
template <class, class>
struct sequence_extract;
template <class Seq, index_t... Is>
struct sequence_extract<Seq, Sequence<Is...>>
{
using SeqType = Sequence<Seq{}.Get(Number<Is>{})...>;
};
template <class Seq, index_t I>
struct sequence_split
{
static constexpr index_t NSize = Seq{}.GetSize();
using range0 = typename increasing_sequence_gen<0, I, 1>::SeqType;
using range1 = typename increasing_sequence_gen<I, NSize, 1>::SeqType;
using SeqType0 = typename sequence_extract<Seq, range0>::SeqType;
using SeqType1 = typename sequence_extract<Seq, range1>::SeqType;
};
template <class Seq>
struct sequence_reverse
{
static constexpr index_t NSize = Seq{}.GetSize();
using seq_split = sequence_split<Seq, NSize / 2>;
using SeqType = typename sequence_merge<
typename sequence_reverse<typename seq_split::SeqType1>::SeqType,
typename sequence_reverse<typename seq_split::SeqType0>::SeqType>::SeqType;
};
template <index_t I>
struct sequence_reverse<Sequence<I>>
{
using SeqType = Sequence<I>;
};
template <index_t I0, index_t I1>
struct sequence_reverse<Sequence<I0, I1>>
{
using SeqType = Sequence<I1, I0>;
};
template <index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
......@@ -179,9 +243,9 @@ __host__ __device__ constexpr auto operator+(Sequence<Xs...>, Number<Y>)
template <index_t... Xs, index_t Y>
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Number<Y>)
{
#if 0 // doesn't compile
constexpr auto seq_x = Sequence<Xs...>{};
#if 0 // doesn't compile
static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) {
constexpr auto I = decltype(Iter){};
static_assert(seq_x.Get(I) >= Y, "wrong! going to underflow");
......@@ -253,95 +317,12 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
return Sequence<Is...>{};
}
#if 0
// TODO: for some reason, compiler cannot instantiate this template
template <index_t... Is, index_t I>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<Is..., I>)
{
static_assert(sizeof...(Is) > 0, "empty Sequence!");
return Sequence<Is...>{};
}
#else
// TODO: delete these very ugly mess
template <index_t I0, index_t I1>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1>)
{
return Sequence<I0>{};
}
template <index_t I0, index_t I1, index_t I2>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2>)
{
return Sequence<I0, I1>{};
}
template <index_t I0, index_t I1, index_t I2, index_t I3>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3>)
{
return Sequence<I0, I1, I2>{};
}
template <index_t I0, index_t I1, index_t I2, index_t I3, index_t I4>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4>)
{
return Sequence<I0, I1, I2, I3>{};
}
template <index_t I0, index_t I1, index_t I2, index_t I3, index_t I4, index_t I5>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5>)
{
return Sequence<I0, I1, I2, I3, I4>{};
}
template <index_t I0, index_t I1, index_t I2, index_t I3, index_t I4, index_t I5, index_t I6>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5, I6>)
{
return Sequence<I0, I1, I2, I3, I4, I5>{};
}
template <index_t I0,
index_t I1,
index_t I2,
index_t I3,
index_t I4,
index_t I5,
index_t I6,
index_t I7>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5, I6, I7>)
{
return Sequence<I0, I1, I2, I3, I4, I5, I6>{};
}
template <index_t I0,
index_t I1,
index_t I2,
index_t I3,
index_t I4,
index_t I5,
index_t I6,
index_t I7,
index_t I8>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5, I6, I7, I8>)
{
return Sequence<I0, I1, I2, I3, I4, I5, I6, I7>{};
}
template <index_t I0,
index_t I1,
index_t I2,
index_t I3,
index_t I4,
index_t I5,
index_t I6,
index_t I7,
index_t I8,
index_t I9>
__host__ __device__ constexpr auto
sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5, I6, I7, I8, I9>)
template <class Seq>
__host__ __device__ constexpr auto sequence_pop_back(Seq)
{
return Sequence<I0, I1, I2, I3, I4, I5, I6, I7, I8>{};
static_assert(Seq{}.GetSize() > 0, "empty Sequence!");
return sequence_pop_front(Seq{}.Reverse()).Reverse();
}
#endif
template <class F, index_t... Xs>
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>)
......@@ -399,38 +380,20 @@ __host__ __device__ constexpr index_t
return Reduce{}(a, I);
}
template <index_t NRemain>
struct scan_sequence_impl
template <index_t... Is>
__host__ __device__ constexpr auto Sequence<Is...>::Reverse() const
{
template <class ScanedSeq, class RemainSeq, class Reduce>
__host__ __device__ constexpr auto operator()(ScanedSeq, RemainSeq, Reduce) const
{
static_assert(RemainSeq{}.GetSize() == NRemain,
"wrong! RemainSeq and NRemain not consistent!");
constexpr index_t a = Reduce{}(ScanedSeq{}.Back(), RemainSeq{}.Front());
constexpr auto scaned_seq = ScanedSeq{}.PushBack(Number<a>{});
static_if<(NRemain > 1)>{}([&](auto fwd) {
return scan_sequence_impl<NRemain - 1>{}(
scaned_seq, RemainSeq{}.PopFront(), fwd(Reduce{}));
}).else_([&](auto fwd) { return fwd(scaned_seq); });
}
};
return typename sequence_reverse<Sequence<Is...>>::SeqType{};
}
template <class Seq, class Reduce>
__host__ __device__ constexpr auto scan_sequence(Seq, Reduce)
__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce)
{
constexpr auto scaned_seq = Sequence<Seq{}.front()>{};
constexpr auto remain_seq = Seq{}.PopFront();
constexpr index_t remain_size = Seq::GetSize() - 1;
return scan_sequence_impl<remain_size>{}(scaned_seq, remain_seq, Reduce{});
return typename sequence_reverse_inclusive_scan<Seq, Reduce>::SeqType{};
}
template <class Seq, class Reduce>
__host__ __device__ constexpr auto reverse_scan_sequence(Seq, Reduce)
__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce)
{
return scan_seqeunce(Seq{}.Reverse(), Reduce{}).Reverse();
return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}).Reverse();
}
......@@ -80,29 +80,26 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
"wrong! cannot evenly divide work for workgroup ");
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
const index_t w_block_work_id = itmp / NBlockWork;
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
constexpr index_t KBlockWork = mod_conv::integer_divide_ceil(K, KPerBlock);
constexpr index_t HBlockWork = mod_conv::integer_divide_ceil(Ho, HoPerBlock);
constexpr index_t WBlockWork = mod_conv::integer_divide_ceil(Wo, WoPerBlock);
constexpr index_t NBlockWork = mod_conv::integer_divide_ceil(N, NPerBlock);
constexpr auto block_work_desc = make_ConstantTensorDescriptor(
Sequence<KBlockWork, HBlockWork, WBlockWork, NBlockWork>{});
const auto block_work_multi_id = block_work_desc.GetMultiIndex(get_block_1d_id());
const index_t k_block_data_begin = block_work_multi_id[0] * KPerBlock;
const index_t ho_block_data_begin = block_work_multi_id[1] * HoPerBlock;
const index_t wo_block_data_begin = block_work_multi_id[2] * WoPerBlock;
const index_t n_block_data_begin = block_work_multi_id[3] * NPerBlock;
const index_t hi_block_data_begin = ho_block_data_begin;
const index_t wi_block_data_begin = wo_block_data_begin;
// global tensor view
constexpr auto wei_c_k_global_desc =
make_ConstantTensorDescriptor(Sequence<C, K>{}, Sequence<Y * X * K, 1>{});
constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(I0, I3);
// LDS tensor view
// be careful of alignment
......@@ -360,13 +357,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto f_dummy) { // f_dummy do nothing but
// perfect forwarding.
// Using this trick to
// make this lambda a generic lambda, so it won't be compiled until
// instantiated
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto fwd) {
// fwd do nothing but perfect forwarding.
// Using this trick to make this lambda a generic lambda, so it won't be compiled until
// being instantiated here
static_assert(
(f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
(fwd(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
"wrong!");
// output is a 10d tensor
......@@ -374,37 +370,32 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
constexpr index_t N1 = NPerBlock / N2;
constexpr index_t W2 =
(GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC);
(GemmNLevel0Cluster * GemmNLevel1Cluster) / fwd(NPerBlock / GemmNPerThreadSubC);
constexpr index_t W1 = WoPerBlock / W2;
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc =
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
K1,
K2,
Ho,
Wo / (W1 * W2),
W1,
W2,
N / f_dummy(N1 * N2),
N1,
N2>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
constexpr auto out_10d_global_desc = fwd(out_k_h_w_n_global_desc)
.Fold(I3, Number<N1>{}, Number<N2>{})
.Fold(I2, Number<W1>{}, Number<W2>{})
.Fold(I0, Number<K1>{}, Number<K2>{});
constexpr auto out_10d_thread_desc = fwd(out_k_h_w_n_thread_desc)
.Fold(I3, Number<1>{}, Number<N2>{})
.Fold(I2, Number<W1>{}, Number<1>{})
.Fold(I0, Number<1>{}, Number<K2>{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
"a: out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "a: out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
"out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
"a: out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "a: out_10d_global_desc");
}
#endif
......@@ -419,8 +410,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{});
}).else_([&](auto f_dummy) {
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
}).else_([&](auto fwd) {
static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0,
"wrong!");
......@@ -429,32 +420,33 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
constexpr index_t W1 = WoPerBlock / fwd(W2 * W3);
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3, N / N1, N1>{});
constexpr auto out_10d_global_desc =
fwd(out_k_h_w_n_global_desc)
.Fold(I3, Number<N1>{})
.Fold(I2, Number<W1>{}, Number<W2>{}, Number<W3>{})
.Fold(I0, Number<K1>{}, Number<K2>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
constexpr auto out_10d_thread_desc =
fwd(out_k_h_w_n_thread_desc)
.Fold(I3, Number<N1>{})
.Fold(I2, Number<W1>{}, Number<1>{}, Number<W3>{})
.Fold(I0, Number<1>{}, Number<K2>{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
"b: out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "b: out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
"out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
for(index_t i = 0; i < 64; ++i)
{
printf("out %f, ", p_out_thread[i]);
}
"b: out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "b: out_10d_global_desc");
}
#endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment