"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "9da92cdb0f657d23ccf5594387398e4033a76f58"
Commit 29a118c6 authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into merge_use_division_mod

parents 1a43a538 19613902
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// GemmM = K
// GemmK = N * Ho * Wo
// GemmN = C * Y * X
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmK1Value>
__host__ __device__ constexpr auto
transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
const TensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
const TensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
const TensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3);
const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3);
const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2);
const auto X = wei_k_c_y_x_grid_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = K;
const auto GemmN = C * Y * X;
const auto GemmK = N * Ho * Wo;
const auto GemmK0 = GemmK / GemmK1;
// weight tensor
const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// input tensor
const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor(
in_n_c_hi_wi_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor(
in_n_c_hip_wip_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_gemmk_gemmn_grid_desc =
transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmk0_gemmn_gemmk1_grid_desc =
transform_tensor_descriptor(in_gemmk_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// output tensor
const auto out_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto out_gemmk0_gemmm_gemmk1_grid_desc =
transform_tensor_descriptor(out_gemmk_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif
...@@ -189,8 +189,7 @@ struct TensorAdaptor ...@@ -189,8 +189,7 @@ struct TensorAdaptor
bool is_known = true; bool is_known = true;
static_for<0, Transforms::Size(), 1>{}([&](auto i) { static_for<0, Transforms::Size(), 1>{}([&](auto i) {
is_known &= is_known &= remove_cvref_t<decltype(Transforms{}[i])>::IsKnownAtCompileTime();
remove_cv_t<remove_reference_t<decltype(Transforms{}[i])>>::IsKnownAtCompileTime();
}); });
return is_known && is_known_at_compile_time<ElementSize>::value; return is_known && is_known_at_compile_time<ElementSize>::value;
......
...@@ -185,8 +185,7 @@ struct TensorDescriptor ...@@ -185,8 +185,7 @@ struct TensorDescriptor
bool is_known = true; bool is_known = true;
static_for<0, Transforms::Size(), 1>{}([&](auto i) { static_for<0, Transforms::Size(), 1>{}([&](auto i) {
is_known &= is_known &= remove_cvref_t<decltype(Transforms{}[i])>::IsKnownAtCompileTime();
remove_cv_t<remove_reference_t<decltype(Transforms{}[i])>>::IsKnownAtCompileTime();
}); });
return is_known && is_known_at_compile_time<ElementSize>::value && return is_known && is_known_at_compile_time<ElementSize>::value &&
...@@ -587,11 +586,11 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc& ...@@ -587,11 +586,11 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc&
template <typename TensorDesc> template <typename TensorDesc>
using TensorCoordinate_t = decltype(make_tensor_coordinate( using TensorCoordinate_t = decltype(make_tensor_coordinate(
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{})); TensorDesc{}, MultiIndex<remove_cvref_t<TensorDesc>::GetNumOfDimension()>{}));
template <typename TensorDesc> template <typename TensorDesc>
using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step( using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step(
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{})); TensorDesc{}, MultiIndex<remove_cvref_t<TensorDesc>::GetNumOfDimension()>{}));
} // namespace ck } // namespace ck
#endif #endif
...@@ -110,13 +110,11 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -110,13 +110,11 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
const BThreadBuffer& b_thread_buf, const BThreadBuffer& b_thread_buf,
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
{ {
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABlockBuffer::type>>, static_assert(
remove_cv_t<remove_reference_t<FloatA>>>::value && is_same<remove_cvref_t<typename ABlockBuffer::type>, remove_cvref_t<FloatA>>::value &&
is_same<remove_cv_t<remove_reference_t<typename BThreadBuffer::type>>, is_same<remove_cvref_t<typename BThreadBuffer::type>, remove_cvref_t<FloatB>>::value &&
remove_cv_t<remove_reference_t<FloatB>>>::value && is_same<remove_cvref_t<typename CThreadBuffer::type>, remove_cvref_t<FloatC>>::value &&
is_same<remove_cv_t<remove_reference_t<typename CThreadBuffer::type>>, "wrong! inconsistent type");
remove_cv_t<remove_reference_t<FloatC>>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
......
...@@ -55,19 +55,16 @@ struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1 ...@@ -55,19 +55,16 @@ struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
CBuffer& c_buf, CBuffer& c_buf,
COriginIdx) COriginIdx)
{ {
static_assert(is_known_at_compile_time<remove_cvref_t<AOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<BOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<COriginIdx>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
static_assert( static_assert(
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value && is_same<remove_cvref_t<typename ABuffer::type>, remove_cvref_t<FloatA>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value && is_same<remove_cvref_t<typename BBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value, is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); "wrong! inconsistent type");
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
remove_cv_t<remove_reference_t<FloatA>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
remove_cv_t<remove_reference_t<FloatB>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
remove_cv_t<remove_reference_t<FloatC>>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -157,19 +154,16 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_ ...@@ -157,19 +154,16 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_
CBuffer& c_buf, CBuffer& c_buf,
COriginIdx) COriginIdx)
{ {
static_assert(is_known_at_compile_time<remove_cvref_t<AOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<BOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<COriginIdx>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
static_assert( static_assert(
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value && is_same<remove_cvref_t<typename ABuffer::type>, remove_cvref_t<FloatA>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value && is_same<remove_cvref_t<typename BBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value, is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); "wrong! inconsistent type");
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
remove_cv_t<remove_reference_t<FloatA>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
remove_cv_t<remove_reference_t<FloatB>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
remove_cv_t<remove_reference_t<FloatC>>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
......
...@@ -41,19 +41,16 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3 ...@@ -41,19 +41,16 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
CDesc::IsKnownAtCompileTime(), CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<AOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<BOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<COriginIdx>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
static_assert( static_assert(
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value && is_same<remove_cvref_t<typename ABuffer::type>, remove_cvref_t<FloatA>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value && is_same<remove_cvref_t<typename BBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value, is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); "wrong! inconsistent type");
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
remove_cv_t<remove_reference_t<FloatA>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
remove_cv_t<remove_reference_t<FloatB>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
remove_cv_t<remove_reference_t<FloatC>>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
......
...@@ -30,11 +30,11 @@ struct ThreadwiseTensorSliceSet_v1 ...@@ -30,11 +30,11 @@ struct ThreadwiseTensorSliceSet_v1
static_assert(Buffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); static_assert(Buffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
static_assert(is_known_at_compile_time<remove_cv_t<remove_reference_t<OriginIdx>>>::value, static_assert(is_known_at_compile_time<remove_cvref_t<OriginIdx>>::value,
"wrong! OriginIdx need to be known at compile-time"); "wrong! OriginIdx need to be known at compile-time");
// Desc is known at compile-time // Desc is known at compile-time
constexpr auto desc = remove_cv_t<remove_reference_t<Desc>>{}; constexpr auto desc = remove_cvref_t<Desc>{};
// OriginIdx is known at compile-time // OriginIdx is known at compile-time
constexpr auto origin_idx = to_multi_index(OriginIdx{}); constexpr auto origin_idx = to_multi_index(OriginIdx{});
......
...@@ -95,18 +95,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -95,18 +95,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3
static_assert(SrcDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time"); "wrong! SrcDesc need to known at compile-time");
static_assert( static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
is_known_at_compile_time<remove_cv_t<remove_reference_t<SrcSliceOriginIdx>>>::value, "wrong! SrcSliceOrigin need to known at compile-time");
"wrong! SrcSliceOrigin need to known at compile-time");
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer"); static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
// static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
// remove_cv_t<remove_reference_t<SrcData>>>::value,
//"wrong! SrcBuffer data type is wrong");
// SrcDesc and src_slice_origin_idx are known at compile-time // SrcDesc and src_slice_origin_idx are known at compile-time
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{}; constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -421,16 +416,15 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -421,16 +416,15 @@ struct ThreadwiseTensorSliceTransfer_v2
static_assert(DstDesc::IsKnownAtCompileTime(), static_assert(DstDesc::IsKnownAtCompileTime(),
"wrong! DstDesc need to known at compile-time"); "wrong! DstDesc need to known at compile-time");
static_assert( static_assert(is_known_at_compile_time<remove_cvref_t<DstSliceOriginIdx>>::value,
is_known_at_compile_time<remove_cv_t<remove_reference_t<DstSliceOriginIdx>>>::value, "wrong! DstSliceOrigin need to known at compile-time");
"wrong! DstSliceOrigin need to known at compile-time");
static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>, static_assert(
remove_cv_t<remove_reference_t<DstData>>>::value && is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value &&
"wrong! inconsistent type"); "wrong! inconsistent type");
// DstDesc and dst_slice_origin_idx are known at compile-time // DstDesc and dst_slice_origin_idx are known at compile-time
constexpr auto dst_desc = remove_cv_t<remove_reference_t<DstDesc>>{}; constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{}; constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{};
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -742,9 +736,9 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -742,9 +736,9 @@ struct ThreadwiseTensorSliceTransfer_v3
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
"wrong!"); "wrong!");
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>, static_assert(
remove_cv_t<remove_reference_t<SrcData>>>::value, is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value,
"wrong! SrcBuffer and SrcData data type are inconsistent"); "wrong! SrcBuffer and SrcData data type are inconsistent");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -899,9 +893,9 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -899,9 +893,9 @@ struct ThreadwiseTensorSliceTransfer_v3
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
"wrong!"); "wrong!");
static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>, static_assert(
remove_cv_t<remove_reference_t<DstData>>>::value, is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
"wrong! SrcBuffer or DstBuffer data type is wrong"); "wrong! SrcBuffer or DstBuffer data type is wrong");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -1315,24 +1309,21 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1315,24 +1309,21 @@ struct ThreadwiseTensorSliceTransfer_v4
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time"); "wrong! SrcDesc and DstDesc need to known at compile-time");
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>, static_assert(
remove_cv_t<remove_reference_t<SrcData>>>::value && is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value &&
is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>, is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
remove_cv_t<remove_reference_t<DstData>>>::value, "wrong! SrcBuffer or DstBuffer data type is wrong");
"wrong! SrcBuffer or DstBuffer data type is wrong");
static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
static_assert( static_assert(is_known_at_compile_time<remove_cvref_t<SrcRefToOriginDisplacement>>::value &&
is_known_at_compile_time< is_known_at_compile_time<remove_cvref_t<DstOriginIdx>>::value,
remove_cv_t<remove_reference_t<SrcRefToOriginDisplacement>>>::value && "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
is_known_at_compile_time<remove_cv_t<remove_reference_t<DstOriginIdx>>>::value, "at compile-time");
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time");
// SrcDesc and DstDesc are known at compile-time // SrcDesc and DstDesc are known at compile-time
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{}; constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto dst_desc = remove_cv_t<remove_reference_t<DstDesc>>{}; constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{});
......
...@@ -80,9 +80,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -80,9 +80,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
"wrong!"); "wrong!");
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>, static_assert(
remove_cv_t<remove_reference_t<SrcData>>>::value, is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value,
"wrong! SrcBuffer and SrcData data type are inconsistent"); "wrong! SrcBuffer and SrcData data type are inconsistent");
// tensor descriptor for src_vector // tensor descriptor for src_vector
constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{}; constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{};
...@@ -248,9 +248,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -248,9 +248,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
"wrong!"); "wrong!");
static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>, static_assert(
remove_cv_t<remove_reference_t<DstData>>>::value, is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
"wrong! SrcBuffer or DstBuffer data type is wrong"); "wrong! SrcBuffer or DstBuffer data type is wrong");
// tensor descriptor for dst_vector // tensor descriptor for dst_vector
constexpr auto dst_vector_tensor_lengths = DstVectorTensorLengths{}; constexpr auto dst_vector_tensor_lengths = DstVectorTensorLengths{};
...@@ -669,24 +669,21 @@ struct ThreadwiseTensorSliceTransfer_v4r1 ...@@ -669,24 +669,21 @@ struct ThreadwiseTensorSliceTransfer_v4r1
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time"); "wrong! SrcDesc and DstDesc need to known at compile-time");
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>, static_assert(
remove_cv_t<remove_reference_t<SrcData>>>::value && is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value &&
is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>, is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
remove_cv_t<remove_reference_t<DstData>>>::value, "wrong! SrcBuffer or DstBuffer data type is wrong");
"wrong! SrcBuffer or DstBuffer data type is wrong");
static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
static_assert( static_assert(is_known_at_compile_time<remove_cvref_t<SrcRefToOriginDisplacement>>::value &&
is_known_at_compile_time< is_known_at_compile_time<remove_cvref_t<DstOriginIdx>>::value,
remove_cv_t<remove_reference_t<SrcRefToOriginDisplacement>>>::value && "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
is_known_at_compile_time<remove_cv_t<remove_reference_t<DstOriginIdx>>>::value, "at compile-time");
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time");
// SrcDesc and DstDesc are known at compile-time // SrcDesc and DstDesc are known at compile-time
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{}; constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto dst_desc = remove_cv_t<remove_reference_t<DstDesc>>{}; constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{});
......
...@@ -225,13 +225,49 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -225,13 +225,49 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
index_t src_wave_addr_offset) index_t src_wave_addr_offset)
{ {
static_assert( static_assert(
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, double>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)), (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented"); "wrong! not implemented");
if constexpr(is_same<T, float>::value) if constexpr(is_same<T, double>::value)
{
// use fp32 load to mimic fp64 load
if constexpr(N == 1)
{
const float2_t tmp = llvm_amdgcn_raw_buffer_load_fp32x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<double>(tmp);
}
else if constexpr(N == 2)
{
const float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<double2_t>(tmp);
}
else if constexpr(N == 4)
{
const float4_t f32_0 = llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
const float4_t f32_1 =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(float),
0);
vector_type<double, 4> tmp;
tmp.AsType<double2_t>()(Number<0>{}) = as_type<double2_t>(f32_0);
tmp.AsType<double2_t>()(Number<1>{}) = as_type<double2_t>(f32_1);
return tmp.AsType<double4_t>()(Number<0>{});
}
}
else if constexpr(is_same<T, float>::value)
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
...@@ -283,25 +319,11 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -283,25 +319,11 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
#if 0 // use fp32 load to mimic fp16 load
vector_type<half_t, 8> tmp;
tmp.AsType<half4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_fp16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.AsType<half4_t>()(Number<1>{}) =
llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(half_t),
0);
return tmp.AsType<half8_t>()(Number<0>{});
#else
float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4( float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<half8_t>(tmp); return as_type<half8_t>(tmp);
#endif
} }
} }
else if constexpr(is_same<T, int32_t>::value) else if constexpr(is_same<T, int32_t>::value)
...@@ -433,13 +455,34 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -433,13 +455,34 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
index_t dst_wave_addr_offset) index_t dst_wave_addr_offset)
{ {
static_assert( static_assert(
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, double>::value && (N == 1 || N == 2)) ||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)),
"wrong! not implemented"); "wrong! not implemented");
if constexpr(is_same<T, float>::value) if constexpr(is_same<T, double>::value)
{
// use fp32 store to mimic fp64 store
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_fp32x2(as_type<float2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_fp32x4(as_type<float4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
}
else if constexpr(is_same<T, float>::value)
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
...@@ -466,6 +509,49 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -466,6 +509,49 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
0); 0);
} }
} }
else if constexpr(is_same<T, half_t>::value)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_fp16(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 8)
{
vector_type<half_t, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(half_t),
0);
}
}
else if constexpr(is_same<T, int32_t>::value) else if constexpr(is_same<T, int32_t>::value)
{ {
if constexpr(N == 1) if constexpr(N == 1)
...@@ -552,49 +638,6 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -552,49 +638,6 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
0); 0);
} }
} }
else if constexpr(is_same<T, half_t>::value)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_fp16(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 8)
{
vector_type<half_t, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(half_t),
0);
}
}
} }
template <typename T, index_t N> template <typename T, index_t N>
...@@ -720,7 +763,7 @@ __device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::typ ...@@ -720,7 +763,7 @@ __device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::typ
} }
// buffer_load requires: // buffer_load requires:
// 1) p_src_wave must be in global memory space // 1) p_src_wave must point to global memory space
// 2) p_src_wave must be a wavewise pointer. // 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true. // It is user's responsibility to make sure that is true.
template <typename T, index_t N> template <typename T, index_t N>
...@@ -754,7 +797,7 @@ amd_buffer_load_invalid_element_return_return_zero(const T* p_src_wave, ...@@ -754,7 +797,7 @@ amd_buffer_load_invalid_element_return_return_zero(const T* p_src_wave,
} }
// buffer_load requires: // buffer_load requires:
// 1) p_src_wave must be in global memory space // 1) p_src_wave must point to global memory space
// 2) p_src_wave must be a wavewise pointer. // 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true. // It is user's responsibility to make sure that is true.
template <typename T, index_t N> template <typename T, index_t N>
...@@ -782,7 +825,7 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, ...@@ -782,7 +825,7 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
} }
// buffer_store requires: // buffer_store requires:
// 1) p_dst_wave must be global memory // 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer. // 2) p_dst_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true. // It is user's responsibility to make sure that is true.
template <typename T, index_t N> template <typename T, index_t N>
...@@ -816,7 +859,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -816,7 +859,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
} }
// buffer_atomic_add requires: // buffer_atomic_add requires:
// 1) p_dst_wave must be global memory // 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer. // 2) p_dst_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true. // It is user's responsibility to make sure that is true.
template <typename T, index_t N> template <typename T, index_t N>
......
...@@ -48,7 +48,7 @@ struct Array<TData, 0> ...@@ -48,7 +48,7 @@ struct Array<TData, 0>
template <typename X, typename... Xs> template <typename X, typename... Xs>
__host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs) __host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs)
{ {
using data_type = remove_cv_t<remove_reference_t<X>>; using data_type = remove_cvref_t<X>;
return Array<data_type, sizeof...(Xs) + 1>{{std::forward<X>(x), std::forward<Xs>(xs)...}}; return Array<data_type, sizeof...(Xs) + 1>{{std::forward<X>(x), std::forward<Xs>(xs)...}};
} }
......
...@@ -73,6 +73,13 @@ struct scalar_type<vector_type<T, N>> ...@@ -73,6 +73,13 @@ struct scalar_type<vector_type<T, N>>
}; };
// //
template <>
struct scalar_type<double>
{
using type = double;
static constexpr index_t vector_size = 1;
};
template <> template <>
struct scalar_type<float> struct scalar_type<float>
{ {
...@@ -864,6 +871,10 @@ struct vector_type<T, 256> ...@@ -864,6 +871,10 @@ struct vector_type<T, 256>
} }
}; };
// fp64
using double2_t = typename vector_type<double, 2>::type;
using double4_t = typename vector_type<double, 4>::type;
// fp32 // fp32
using float2_t = typename vector_type<float, 2>::type; using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type; using float4_t = typename vector_type<float, 4>::type;
......
...@@ -39,18 +39,15 @@ struct DynamicBuffer ...@@ -39,18 +39,15 @@ struct DynamicBuffer
} }
template <typename X, template <typename X,
typename enable_if< typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type, typename scalar_type<remove_cvref_t<T>>::type>::value,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value, bool>::type = false>
bool>::type = false>
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
{ {
// X contains multiple T // X contains multiple T
constexpr index_t scalar_per_t_vector = constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;
constexpr index_t scalar_per_x_vector = constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T"); "wrong! X need to be multiple T");
...@@ -67,15 +64,14 @@ struct DynamicBuffer ...@@ -67,15 +64,14 @@ struct DynamicBuffer
if constexpr(InvalidElementUseNumericalZeroValue) if constexpr(InvalidElementUseNumericalZeroValue)
{ {
return amd_buffer_load_invalid_element_return_return_zero< return amd_buffer_load_invalid_element_return_return_zero<remove_cvref_t<T>,
remove_cv_t<remove_reference_t<T>>, t_per_x>(
t_per_x>(p_data_, i, is_valid_element, element_space_size_); p_data_, i, is_valid_element, element_space_size_);
} }
else else
{ {
return amd_buffer_load_invalid_element_return_customized_value< return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>,
remove_cv_t<remove_reference_t<T>>, t_per_x>(
t_per_x>(
p_data_, i, is_valid_element, element_space_size_, invalid_element_value_); p_data_, i, is_valid_element, element_space_size_, invalid_element_value_);
} }
} }
...@@ -94,18 +90,15 @@ struct DynamicBuffer ...@@ -94,18 +90,15 @@ struct DynamicBuffer
} }
template <typename X, template <typename X,
typename enable_if< typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type, typename scalar_type<remove_cvref_t<T>>::type>::value,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value, bool>::type = false>
bool>::type = false>
__host__ __device__ void Set(index_t i, bool is_valid_element, const X& x) __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
{ {
// X contains multiple T // X contains multiple T
constexpr index_t scalar_per_t_vector = constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;
constexpr index_t scalar_per_x_vector = constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T"); "wrong! X need to be multiple T");
...@@ -115,7 +108,7 @@ struct DynamicBuffer ...@@ -115,7 +108,7 @@ struct DynamicBuffer
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ADDRESSING
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store<remove_cv_t<remove_reference_t<T>>, t_per_x>( amd_buffer_store<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_); x, p_data_, i, is_valid_element, element_space_size_);
#else #else
if(is_valid_element) if(is_valid_element)
...@@ -136,70 +129,65 @@ struct DynamicBuffer ...@@ -136,70 +129,65 @@ struct DynamicBuffer
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to // ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
// ds_write_b128 // ds_write_b128
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
if constexpr(is_same<typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type, if constexpr(is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value)
int8_t>::value)
{ {
static_assert( static_assert((is_same<remove_cvref_t<T>, int8_t>::value &&
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value && is_same<remove_cvref_t<X>, int8_t>::value) ||
is_same<remove_cv_t<remove_reference_t<X>>, int8_t>::value) || (is_same<remove_cvref_t<T>, int8_t>::value &&
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value && is_same<remove_cvref_t<X>, int8x2_t>::value) ||
is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value) || (is_same<remove_cvref_t<T>, int8_t>::value &&
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value && is_same<remove_cvref_t<X>, int8x4_t>::value) ||
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) || (is_same<remove_cvref_t<T>, int8x4_t>::value &&
(is_same<remove_cv_t<remove_reference_t<T>>, int8x4_t>::value && is_same<remove_cvref_t<X>, int8x4_t>::value) ||
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) || (is_same<remove_cvref_t<T>, int8x8_t>::value &&
(is_same<remove_cv_t<remove_reference_t<T>>, int8x8_t>::value && is_same<remove_cvref_t<X>, int8x8_t>::value) ||
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value) || (is_same<remove_cvref_t<T>, int8x16_t>::value &&
(is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value && is_same<remove_cvref_t<X>, int8x16_t>::value),
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value), "wrong! not implemented for this combination, please add "
"wrong! not implemented for this combination, please add " "implementation");
"implementation");
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value && if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8_t>::value) is_same<remove_cvref_t<X>, int8_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*c_style_pointer_cast<int8_t*>(&p_data_[i]) = *c_style_pointer_cast<int8_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int8_t*>(&x); *c_style_pointer_cast<const int8_t*>(&x);
} }
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value && else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value) is_same<remove_cvref_t<X>, int8x2_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*c_style_pointer_cast<int16_t*>(&p_data_[i]) = *c_style_pointer_cast<int16_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int16_t*>(&x); *c_style_pointer_cast<const int16_t*>(&x);
} }
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value && else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) is_same<remove_cvref_t<X>, int8x4_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*c_style_pointer_cast<int32_t*>(&p_data_[i]) = *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32_t*>(&x); *c_style_pointer_cast<const int32_t*>(&x);
} }
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, else if constexpr(is_same<remove_cvref_t<T>, int8x4_t>::value &&
int8x4_t>::value && is_same<remove_cvref_t<X>, int8x4_t>::value)
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*c_style_pointer_cast<int32_t*>(&p_data_[i]) = *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32_t*>(&x); *c_style_pointer_cast<const int32_t*>(&x);
} }
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, else if constexpr(is_same<remove_cvref_t<T>, int8x8_t>::value &&
int8x8_t>::value && is_same<remove_cvref_t<X>, int8x8_t>::value)
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) = *c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x2_t*>(&x); *c_style_pointer_cast<const int32x2_t*>(&x);
} }
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, else if constexpr(is_same<remove_cvref_t<T>, int8x16_t>::value &&
int8x16_t>::value && is_same<remove_cvref_t<X>, int8x16_t>::value)
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
...@@ -224,18 +212,15 @@ struct DynamicBuffer ...@@ -224,18 +212,15 @@ struct DynamicBuffer
} }
template <typename X, template <typename X,
typename enable_if< typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type, typename scalar_type<remove_cvref_t<T>>::type>::value,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value, bool>::type = false>
bool>::type = false>
__host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x) __host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x)
{ {
// X contains multiple T // X contains multiple T
constexpr index_t scalar_per_t_vector = constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;
constexpr index_t scalar_per_x_vector = constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T"); "wrong! X need to be multiple T");
...@@ -245,7 +230,7 @@ struct DynamicBuffer ...@@ -245,7 +230,7 @@ struct DynamicBuffer
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ADDRESSING
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add<remove_cv_t<remove_reference_t<T>>, t_per_x>( amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_); x, p_data_, i, is_valid_element, element_space_size_);
#else #else
if(is_valid_element) if(is_valid_element)
...@@ -266,9 +251,14 @@ __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize el ...@@ -266,9 +251,14 @@ __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize el
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, true>{p, element_space_size}; return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, true>{p, element_space_size};
} }
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize> template <
AddressSpaceEnum_t BufferAddressSpace,
typename T,
typename ElementSpaceSize,
typename X,
typename enable_if<is_same<remove_cvref_t<T>, remove_cvref_t<X>>::value, bool>::type = false>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, T invalid_element_value) make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element_value)
{ {
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, false>{ return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, false>{
p, element_space_size, invalid_element_value}; p, element_space_size, invalid_element_value};
......
...@@ -159,7 +159,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -159,7 +159,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
template <typename... Xs> template <typename... Xs>
__host__ __device__ constexpr auto make_tuple(Xs&&... xs) __host__ __device__ constexpr auto make_tuple(Xs&&... xs)
{ {
return Tuple<remove_cv_t<remove_reference_t<Xs>>...>(std::forward<Xs>(xs)...); return Tuple<remove_cvref_t<Xs>...>(std::forward<Xs>(xs)...);
} }
} // namespace ck } // namespace ck
......
...@@ -14,9 +14,7 @@ struct is_known_at_compile_time<Tuple<Ts...>> ...@@ -14,9 +14,7 @@ struct is_known_at_compile_time<Tuple<Ts...>>
return container_reduce( return container_reduce(
Tuple<Ts...>{}, Tuple<Ts...>{},
[](auto x, bool r) { [](auto x, bool r) {
return is_known_at_compile_time< return is_known_at_compile_time<remove_cvref_t<decltype(x)>>::value & r;
remove_cv_t<remove_reference_t<decltype(x)>>>::value &
r;
}, },
true); true);
} }
......
...@@ -22,6 +22,9 @@ using remove_reference_t = typename std::remove_reference<T>::type; ...@@ -22,6 +22,9 @@ using remove_reference_t = typename std::remove_reference<T>::type;
template <typename T> template <typename T>
using remove_cv_t = typename std::remove_cv<T>::type; using remove_cv_t = typename std::remove_cv<T>::type;
template <typename T>
using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
template <typename T> template <typename T>
inline constexpr bool is_pointer_v = std::is_pointer<T>::value; inline constexpr bool is_pointer_v = std::is_pointer<T>::value;
......
...@@ -374,13 +374,8 @@ extern "C" __global__ void ...@@ -374,13 +374,8 @@ extern "C" __global__ void
CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1{}, CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1{},
CGridBlockCluster_BlockId_To_GM10_GN10{})); CGridBlockCluster_BlockId_To_GM10_GN10{}));
const auto desc_tuple = *reinterpret_cast<const DescTuple*>( const auto desc_tuple =
#pragma clang diagnostic push *reinterpret_cast<const DescTuple*>(cast_pointer_to_generic_address_space(p_desc_tuple));
#pragma clang diagnostic ignored "-Wold-style-cast"
// TODO: how to cast?
(const void*)p_desc_tuple
#pragma clang diagnostic pop
);
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = desc_tuple[I0]; const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = desc_tuple[I0];
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = desc_tuple[I1]; const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = desc_tuple[I1];
......
...@@ -13,9 +13,15 @@ include_directories(BEFORE ...@@ -13,9 +13,15 @@ include_directories(BEFORE
set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp) set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp)
set(CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp) set(CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp)
set(CONV_WRW_DRIVER_OFFLINE_SOURCE src/conv_wrw_driver_offline.cpp)
set(GEMM_DRIVER_OFFLINE_SOURCE src/gemm_driver_offline.cpp)
add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE}) add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE})
add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE}) add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE})
add_executable(conv_wrw_driver_offline ${CONV_WRW_DRIVER_OFFLINE_SOURCE})
add_executable(gemm_driver_offline ${GEMM_DRIVER_OFFLINE_SOURCE})
target_link_libraries(conv_fwd_driver_offline PRIVATE host_tensor) target_link_libraries(conv_fwd_driver_offline PRIVATE host_tensor)
target_link_libraries(conv_bwd_driver_offline PRIVATE host_tensor) target_link_libraries(conv_bwd_driver_offline PRIVATE host_tensor)
target_link_libraries(conv_wrw_driver_offline PRIVATE host_tensor)
target_link_libraries(gemm_driver_offline PRIVATE host_tensor)
...@@ -208,20 +208,20 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( ...@@ -208,20 +208,20 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
// HACK: hacks that control index calculation when iterating over A, B, C matrix // HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmm Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: Gemmk0 make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmm Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1
constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple( constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmn Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0 make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmn Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1
// clang-format off // clang-format off
constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple( constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(
......
#include <unistd.h> #include <unistd.h>
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" #include "transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
#include "driver_gemm_xdlops_v2r3.hpp" #include "driver_gemm_xdlops_v2r3.hpp"
template <typename TInWei, template <typename TInWei,
...@@ -14,17 +14,17 @@ template <typename TInWei, ...@@ -14,17 +14,17 @@ template <typename TInWei,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads>
void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
const InLengths& in_n_hi_wi_c_lengths, const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_y_x_c_lengths, const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_ho_wo_k_lengths, const OutLengths& out_n_k_ho_wo_lengths,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
const InRightPads& in_right_pads, const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_hi_wi_c, const Tensor<TInWei>& in_n_c_hi_wi,
const Tensor<TInWei>& wei_k_y_x_c, Tensor<TInWei>& wei_k_c_y_x,
Tensor<TOut>& out_n_ho_wo_k, const Tensor<TOut>& out_n_k_ho_wo,
ck::index_t nrepeat) ck::index_t nrepeat)
{ {
using namespace ck; using namespace ck;
...@@ -34,55 +34,21 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( ...@@ -34,55 +34,21 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr auto I8 = Number<8>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
#if 1 #if 1
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32 // [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
...@@ -91,54 +57,26 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( ...@@ -91,54 +57,26 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmMPerWave = 32; constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32; constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4; constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2; constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2; constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
// using vector load 4, so config's wo*ho must be a multiple of 4
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 0
// [M, N, K0, K1] = [256, 256, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 4;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1 #elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16 // [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256; constexpr index_t GemmMPerBlock = 256;
...@@ -154,70 +92,73 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( ...@@ -154,70 +92,73 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
// using vector load 4, so config's wo*ho must be a multiple of 4
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif #endif
const auto descs = const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc, wei_k_c_y_x_desc,
in_n_hi_wi_c_desc, in_n_c_hi_wi_desc,
out_n_ho_wo_k_desc, out_n_k_ho_wo_desc,
conv_strides, conv_strides,
conv_dilations, conv_dilations,
in_left_pads, in_left_pads,
in_right_pads, in_right_pads,
Number<GemmK1>{}); Number<GemmK1>{});
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto out_gemmm_gemmn_grid_desc = descs[I2]; const auto wei_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix // HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks =
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), make_tuple(make_tuple(Sequence<0, 0, 1, 0, 0>{}, // 0+: GemmK0
make_tuple( Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); Sequence<0, 0, 1, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 2, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM
Sequence<0, 0, 2, 0, 0>{})); // 2-: GemmK1
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 1+: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 1-: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 2-: GemmK1
constexpr auto out_m0_m1_m2_n_grid_step_hacks = constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 1, 0, 0>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 2, 0, 0>{})); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0>{}; Sequence<0, 0, 1, 0, 0>{};
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0>{};
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
...@@ -227,14 +168,15 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( ...@@ -227,14 +168,15 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc), decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc), decltype(wei_gemmm_gemmn_grid_desc),
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
GemmMPerWave, GemmMPerWave,
GemmNPerWave, GemmNPerWave,
GemmK1,
MRepeat, MRepeat,
NRepeat, NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
...@@ -250,53 +192,37 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( ...@@ -250,53 +192,37 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
2, 2,
GemmBBlockTransferSrcScalarPerVector_GemmK1, GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmK1, GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy false, // don't move back src coordinate after threadwise copy
Sequence<2, 3, 0, 1, 7, 5, 4, 6>, Sequence<3, 0, 1, 2, 7, 5, 4, 6>,
6, 7,
GemmCThreadTransferDstScalarPerVector, GemmCThreadTransferDstScalarPerVector,
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(out_m0_m1_m2_n_grid_step_hacks), decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false // CAccessOrderMRepeatNRepeat false>(static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()), static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmm_gemmk1_grid_desc, in_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc, wei_gemmm_gemmn_grid_desc,
out_gemmm_gemmn_grid_desc, out_gemmk0_gemmm_gemmk1_grid_step_hacks,
wei_gemmk0_gemmm_gemmk1_grid_step_hacks, in_gemmk0_gemmn_gemmk1_grid_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_step_hacks, wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
out_m0_m1_m2_n_grid_step_hacks, out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, nrepeat);
nrepeat);
float perf = static_cast<float>(calculate_convolution_flops(
{ in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) /
const auto N = out_n_ho_wo_k_lengths[I0]; (std::size_t(1000) * 1000 * 1000) / ave_time;
const auto K = out_n_ho_wo_k_lengths[I3];
const auto C = wei_k_y_x_c_lengths[I3]; std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
const auto Hi = in_n_hi_wi_c_lengths[I1];
const auto Wi = in_n_hi_wi_c_lengths[I2];
const auto Ho = out_n_ho_wo_k_lengths[I1];
const auto Wo = out_n_ho_wo_k_lengths[I2];
const auto Y = wei_k_y_x_c_lengths[I1];
const auto X = wei_k_y_x_c_lengths[I2];
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
} }
// copy result back to host // copy result back to host
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); wei_k_c_y_x_device_buf.FromDevice(wei_k_c_y_x.mData.data());
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment