Commit 4f2c8bce authored by Chao Liu's avatar Chao Liu
Browse files

adding bias add

parent 165e30cd
......@@ -614,6 +614,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
FloatC,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2),
CElementwiseOperation,
Sequence<M0, N0, I1, I1, M2, I1, M4, I1>,
CThreadTransferSrcDstAccessOrder,
......@@ -623,6 +625,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
1,
true>{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(m_thread_data_on_grid_idx[I0],
n_thread_data_on_grid_idx[I0],
m_thread_data_on_grid_idx[I1],
......@@ -638,7 +642,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
c_thread_buf,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks,
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c0_grid_buf,
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c1_grid_buf);
}
}
}; // namespace ck
......
......@@ -15,19 +15,23 @@ namespace ck {
// tensor coordinate instead
// 3. Don't use a pointer to VGPR buffer, use vector instead
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
// TODO: fix this
// Assume:
// 1. src:
// 1. Src0Desc is known at compile-time
// 2. Src0Buffer is StaticBuffer
// 1. SrcDesc is known at compile-time
// 2. SrcBuffer is StaticBuffer
// 3. SrcSliceOrginIdx is known at compile-time
// 2. dst:
// 1. DstDesc is not known at compile-time
// 2. DstBuffer is DynamicBuffer
// 3. DstSliceOrginIdx is not known at compile time
template <typename Src0Data,
template <typename SrcData,
typename DstData,
typename Src0Desc,
typename SrcDesc,
typename DstDesc,
typename Dst0Desc, // this is really one of sources, but it has same shape as DstDesc
typename Dst1Desc, // this is really one of sources, but it has same shape as DstDesc
typename SrcElementwiseOperation,
typename SliceLengths,
typename DimAccessOrder,
......@@ -36,7 +40,7 @@ template <typename Src0Data,
InMemoryDataOperationEnum_t DstInMemOp,
index_t DstScalarStrideInVector,
bool DstResetCoordinateAfterRun,
typename enable_if<Src0Desc::IsKnownAtCompileTime(), bool>::type = false>
typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v1r4
{
static constexpr index_t nDim = SliceLengths::Size();
......@@ -44,18 +48,26 @@ struct ThreadwiseTensorSliceTransfer_v1r4
using Index = MultiIndex<nDim>;
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using Dst0Coord = decltype(make_tensor_coordinate(Dst0Desc{}, Index{}));
using Dst1Coord = decltype(make_tensor_coordinate(Dst1Desc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
using Dst0CoordStep = decltype(make_tensor_coordinate_step(Dst0Desc{}, Index{}));
using Dst1CoordStep = decltype(make_tensor_coordinate_step(Dst1Desc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r4(
const DstDesc& dst_desc,
const Dst0Desc& dst0_desc,
const Dst1Desc& dst1_desc,
const Index& dst_slice_origin_idx,
const SrcElementwiseOperation src_element_op)
: dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)),
dst0_coord_(make_tensor_coordinate(dst0_desc, dst_slice_origin_idx)),
dst1_coord_(make_tensor_coordinate(dst1_desc, dst_slice_origin_idx)),
src_element_op_{src_element_op}
{
static_assert(Src0Desc::IsKnownAtCompileTime(),
"wrong! Src0Desc need to known at compile-time");
static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
}
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
......@@ -64,26 +76,36 @@ struct ThreadwiseTensorSliceTransfer_v1r4
}
template <typename SrcSliceOriginIdx,
typename Src0Buffer,
typename SrcBuffer,
typename DstBuffer,
typename DstStepHacks>
__device__ void Run(const Src0Desc&,
typename Dst0Buffer,
typename Dst1Buffer,
typename DstStepHacks,
typename Dst0StepHacks,
typename Dst1StepHacks>
__device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&,
const Src0Buffer& src0_buf,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf,
const DstStepHacks& dst_step_hacks)
const DstStepHacks& dst_step_hacks,
const Dst0Desc& dst0_desc,
const Dst0Buffer& dst0_buf,
const Dst0StepHacks& dst0_step_hacks,
const Dst1Desc& dst1_desc,
const Dst1Buffer& dst1_buf,
const Dst1StepHacks& dst1_step_hacks)
{
static_assert(Src0Desc::IsKnownAtCompileTime(),
"wrong! Src0Desc need to known at compile-time");
static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
"wrong! SrcSliceOrigin need to known at compile-time");
static_assert(Src0Buffer::IsStaticBuffer(), "wrong! Src0Buffer need to be StaticBuffer");
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
// Src0Desc and src_slice_origin_idx are known at compile-time
constexpr auto src_desc = remove_cvref_t<Src0Desc>{};
// SrcDesc and src_slice_origin_idx are known at compile-time
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
constexpr auto I0 = Number<0>{};
......@@ -104,7 +126,7 @@ struct ThreadwiseTensorSliceTransfer_v1r4
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
// make forward steps
// make forward steps: dst
const auto dst_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
......@@ -118,7 +140,39 @@ struct ThreadwiseTensorSliceTransfer_v1r4
},
Number<nDim>{});
// make backward steps
// make forward steps: dst0
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
// TODO: fix this
const auto dst0_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(
dst0_desc, forward_step_idx, dst0_step_hacks[I0][i]);
},
Number<nDim>{});
// make forward steps: dst1
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
// TODO: fix this
const auto dst1_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(
dst1_desc, forward_step_idx, dst1_step_hacks[I0][i]);
},
Number<nDim>{});
// make backward steps: dst
const auto dst_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
......@@ -132,6 +186,38 @@ struct ThreadwiseTensorSliceTransfer_v1r4
},
Number<nDim>{});
// make backward steps: dst0
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
// TODO: fix this
const auto dst0_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(
dst0_desc, backward_step_idx, dst0_step_hacks[I1][i]);
},
Number<nDim>{});
// make backward steps: dst1
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
// TODO: fix this
const auto dst1_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(
dst1_desc, backward_step_idx, dst1_step_hacks[I1][i]);
},
Number<nDim>{});
// loop over tensor and copy
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
// judge move forward or move backward
......@@ -172,14 +258,14 @@ struct ThreadwiseTensorSliceTransfer_v1r4
using dst_vector_t =
typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
// copy data from src0_buf into dst_vector
// copy data from src_buf into dst_vector
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector);
// apply element-wise operation and type convert
dst_vector.template AsType<DstData>()(i) =
type_convert<DstData>(src_element_op_(src0_buf[Number<src_offset>{}]));
type_convert<DstData>(src_element_op_(src_buf[Number<src_offset>{}]));
});
const bool is_dst_valid =
......@@ -261,22 +347,47 @@ struct ThreadwiseTensorSliceTransfer_v1r4
}
}
template <typename SrcSliceOriginIdx, typename Src0Buffer, typename DstBuffer>
__device__ void Run(const Src0Desc&,
template <typename SrcSliceOriginIdx,
typename SrcBuffer,
typename DstBuffer,
typename Dst0Buffer,
typename Dst1Buffer,
typename DstStepHacks>
__device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&,
const Src0Buffer& src0_buf,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
DstBuffer& dst_buf,
const DstStepHacks& dst_step_hacks,
const Dst0Desc& dst0_desc,
const Dst0Buffer& dst0_buf,
const Dst1Desc& dst1_desc,
const Dst1Buffer& dst1_buf)
{
constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform();
auto f_step_hacks = [&](auto desc) {
constexpr index_t ntransform = decltype(desc)::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
constexpr auto zeros = typename uniform_sequence_gen<ntransform, 0>::type{};
constexpr auto dst_step_hacks =
constexpr auto step_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
Run(Src0Desc{}, SrcSliceOriginIdx{}, src0_buf, dst_desc, dst_buf, dst_step_hacks);
return step_hacks;
};
Run(SrcDesc{},
SrcSliceOriginIdx{},
src_buf,
dst_desc,
dst_buf,
dst_step_hacks,
dst0_desc,
dst0_buf,
f_step_hacks(dst0_desc),
dst1_desc,
dst1_buf,
f_step_hacks(dst1_desc));
}
__device__ static constexpr auto GetDstCoordinateResetStep()
......@@ -356,6 +467,8 @@ struct ThreadwiseTensorSliceTransfer_v1r4
private:
DstCoord dst_coord_;
Dst0Coord dst0_coord_;
Dst1Coord dst1_coord_;
SrcElementwiseOperation src_element_op_;
}; // namespace ck
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment