Commit 4f2c8bce authored by Chao Liu's avatar Chao Liu
Browse files

adding bias add

parent 165e30cd
...@@ -614,6 +614,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 ...@@ -614,6 +614,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
FloatC, FloatC,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2),
CElementwiseOperation, CElementwiseOperation,
Sequence<M0, N0, I1, I1, M2, I1, M4, I1>, Sequence<M0, N0, I1, I1, M2, I1, M4, I1>,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
...@@ -623,6 +625,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 ...@@ -623,6 +625,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
1, 1,
true>{ true>{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(m_thread_data_on_grid_idx[I0], make_multi_index(m_thread_data_on_grid_idx[I0],
n_thread_data_on_grid_idx[I0], n_thread_data_on_grid_idx[I0],
m_thread_data_on_grid_idx[I1], m_thread_data_on_grid_idx[I1],
...@@ -638,7 +642,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 ...@@ -638,7 +642,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
c_thread_buf, c_thread_buf,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_grid_buf, c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks,
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c0_grid_buf,
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c1_grid_buf);
} }
} }
}; // namespace ck }; // namespace ck
......
...@@ -15,19 +15,23 @@ namespace ck { ...@@ -15,19 +15,23 @@ namespace ck {
// tensor coordinate instead // tensor coordinate instead
// 3. Don't use a pointer to VGPR buffer, use vector instead // 3. Don't use a pointer to VGPR buffer, use vector instead
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
// TODO: fix this
// Assume: // Assume:
// 1. src: // 1. src:
// 1. Src0Desc is known at compile-time // 1. SrcDesc is known at compile-time
// 2. Src0Buffer is StaticBuffer // 2. SrcBuffer is StaticBuffer
// 3. SrcSliceOrginIdx is known at compile-time // 3. SrcSliceOrginIdx is known at compile-time
// 2. dst: // 2. dst:
// 1. DstDesc is not known at compile-time // 1. DstDesc is not known at compile-time
// 2. DstBuffer is DynamicBuffer // 2. DstBuffer is DynamicBuffer
// 3. DstSliceOrginIdx is not known at compile time // 3. DstSliceOrginIdx is not known at compile time
template <typename Src0Data, template <typename SrcData,
typename DstData, typename DstData,
typename Src0Desc, typename SrcDesc,
typename DstDesc, typename DstDesc,
typename Dst0Desc, // this is really one of sources, but it has same shape as DstDesc
typename Dst1Desc, // this is really one of sources, but it has same shape as DstDesc
typename SrcElementwiseOperation, typename SrcElementwiseOperation,
typename SliceLengths, typename SliceLengths,
typename DimAccessOrder, typename DimAccessOrder,
...@@ -36,26 +40,34 @@ template <typename Src0Data, ...@@ -36,26 +40,34 @@ template <typename Src0Data,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum_t DstInMemOp,
index_t DstScalarStrideInVector, index_t DstScalarStrideInVector,
bool DstResetCoordinateAfterRun, bool DstResetCoordinateAfterRun,
typename enable_if<Src0Desc::IsKnownAtCompileTime(), bool>::type = false> typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v1r4 struct ThreadwiseTensorSliceTransfer_v1r4
{ {
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using Dst0Coord = decltype(make_tensor_coordinate(Dst0Desc{}, Index{}));
using Dst1Coord = decltype(make_tensor_coordinate(Dst1Desc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
using Dst0CoordStep = decltype(make_tensor_coordinate_step(Dst0Desc{}, Index{}));
using Dst1CoordStep = decltype(make_tensor_coordinate_step(Dst1Desc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r4( __device__ constexpr ThreadwiseTensorSliceTransfer_v1r4(
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Dst0Desc& dst0_desc,
const Dst1Desc& dst1_desc,
const Index& dst_slice_origin_idx, const Index& dst_slice_origin_idx,
const SrcElementwiseOperation src_element_op) const SrcElementwiseOperation src_element_op)
: dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)), : dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)),
dst0_coord_(make_tensor_coordinate(dst0_desc, dst_slice_origin_idx)),
dst1_coord_(make_tensor_coordinate(dst1_desc, dst_slice_origin_idx)),
src_element_op_{src_element_op} src_element_op_{src_element_op}
{ {
static_assert(Src0Desc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! Src0Desc need to known at compile-time"); "wrong! SrcDesc need to known at compile-time");
} }
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
...@@ -64,26 +76,36 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -64,26 +76,36 @@ struct ThreadwiseTensorSliceTransfer_v1r4
} }
template <typename SrcSliceOriginIdx, template <typename SrcSliceOriginIdx,
typename Src0Buffer, typename SrcBuffer,
typename DstBuffer, typename DstBuffer,
typename DstStepHacks> typename Dst0Buffer,
__device__ void Run(const Src0Desc&, typename Dst1Buffer,
typename DstStepHacks,
typename Dst0StepHacks,
typename Dst1StepHacks>
__device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&, const SrcSliceOriginIdx&,
const Src0Buffer& src0_buf, const SrcBuffer& src_buf,
const DstDesc& dst_desc, const DstDesc& dst_desc,
DstBuffer& dst_buf, DstBuffer& dst_buf,
const DstStepHacks& dst_step_hacks) const DstStepHacks& dst_step_hacks,
const Dst0Desc& dst0_desc,
const Dst0Buffer& dst0_buf,
const Dst0StepHacks& dst0_step_hacks,
const Dst1Desc& dst1_desc,
const Dst1Buffer& dst1_buf,
const Dst1StepHacks& dst1_step_hacks)
{ {
static_assert(Src0Desc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! Src0Desc need to known at compile-time"); "wrong! SrcDesc need to known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value, static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
"wrong! SrcSliceOrigin need to known at compile-time"); "wrong! SrcSliceOrigin need to known at compile-time");
static_assert(Src0Buffer::IsStaticBuffer(), "wrong! Src0Buffer need to be StaticBuffer"); static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
// Src0Desc and src_slice_origin_idx are known at compile-time // SrcDesc and src_slice_origin_idx are known at compile-time
constexpr auto src_desc = remove_cvref_t<Src0Desc>{}; constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -104,7 +126,7 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -104,7 +126,7 @@ struct ThreadwiseTensorSliceTransfer_v1r4
constexpr auto ordered_access_lengths = constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order); container_reorder_given_new2old(access_lengths, dim_access_order);
// make forward steps // make forward steps: dst
const auto dst_forward_steps = generate_tuple( const auto dst_forward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
Index forward_step_idx; Index forward_step_idx;
...@@ -118,7 +140,39 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -118,7 +140,39 @@ struct ThreadwiseTensorSliceTransfer_v1r4
}, },
Number<nDim>{}); Number<nDim>{});
// make backward steps // make forward steps: dst0
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
// TODO: fix this
const auto dst0_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(
dst0_desc, forward_step_idx, dst0_step_hacks[I0][i]);
},
Number<nDim>{});
// make forward steps: dst1
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
// TODO: fix this
const auto dst1_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(
dst1_desc, forward_step_idx, dst1_step_hacks[I0][i]);
},
Number<nDim>{});
// make backward steps: dst
const auto dst_backward_steps = generate_tuple( const auto dst_backward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
Index backward_step_idx; Index backward_step_idx;
...@@ -132,6 +186,38 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -132,6 +186,38 @@ struct ThreadwiseTensorSliceTransfer_v1r4
}, },
Number<nDim>{}); Number<nDim>{});
// make backward steps: dst0
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
// TODO: fix this
const auto dst0_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(
dst0_desc, backward_step_idx, dst0_step_hacks[I1][i]);
},
Number<nDim>{});
// make backward steps: dst1
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
// TODO: fix this
const auto dst1_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(
dst1_desc, backward_step_idx, dst1_step_hacks[I1][i]);
},
Number<nDim>{});
// loop over tensor and copy // loop over tensor and copy
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) { static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
// judge move forward or move backward // judge move forward or move backward
...@@ -172,14 +258,14 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -172,14 +258,14 @@ struct ThreadwiseTensorSliceTransfer_v1r4
using dst_vector_t = using dst_vector_t =
typename vector_type_maker<DstData, DstScalarPerVector>::type::type; typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
// copy data from src0_buf into dst_vector // copy data from src_buf into dst_vector
static_for<0, DstScalarPerVector, 1>{}([&](auto i) { static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector); src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector);
// apply element-wise operation and type convert // apply element-wise operation and type convert
dst_vector.template AsType<DstData>()(i) = dst_vector.template AsType<DstData>()(i) =
type_convert<DstData>(src_element_op_(src0_buf[Number<src_offset>{}])); type_convert<DstData>(src_element_op_(src_buf[Number<src_offset>{}]));
}); });
const bool is_dst_valid = const bool is_dst_valid =
...@@ -261,22 +347,47 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -261,22 +347,47 @@ struct ThreadwiseTensorSliceTransfer_v1r4
} }
} }
template <typename SrcSliceOriginIdx, typename Src0Buffer, typename DstBuffer> template <typename SrcSliceOriginIdx,
__device__ void Run(const Src0Desc&, typename SrcBuffer,
typename DstBuffer,
typename Dst0Buffer,
typename Dst1Buffer,
typename DstStepHacks>
__device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&, const SrcSliceOriginIdx&,
const Src0Buffer& src0_buf, const SrcBuffer& src_buf,
const DstDesc& dst_desc, const DstDesc& dst_desc,
DstBuffer& dst_buf) DstBuffer& dst_buf,
const DstStepHacks& dst_step_hacks,
const Dst0Desc& dst0_desc,
const Dst0Buffer& dst0_buf,
const Dst1Desc& dst1_desc,
const Dst1Buffer& dst1_buf)
{ {
constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform(); auto f_step_hacks = [&](auto desc) {
constexpr index_t ntransform = decltype(desc)::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
constexpr auto zeros = typename uniform_sequence_gen<ntransform, 0>::type{};
constexpr auto dst_step_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}), constexpr auto step_hacks =
generate_tuple([&](auto) { return zeros; }, Number<nDim>{})); make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
Run(Src0Desc{}, SrcSliceOriginIdx{}, src0_buf, dst_desc, dst_buf, dst_step_hacks);
return step_hacks;
};
Run(SrcDesc{},
SrcSliceOriginIdx{},
src_buf,
dst_desc,
dst_buf,
dst_step_hacks,
dst0_desc,
dst0_buf,
f_step_hacks(dst0_desc),
dst1_desc,
dst1_buf,
f_step_hacks(dst1_desc));
} }
__device__ static constexpr auto GetDstCoordinateResetStep() __device__ static constexpr auto GetDstCoordinateResetStep()
...@@ -356,6 +467,8 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -356,6 +467,8 @@ struct ThreadwiseTensorSliceTransfer_v1r4
private: private:
DstCoord dst_coord_; DstCoord dst_coord_;
Dst0Coord dst0_coord_;
Dst1Coord dst1_coord_;
SrcElementwiseOperation src_element_op_; SrcElementwiseOperation src_element_op_;
}; // namespace ck }; // namespace ck
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment