Commit a84f254a authored by Chao Liu's avatar Chao Liu
Browse files

added conv+bias+relu

parent 8159be33
...@@ -19,7 +19,6 @@ template <typename GridwiseGemm, ...@@ -19,7 +19,6 @@ template <typename GridwiseGemm,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2, typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2, typename C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
...@@ -34,12 +33,10 @@ __global__ void ...@@ -34,12 +33,10 @@ __global__ void
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const FloatC* __restrict__ p_c0_grid, const FloatC* __restrict__ p_c0_grid,
const FloatC* __restrict__ p_c1_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
...@@ -54,13 +51,11 @@ __global__ void ...@@ -54,13 +51,11 @@ __global__ void
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_c0_grid, p_c0_grid,
p_c1_grid,
p_shared_block, p_shared_block,
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
...@@ -76,7 +71,6 @@ template <index_t BlockSize, ...@@ -76,7 +71,6 @@ template <index_t BlockSize,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
typename C0GridDesc_M_N, typename C0GridDesc_M_N,
typename C1GridDesc_M_N,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
...@@ -326,9 +320,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6 ...@@ -326,9 +320,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6
using C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = using C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(C0GridDesc_M_N{})); decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(C0GridDesc_M_N{}));
using C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(C1GridDesc_M_N{}));
using Block2CTileMap = decltype(MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); using Block2CTileMap = decltype(MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
template <bool HasMainKBlockLoop> template <bool HasMainKBlockLoop>
...@@ -337,13 +328,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6 ...@@ -337,13 +328,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const FloatC* __restrict__ p_c0_grid, const FloatC* __restrict__ p_c0_grid,
const FloatC* __restrict__ p_c1_grid,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
...@@ -359,9 +348,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6 ...@@ -359,9 +348,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6
auto c0_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c0_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c0_grid, c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); p_c0_grid, c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
auto c1_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c1_grid, c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
// divide block work by [M, N] // divide block work by [M, N]
...@@ -615,7 +601,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6 ...@@ -615,7 +601,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2),
CElementwiseOperation, CElementwiseOperation,
Sequence<M0, N0, I1, I1, M2, I1, M4, I1>, Sequence<M0, N0, I1, I1, M2, I1, M4, I1>,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
...@@ -626,7 +611,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6 ...@@ -626,7 +611,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6
true>{ true>{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(m_thread_data_on_grid_idx[I0], make_multi_index(m_thread_data_on_grid_idx[I0],
n_thread_data_on_grid_idx[I0], n_thread_data_on_grid_idx[I0],
m_thread_data_on_grid_idx[I1], m_thread_data_on_grid_idx[I1],
...@@ -644,9 +628,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6 ...@@ -644,9 +628,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6
c_grid_buf, c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks,
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c0_grid_buf, c0_grid_buf);
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c1_grid_buf);
} }
} }
}; // namespace ck }; // namespace ck
......
...@@ -31,7 +31,6 @@ template <typename SrcData, ...@@ -31,7 +31,6 @@ template <typename SrcData,
typename SrcDesc, typename SrcDesc,
typename DstDesc, typename DstDesc,
typename Dst0Desc, // this is really one of sources, but it has same shape as DstDesc typename Dst0Desc, // this is really one of sources, but it has same shape as DstDesc
typename Dst1Desc, // this is really one of sources, but it has same shape as DstDesc
typename DstElementwiseOperation, typename DstElementwiseOperation,
typename SliceLengths, typename SliceLengths,
typename DimAccessOrder, typename DimAccessOrder,
...@@ -49,21 +48,17 @@ struct ThreadwiseTensorSliceTransfer_v1r5 ...@@ -49,21 +48,17 @@ struct ThreadwiseTensorSliceTransfer_v1r5
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using Dst0Coord = decltype(make_tensor_coordinate(Dst0Desc{}, Index{})); using Dst0Coord = decltype(make_tensor_coordinate(Dst0Desc{}, Index{}));
using Dst1Coord = decltype(make_tensor_coordinate(Dst1Desc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
using Dst0CoordStep = decltype(make_tensor_coordinate_step(Dst0Desc{}, Index{})); using Dst0CoordStep = decltype(make_tensor_coordinate_step(Dst0Desc{}, Index{}));
using Dst1CoordStep = decltype(make_tensor_coordinate_step(Dst1Desc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r5( __device__ constexpr ThreadwiseTensorSliceTransfer_v1r5(
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Dst0Desc& dst0_desc, const Dst0Desc& dst0_desc,
const Dst1Desc& dst1_desc,
const Index& dst_slice_origin_idx, const Index& dst_slice_origin_idx,
const DstElementwiseOperation& dst_element_op) const DstElementwiseOperation& dst_element_op)
: dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)), : dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)),
dst0_coord_(make_tensor_coordinate(dst0_desc, dst_slice_origin_idx)), dst0_coord_(make_tensor_coordinate(dst0_desc, dst_slice_origin_idx)),
dst1_coord_(make_tensor_coordinate(dst1_desc, dst_slice_origin_idx)),
dst_element_op_{dst_element_op} dst_element_op_{dst_element_op}
{ {
static_assert(SrcDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime(),
...@@ -79,10 +74,8 @@ struct ThreadwiseTensorSliceTransfer_v1r5 ...@@ -79,10 +74,8 @@ struct ThreadwiseTensorSliceTransfer_v1r5
typename SrcBuffer, typename SrcBuffer,
typename DstBuffer, typename DstBuffer,
typename Dst0Buffer, typename Dst0Buffer,
typename Dst1Buffer,
typename DstStepHacks, typename DstStepHacks,
typename Dst0StepHacks, typename Dst0StepHacks>
typename Dst1StepHacks>
__device__ void Run(const SrcDesc&, __device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&, const SrcSliceOriginIdx&,
const SrcBuffer& src_buf, const SrcBuffer& src_buf,
...@@ -91,10 +84,7 @@ struct ThreadwiseTensorSliceTransfer_v1r5 ...@@ -91,10 +84,7 @@ struct ThreadwiseTensorSliceTransfer_v1r5
const DstStepHacks& dst_step_hacks, const DstStepHacks& dst_step_hacks,
const Dst0Desc& dst0_desc, const Dst0Desc& dst0_desc,
const Dst0Buffer& dst0_buf, const Dst0Buffer& dst0_buf,
const Dst0StepHacks& dst0_step_hacks, const Dst0StepHacks& dst0_step_hacks)
const Dst1Desc& dst1_desc,
const Dst1Buffer& dst1_buf,
const Dst1StepHacks& dst1_step_hacks)
{ {
static_assert(SrcDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time"); "wrong! SrcDesc need to known at compile-time");
...@@ -156,22 +146,6 @@ struct ThreadwiseTensorSliceTransfer_v1r5 ...@@ -156,22 +146,6 @@ struct ThreadwiseTensorSliceTransfer_v1r5
}, },
Number<nDim>{}); Number<nDim>{});
// make forward steps: dst1
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
// TODO: fix this
const auto dst1_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(
dst1_desc, forward_step_idx, dst1_step_hacks[I0][i]);
},
Number<nDim>{});
// make backward steps: dst // make backward steps: dst
const auto dst_backward_steps = generate_tuple( const auto dst_backward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -202,22 +176,6 @@ struct ThreadwiseTensorSliceTransfer_v1r5 ...@@ -202,22 +176,6 @@ struct ThreadwiseTensorSliceTransfer_v1r5
}, },
Number<nDim>{}); Number<nDim>{});
// make backward steps: dst1
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
// TODO: fix this
const auto dst1_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(
dst1_desc, backward_step_idx, dst1_step_hacks[I1][i]);
},
Number<nDim>{});
// loop over tensor and copy // loop over tensor and copy
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) { static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
// judge move forward or move backward // judge move forward or move backward
...@@ -258,7 +216,7 @@ struct ThreadwiseTensorSliceTransfer_v1r5 ...@@ -258,7 +216,7 @@ struct ThreadwiseTensorSliceTransfer_v1r5
using dst_vector_t = using dst_vector_t =
typename vector_type_maker<DstData, DstScalarPerVector>::type::type; typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
// load dst0 and dst1 and apply elementwise operation // load dst0 and apply elementwise operation
{ {
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1 // WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
// TODO: fix this // TODO: fix this
...@@ -270,29 +228,22 @@ struct ThreadwiseTensorSliceTransfer_v1r5 ...@@ -270,29 +228,22 @@ struct ThreadwiseTensorSliceTransfer_v1r5
const SrcData src_v = src_buf[Number<src_offset>{}]; const SrcData src_v = src_buf[Number<src_offset>{}];
// load dst0 and dst1 // load dst0
const bool is_dst0_valid = const bool is_dst0_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst0_desc, coordinate_has_valid_offset_assuming_visible_index_is_valid(dst0_desc,
dst0_coord_); dst0_coord_);
const bool is_dst1_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst1_desc,
dst1_coord_);
const DstData dst0_v = const DstData dst0_v =
dst0_buf.template Get<DstData>(dst0_coord_.GetOffset(), is_dst0_valid); dst0_buf.template Get<DstData>(dst0_coord_.GetOffset(), is_dst0_valid);
const DstData dst1_v =
dst1_buf.template Get<DstData>(dst1_coord_.GetOffset(), is_dst1_valid);
#if !CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R5_TYPE_CONVERT_ISSUE #if !CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R5_TYPE_CONVERT_ISSUE
// apply element-wise operation in SrcData type // apply element-wise operation in SrcData type
const SrcData dst_v = dst_element_op_( const SrcData dst_v = dst_element_op_(src_v, type_convert<SrcData>(dst0_v));
src_v, type_convert<SrcData>(dst0_v), type_convert<SrcData>(dst1_v));
// apply type convert // apply type convert
dst_vector.template AsType<DstData>()(Number<0>{}) = type_convert<DstData>(dst_v); dst_vector.template AsType<DstData>()(Number<0>{}) = type_convert<DstData>(dst_v);
#else #else
// apply element-wise operation in DstData type // apply element-wise operation in DstData type
const DstData dst_v = dst_element_op_(src_v, dst0_v, dst1_v); const DstData dst_v = dst_element_op_(src_v, dst0_v);
dst_vector.template AsType<DstData>()(Number<0>{}) = dst_v; dst_vector.template AsType<DstData>()(Number<0>{}) = dst_v;
#endif #endif
...@@ -361,10 +312,6 @@ struct ThreadwiseTensorSliceTransfer_v1r5 ...@@ -361,10 +312,6 @@ struct ThreadwiseTensorSliceTransfer_v1r5
// dst0 // dst0
move_tensor_coordinate( move_tensor_coordinate(
dst0_desc, dst0_coord_, dst0_forward_steps[dim_access_order[i]]); dst0_desc, dst0_coord_, dst0_forward_steps[dim_access_order[i]]);
// dst1
move_tensor_coordinate(
dst1_desc, dst1_coord_, dst1_forward_steps[dim_access_order[i]]);
} }
else else
{ {
...@@ -374,10 +321,6 @@ struct ThreadwiseTensorSliceTransfer_v1r5 ...@@ -374,10 +321,6 @@ struct ThreadwiseTensorSliceTransfer_v1r5
// dst0 // dst0
move_tensor_coordinate( move_tensor_coordinate(
dst0_desc, dst0_coord_, dst0_backward_steps[dim_access_order[i]]); dst0_desc, dst0_coord_, dst0_backward_steps[dim_access_order[i]]);
// dst1
move_tensor_coordinate(
dst1_desc, dst1_coord_, dst1_backward_steps[dim_access_order[i]]);
} }
} }
}); });
...@@ -397,7 +340,6 @@ struct ThreadwiseTensorSliceTransfer_v1r5 ...@@ -397,7 +340,6 @@ struct ThreadwiseTensorSliceTransfer_v1r5
typename SrcBuffer, typename SrcBuffer,
typename DstBuffer, typename DstBuffer,
typename Dst0Buffer, typename Dst0Buffer,
typename Dst1Buffer,
typename DstStepHacks> typename DstStepHacks>
__device__ void Run(const SrcDesc&, __device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&, const SrcSliceOriginIdx&,
...@@ -406,9 +348,7 @@ struct ThreadwiseTensorSliceTransfer_v1r5 ...@@ -406,9 +348,7 @@ struct ThreadwiseTensorSliceTransfer_v1r5
DstBuffer& dst_buf, DstBuffer& dst_buf,
const DstStepHacks& dst_step_hacks, const DstStepHacks& dst_step_hacks,
const Dst0Desc& dst0_desc, const Dst0Desc& dst0_desc,
const Dst0Buffer& dst0_buf, const Dst0Buffer& dst0_buf)
const Dst1Desc& dst1_desc,
const Dst1Buffer& dst1_buf)
{ {
auto f_step_hacks = [&](auto desc) { auto f_step_hacks = [&](auto desc) {
constexpr index_t ntransform = decltype(desc)::GetNumOfTransform(); constexpr index_t ntransform = decltype(desc)::GetNumOfTransform();
...@@ -430,10 +370,7 @@ struct ThreadwiseTensorSliceTransfer_v1r5 ...@@ -430,10 +370,7 @@ struct ThreadwiseTensorSliceTransfer_v1r5
dst_step_hacks, dst_step_hacks,
dst0_desc, dst0_desc,
dst0_buf, dst0_buf,
f_step_hacks(dst0_desc), f_step_hacks(dst0_desc));
dst1_desc,
dst1_buf,
f_step_hacks(dst1_desc));
} }
__device__ static constexpr auto GetDstCoordinateResetStep() __device__ static constexpr auto GetDstCoordinateResetStep()
...@@ -514,7 +451,6 @@ struct ThreadwiseTensorSliceTransfer_v1r5 ...@@ -514,7 +451,6 @@ struct ThreadwiseTensorSliceTransfer_v1r5
private: private:
DstCoord dst_coord_; DstCoord dst_coord_;
Dst0Coord dst0_coord_; Dst0Coord dst0_coord_;
Dst1Coord dst1_coord_;
const DstElementwiseOperation dst_element_op_; const DstElementwiseOperation dst_element_op_;
}; // namespace ck }; // namespace ck
......
...@@ -145,7 +145,6 @@ ...@@ -145,7 +145,6 @@
#define CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R5_TYPE_CONVERT_ISSUE 1 #define CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R5_TYPE_CONVERT_ISSUE 1
#endif #endif
namespace ck { namespace ck {
enum InMemoryDataOperationEnum_t enum InMemoryDataOperationEnum_t
......
...@@ -15,7 +15,7 @@ template <ck::index_t... Is> ...@@ -15,7 +15,7 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; using AddRelu = ck::tensor_operation::element_wise::AddRelu;
using device_conv2d_fwd_xdl_bias_relu_nhwc_kyxc_nhwk_f16_instances = std::tuple< using device_conv2d_fwd_xdl_bias_relu_nhwc_kyxc_nhwk_f16_instances = std::tuple<
// clang-format off // clang-format off
...@@ -23,24 +23,24 @@ using device_conv2d_fwd_xdl_bias_relu_nhwc_kyxc_nhwk_f16_instances = std::tuple< ...@@ -23,24 +23,24 @@ using device_conv2d_fwd_xdl_bias_relu_nhwc_kyxc_nhwk_f16_instances = std::tuple<
//################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| //################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//################################################################################| | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | //################################################################################| | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 2, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 2, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>
// clang-format on // clang-format on
>; >;
void add_device_conv2d_fwd_bias_relu_xdl_nhwc_kyxc_nhwk_fp16_instances( void add_device_conv2d_fwd_bias_relu_xdl_nhwc_kyxc_nhwk_fp16_instances(
std::vector<DeviceConvFwdBiasActivationPtr<PassThrough, PassThrough, AddReluAdd>>& std::vector<DeviceConvFwdBiasActivationPtr<PassThrough, PassThrough, AddRelu>>&
instance_container) instance_container)
{ {
using Instances = device_conv2d_fwd_xdl_bias_relu_nhwc_kyxc_nhwk_f16_instances; using Instances = device_conv2d_fwd_xdl_bias_relu_nhwc_kyxc_nhwk_f16_instances;
......
...@@ -17,7 +17,7 @@ namespace tensor_operation { ...@@ -17,7 +17,7 @@ namespace tensor_operation {
namespace device { namespace device {
// out[N, Ho, Wo, K] = // out[N, Ho, Wo, K] =
// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K]) + residual[N, Ho, Wo, K] // activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K])
template <typename InDataType, template <typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
...@@ -209,14 +209,10 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -209,14 +209,10 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
const auto bias_grid_desc_gemmm_gemmn = const auto bias_grid_desc_gemmm_gemmn =
make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1));
// C1: residual tensor: assume same layout as output tensor
const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc;
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc, wei_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc, out_gemmm_gemmn_grid_desc,
bias_grid_desc_gemmm_gemmn, bias_grid_desc_gemmm_gemmn);
resi_grid_desc_gemmm_gemmn);
} }
using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
...@@ -226,7 +222,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -226,7 +222,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>; using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>; using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
using C0GridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I3])>; using C0GridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I3])>;
using C1GridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I4])>;
// TODO remove these hacks // TODO remove these hacks
static constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple( static constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(
...@@ -279,7 +274,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -279,7 +274,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
C0GridDesc_M_N, C0GridDesc_M_N,
C1GridDesc_M_N,
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
...@@ -325,9 +319,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -325,9 +319,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
using C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = using C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(C0GridDesc_M_N{})); decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(C0GridDesc_M_N{}));
using C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(C1GridDesc_M_N{}));
using Block2CTileMap = decltype(GridwiseGemm::MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); using Block2CTileMap = decltype(GridwiseGemm::MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
// Argument // Argument
...@@ -337,7 +328,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -337,7 +328,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
const WeiDataType* p_wei_grid, const WeiDataType* p_wei_grid,
OutDataType* p_out_grid, OutDataType* p_out_grid,
const OutDataType* p_bias_grid, const OutDataType* p_bias_grid,
const OutDataType* p_resi_grid,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
...@@ -357,15 +347,12 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -357,15 +347,12 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
p_b_grid_{p_wei_grid}, p_b_grid_{p_wei_grid},
p_c_grid_{p_out_grid}, p_c_grid_{p_out_grid},
p_c0_grid_{p_bias_grid}, p_c0_grid_{p_bias_grid},
p_c1_grid_{p_resi_grid},
a_grid_desc_k0_m_k1_{}, a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{}, b_grid_desc_k0_n_k1_{},
c_grid_desc_m_n_{}, c_grid_desc_m_n_{},
c0_grid_desc_m_n_{}, c0_grid_desc_m_n_{},
c1_grid_desc_m_n_{},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
block_2_ctile_map_{}, block_2_ctile_map_{},
M01_{M01}, M01_{M01},
N01_{N01}, N01_{N01},
...@@ -389,7 +376,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -389,7 +376,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
b_grid_desc_k0_n_k1_ = descs[I1]; b_grid_desc_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2]; c_grid_desc_m_n_ = descs[I2];
c0_grid_desc_m_n_ = descs[I3]; c0_grid_desc_m_n_ = descs[I3];
c1_grid_desc_m_n_ = descs[I4];
if(GridwiseGemm::CheckValidity( if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
...@@ -400,9 +386,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -400,9 +386,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c0_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c0_grid_desc_m_n_);
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c1_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01); block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
} }
} }
...@@ -412,15 +395,12 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -412,15 +395,12 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
const CDataType* p_c0_grid_; const CDataType* p_c0_grid_;
const CDataType* p_c1_grid_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
C0GridDesc_M_N c0_grid_desc_m_n_; C0GridDesc_M_N c0_grid_desc_m_n_;
C1GridDesc_M_N c1_grid_desc_m_n_;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
Block2CTileMap block_2_ctile_map_; Block2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
...@@ -450,9 +430,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -450,9 +430,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0) std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
...@@ -483,7 +460,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -483,7 +460,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<DeviceOp::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<DeviceOp::C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<DeviceOp::C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<DeviceOp::C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
...@@ -499,12 +475,10 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -499,12 +475,10 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.p_c0_grid_, arg.p_c0_grid_,
arg.p_c1_grid_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.in_element_op_, arg.in_element_op_,
arg.wei_element_op_, arg.wei_element_op_,
arg.out_element_op_, arg.out_element_op_,
...@@ -520,7 +494,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -520,7 +494,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<DeviceOp::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<DeviceOp::C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<DeviceOp::C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<DeviceOp::C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
...@@ -536,12 +509,10 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -536,12 +509,10 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.p_c0_grid_, arg.p_c0_grid_,
arg.p_c1_grid_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.in_element_op_, arg.in_element_op_,
arg.wei_element_op_, arg.wei_element_op_,
arg.out_element_op_, arg.out_element_op_,
...@@ -581,7 +552,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -581,7 +552,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
const WeiDataType* p_wei_grid, const WeiDataType* p_wei_grid,
OutDataType* p_out_grid, OutDataType* p_out_grid,
const OutDataType* p_bias_grid, const OutDataType* p_bias_grid,
const OutDataType* p_resi_grid,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
...@@ -600,7 +570,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -600,7 +570,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
p_wei_grid, p_wei_grid,
p_out_grid, p_out_grid,
p_bias_grid, p_bias_grid,
p_resi_grid,
N, N,
K, K,
C, C,
...@@ -625,7 +594,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -625,7 +594,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
const void* p_wei_grid, const void* p_wei_grid,
void* p_out_grid, void* p_out_grid,
const void* p_bias_grid, const void* p_bias_grid,
const void* p_resi_grid,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
...@@ -644,7 +612,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -644,7 +612,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
static_cast<const WeiDataType*>(p_wei_grid), static_cast<const WeiDataType*>(p_wei_grid),
static_cast<OutDataType*>(p_out_grid), static_cast<OutDataType*>(p_out_grid),
static_cast<const OutDataType*>(p_bias_grid), static_cast<const OutDataType*>(p_bias_grid),
static_cast<const OutDataType*>(p_resi_grid),
N, N,
K, K,
C, C,
......
...@@ -18,7 +18,6 @@ struct DeviceConvFwdBiasActivation : public BaseOperator ...@@ -18,7 +18,6 @@ struct DeviceConvFwdBiasActivation : public BaseOperator
const void* p_wei, const void* p_wei,
void* p_out, void* p_out,
const void* p_bias, const void* p_bias,
const void* p_resi,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
......
...@@ -14,6 +14,34 @@ struct PassThrough ...@@ -14,6 +14,34 @@ struct PassThrough
} }
}; };
struct AddRelu
{
template <typename T1>
__host__ constexpr float operator()(float v0, T1 v1) const
{
float b = v0 + v1;
float c = b > 0 ? b : 0;
return c;
}
template <typename T1>
__device__ constexpr float operator()(float v0, T1 v1) const
{
#if 0
float a = v1 + v0;
float b = max(a, float(0));
return b;
#else
float b = v1 + v0;
float c = b > 0 ? b : 0;
return c;
#endif
}
};
struct AddReluAdd struct AddReluAdd
{ {
template <typename T1, typename T2> template <typename T1, typename T2>
...@@ -44,6 +72,79 @@ struct AddReluAdd ...@@ -44,6 +72,79 @@ struct AddReluAdd
} }
}; };
struct AddLeakyReluAdd
{
template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
float a = v0 + v1;
float b = 0.1 * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
}
template <typename T1, typename T2>
__device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
#if 0
// this use not too many registers, but use fp64 mul
float a = v0 + v1;
float b = 0.1 * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
#elif 0
// this spill register
float a = v0 + v1;
float b = float(0.1) * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
#elif 0
// this use lots of registers (but no spill)
constexpr float alpha = 0.1;
constexpr float alpha_inv = 1.0 / alpha;
float a = v2 * alpha_inv;
float b = v1 + v0;
float c = b > 0 ? b : 0;
float d = alpha * (a + c);
return d;
#elif 1
// this use lots of registers (but no spill), 89 Tflops
constexpr float alpha = 0.1;
constexpr float alpha_inv = 1.0 / alpha;
float a = v2 * alpha_inv;
float b = v1 + v0;
float c = max(b, float(0));
float d = alpha * (a + c);
return d;
#elif 1
// this spill registers, 89 Tflops
float a = v0 + v1;
float alpha = 0.1;
float b;
asm volatile("\n \
v_mul_f32_e32 %0, %1, %2 \n \
"
: "=v"(b)
: "s"(alpha), "v"(a));
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
#endif
}
};
} // namespace element_wise } // namespace element_wise
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
......
...@@ -13,24 +13,7 @@ ...@@ -13,24 +13,7 @@
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_base.hpp" #include "device_base.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
struct PassThrough
{
template <typename T>
__host__ __device__ constexpr T operator()(T v) const
{
return v;
}
};
struct Relu
{
template <typename T>
__host__ __device__ constexpr T operator()(T v) const
{
return v > 0 ? v : 0;
}
};
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -44,9 +27,9 @@ using ALayout = ck::tensor_layout::gemm::RowMajor; ...@@ -44,9 +27,9 @@ using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor; using CLayout = ck::tensor_layout::gemm::RowMajor;
using AOp = PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BOp = PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using COp = Relu; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for NT problem // Compilation parameters for NT problem
// clang-format off // clang-format off
...@@ -55,7 +38,7 @@ using DeviceGemmInstance = ...@@ -55,7 +38,7 @@ using DeviceGemmInstance =
//#########################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| //#########################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//#########################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | //#########################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceGemmXdl< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AOp, BOp, COp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>; ck::tensor_operation::device::DeviceGemmXdl< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>;
// clang-format on // clang-format on
template <typename AType, template <typename AType,
...@@ -189,9 +172,9 @@ int main(int argc, char* argv[]) ...@@ -189,9 +172,9 @@ int main(int argc, char* argv[])
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
AOp{}, AElementOp{},
BOp{}, BElementOp{},
COp{}); CElementOp{});
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
...@@ -217,7 +200,7 @@ int main(int argc, char* argv[]) ...@@ -217,7 +200,7 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
host_verify(a_m_k, b_k_n, c_m_n_host_result, AOp{}, BOp{}, COp{}); host_verify(a_m_k, b_k_n, c_m_n_host_result, AElementOp{}, BElementOp{}, CElementOp{});
check_error(c_m_n_host_result, c_m_n_device_result); check_error(c_m_n_host_result, c_m_n_device_result);
} }
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#include "host_gemm.hpp" #include "host_gemm.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_base.hpp" #include "device_base.hpp"
#include "example/2_gemm_xdl_bias_relu_add/include/device_gemm_xdl_two_extra_source_reduce.hpp" #include "example/3_gemm_xdl_bias_relu_add/include/device_gemm_xdl_two_extra_source_reduce.hpp"
// C[m, n] = Relu(A[m, k] * B[k, n] + C0[m]) + C1[m, n] // C[m, n] = Relu(A[m, k] * B[k, n] + C0[m]) + C1[m, n]
// assume C0 is contiguous in memory // assume C0 is contiguous in memory
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define DEVICE_GEMM_XDL_TWO_EXTRA_SOURCE_REDUCE_HPP #define DEVICE_GEMM_XDL_TWO_EXTRA_SOURCE_REDUCE_HPP
#include <iostream> #include <iostream>
#include <sstream>
#include "device.hpp" #include "device.hpp"
#include "device_base.hpp" #include "device_base.hpp"
#include "device_gemm.hpp" #include "device_gemm.hpp"
...@@ -560,6 +561,23 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator ...@@ -560,6 +561,23 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
{ {
return std::make_unique<Invoker>(Invoker{}); return std::make_unique<Invoker>(Invoker{});
} }
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmXdl_two_extra_source_reduce"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock
<< ">";
// clang-format on
return str.str();
}
}; };
} // namespace device } // namespace device
......
...@@ -12,25 +12,7 @@ ...@@ -12,25 +12,7 @@
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "device_operation/include/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "device_operation/include/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp"
struct PassThrough
{
template <typename T>
__host__ __device__ constexpr T operator()(T v) const
{
return v;
}
};
struct Relu
{
template <typename T>
__host__ __device__ constexpr T operator()(T v) const
{
T tmp = 0.1 * v;
return tmp > 0 ? tmp : 0;
}
};
using InDataType = ck::half_t; using InDataType = ck::half_t;
using WeiDataType = ck::half_t; using WeiDataType = ck::half_t;
...@@ -44,9 +26,9 @@ using InLayout = ck::tensor_layout::convolution::NHWC; ...@@ -44,9 +26,9 @@ using InLayout = ck::tensor_layout::convolution::NHWC;
using WeiLayout = ck::tensor_layout::convolution::KYXC; using WeiLayout = ck::tensor_layout::convolution::KYXC;
using OutLayout = ck::tensor_layout::convolution::NHWK; using OutLayout = ck::tensor_layout::convolution::NHWK;
using InElementOp = PassThrough; using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = Relu; using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
using DeviceConvFwdInstance = using DeviceConvFwdInstance =
ck::tensor_operation::device::DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ck::tensor_operation::device::DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
......
...@@ -28,7 +28,7 @@ using OutLayout = ck::tensor_layout::convolution::NHWK; ...@@ -28,7 +28,7 @@ using OutLayout = ck::tensor_layout::convolution::NHWK;
using InElementOp = ck::tensor_operation::element_wise::PassThrough; using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; using OutElementOp = ck::tensor_operation::element_wise::AddRelu;
// clang-format off // clang-format off
using DeviceConvFwdInstance = ck::tensor_operation::device:: using DeviceConvFwdInstance = ck::tensor_operation::device::
...@@ -50,7 +50,6 @@ void host_reference_calculation(const Tensor<TIn>& in_n_c_hi_wi, ...@@ -50,7 +50,6 @@ void host_reference_calculation(const Tensor<TIn>& in_n_c_hi_wi,
const Tensor<TWei>& wei_k_c_y_x, const Tensor<TWei>& wei_k_c_y_x,
Tensor<TOut>& out_n_k_ho_wo, Tensor<TOut>& out_n_k_ho_wo,
const Tensor<TOut>& bias_k, const Tensor<TOut>& bias_k,
const Tensor<TOut>& resi_n_k_ho_wo,
const std::vector<ck::index_t>& conv_strides, const std::vector<ck::index_t>& conv_strides,
const std::vector<ck::index_t>& conv_dilations, const std::vector<ck::index_t>& conv_dilations,
const std::vector<ck::index_t>& in_left_pads, const std::vector<ck::index_t>& in_left_pads,
...@@ -79,7 +78,7 @@ void host_reference_calculation(const Tensor<TIn>& in_n_c_hi_wi, ...@@ -79,7 +78,7 @@ void host_reference_calculation(const Tensor<TIn>& in_n_c_hi_wi,
} }
} }
out_n_k_ho_wo(n, k, ho, wo) = out_element_op(v, bias_k(k), resi_n_k_ho_wo(n, k, ho, wo)); out_n_k_ho_wo(n, k, ho, wo) = out_element_op(v, bias_k(k));
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(f_nchw,
...@@ -198,14 +197,10 @@ int main(int argc, char* argv[]) ...@@ -198,14 +197,10 @@ int main(int argc, char* argv[])
Tensor<OutDataType> bias_k( Tensor<OutDataType> bias_k(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(K)}))); HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(K)})));
// residual: assume same layout as output tensor
Tensor<OutDataType> resi_n_k_ho_wo(f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{}));
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl;
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl;
std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo_host_result.mDesc << std::endl; std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo_host_result.mDesc << std::endl;
std::cout << "bias_k: " << bias_k.mDesc << std::endl; std::cout << "bias_k: " << bias_k.mDesc << std::endl;
std::cout << "resi_n_k_ho_wo: " << resi_n_k_ho_wo.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -214,13 +209,11 @@ int main(int argc, char* argv[]) ...@@ -214,13 +209,11 @@ int main(int argc, char* argv[])
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5}); wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
bias_k.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5}); bias_k.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
resi_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
break; break;
default: default:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0}); in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5}); wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
bias_k.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0}); bias_k.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
resi_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
} }
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace());
...@@ -228,12 +221,10 @@ int main(int argc, char* argv[]) ...@@ -228,12 +221,10 @@ int main(int argc, char* argv[])
DeviceMem out_device_buf(sizeof(OutDataType) * DeviceMem out_device_buf(sizeof(OutDataType) *
out_n_k_ho_wo_device_result.mDesc.GetElementSpace()); out_n_k_ho_wo_device_result.mDesc.GetElementSpace());
DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpace()); DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpace());
DeviceMem resi_device_buf(sizeof(OutDataType) * resi_n_k_ho_wo.mDesc.GetElementSpace());
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
bias_device_buf.ToDevice(bias_k.mData.data()); bias_device_buf.ToDevice(bias_k.mData.data());
resi_device_buf.ToDevice(resi_n_k_ho_wo.mData.data());
auto conv = DeviceConvFwdInstance{}; auto conv = DeviceConvFwdInstance{};
auto invoker = conv.MakeInvoker(); auto invoker = conv.MakeInvoker();
...@@ -242,7 +233,6 @@ int main(int argc, char* argv[]) ...@@ -242,7 +233,6 @@ int main(int argc, char* argv[])
static_cast<const WeiDataType*>(wei_device_buf.GetDeviceBuffer()), static_cast<const WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<const OutDataType*>(bias_device_buf.GetDeviceBuffer()), static_cast<const OutDataType*>(bias_device_buf.GetDeviceBuffer()),
static_cast<const OutDataType*>(resi_device_buf.GetDeviceBuffer()),
N, N,
K, K,
C, C,
...@@ -270,8 +260,7 @@ int main(int argc, char* argv[]) ...@@ -270,8 +260,7 @@ int main(int argc, char* argv[])
std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) + std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) +
sizeof(WeiDataType) * (K * C * Y * X) + sizeof(WeiDataType) * (K * C * Y * X) +
sizeof(OutDataType) * (N * K * Ho * Wo) + sizeof(OutDataType) * (K) + sizeof(OutDataType) * (N * K * Ho * Wo) + sizeof(OutDataType) * (K);
sizeof(OutDataType) * (N * K * Ho * Wo);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -286,7 +275,6 @@ int main(int argc, char* argv[]) ...@@ -286,7 +275,6 @@ int main(int argc, char* argv[])
wei_k_c_y_x, wei_k_c_y_x,
out_n_k_ho_wo_host_result, out_n_k_ho_wo_host_result,
bias_k, bias_k,
resi_n_k_ho_wo,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
......
...@@ -14,80 +14,6 @@ ...@@ -14,80 +14,6 @@
#include "device_conv2d_fwd_xdl_bias_activation_add_nhwc_kyxc_nhwk.hpp" #include "device_conv2d_fwd_xdl_bias_activation_add_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
struct AddLeakyReluAdd
{
template <typename T1, typename T2>
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
float a = v0 + v1;
float b = 0.1 * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
}
template <typename T1, typename T2>
__device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
#if 0
// this use not too many registers, but use fp64 mul
float a = v0 + v1;
float b = 0.1 * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
#elif 0
// this spill register
float a = v0 + v1;
float b = float(0.1) * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
#elif 0
// this use lots of registers (but no spill)
constexpr float alpha = 0.1;
constexpr float alpha_inv = 1.0 / alpha;
float a = v2 * alpha_inv;
float b = v1 + v0;
float c = b > 0 ? b : 0;
float d = alpha * (a + c);
return d;
#elif 1
// this use lots of registers (but no spill), 89 Tflops
constexpr float alpha = 0.1;
constexpr float alpha_inv = 1.0 / alpha;
float a = v2 * alpha_inv;
float b = v1 + v0;
float c = max(b, float(0));
float d = alpha * (a + c);
return d;
#elif 1
// this spill registers, 89 Tflops
float a = v0 + v1;
float alpha = 0.1;
float b;
asm volatile("\n \
v_mul_f32_e32 %0, %1, %2 \n \
"
: "=v"(b)
: "s"(alpha), "v"(a));
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
#endif
}
};
using InDataType = ck::half_t; using InDataType = ck::half_t;
using WeiDataType = ck::half_t; using WeiDataType = ck::half_t;
using OutDataType = ck::half_t; using OutDataType = ck::half_t;
......
...@@ -44,6 +44,17 @@ target_compile_features(device_conv2d_fwd_instance PUBLIC) ...@@ -44,6 +44,17 @@ target_compile_features(device_conv2d_fwd_instance PUBLIC)
set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
install(TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib) install(TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib)
# device_conv2d_fwd_bias_relu_instance
set(DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE
${PROJECT_SOURCE_DIR}/device_operation/device_conv2d_fwd_xdl_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp;
)
add_library(device_conv2d_fwd_bias_relu_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE})
target_include_directories(device_conv2d_fwd_bias_relu_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
target_compile_features(device_conv2d_fwd_bias_relu_instance PUBLIC)
set_target_properties(device_conv2d_fwd_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
install(TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib)
# device_conv2d_fwd_bias_relu_add_instance # device_conv2d_fwd_bias_relu_add_instance
set(DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE set(DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE
${PROJECT_SOURCE_DIR}/device_operation/device_conv2d_fwd_xdl_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/device_conv2d_fwd_xdl_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp;
...@@ -56,10 +67,11 @@ set_target_properties(device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITI ...@@ -56,10 +67,11 @@ set_target_properties(device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITI
install(TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib) install(TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib)
# ck_profiler # ck_profiler
set(PROFILER_SOURCE profiler.cpp profile_gemm.cpp profile_conv_fwd.cpp profile_conv_fwd_bias_relu_add.cpp) set(PROFILER_SOURCE profiler.cpp profile_gemm.cpp profile_conv_fwd.cpp profile_conv_fwd_bias_relu.cpp profile_conv_fwd_bias_relu_add.cpp)
add_executable(ckProfiler ${PROFILER_SOURCE}) add_executable(ckProfiler ${PROFILER_SOURCE})
target_link_libraries(ckProfiler PRIVATE host_tensor) target_link_libraries(ckProfiler PRIVATE host_tensor)
target_link_libraries(ckProfiler PRIVATE device_gemm_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance)
#pragma once
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_conv.hpp"
#include "tensor_layout.hpp"
#include "device_tensor.hpp"
#include "device_conv_fwd_bias_activation.hpp"
#include "element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_conv2d_fwd_bias_activation_instance {
using DeviceConvFwdBiasReluPtr =
DeviceConvFwdBiasActivationPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::AddRelu>;
void add_device_conv2d_fwd_bias_relu_xdl_nhwc_kyxc_nhwk_fp16_instances(
std::vector<DeviceConvFwdBiasReluPtr>&);
} // namespace device_conv2d_fwd_bias_activation_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
namespace ck {
namespace profiler {
template <typename TIn,
typename TWei,
typename TOut,
typename InElementOp,
typename WeiElementOp,
typename OutElementOp>
void host_reference_calculation(const Tensor<TIn>& in_n_c_hi_wi,
const Tensor<TWei>& wei_k_c_y_x,
Tensor<TOut>& out_n_k_ho_wo,
const Tensor<TOut>& bias_k,
const std::vector<ck::index_t>& conv_strides,
const std::vector<ck::index_t>& conv_dilations,
const std::vector<ck::index_t>& in_left_pads,
const std::vector<ck::index_t>& /* in_right_pads */,
const InElementOp& in_element_op,
const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op)
{
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
double v = 0;
for(int c = 0; c < wei_k_c_y_x.mDesc.GetLengths()[1]; ++c)
{
for(int y = 0; y < wei_k_c_y_x.mDesc.GetLengths()[2]; ++y)
{
int hi = ho * conv_strides[0] + y * conv_dilations[0] - in_left_pads[0];
for(int x = 0; x < wei_k_c_y_x.mDesc.GetLengths()[3]; ++x)
{
int wi = wo * conv_strides[1] + x * conv_dilations[1] - in_left_pads[1];
if(hi >= 0 && hi < in_n_c_hi_wi.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in_n_c_hi_wi.mDesc.GetLengths()[3])
{
v += in_element_op(static_cast<const double>(in_n_c_hi_wi(n, c, hi, wi))) *
wei_element_op(static_cast<const double>(wei_k_c_y_x(k, c, y, x)));
}
}
}
}
out_n_k_ho_wo(n, k, ho, wo) = out_element_op(v, bias_k(k));
};
make_ParallelTensorFunctor(f_nchw,
out_n_k_ho_wo.mDesc.GetLengths()[0],
out_n_k_ho_wo.mDesc.GetLengths()[1],
out_n_k_ho_wo.mDesc.GetLengths()[2],
out_n_k_ho_wo.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
}
template <int NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout>
void profile_conv_fwd_bias_relu_impl(int do_verification,
int init_method,
bool do_log,
int nrepeat,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
{
const ck::index_t Y = filter_spatial_lengths[0];
const ck::index_t X = filter_spatial_lengths[1];
const ck::index_t Hi = input_spatial_lengths[0];
const ck::index_t Wi = input_spatial_lengths[1];
const ck::index_t Ho = output_spatial_lengths[0];
const ck::index_t Wo = output_spatial_lengths[1];
auto f_host_tensor_descriptor =
[](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W, auto layout) {
if constexpr(is_same<decltype(layout), ck::tensor_layout::convolution::NCHW>::value ||
is_same<decltype(layout), ck::tensor_layout::convolution::KCYX>::value ||
is_same<decltype(layout), ck::tensor_layout::convolution::NKHW>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({N_, C_, H, W}),
std::vector<std::size_t>({C_ * H * W, H * W, W, 1}));
}
else if constexpr(is_same<decltype(layout), tensor_layout::convolution::NHWC>::value ||
is_same<decltype(layout), tensor_layout::convolution::KYXC>::value ||
is_same<decltype(layout), tensor_layout::convolution::NHWK>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({N_, C_, H, W}),
std::vector<std::size_t>({C_ * H * W, 1, W * C_, C_}));
}
};
Tensor<InDataType> in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{}));
Tensor<WeiDataType> wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{}));
Tensor<OutDataType> out_n_k_ho_wo_host_result(
f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{}));
Tensor<OutDataType> out_n_k_ho_wo_device_result(
f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{}));
// bias: assume contiguous 1d vector
Tensor<OutDataType> bias_k(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(K)})));
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl;
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl;
std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo_host_result.mDesc << std::endl;
std::cout << "bias_k: " << bias_k.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
bias_k.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
break;
default:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
bias_k.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
}
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::AddRelu;
if(do_verification)
{
host_reference_calculation(in_n_c_hi_wi,
wei_k_c_y_x,
out_n_k_ho_wo_host_result,
bias_k,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
InElementOp{},
WeiElementOp{},
OutElementOp{});
}
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) *
out_n_k_ho_wo_device_result.mDesc.GetElementSpace());
DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpace());
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
bias_device_buf.ToDevice(bias_k.mData.data());
using DeviceConvFwdBiasReluPtr = ck::tensor_operation::device::
DeviceConvFwdBiasActivationPtr<InElementOp, WeiElementOp, OutElementOp>;
// add device operator instances
std::vector<DeviceConvFwdBiasReluPtr> op_ptrs;
if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ck::half_t> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>)
{
ck::tensor_operation::device::device_conv2d_fwd_bias_activation_instance::
add_device_conv2d_fwd_bias_relu_xdl_nhwc_kyxc_nhwk_fp16_instances(op_ptrs);
}
if(op_ptrs.size() <= 0)
{
throw std::runtime_error("wrong! no device Conv instance found");
}
std::string best_conv_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device Conv instances
for(auto& op_ptr : op_ptrs)
{
auto argument_ptr = op_ptr->MakeArgumentPointer(
static_cast<const InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<const WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<const OutDataType*>(bias_device_buf.GetDeviceBuffer()),
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
InElementOp{},
WeiElementOp{},
OutElementOp{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
std::string conv_name = op_ptr->GetTypeString();
float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat);
std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X;
std::size_t num_btype =
sizeof(InDataType) * (N * C * Hi * Wi) + sizeof(WeiDataType) * (K * C * Y * X) +
sizeof(OutDataType) * (N * K * Ho * Wo) + sizeof(OutDataType) * (K);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << conv_name << std::endl;
if(tflops > best_tflops)
{
best_conv_name = conv_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data());
check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result);
if(do_log)
{
LogRangeAsType<float>(std::cout << "in : ", in_n_c_hi_wi.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "wei: ", wei_k_c_y_x.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "out_host : ", out_n_k_ho_wo_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "out_device: ", out_n_k_ho_wo_device_result.mData, ",")
<< std::endl;
}
}
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_conv_name << std::endl;
}
} // namespace profiler
} // namespace ck
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "profile_conv_fwd_bias_relu_impl.hpp"
enum ConvDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
};
enum ConvInputLayout
{
NCHW, // 0
NHWC, // 1
};
enum ConvWeightLayout
{
KCYX, // 0
KYXC, // 1
};
enum ConvOutputLayout
{
NKHW, // 0
NHWK, // 1
};
int profile_conv_fwd_bias_relu(int argc, char* argv[])
{
if(argc != 25)
{
printf("arg1: tensor operation (conv_fwd_bias_relu: ForwardConvolution+Bias+ReLu)\n");
printf("arg2: data type (0: fp32; 1: fp16)\n");
printf("arg3: input tensor layout (0: NCHW; 1: NHWC)\n");
printf("arg4: weight tensor layout (0: KCYX; 1: KYXC)\n");
printf("arg5: output tensor layout (0: NKHW; 1: NHWK)\n");
printf("arg6: verification (0: no; 1: yes)\n");
printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg8: print tensor value (0: no; 1: yes)\n");
printf("arg9: run kernel # of times (>1)\n");
printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx\n");
exit(1);
}
const int data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const int in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3]));
const int wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4]));
const int out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5]));
const bool do_verification = std::stoi(argv[6]);
const int init_method = std::stoi(argv[7]);
const bool do_log = std::stoi(argv[8]);
const int nrepeat = std::stoi(argv[9]);
const ck::index_t N = std::stoi(argv[10]);
const ck::index_t K = std::stoi(argv[11]);
const ck::index_t C = std::stoi(argv[12]);
const ck::index_t Y = std::stoi(argv[13]);
const ck::index_t X = std::stoi(argv[14]);
const ck::index_t Hi = std::stoi(argv[15]);
const ck::index_t Wi = std::stoi(argv[16]);
const ck::index_t conv_stride_h = std::stoi(argv[17]);
const ck::index_t conv_stride_w = std::stoi(argv[18]);
const ck::index_t conv_dilation_h = std::stoi(argv[19]);
const ck::index_t conv_dilation_w = std::stoi(argv[20]);
const ck::index_t in_left_pad_h = std::stoi(argv[21]);
const ck::index_t in_left_pad_w = std::stoi(argv[22]);
const ck::index_t in_right_pad_h = std::stoi(argv[23]);
const ck::index_t in_right_pad_w = std::stoi(argv[24]);
const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1;
const ck::index_t XEff = (X - 1) * conv_dilation_w + 1;
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC &&
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
{
ck::profiler::profile_conv_fwd_bias_relu_impl<2,
ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK>(
do_verification,
init_method,
do_log,
nrepeat,
N,
K,
C,
std::vector<ck::index_t>{Hi, Wi},
std::vector<ck::index_t>{Y, X},
std::vector<ck::index_t>{Ho, Wo},
std::vector<ck::index_t>{conv_stride_h, conv_stride_w},
std::vector<ck::index_t>{conv_dilation_h, conv_dilation_w},
std::vector<ck::index_t>{in_left_pad_h, in_left_pad_w},
std::vector<ck::index_t>{in_right_pad_h, in_right_pad_w});
}
else
{
throw std::runtime_error("wrong! data_type & layout for this operator is not implemented");
}
return 1;
}
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
int profile_gemm(int, char*[]); int profile_gemm(int, char*[]);
int profile_conv_fwd(int, char*[]); int profile_conv_fwd(int, char*[]);
int profile_conv_fwd_bias_relu(int, char*[]);
int profile_conv_fwd_bias_relu_add(int, char*[]); int profile_conv_fwd_bias_relu_add(int, char*[]);
int main(int argc, char* argv[]) int main(int argc, char* argv[])
...@@ -19,6 +20,10 @@ int main(int argc, char* argv[]) ...@@ -19,6 +20,10 @@ int main(int argc, char* argv[])
{ {
return profile_conv_fwd(argc, argv); return profile_conv_fwd(argc, argv);
} }
else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0)
{
return profile_conv_fwd_bias_relu(argc, argv);
}
else if(strcmp(argv[1], "conv_fwd_bias_relu_add") == 0) else if(strcmp(argv[1], "conv_fwd_bias_relu_add") == 0)
{ {
return profile_conv_fwd_bias_relu_add(argc, argv); return profile_conv_fwd_bias_relu_add(argc, argv);
...@@ -28,6 +33,7 @@ int main(int argc, char* argv[]) ...@@ -28,6 +33,7 @@ int main(int argc, char* argv[])
printf( printf(
"arg1: tensor operation (gemm: GEMM;\n" "arg1: tensor operation (gemm: GEMM;\n"
" conv_fwd: ForwardConvolution;\n" " conv_fwd: ForwardConvolution;\n"
" conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU)\n"
" conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add)\n"); " conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add)\n");
return 0; return 0;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment