Unverified Commit d075adf1 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Use Tuple and vector_type instead of Array for holding tensor data (#30)

* replacing array with tuple and vector for tensor data
parent e4790c25
...@@ -6,12 +6,10 @@ ...@@ -6,12 +6,10 @@
namespace ck { namespace ck {
// blockwise GEMM: C[M, N] += transpose(A[K, M]) * B[K, N]
// A and B are visable to the whole block, C is distributed among each thread
// If following number are power of 2, index calculation shall be greatly reduced:
// KPerThread, HPerThread, MLevel0ThreadCluster, NLevel0ThreadCluster,
// MLevel1ThreadCluster, NLevel1ThreadCluster
template <index_t BlockSize, template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename BlockMatrixA, typename BlockMatrixA,
typename BlockMatrixB, typename BlockMatrixB,
typename ThreadMatrixC, typename ThreadMatrixC,
...@@ -30,9 +28,34 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -30,9 +28,34 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
index_t w; index_t w;
}; };
index_t mMyThreadOffsetA; // HACK: fix this @Jing Zhang
static constexpr index_t KPerThreadSubC = 4;
static constexpr auto a_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadSubC>{}));
static constexpr auto b_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<EPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
static constexpr auto c_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThreadSubC>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
using AThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
FloatA,
BlockMatrixA,
decltype(a_thread_mtx_),
Sequence<EPerThreadLoop, KPerThreadSubC>,
Sequence<0, 1>,
1,
ThreadGemmADataPerRead_K,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>;
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v3() __device__ BlockwiseGemm_km_kn_m0m1n0n1_v3()
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.k * KPerThread)}
{ {
static_assert(BlockMatrixA::IsKnownAtCompileTime() && static_assert(BlockMatrixA::IsKnownAtCompileTime() &&
BlockMatrixB::IsKnownAtCompileTime() && BlockMatrixB::IsKnownAtCompileTime() &&
...@@ -61,11 +84,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -61,11 +84,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
static_assert(BlockSize == KThreadCluster * HThreadCluster * WThreadCluster, static_assert(BlockSize == KThreadCluster * HThreadCluster * WThreadCluster,
"wrong! wrong blocksize\n"); "wrong! wrong blocksize\n");
auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
mMyThreadOffsetA =
BlockMatrixA{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.k * KPerThread));
} }
__device__ static constexpr auto GetThreadMatrixCLengths() __device__ static constexpr auto GetThreadMatrixCLengths()
...@@ -91,37 +109,18 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -91,37 +109,18 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
return MatrixIndex{k_thread_id, h_thread_id, w_thread_id}; return MatrixIndex{k_thread_id, h_thread_id, w_thread_id};
} }
template <typename SrcDesc, template <typename ABlockBuffer, typename BThreadBuffer, typename CThreadBuffer>
typename DstDesc, __device__ void Run(const ABlockBuffer& a_block_buf,
index_t NSliceRow, const BThreadBuffer& b_thread_buf,
index_t NSliceCol, CThreadBuffer& c_thread_buf) const
index_t DataPerAccess>
struct ThreadwiseSliceCopy_a
{
template <typename Data>
__device__ static void Run(const Data* p_src, Data* p_dst)
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
using vector_t = typename vector_type_maker<Data, DataPerAccess>::type::type;
static_for<0, NSliceRow, 1>{}([&](auto i) {
static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) {
constexpr auto src_offset = SrcDesc{}.CalculateOffset(make_tuple(i, j));
constexpr auto dst_offset = DstDesc{}.CalculateOffset(make_tuple(i, j));
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
*reinterpret_cast<const vector_t*>(&p_src[src_offset]);
});
});
}
};
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void
Run_naive(const FloatA* p_a_block, const FloatB* p_b_thread, FloatC* p_c_thread) const
{ {
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABlockBuffer::type>>,
remove_cv_t<remove_reference_t<FloatA>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename BThreadBuffer::type>>,
remove_cv_t<remove_reference_t<FloatB>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename CThreadBuffer::type>>,
remove_cv_t<remove_reference_t<FloatC>>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -132,8 +131,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -132,8 +131,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
constexpr auto EPerBlock = a_block_mtx.GetLength(I0); constexpr auto EPerBlock = a_block_mtx.GetLength(I0);
constexpr auto KPerThreadSubC = 4; // HACK: fix this @Jing Zhang
constexpr auto HoPerThreadSubC = 2; constexpr auto HoPerThreadSubC = 2;
constexpr auto WoPerThreadSubC = 2; constexpr auto WoPerThreadSubC = 2;
...@@ -141,63 +139,53 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -141,63 +139,53 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
static_assert(HPerThread % HoPerThreadSubC == 0, ""); static_assert(HPerThread % HoPerThreadSubC == 0, "");
static_assert(WPerThread % WoPerThreadSubC == 0, ""); static_assert(WPerThread % WoPerThreadSubC == 0, "");
// thread A, B for GEMM // thread A buffer for GEMM
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2( StaticBuffer<FloatA, a_thread_mtx_.GetElementSpaceSize()> a_thread_buf;
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadSubC>{}));
constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<EPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
constexpr auto c_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThreadSubC>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
FloatA p_a_thread[a_thread_mtx.GetElementSpaceSize()];
constexpr auto a_thread_copy = ThreadwiseSliceCopy_a<BlockMatrixA, constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3<FloatA,
decltype(a_thread_mtx), FloatB,
EPerThreadLoop, FloatC,
KPerThreadSubC, decltype(a_thread_mtx_),
ThreadGemmADataPerRead_K>{}; decltype(b_thread_mtx_),
decltype(c_thread_mtx_),
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3<decltype(a_thread_mtx),
decltype(b_thread_mtx),
decltype(c_thread_mtx),
HoPerThreadSubC, HoPerThreadSubC,
WoPerThreadSubC>{}; WoPerThreadSubC>{};
// loop over k
#pragma unroll
for(index_t e_begin = 0; e_begin < EPerBlock; e_begin += EPerThreadLoop)
{
#pragma unroll
for(index_t k_begin = 0; k_begin < KPerThread; k_begin += KPerThreadSubC)
{
a_thread_copy.Run(p_a_block +
a_block_mtx.CalculateOffset(make_tuple(e_begin, k_begin)) +
mMyThreadOffsetA,
p_a_thread);
#pragma unroll static_for<0, EPerBlock, EPerThreadLoop>{}([&](auto e_begin) {
for(index_t h_begin = 0; h_begin < HPerThread; h_begin += HoPerThreadSubC) static_for<0, KPerThread, KPerThreadSubC>{}([&](auto k_begin) {
{
#pragma unroll a_thread_copy_.Run(a_block_mtx,
for(index_t w_begin = 0; w_begin < WPerThread; w_begin += WoPerThreadSubC) make_tuple(e_begin, k_begin),
{ a_block_buf,
threadwise_gemm.Run(p_a_thread, a_thread_mtx_,
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple( make_tuple(I0, I0),
e_begin, 0, h_begin, w_begin)), a_thread_buf);
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(
k_begin, 0, h_begin, w_begin))); static_for<0, HPerThread, HoPerThreadSubC>{}([&](auto h_begin) {
} static_for<0, WPerThread, WoPerThreadSubC>{}([&](auto w_begin) {
} threadwise_gemm.Run(a_thread_buf,
} make_tuple(I0, I0),
} b_thread_buf,
make_tuple(e_begin, I0, h_begin, w_begin),
c_thread_buf,
make_tuple(k_begin, I0, h_begin, w_begin));
});
});
});
});
} }
template <typename FloatA, typename FloatB, typename FloatC> template <typename ABlockSliceMoveStepIdx>
__device__ void Run(const FloatA* p_a_block, const FloatB* p_b_thread, FloatC* p_c_thread) const __device__ void MoveASliceWindow(const BlockMatrixA&,
const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx)
{ {
Run_naive(p_a_block, p_b_thread, p_c_thread); a_thread_copy_.MoveSrcSliceWindow(BlockMatrixA{}, a_block_slice_move_step_idx);
} }
private:
MatrixIndex c_thread_begin_mtx_idx_;
AThreadCopy a_thread_copy_;
}; };
} // namespace ck } // namespace ck
......
...@@ -5,9 +5,10 @@ ...@@ -5,9 +5,10 @@
#include "dynamic_multi_index_transform_helper.hpp" #include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_gemm_v2.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp" #include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "blockwise_gemm_v2.hpp" #include "threadwise_dynamic_tensor_slice_set.hpp"
namespace ck { namespace ck {
...@@ -256,7 +257,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -256,7 +257,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
make_tuple(Number<MRepeat * MPerThread>{}, Number<NRepeat * NPerThread>{})); make_tuple(Number<MRepeat * MPerThread>{}, Number<NRepeat * NPerThread>{}));
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v1<BlockSize, BlockwiseGemm_km_kn_m0m1n0n1_v1r1<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k_m_block_desc), decltype(a_k_m_block_desc),
decltype(b_k_n_block_desc), decltype(b_k_n_block_desc),
decltype(c_m0m1_n0n1_thread_desc), decltype(c_m0m1_n0n1_thread_desc),
...@@ -281,10 +285,13 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -281,10 +285,13 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size; FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output // register allocation for output
FloatAcc p_c_thread[c_m0m1_n0n1_thread_desc.GetElementSpaceSize()]; auto c_thread_buf =
make_static_buffer<FloatAcc>(c_m0m1_n0n1_thread_desc.GetElementSpaceSize());
// zero out threadwise output ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
threadwise_matrix_set_zero_v2(c_m0m1_n0n1_thread_desc, p_c_thread); decltype(c_m0m1_n0n1_thread_desc),
Sequence<MRepeat * MPerThread, NRepeat * NPerThread>>{}
.Run(c_m0m1_n0n1_thread_desc, make_tuple(I0, I0), c_thread_buf, FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
...@@ -300,6 +307,18 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -300,6 +307,18 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
constexpr auto b_k_n_global_move_slice_window_iterator_hack = constexpr auto b_k_n_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{}; BGlobalMoveSliceWindowIteratorHacks{};
FloatAB* p_a_block_even = p_a_block_double;
FloatAB* p_b_block_even = p_b_block_double;
FloatAB* p_a_block_odd = p_a_block_double + a_block_space_size;
FloatAB* p_b_block_odd = p_b_block_double + b_block_space_size;
auto a_block_even_buf = make_dynamic_buffer(p_a_block_even);
auto b_block_even_buf = make_dynamic_buffer(p_b_block_even);
auto a_block_odd_buf = make_dynamic_buffer(p_a_block_odd);
auto b_block_odd_buf = make_dynamic_buffer(p_b_block_odd);
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
...@@ -311,12 +330,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -311,12 +330,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
FloatAB* p_a_block_even = p_a_block_double;
FloatAB* p_b_block_even = p_b_block_double;
FloatAB* p_a_block_odd = p_a_block_double + a_block_space_size;
FloatAB* p_b_block_odd = p_b_block_double + b_block_space_size;
index_t k_block_data_begin = 0; index_t k_block_data_begin = 0;
// LDS double buffer: main body // LDS double buffer: main body
...@@ -340,7 +353,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -340,7 +353,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread); blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_odd); a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_odd);
...@@ -363,7 +376,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -363,7 +376,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread); blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_even); a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_even);
...@@ -390,7 +403,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -390,7 +403,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS // LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size); a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size);
...@@ -399,16 +412,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -399,16 +412,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
__syncthreads(); __syncthreads();
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double + a_block_space_size, blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
p_b_block_double + b_block_space_size,
p_c_thread);
} }
else // if has 1 iteration left else // if has 1 iteration left
{ {
__syncthreads(); __syncthreads();
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
} }
// output: register to global memory // output: register to global memory
...@@ -461,7 +472,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -461,7 +472,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
n_thread_data_on_global % N1)) n_thread_data_on_global % N1))
.Run(c_m0_m1_n0_n1_thread_desc, .Run(c_m0_m1_n0_n1_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
p_c_thread, c_thread_buf,
c_m0_m1_n0_n1_global_desc, c_m0_m1_n0_n1_global_desc,
p_c_global, p_c_global,
c_m0_m1_n0_n1_global_tensor_iterator_hacks); c_m0_m1_n0_n1_global_tensor_iterator_hacks);
......
...@@ -145,8 +145,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -145,8 +145,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{})); Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
const auto blockwise_gemm = auto blockwise_gemm = BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize,
BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize, FloatAB,
FloatAB,
FloatAcc,
decltype(a_e_k_block_desc), decltype(a_e_k_block_desc),
decltype(b_e_n_ho_wo_block_desc), decltype(b_e_n_ho_wo_block_desc),
decltype(c_k_n_ho_wo_thread_desc), decltype(c_k_n_ho_wo_thread_desc),
...@@ -223,11 +225,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -223,11 +225,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
FloatAB* p_a_block = p_shared_block; FloatAB* p_a_block = p_shared_block;
auto a_block_buf = make_dynamic_buffer(p_a_block);
// register allocation for output // register allocation for output
FloatAcc p_c_thread[c_k_n_ho_wo_thread_desc.GetElementSpaceSize()]; StaticBuffer<FloatAcc, c_k_n_ho_wo_thread_desc.GetElementSpaceSize()> c_thread_buf;
// zero out threadwise output // initialize output thread tensor
threadwise_matrix_set_zero_v3(c_k_n_ho_wo_thread_desc, p_c_thread); ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
decltype(c_k_n_ho_wo_thread_desc),
Sequence<KPerThread, 1, HoPerThread, WoPerThread>>{}
.Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0); constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
...@@ -242,12 +249,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -242,12 +249,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack = constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{}; BGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_thread_space_size = b_e_n_ho_wo_thread_desc.GetElementSpaceSize(); // double regsiter buffer for b
FloatAB p_b_thread[b_thread_space_size * 2]; StaticBuffer<FloatAB, b_e_n_ho_wo_thread_desc.GetElementSpaceSize()> b_thread_even_buf,
b_thread_odd_buf;
FloatAB* p_b_thread_double = p_b_thread; // LDS double buffer: preload data
// LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead(a_e_k_global_desc, p_a_global, a_e_k_global_iterator_hacks); a_blockwise_copy.RunRead(a_e_k_global_desc, p_a_global, a_e_k_global_iterator_hacks);
...@@ -255,7 +261,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -255,7 +261,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_b_global, p_b_global,
b_e_n_ho_wo_thread_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
p_b_thread_double, b_thread_even_buf,
b_e_n_ho_wo_global_iterator_hacks); b_e_n_ho_wo_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_e_k_desc, p_a_block); a_blockwise_copy.RunWrite(a_e_k_desc, p_a_block);
...@@ -263,13 +269,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -263,13 +269,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
__syncthreads(); __syncthreads();
index_t b_block_data_begin = 0;
#if 1
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
FloatAB* p_b_thread_even = p_b_thread_double; index_t e_block_data_begin = 0;
FloatAB* p_b_thread_odd = p_b_thread_double + b_thread_space_size;
// LDS double buffer: main body // LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow // use Do-While loop instead of For loop to simplify control flow
...@@ -283,16 +285,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -283,16 +285,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_b_global, p_b_global,
b_e_n_ho_wo_thread_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
p_b_thread_odd, b_thread_odd_buf,
b_e_n_ho_wo_global_iterator_hacks); b_e_n_ho_wo_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run( // TODO: @Zhang Jing: blockwise gemm should be able to move slice window
p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)), blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
p_b_thread_even,
p_c_thread);
b_block_data_begin += EPerBlock; blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc, b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_thread_slice_copy_step); b_thread_slice_copy_step);
...@@ -301,18 +301,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -301,18 +301,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_b_global, p_b_global,
b_e_n_ho_wo_thread_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
p_b_thread_even, b_thread_even_buf,
b_e_n_ho_wo_global_iterator_hacks); b_e_n_ho_wo_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run( blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)),
p_b_thread_odd, blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
p_c_thread);
b_block_data_begin += EPerBlock; e_block_data_begin += 2 * EPerBlock;
} while(b_block_data_begin < E - 2 * EPerBlock); } while(e_block_data_begin < E - 2 * EPerBlock);
} }
// LDS double buffer: tail // LDS double buffer: tail
...@@ -325,34 +324,23 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -325,34 +324,23 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_b_global, p_b_global,
b_e_n_ho_wo_thread_desc, b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
p_b_thread_double + b_thread_space_size, b_thread_odd_buf,
b_e_n_ho_wo_global_iterator_hacks); b_e_n_ho_wo_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run( blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)),
p_b_thread_double,
p_c_thread);
b_block_data_begin += EPerBlock; blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run( blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)),
p_b_thread_double + b_thread_space_size,
p_c_thread);
} }
else // if has 1 iteration left else // if has 1 iteration left
{ {
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run( blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)),
p_b_thread_double,
p_c_thread);
} }
#endif
#if 1
// output: register to global memory // output: register to global memory
{ {
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor // hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
...@@ -380,12 +368,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -380,12 +368,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global)) k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global))
.Run(c_k_n_ho_wo_thread_desc, .Run(c_k_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
p_c_thread, c_thread_buf,
c_k_n_ho_wo_global_desc, c_k_n_ho_wo_global_desc,
p_c_global, p_c_global,
c_k_n_ho_wo_global_tensor_iterator_hacks); c_k_n_ho_wo_global_tensor_iterator_hacks);
} }
#endif
} }
// pass tensor descriptor by reference // pass tensor descriptor by reference
......
#ifndef CK_THREADWISE_DYNAMIC_TENSOR_SET_HPP
#define CK_THREADWISE_DYNAMIC_TENSOR_SET_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
namespace ck {
// Assume:
// 1. Desc is known at compile-time
// 2. Buffer is StaticBuffer
// 3. OriginIdx is known at compile-time
// 4. use #-iterator
template <typename Data,
typename Desc,
typename SliceLengths,
typename std::enable_if<Desc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseDynamicTensorSliceSet_v1
{
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
template <typename OriginIdx, typename Buffer>
__device__ void Run(const Desc&, const OriginIdx&, Buffer& buf, const Data& initial_value) const
{
static_assert(Desc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time");
static_assert(Buffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
static_assert(is_known_at_compile_time<remove_cv_t<remove_reference_t<OriginIdx>>>::value,
"wrong! OriginIdx need to be known at compile-time");
// Desc is known at compile-time
constexpr auto desc = remove_cv_t<remove_reference_t<Desc>>{};
// OriginIdx is known at compile-time
constexpr auto origin_idx = to_multi_index(OriginIdx{});
static_ford<SliceLengths>{}([&](auto access_idx) {
constexpr auto coord = make_dynamic_tensor_coordinate(desc, origin_idx + access_idx);
constexpr bool is_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(desc, coord);
constexpr index_t offset = coord.GetOffset();
if constexpr(is_valid)
{
buf(Number<offset>{}) = initial_value;
}
});
}
};
} // namespace ck
#endif
...@@ -6,100 +6,52 @@ ...@@ -6,100 +6,52 @@
namespace ck { namespace ck {
template <typename Float, typename Desc>
__device__ void threadwise_matrix_set_zero_v2(Desc, Float* __restrict__ p_thread)
{
static_assert(Desc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto desc = Desc{};
constexpr auto M = desc.GetLength(I0);
constexpr auto N = desc.GetLength(I1);
static_for<0, M, 1>{}([&](auto i) {
static_for<0, N, 1>{}([&](auto j) {
constexpr auto offset = desc.CalculateOffset(make_tuple(i, j));
p_thread[offset] = Float(0);
});
});
}
template <typename SrcDesc,
typename DstDesc,
index_t NSliceRow,
index_t NSliceCol,
index_t DataPerAccess>
struct ThreadwiseMatrixSliceCopy_v2
{
template <typename Data>
__device__ static void Run(const Data* p_src, Data* p_dst)
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
using vector_t = typename vector_type_maker<Data, DataPerAccess>::type::type;
static_for<0, NSliceRow, 1>{}([&](auto i) {
static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) {
constexpr auto src_offset = SrcDesc{}.CalculateOffset(make_tuple(i, j));
constexpr auto dst_offset = DstDesc{}.CalculateOffset(make_tuple(i, j));
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
*reinterpret_cast<const vector_t*>(&p_src[src_offset]);
});
});
}
};
// C[M, N] += transpose(A[K, M]) * B[K, N] // C[M, N] += transpose(A[K, M]) * B[K, N]
// Element of matrix can be vectorized data // Element of matrix can be vectorized data
template <typename ADesc, // Assume:
// 1. ADesc, BDesc, CDesc are known at compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
template <typename FloatA,
typename FloatB,
typename FloatC,
typename ADesc,
typename BDesc, typename BDesc,
typename CDesc, typename CDesc,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(), CDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseGemm_km_kn_mn_v1 struct ThreadwiseGemm_km_kn_mn_v1r1
{ {
template <typename FloatA, typename FloatB, typename FloatC> template <typename ABuffer,
__device__ static void Run_source(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) typename AOriginIdx,
typename BBuffer,
typename BOriginIdx,
typename CBuffer,
typename COriginIdx>
__device__ static void Run(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
{ {
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(), CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
constexpr auto I0 = Number<0>{}; static_assert(
constexpr auto I1 = Number<1>{}; is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value &&
constexpr auto M = CDesc{}.GetLength(I0); is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value,
constexpr auto N = CDesc{}.GetLength(I1); "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
constexpr auto K = ADesc{}.GetLength(I0);
static_for<0, K, 1>{}([&](auto k) {
static_for<0, M, 1>{}([&](auto m) {
static_for<0, N, 1>{}([&](auto n) {
constexpr auto a_offset = ADesc{}.CalculateOffset(make_tuple(k, m));
constexpr auto b_offset = BDesc{}.CalculateOffset(make_tuple(k, n));
constexpr auto c_offset = CDesc{}.CalculateOffset(make_tuple(m, n));
p_c[c_offset] +=
inner_product_with_conversion<FloatC>{}(p_a[a_offset], p_b[b_offset]);
});
});
});
}
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
template <typename FloatA, typename FloatB, typename FloatC> remove_cv_t<remove_reference_t<FloatA>>>::value &&
__device__ static void Run_amd_asm(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
{ remove_cv_t<remove_reference_t<FloatB>>>::value &&
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
CDesc::IsKnownAtCompileTime(), remove_cv_t<remove_reference_t<FloatC>>>::value &&
"wrong! Desc should be known at compile-time"); "wrong! inconsistent type");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -110,61 +62,81 @@ struct ThreadwiseGemm_km_kn_mn_v1 ...@@ -110,61 +62,81 @@ struct ThreadwiseGemm_km_kn_mn_v1
constexpr auto N = CDesc{}.GetLength(I1); constexpr auto N = CDesc{}.GetLength(I1);
constexpr auto K = ADesc{}.GetLength(I0); constexpr auto K = ADesc{}.GetLength(I0);
static_assert(N == 4 || N == 2, "wrong! this config not supported by asm yet"); constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, K, 1>{}([&](auto k) { static_for<0, K, 1>{}([&](auto k) {
static_for<0, M, 1>{}([&](auto m) { static_for<0, M, 1>{}([&](auto m) {
constexpr auto a_offset = ADesc{}.CalculateOffset(make_tuple(k, m)); constexpr index_t a_offset =
ADesc{}.CalculateOffset(a_origin_idx + make_tuple(k, m));
#if 0
if constexpr(N == 2) if constexpr(N == 2)
{ {
constexpr auto b_offset_0 = BDesc{}.CalculateOffset(make_tuple(k, I0)); constexpr index_t b_offset_0 =
constexpr auto b_offset_1 = BDesc{}.CalculateOffset(make_tuple(k, I1)); BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I0));
constexpr index_t b_offset_1 =
constexpr auto c_offset_0 = CDesc{}.CalculateOffset(make_tuple(m, I0)); BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I1));
constexpr auto c_offset_1 = CDesc{}.CalculateOffset(make_tuple(m, I1));
constexpr index_t c_offset_0 =
amd_assembly_outer_product_1x2(p_a[a_offset], CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I0));
p_b[b_offset_0], constexpr index_t c_offset_1 =
p_b[b_offset_1], CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I1));
p_c[c_offset_0],
p_c[c_offset_1]); amd_assembly_outer_product_1x2(a_buf[Number<a_offset>{}],
b_buf[Number<b_offset_0>{}],
b_buf[Number<b_offset_1>{}],
c_buf(Number<c_offset_0>{}),
c_buf(Number<c_offset_1>{}));
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
constexpr auto b_offset_0 = BDesc{}.CalculateOffset(make_tuple(k, I0)); constexpr index_t b_offset_0 =
constexpr auto b_offset_1 = BDesc{}.CalculateOffset(make_tuple(k, I1)); BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I0));
constexpr auto b_offset_2 = BDesc{}.CalculateOffset(make_tuple(k, I2)); constexpr index_t b_offset_1 =
constexpr auto b_offset_3 = BDesc{}.CalculateOffset(make_tuple(k, I3)); BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I1));
constexpr index_t b_offset_2 =
constexpr auto c_offset_0 = CDesc{}.CalculateOffset(make_tuple(m, I0)); BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I2));
constexpr auto c_offset_1 = CDesc{}.CalculateOffset(make_tuple(m, I1)); constexpr index_t b_offset_3 =
constexpr auto c_offset_2 = CDesc{}.CalculateOffset(make_tuple(m, I2)); BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I3));
constexpr auto c_offset_3 = CDesc{}.CalculateOffset(make_tuple(m, I3));
constexpr index_t c_offset_0 =
amd_assembly_outer_product_1x4(p_a[a_offset], CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I0));
p_b[b_offset_0], constexpr index_t c_offset_1 =
p_b[b_offset_1], CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I1));
p_b[b_offset_2], constexpr index_t c_offset_2 =
p_b[b_offset_3], CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I2));
p_c[c_offset_0], constexpr index_t c_offset_3 =
p_c[c_offset_1], CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I3));
p_c[c_offset_2],
p_c[c_offset_3]); amd_assembly_outer_product_1x4(a_buf[Number<a_offset>{}],
} b_buf[Number<b_offset_0>{}],
}); b_buf[Number<b_offset_1>{}],
}); b_buf[Number<b_offset_2>{}],
b_buf[Number<b_offset_3>{}],
c_buf(Number<c_offset_0>{}),
c_buf(Number<c_offset_1>{}),
c_buf(Number<c_offset_2>{}),
c_buf(Number<c_offset_3>{}));
} }
else
#endif #endif
template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{ {
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM static_for<0, N, 1>{}([&](auto n) {
Run_amd_asm(p_a, p_b, p_c);
#else constexpr index_t b_offset =
Run_source(p_a, p_b, p_c); BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, n));
#endif constexpr index_t c_offset =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, n));
amd_assembly_inner_product(a_buf[Number<a_offset>{}],
b_buf[Number<b_offset>{}],
c_buf(Number<c_offset>{}));
});
}
});
});
} }
}; };
......
...@@ -6,35 +6,15 @@ ...@@ -6,35 +6,15 @@
namespace ck { namespace ck {
template <typename Float, typename Desc>
__device__ void threadwise_matrix_set_zero_v3(Desc, Float* __restrict__ p_thread)
{
static_assert(Desc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time");
constexpr auto I0 = Number<0>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto desc = Desc{};
constexpr auto K = desc.GetLength(I0);
constexpr auto H = desc.GetLength(I2);
constexpr auto W = desc.GetLength(I3);
static_for<0, K, 1>{}([&](auto i) {
static_for<0, H, 1>{}([&](auto j) {
static_for<0, W, 1>{}([&](auto k) {
constexpr auto offset = desc.CalculateOffset(make_tuple(i, 0, j, k));
p_thread[offset] = Float(0);
});
});
});
}
// C[M, N] += transpose(A[K, M]) * B[K, N] // C[M, N] += transpose(A[K, M]) * B[K, N]
// Element of matrix can be vectorized data // Element of matrix can be vectorized data
template <typename ADesc, // Assume:
// 1. ADesc, BDesc, CDesc are known at compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
template <typename FloatA,
typename FloatB,
typename FloatC,
typename ADesc,
typename BDesc, typename BDesc,
typename CDesc, typename CDesc,
index_t H, index_t H,
...@@ -44,13 +24,37 @@ template <typename ADesc, ...@@ -44,13 +24,37 @@ template <typename ADesc,
bool>::type = false> bool>::type = false>
struct ThreadwiseGemm_km_kn_mn_v3 struct ThreadwiseGemm_km_kn_mn_v3
{ {
template <typename FloatA, typename FloatB, typename FloatC> template <typename ABuffer,
__device__ static void Run_source(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) typename AOriginIdx,
typename BBuffer,
typename BOriginIdx,
typename CBuffer,
typename COriginIdx>
__device__ static void Run(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
{ {
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(), CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
static_assert(
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
remove_cv_t<remove_reference_t<FloatA>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
remove_cv_t<remove_reference_t<FloatB>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
remove_cv_t<remove_reference_t<FloatC>>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
...@@ -59,79 +63,100 @@ struct ThreadwiseGemm_km_kn_mn_v3 ...@@ -59,79 +63,100 @@ struct ThreadwiseGemm_km_kn_mn_v3
constexpr auto E = ADesc{}.GetLength(I0); constexpr auto E = ADesc{}.GetLength(I0);
constexpr auto K = ADesc{}.GetLength(I1); constexpr auto K = ADesc{}.GetLength(I1);
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, E, 1>{}([&](auto e) { static_for<0, E, 1>{}([&](auto e) {
static_for<0, K, 1>{}([&](auto k) { static_for<0, K, 1>{}([&](auto k) {
constexpr auto a_offset = ADesc{}.CalculateOffset(make_tuple(e, k)); constexpr index_t a_offset =
ADesc{}.CalculateOffset(a_origin_idx + make_tuple(e, k));
if constexpr(H == 2 && W == 2) if constexpr(H == 2 && W == 2)
{ {
constexpr index_t b_offset_0 =
constexpr auto b_offset_0 = BDesc{}.CalculateOffset(make_tuple(e, 0, 0, 0)); BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 0));
constexpr auto b_offset_1 = BDesc{}.CalculateOffset(make_tuple(e, 0, 0, 1)); constexpr index_t b_offset_1 =
constexpr auto b_offset_2 = BDesc{}.CalculateOffset(make_tuple(e, 0, 1, 0)); BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 1));
constexpr auto b_offset_3 = BDesc{}.CalculateOffset(make_tuple(e, 0, 1, 1)); constexpr index_t b_offset_2 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 0));
constexpr auto c_offset_0 = CDesc{}.CalculateOffset(make_tuple(k, 0, 0, 0)); constexpr index_t b_offset_3 =
constexpr auto c_offset_1 = CDesc{}.CalculateOffset(make_tuple(k, 0, 0, 1)); BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 1));
constexpr auto c_offset_2 = CDesc{}.CalculateOffset(make_tuple(k, 0, 1, 0));
constexpr auto c_offset_3 = CDesc{}.CalculateOffset(make_tuple(k, 0, 1, 1)); constexpr index_t c_offset_0 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 0));
amd_assembly_outer_product_1x4(p_a[a_offset], constexpr index_t c_offset_1 =
p_b[b_offset_0], CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 1));
p_b[b_offset_1], constexpr index_t c_offset_2 =
p_b[b_offset_2], CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 0));
p_b[b_offset_3], constexpr index_t c_offset_3 =
p_c[c_offset_0], CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 1));
p_c[c_offset_1],
p_c[c_offset_2], amd_assembly_outer_product_1x4(a_buf[Number<a_offset>{}],
p_c[c_offset_3]); b_buf[Number<b_offset_0>{}],
b_buf[Number<b_offset_1>{}],
b_buf[Number<b_offset_2>{}],
b_buf[Number<b_offset_3>{}],
c_buf(Number<c_offset_0>{}),
c_buf(Number<c_offset_1>{}),
c_buf(Number<c_offset_2>{}),
c_buf(Number<c_offset_3>{}));
} }
else if constexpr(H == 4 && W == 1) else if constexpr(H == 4 && W == 1)
{ {
constexpr index_t b_offset_0 =
constexpr auto b_offset_0 = BDesc{}.CalculateOffset(make_tuple(e, 0, 0, 0)); BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 0));
constexpr auto b_offset_1 = BDesc{}.CalculateOffset(make_tuple(e, 0, 1, 0)); constexpr index_t b_offset_1 =
constexpr auto b_offset_2 = BDesc{}.CalculateOffset(make_tuple(e, 0, 2, 0)); BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 0));
constexpr auto b_offset_3 = BDesc{}.CalculateOffset(make_tuple(e, 0, 3, 0)); constexpr index_t b_offset_2 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 2, 0));
constexpr auto c_offset_0 = CDesc{}.CalculateOffset(make_tuple(k, 0, 0, 0)); constexpr index_t b_offset_3 =
constexpr auto c_offset_1 = CDesc{}.CalculateOffset(make_tuple(k, 0, 1, 0)); BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 3, 0));
constexpr auto c_offset_2 = CDesc{}.CalculateOffset(make_tuple(k, 0, 2, 0));
constexpr auto c_offset_3 = CDesc{}.CalculateOffset(make_tuple(k, 0, 3, 0)); constexpr index_t c_offset_0 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 0));
amd_assembly_outer_product_1x4(p_a[a_offset], constexpr index_t c_offset_1 =
p_b[b_offset_0], CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 0));
p_b[b_offset_1], constexpr index_t c_offset_2 =
p_b[b_offset_2], CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 2, 0));
p_b[b_offset_3], constexpr index_t c_offset_3 =
p_c[c_offset_0], CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 3, 0));
p_c[c_offset_1],
p_c[c_offset_2], amd_assembly_outer_product_1x4(a_buf[Number<a_offset>{}],
p_c[c_offset_3]); b_buf[Number<b_offset_0>{}],
b_buf[Number<b_offset_1>{}],
b_buf[Number<b_offset_2>{}],
b_buf[Number<b_offset_3>{}],
c_buf(Number<c_offset_0>{}),
c_buf(Number<c_offset_1>{}),
c_buf(Number<c_offset_2>{}),
c_buf(Number<c_offset_3>{}));
} }
else else
{ {
static_for<0, H, 1>{}([&](auto h) { static_for<0, H, 1>{}([&](auto h) {
static_for<0, W, 1>{}([&](auto w) { static_for<0, W, 1>{}([&](auto w) {
constexpr auto b_offset =
BDesc{}.CalculateOffset(make_tuple(e, 0, h, w));
constexpr auto c_offset =
CDesc{}.CalculateOffset(make_tuple(k, 0, h, w));
p_c[c_offset] += inner_product_with_conversion<FloatC>{}(p_a[a_offset], constexpr index_t b_offset =
p_b[b_offset]); BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, h, w));
constexpr index_t c_offset =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, h, w));
#if 0
c_buf(Number<c_offset>{}) += inner_product_with_conversion<FloatC>{}(
a_buf[Number<a_offset>{}], b_buf[Number<b_offset>{}]);
#else
amd_assembly_inner_product(a_buf[Number<a_offset>{}],
b_buf[Number<b_offset>{}],
c_buf(Number<c_offset>{}));
#endif
}); });
}); });
} }
}); });
}); });
} }
template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{
Run_source(p_a, p_b, p_c);
}
}; };
} // namespace ck } // namespace ck
......
...@@ -5,6 +5,75 @@ ...@@ -5,6 +5,75 @@
namespace ck { namespace ck {
// c += inner_product(a, b)
__device__ void amd_assembly_inner_product(const float& a, const float& b, float& c)
{
#if CK_USE_AMD_V_FMAC_F32
asm volatile("\n \
v_fmac_f32 %0, %1, %2 \n \
"
: "=v"(c)
: "v"(a), "v"(b), "0"(c));
#else
asm volatile("\n \
v_mac_f32 %0, %1, %2 \n \
"
: "=v"(c)
: "v"(a), "v"(b), "0"(c));
#endif
}
__device__ void amd_assembly_inner_product(const int8x4_t& a, const int8x4_t& b, int32_t& c)
{
#if 1
asm volatile("\n \
v_dot4_i32_i8 %0, %1, %2, %0\n \
"
: "=v"(c)
: "v"(as_type<int32_t>(a)), "v"(as_type<int32_t>(b)), "0"(c));
#else
c = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b), c, false);
#endif
}
__device__ void amd_assembly_inner_product(const int8x8_t& a, const int8x8_t& b, int32_t& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
amd_assembly_inner_product(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I0],
c);
amd_assembly_inner_product(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I1],
c);
}
__device__ void amd_assembly_inner_product(const int8x16_t& a, const int8x16_t& b, int32_t& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
amd_assembly_inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I0],
c);
amd_assembly_inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I1],
c);
amd_assembly_inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I2],
c);
amd_assembly_inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I3],
c);
}
// c0 += inner_product(a, b0) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1) // c1 += inner_product(a, b1)
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1) __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
......
#ifndef CK_BUFFER_HPP
#define CK_BUFFER_HPP
#include "statically_indexed_array.hpp"
namespace ck {
template <typename T, index_t N>
struct StaticBuffer : public StaticallyIndexedArray<T, N>
{
using type = T;
using base = StaticallyIndexedArray<T, N>;
__host__ __device__ constexpr StaticBuffer() : base{} {}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
};
template <typename T, index_t N>
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
{
return StaticBuffer<T, N>{};
}
template <typename T>
struct DynamicBuffer
{
using type = T;
T* p_data_;
__host__ __device__ constexpr DynamicBuffer(T* p_data) : p_data_{p_data} {}
__host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; }
__host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }
template <typename X,
typename std::enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false>
__host__ __device__ constexpr const auto Get(index_t i) const
{
return *reinterpret_cast<const X*>(&p_data_[i]);
}
template <typename X,
typename std::enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false>
__host__ __device__ void Set(index_t i, const X& x)
{
*reinterpret_cast<X*>(&p_data_[i]) = x;
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return false; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
};
template <typename T>
__host__ __device__ constexpr auto make_dynamic_buffer(T* p)
{
return DynamicBuffer<T>{p};
}
} // namespace ck
#endif
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "statically_indexed_array.hpp" #include "statically_indexed_array.hpp"
#include "container_element_picker.hpp" #include "container_element_picker.hpp"
#include "float_type.hpp" #include "float_type.hpp"
#include "buffer.hpp"
#include "functional.hpp" #include "functional.hpp"
#include "functional2.hpp" #include "functional2.hpp"
#include "functional3.hpp" #include "functional3.hpp"
......
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
#define CK_DEVICE_BACKEND_AMD 1 #define CK_DEVICE_BACKEND_AMD 1
// GPU ID // GPU ID
#if 1 #if 0
#define CK_AMD_GPU_GFX906 1 #define CK_AMD_GPU_GFX906 1
#elif 0 #elif 0
#define CK_AMD_GPU_GFX908 1 #define CK_AMD_GPU_GFX908 1
#elif 0 #elif 1
#define CK_AMD_GPU_GFX1030 1 #define CK_AMD_GPU_GFX1030 1
#endif #endif
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
#endif #endif
// launch bounds // launch bounds
#define CK_USE_LAUNCH_BOUNDS 0 #define CK_USE_LAUNCH_BOUNDS 1
#ifdef CK_USE_LAUNCH_BOUNDS #ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 256 #define CK_MAX_THREAD_PER_BLOCK 256
......
#ifndef CK_FLOAT_TYPE_AMD_HPP #ifndef CK_FLOAT_TYPE_AMD_HPP
#define CK_FLOAT_TYPE_AMD_HPP #define CK_FLOAT_TYPE_AMD_HPP
#include "statically_indexed_array.hpp"
namespace ck { namespace ck {
using half_t = _Float16; using half_t = _Float16;
...@@ -43,6 +45,15 @@ struct vector_type_maker<vector_type<T, N1>, N0> ...@@ -43,6 +45,15 @@ struct vector_type_maker<vector_type<T, N1>, N0>
using type = vector_type<T, N0 * N1>; using type = vector_type<T, N0 * N1>;
}; };
template <typename T, index_t N>
using vector_type_maker_t = typename vector_type_maker<T, N>::type;
template <typename T, index_t N>
__host__ __device__ constexpr auto make_vector_type(Number<N>)
{
return typename vector_type_maker<T, N>::type{};
}
// scalar_type // scalar_type
template <typename TV> template <typename TV>
struct scalar_type; struct scalar_type;
...@@ -403,32 +414,249 @@ struct vector_type<T, 16> ...@@ -403,32 +414,249 @@ struct vector_type<T, 16>
} }
}; };
template <typename T>
struct vector_type<T, 32>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
using type = d32_t;
union
{
d32_t d32_;
StaticallyIndexedArray<d1_t, 32> d1x32_;
StaticallyIndexedArray<d2_t, 16> d2x16_;
StaticallyIndexedArray<d4_t, 8> d4x8_;
StaticallyIndexedArray<d8_t, 4> d8x4_;
StaticallyIndexedArray<d16_t, 2> d16x2_;
StaticallyIndexedArray<d32_t, 1> d32x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x32_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x16_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x8_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x4_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x2_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x32_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x16_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x8_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x4_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x2_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x1_;
}
}
};
template <typename T>
struct vector_type<T, 64>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
typedef T d64_t __attribute__((ext_vector_type(64)));
using type = d64_t;
union
{
d64_t d64_;
StaticallyIndexedArray<d1_t, 64> d1x64_;
StaticallyIndexedArray<d2_t, 32> d2x32_;
StaticallyIndexedArray<d4_t, 16> d4x16_;
StaticallyIndexedArray<d8_t, 8> d8x8_;
StaticallyIndexedArray<d16_t, 4> d16x4_;
StaticallyIndexedArray<d32_t, 2> d32x2_;
StaticallyIndexedArray<d64_t, 1> d64x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x2_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x2_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
}
};
// fp32 // fp32
using float2_t = typename vector_type<float, 2>::type; using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type; using float4_t = typename vector_type<float, 4>::type;
using float8_t = typename vector_type<float, 8>::type; using float8_t = typename vector_type<float, 8>::type;
using float16_t = typename vector_type<float, 16>::type;
using float32_t = typename vector_type<float, 32>::type;
using float64_t = typename vector_type<float, 64>::type;
// fp16 // fp16
using half2_t = typename vector_type<half_t, 2>::type; using half2_t = typename vector_type<half_t, 2>::type;
using half4_t = typename vector_type<half_t, 4>::type; using half4_t = typename vector_type<half_t, 4>::type;
using half8_t = typename vector_type<half_t, 8>::type; using half8_t = typename vector_type<half_t, 8>::type;
using half16_t = typename vector_type<half_t, 16>::type; using half16_t = typename vector_type<half_t, 16>::type;
using half32_t = typename vector_type<half_t, 32>::type;
using half64_t = typename vector_type<half_t, 64>::type;
// bfp16 // bfp16
using ushort2_t = typename vector_type<ushort, 2>::type; using ushort2_t = typename vector_type<ushort, 2>::type;
using ushort4_t = typename vector_type<ushort, 4>::type; using ushort4_t = typename vector_type<ushort, 4>::type;
using ushort8_t = typename vector_type<ushort, 8>::type; using ushort8_t = typename vector_type<ushort, 8>::type;
using ushort16_t = typename vector_type<ushort, 16>::type;
using ushort32_t = typename vector_type<ushort, 32>::type;
using ushort64_t = typename vector_type<ushort, 64>::type;
// i32 // i32
using int32x2_t = typename vector_type<int32_t, 2>::type; using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type; using int32x4_t = typename vector_type<int32_t, 4>::type;
using int32x8_t = typename vector_type<int32_t, 8>::type; using int32x8_t = typename vector_type<int32_t, 8>::type;
using int32x16_t = typename vector_type<int32_t, 16>::type;
using int32x32_t = typename vector_type<int32_t, 32>::type;
using int32x64_t = typename vector_type<int32_t, 64>::type;
// i8 // i8
using int8x2_t = typename vector_type<int8_t, 2>::type; using int8x2_t = typename vector_type<int8_t, 2>::type;
using int8x4_t = typename vector_type<int8_t, 4>::type; using int8x4_t = typename vector_type<int8_t, 4>::type;
using int8x8_t = typename vector_type<int8_t, 8>::type; using int8x8_t = typename vector_type<int8_t, 8>::type;
using int8x16_t = typename vector_type<int8_t, 16>::type; using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type;
// data type conversion // data type conversion
template <typename T> template <typename T>
......
...@@ -5,11 +5,26 @@ ...@@ -5,11 +5,26 @@
namespace ck { namespace ck {
template <index_t... Is>
__host__ __device__ constexpr auto make_sequence(Number<Is>...)
{
return Sequence<Is...>{};
}
// F returns index_t
template <typename F, index_t N> template <typename F, index_t N>
__host__ __device__ constexpr auto generate_sequence(F, Number<N>) __host__ __device__ constexpr auto generate_sequence(F, Number<N>)
{ {
return typename sequence_gen<N, F>::type{}; return typename sequence_gen<N, F>::type{};
} }
// F returns Number<>
template <typename F, index_t N>
__host__ __device__ constexpr auto generate_sequence_v2(F&& f, Number<N>)
{
return unpack([&f](auto&&... xs) { return make_sequence(f(xs)...); },
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -53,7 +53,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -53,7 +53,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr auto C0 = C / Number<InWeiVectorSize>{}; constexpr auto C0 = C / Number<InWeiVectorSize>{};
constexpr auto C1 = Number<InWeiVectorSize>{}; constexpr auto C1 = Number<InWeiVectorSize>{};
#if 1 #if 0
// run-time variables // run-time variables
constexpr auto in_n_hi_wi_c0_desc = constexpr auto in_n_hi_wi_c0_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C0)); make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C0));
...@@ -112,7 +112,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -112,7 +112,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
#if 0 #if 1
// cdata = 16, BlockSize = 64, 16x64x4 // cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
......
...@@ -64,7 +64,7 @@ int main(int argc, char* argv[]) ...@@ -64,7 +64,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 1
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
constexpr index_t HI = 1080; constexpr index_t HI = 1080;
...@@ -630,7 +630,7 @@ int main(int argc, char* argv[]) ...@@ -630,7 +630,7 @@ int main(int argc, char* argv[])
print_array("ConvStrides", to_multi_index(ConvStrides{})); print_array("ConvStrides", to_multi_index(ConvStrides{}));
print_array("ConvDilations", to_multi_index(ConvDilations{})); print_array("ConvDilations", to_multi_index(ConvDilations{}));
#if 1 #if 0
using in_data_t = float; using in_data_t = float;
constexpr index_t in_vector_size = 1; constexpr index_t in_vector_size = 1;
using acc_data_t = float; using acc_data_t = float;
...@@ -724,13 +724,12 @@ int main(int argc, char* argv[]) ...@@ -724,13 +724,12 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 1 #elif 0
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t, device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t,
in_vector_size, in_vector_size,
acc_data_t, acc_data_t,
out_data_t> out_data_t>(
in_nchw_desc,
(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
wei_kcyx, wei_kcyx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment