Commit 474733b5 authored by Chao Liu's avatar Chao Liu
Browse files

updating v5r1

parent 415a4a5b
......@@ -518,6 +518,14 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABlockBuffer::type>>,
remove_cv_t<remove_reference_t<FloatA>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename BBlockBuffer::type>>,
remove_cv_t<remove_reference_t<FloatB>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename CThreadBuffer::type>>,
remove_cv_t<remove_reference_t<FloatC>>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......
......@@ -228,8 +228,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
index_t w;
};
MatrixIndex c_thread_begin_mtx_idx_;
// HACK: fix this @Jing Zhang
static constexpr index_t KPerThreadSubC = 4;
......@@ -255,6 +253,8 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
AddressSpace::Vgpr,
1>;
MatrixIndex c_thread_begin_mtx_idx_;
AThreadCopy a_thread_copy_;
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v3()
......@@ -313,9 +313,18 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
return MatrixIndex{k_thread_id, h_thread_id, w_thread_id};
}
__device__ void Run(const FloatA* p_a_block, const FloatB* p_b_thread, FloatC* p_c_thread) const
template <typename ABlockBuffer, typename BThreadBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BThreadBuffer& b_thread_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_block_buf = make_dynamic_buffer(p_a_block);
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABlockBuffer::type>>,
remove_cv_t<remove_reference_t<FloatA>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename BThreadBuffer::type>>,
remove_cv_t<remove_reference_t<FloatB>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename CThreadBuffer::type>>,
remove_cv_t<remove_reference_t<FloatC>>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -334,12 +343,8 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
static_assert(HPerThread % HoPerThreadSubC == 0, "");
static_assert(WPerThread % WoPerThreadSubC == 0, "");
// thread A, B for GEMM
FloatA p_a_thread[a_thread_mtx_.GetElementSpaceSize()];
auto a_thread_buf = make_dynamic_buffer(p_a_thread);
auto b_thread_buf = make_dynamic_buffer(p_b_thread);
auto c_thread_buf = make_dynamic_buffer(p_c_thread);
// thread A buffer for GEMM
StaticBuffer<FloatA, a_thread_mtx_.GetElementSpaceSize()> a_thread_buf;
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3<FloatA,
FloatB,
......
......@@ -685,7 +685,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// zero out threadwise output
threadwise_matrix_set_zero_v3(c_k_n_ho_wo_thread_desc, p_c_thread);
#else
#elif 0
// register allocation for output
FloatAcc p_c_thread[c_k_n_ho_wo_thread_desc.GetElementSpaceSize()];
......@@ -695,7 +695,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
decltype(c_k_n_ho_wo_thread_desc),
Sequence<KPerThread, 1, HoPerThread, WoPerThread>>{}
.Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
#elif 1
// register allocation for output
StaticBuffer<FloatAcc, c_k_n_ho_wo_thread_desc.GetElementSpaceSize()> c_thread_buf;
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
decltype(c_k_n_ho_wo_thread_desc),
Sequence<KPerThread, 1, HoPerThread, WoPerThread>>{}
.Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
#endif
constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
......@@ -735,7 +742,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_b_global,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
b_thread_even_buf.p_data_,
b_thread_even_buf,
b_e_n_ho_wo_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_e_k_desc, p_a_block);
......@@ -759,14 +766,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_b_global,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
b_thread_odd_buf.p_data_,
b_thread_odd_buf,
b_e_n_ho_wo_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)),
b_thread_even_buf.p_data_,
p_c_thread);
make_dynamic_buffer(p_a_block + a_e_k_block_desc.CalculateOffset(
make_tuple(b_block_data_begin, 0))),
b_thread_even_buf,
c_thread_buf);
b_block_data_begin += EPerBlock;
......@@ -777,14 +785,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_b_global,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
b_thread_even_buf.p_data_,
b_thread_even_buf,
b_e_n_ho_wo_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)),
b_thread_odd_buf.p_data_,
p_c_thread);
make_dynamic_buffer(p_a_block + a_e_k_block_desc.CalculateOffset(
make_tuple(b_block_data_begin, 0))),
b_thread_odd_buf,
c_thread_buf);
b_block_data_begin += EPerBlock;
......@@ -801,30 +810,33 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_b_global,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
b_thread_odd_buf.p_data_,
b_thread_odd_buf,
b_e_n_ho_wo_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(
p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)),
b_thread_even_buf.p_data_,
p_c_thread);
make_dynamic_buffer(p_a_block + a_e_k_block_desc.CalculateOffset(
make_tuple(b_block_data_begin, 0))),
b_thread_even_buf,
c_thread_buf);
b_block_data_begin += EPerBlock;
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)),
b_thread_odd_buf.p_data_,
p_c_thread);
make_dynamic_buffer(p_a_block + a_e_k_block_desc.CalculateOffset(
make_tuple(b_block_data_begin, 0))),
b_thread_odd_buf,
c_thread_buf);
}
else // if has 1 iteration left
{
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)),
b_thread_even_buf.p_data_,
p_c_thread);
make_dynamic_buffer(p_a_block + a_e_k_block_desc.CalculateOffset(
make_tuple(b_block_data_begin, 0))),
b_thread_even_buf,
c_thread_buf);
}
// output: register to global memory
......
......@@ -422,12 +422,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
src_slice_origin_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx);
}
template <typename DstSliceOriginIdx, typename SrcIteratorHacks>
template <typename DstBuffer, typename DstSliceOriginIdx, typename SrcIteratorHacks>
__device__ void Run(const SrcDesc& src_desc,
const SrcData* p_src,
const DstDesc&,
const DstSliceOriginIdx&,
DstData* p_dst,
DstBuffer& dst_buf,
const SrcIteratorHacks& src_iterator_hacks)
{
static_assert(DstDesc::IsKnownAtCompileTime(),
......@@ -437,6 +437,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
is_known_at_compile_time<remove_cv_t<remove_reference_t<DstSliceOriginIdx>>>::value,
"wrong! DstSliceOrigin need to known at compile-time");
static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>,
remove_cv_t<remove_reference_t<DstData>>>::value &&
"wrong! inconsistent type");
// DstDesc and dst_slice_origin_idx are known at compile-time
constexpr auto dst_desc = remove_cv_t<remove_reference_t<DstDesc>>{};
constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{};
......@@ -564,7 +568,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
i * src_scalar_step_in_vector);
p_dst[Number<dst_offset>{}] = src_vector.template AsType<SrcData>()[i];
dst_buf(Number<dst_offset>{}) = src_vector.template AsType<SrcData>()[i];
});
constexpr auto move_on_dim = [&]() constexpr
......@@ -613,12 +617,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
}
}
template <typename DstSliceOriginIdx>
template <typename DstBuffer, typename DstSliceOriginIdx>
__device__ void Run(const SrcDesc& src_desc,
const SrcData* p_src,
const DstDesc&,
const DstSliceOriginIdx&,
DstData* p_dst)
DstBuffer& dst_buf)
{
constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform();
......@@ -628,7 +632,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
Run(src_desc, p_src, DstDesc{}, DstSliceOriginIdx{}, p_dst, src_iterator_hacks);
Run(src_desc, p_src, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_iterator_hacks);
}
__device__ static constexpr auto GetSrcCoordinateResetStep()
......
......@@ -209,6 +209,14 @@ struct ThreadwiseGemm_km_kn_mn_v1r1
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
remove_cv_t<remove_reference_t<FloatA>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
remove_cv_t<remove_reference_t<FloatB>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
remove_cv_t<remove_reference_t<FloatC>>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
......
......@@ -180,6 +180,14 @@ struct ThreadwiseGemm_km_kn_mn_v3
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
remove_cv_t<remove_reference_t<FloatA>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
remove_cv_t<remove_reference_t<FloatB>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
remove_cv_t<remove_reference_t<FloatC>>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment