Unverified Commit 0e92deb7 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Tile program init bulk PR (#4)



Tile Program init bulk PR

---------
Co-authored-by: default avatarzjing14 <zhangjing14@gmail.com>
Co-authored-by: default avatarPo-Yen, Chen <PoYen.Chen@amd.com>
parent 0077eeb3
...@@ -178,7 +178,7 @@ struct ThreadwiseTensorSliceTransfer_v5r1 ...@@ -178,7 +178,7 @@ struct ThreadwiseTensorSliceTransfer_v5r1
using src_vector_t = typename decltype(src_vector)::type; using src_vector_t = typename decltype(src_vector)::type;
const bool is_src_valid = const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); coordinate_has_valid_offset_assuming_top_index_is_valid(src_desc, src_coord_);
// copy data from src_buf to src_vector // copy data from src_buf to src_vector
src_vector.template AsType<src_vector_t>()(I0) = src_vector.template AsType<src_vector_t>()(I0) =
...@@ -361,7 +361,7 @@ struct ThreadwiseTensorSliceTransfer_v5r1 ...@@ -361,7 +361,7 @@ struct ThreadwiseTensorSliceTransfer_v5r1
// copy data from dst_vector to dst_buf // copy data from dst_vector to dst_buf
const bool is_dst_valid = const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); coordinate_has_valid_offset_assuming_top_index_is_valid(dst_desc, dst_coord_);
dst_buf.template Set<dst_vector_t>( dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset(), dst_coord_.GetOffset(),
......
...@@ -94,7 +94,7 @@ struct ThreadwiseTensorSliceTransfer_v6r1 ...@@ -94,7 +94,7 @@ struct ThreadwiseTensorSliceTransfer_v6r1
using dst_vector_t = typename dst_vector_type::type; using dst_vector_t = typename dst_vector_type::type;
const bool is_src_valid = const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); coordinate_has_valid_offset_assuming_top_index_is_valid(src_desc, src_coord_);
// copy data from src_buf into src_vector_container // copy data from src_buf into src_vector_container
auto src_vector_container = src_vector_type{ auto src_vector_container = src_vector_type{
...@@ -114,7 +114,7 @@ struct ThreadwiseTensorSliceTransfer_v6r1 ...@@ -114,7 +114,7 @@ struct ThreadwiseTensorSliceTransfer_v6r1
}); });
const bool is_dst_valid = const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); coordinate_has_valid_offset_assuming_top_index_is_valid(dst_desc, dst_coord_);
// copy data from dst_vector into dst_buf // copy data from dst_vector into dst_buf
dst_buf.template Update<DstInMemOp, dst_vector_t>( dst_buf.template Update<DstInMemOp, dst_vector_t>(
...@@ -126,28 +126,21 @@ struct ThreadwiseTensorSliceTransfer_v6r1 ...@@ -126,28 +126,21 @@ struct ThreadwiseTensorSliceTransfer_v6r1
if constexpr(idx_1d.value != num_access - 1) if constexpr(idx_1d.value != num_access - 1)
{ {
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate(
src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step)); move_tensor_coordinate(src_desc, src_coord_, forward_step);
move_tensor_coordinate( move_tensor_coordinate(dst_desc, dst_coord_, forward_step);
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
} }
}); });
// move coordinate back to slice origin (or not) // move coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun) if constexpr(SrcResetCoordinateAfterRun)
{ {
const auto src_reset_step = move_tensor_coordinate(src_desc, src_coord_, GetCoordinateResetStep());
make_tensor_coordinate_step(src_desc, GetCoordinateResetStep());
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
} }
if constexpr(DstResetCoordinateAfterRun) if constexpr(DstResetCoordinateAfterRun)
{ {
const auto dst_reset_step = move_tensor_coordinate(dst_desc, dst_coord_, GetCoordinateResetStep());
make_tensor_coordinate_step(dst_desc, GetCoordinateResetStep());
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
} }
} }
...@@ -198,10 +191,7 @@ struct ThreadwiseTensorSliceTransfer_v6r1 ...@@ -198,10 +191,7 @@ struct ThreadwiseTensorSliceTransfer_v6r1
? dst_slice_origin_step_idx ? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetCoordinateResetStep(); : dst_slice_origin_step_idx + GetCoordinateResetStep();
// is it OK to construct a new step every time? move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step_idx);
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
} }
private: private:
......
...@@ -113,10 +113,10 @@ struct ThreadwiseTensorSliceTransfer_v6r2 ...@@ -113,10 +113,10 @@ struct ThreadwiseTensorSliceTransfer_v6r2
using dst_vector_t = typename dst_vector_type::type; using dst_vector_t = typename dst_vector_type::type;
const bool is_src0_valid = const bool is_src0_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src0_desc, src0_coord_); coordinate_has_valid_offset_assuming_top_index_is_valid(src0_desc, src0_coord_);
const bool is_src1_valid = const bool is_src1_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src1_desc, src1_coord_); coordinate_has_valid_offset_assuming_top_index_is_valid(src1_desc, src1_coord_);
// copy data from src0_buf into src0_vector_container // copy data from src0_buf into src0_vector_container
auto src0_vector_container = src0_vector_type{ auto src0_vector_container = src0_vector_type{
...@@ -135,7 +135,7 @@ struct ThreadwiseTensorSliceTransfer_v6r2 ...@@ -135,7 +135,7 @@ struct ThreadwiseTensorSliceTransfer_v6r2
}); });
const bool is_dst_valid = const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); coordinate_has_valid_offset_assuming_top_index_is_valid(dst_desc, dst_coord_);
// copy data from dst_vector into dst_buf // copy data from dst_vector into dst_buf
dst_buf.template Update<DstInMemOp, dst_vector_t>( dst_buf.template Update<DstInMemOp, dst_vector_t>(
......
...@@ -131,13 +131,13 @@ struct ThreadwiseTensorSliceTransfer_v6r3 ...@@ -131,13 +131,13 @@ struct ThreadwiseTensorSliceTransfer_v6r3
using dst_vector_t = typename dst_vector_type::type; using dst_vector_t = typename dst_vector_type::type;
const bool is_src0_valid = const bool is_src0_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src0_desc, src0_coord_); coordinate_has_valid_offset_assuming_top_index_is_valid(src0_desc, src0_coord_);
const bool is_src1_valid = const bool is_src1_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src1_desc, src1_coord_); coordinate_has_valid_offset_assuming_top_index_is_valid(src1_desc, src1_coord_);
const bool is_src2_valid = const bool is_src2_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src2_desc, src2_coord_); coordinate_has_valid_offset_assuming_top_index_is_valid(src2_desc, src2_coord_);
// copy data from src0_buf into src0_vector_container // copy data from src0_buf into src0_vector_container
auto src0_vector_container = src0_vector_type{ auto src0_vector_container = src0_vector_type{
...@@ -160,7 +160,7 @@ struct ThreadwiseTensorSliceTransfer_v6r3 ...@@ -160,7 +160,7 @@ struct ThreadwiseTensorSliceTransfer_v6r3
}); });
const bool is_dst_valid = const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); coordinate_has_valid_offset_assuming_top_index_is_valid(dst_desc, dst_coord_);
dst_buf.template Update<DstInMemOp, dst_vector_t>( dst_buf.template Update<DstInMemOp, dst_vector_t>(
dst_coord_.GetOffset(), dst_coord_.GetOffset(),
......
...@@ -138,9 +138,8 @@ struct ThreadwiseTensorSliceTransfer_v7 ...@@ -138,9 +138,8 @@ struct ThreadwiseTensorSliceTransfer_v7
static_for<0, nSrc, 1>{}([&](auto i) { static_for<0, nSrc, 1>{}([&](auto i) {
using src_vector_t = typename remove_cvref_t<decltype(src_vectors[i])>::type; using src_vector_t = typename remove_cvref_t<decltype(src_vectors[i])>::type;
const bool is_src_valid = const bool is_src_valid = coordinate_has_valid_offset_assuming_top_index_is_valid(
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i], src_descs[i], src_coords_[i]);
src_coords_[i]);
src_vectors(i).template AsType<src_vector_t>()(I0) = src_vectors(i).template AsType<src_vector_t>()(I0) =
src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(), src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(),
...@@ -184,9 +183,8 @@ struct ThreadwiseTensorSliceTransfer_v7 ...@@ -184,9 +183,8 @@ struct ThreadwiseTensorSliceTransfer_v7
static_for<0, nDst, 1>{}([&](auto i) { static_for<0, nDst, 1>{}([&](auto i) {
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type; using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
const bool is_dst_valid = const bool is_dst_valid = coordinate_has_valid_offset_assuming_top_index_is_valid(
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], dst_descs[i], dst_coords_[i]);
dst_coords_[i]);
constexpr InMemoryDataOperationEnum DstInMemOp = constexpr InMemoryDataOperationEnum DstInMemOp =
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value)); static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
...@@ -203,15 +201,11 @@ struct ThreadwiseTensorSliceTransfer_v7 ...@@ -203,15 +201,11 @@ struct ThreadwiseTensorSliceTransfer_v7
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(iAccess); constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(iAccess);
static_for<0, nSrc, 1>{}([&](auto i) { static_for<0, nSrc, 1>{}([&](auto i) {
move_tensor_coordinate(src_descs[i], move_tensor_coordinate(src_descs[i], src_coords_(i), forward_step);
src_coords_(i),
make_tensor_coordinate_step(src_descs[i], forward_step));
}); });
static_for<0, nDst, 1>{}([&](auto i) { static_for<0, nDst, 1>{}([&](auto i) {
move_tensor_coordinate(dst_descs[i], move_tensor_coordinate(dst_descs[i], dst_coords_(i), forward_step);
dst_coords_(i),
make_tensor_coordinate_step(dst_descs[i], forward_step));
}); });
} }
}); });
...@@ -220,20 +214,14 @@ struct ThreadwiseTensorSliceTransfer_v7 ...@@ -220,20 +214,14 @@ struct ThreadwiseTensorSliceTransfer_v7
static_for<0, nSrc, 1>{}([&](auto i) { static_for<0, nSrc, 1>{}([&](auto i) {
if constexpr(SrcResetCoordinateAfterRunFlags::At(i)) if constexpr(SrcResetCoordinateAfterRunFlags::At(i))
{ {
const auto src_reset_step = move_tensor_coordinate(src_descs[i], src_coords_(i), GetCoordinateResetStep());
make_tensor_coordinate_step(src_descs[i], GetCoordinateResetStep());
move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step);
} }
}); });
static_for<0, nDst, 1>{}([&](auto i) { static_for<0, nDst, 1>{}([&](auto i) {
if constexpr(DstResetCoordinateAfterRunFlags::At(i)) if constexpr(DstResetCoordinateAfterRunFlags::At(i))
{ {
const auto dst_reset_step = move_tensor_coordinate(dst_descs[i], dst_coords_(i), GetCoordinateResetStep());
make_tensor_coordinate_step(dst_descs[i], GetCoordinateResetStep());
move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step);
} }
}); });
} }
...@@ -266,10 +254,7 @@ struct ThreadwiseTensorSliceTransfer_v7 ...@@ -266,10 +254,7 @@ struct ThreadwiseTensorSliceTransfer_v7
? src_slice_origin_step_idx ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetCoordinateResetStep(); : src_slice_origin_step_idx + GetCoordinateResetStep();
// is it OK to construct a new step every time? move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step_idx);
const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], adjusted_step_idx);
move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step);
} }
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
...@@ -283,10 +268,7 @@ struct ThreadwiseTensorSliceTransfer_v7 ...@@ -283,10 +268,7 @@ struct ThreadwiseTensorSliceTransfer_v7
? dst_slice_origin_step_idx ? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetCoordinateResetStep(); : dst_slice_origin_step_idx + GetCoordinateResetStep();
// is it OK to construct a new step every time? move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step_idx);
const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx);
move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step);
} }
private: private:
......
...@@ -24,7 +24,7 @@ struct TransformConvFwdToGemm ...@@ -24,7 +24,7 @@ struct TransformConvFwdToGemm
typename std::enable_if<NDimSpatial == 1 && typename std::enable_if<NDimSpatial == 1 &&
is_same_v<ALayout, tensor_layout::convolution::GNWC>, is_same_v<ALayout, tensor_layout::convolution::GNWC>,
bool>::type = false> bool>::type = false>
static auto __host__ __device__ static auto
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& /* a_g_n_c_wis_strides */, const std::array<index_t, NDimSpatial + 3>& /* a_g_n_c_wis_strides */,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
...@@ -119,7 +119,7 @@ struct TransformConvFwdToGemm ...@@ -119,7 +119,7 @@ struct TransformConvFwdToGemm
typename std::enable_if<NDimSpatial == 2 && typename std::enable_if<NDimSpatial == 2 &&
is_same_v<ALayout, tensor_layout::convolution::GNHWC>, is_same_v<ALayout, tensor_layout::convolution::GNHWC>,
bool>::type = false> bool>::type = false>
static auto __host__ __device__ static auto
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& /* a_g_n_c_wis_strides */, const std::array<index_t, NDimSpatial + 3>& /* a_g_n_c_wis_strides */,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
...@@ -230,7 +230,7 @@ struct TransformConvFwdToGemm ...@@ -230,7 +230,7 @@ struct TransformConvFwdToGemm
typename std::enable_if<NDimSpatial == 3 && typename std::enable_if<NDimSpatial == 3 &&
is_same_v<ALayout, tensor_layout::convolution::GNDHWC>, is_same_v<ALayout, tensor_layout::convolution::GNDHWC>,
bool>::type = false> bool>::type = false>
static auto __host__ __device__ static auto
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& /* a_g_n_c_wis_strides */, const std::array<index_t, NDimSpatial + 3>& /* a_g_n_c_wis_strides */,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
...@@ -363,7 +363,7 @@ struct TransformConvFwdToGemm ...@@ -363,7 +363,7 @@ struct TransformConvFwdToGemm
(is_same_v<ALayout, tensor_layout::convolution::G_NW_C> || (is_same_v<ALayout, tensor_layout::convolution::G_NW_C> ||
is_same_v<ALayout, tensor_layout::convolution::NWGC>), is_same_v<ALayout, tensor_layout::convolution::NWGC>),
bool>::type = false> bool>::type = false>
static auto __host__ __device__ static auto
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
...@@ -475,7 +475,7 @@ struct TransformConvFwdToGemm ...@@ -475,7 +475,7 @@ struct TransformConvFwdToGemm
NDimSpatial == 2 && (is_same_v<ALayout, tensor_layout::convolution::G_NHW_C> || NDimSpatial == 2 && (is_same_v<ALayout, tensor_layout::convolution::G_NHW_C> ||
is_same_v<ALayout, tensor_layout::convolution::NHWGC>), is_same_v<ALayout, tensor_layout::convolution::NHWGC>),
bool>::type = false> bool>::type = false>
static auto __host__ __device__ static auto
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
...@@ -603,7 +603,7 @@ struct TransformConvFwdToGemm ...@@ -603,7 +603,7 @@ struct TransformConvFwdToGemm
NDimSpatial == 3 && (is_same_v<ALayout, tensor_layout::convolution::G_NDHW_C> || NDimSpatial == 3 && (is_same_v<ALayout, tensor_layout::convolution::G_NDHW_C> ||
is_same_v<ALayout, tensor_layout::convolution::NDHWGC>), is_same_v<ALayout, tensor_layout::convolution::NDHWGC>),
bool>::type = false> bool>::type = false>
static auto __host__ __device__ static auto
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
...@@ -754,7 +754,7 @@ struct TransformConvFwdToGemm ...@@ -754,7 +754,7 @@ struct TransformConvFwdToGemm
is_same_v<BLayout, tensor_layout::convolution::GKYXC> || is_same_v<BLayout, tensor_layout::convolution::GKYXC> ||
is_same_v<BLayout, tensor_layout::convolution::GKZYXC>, is_same_v<BLayout, tensor_layout::convolution::GKZYXC>,
bool>::type = false> bool>::type = false>
static auto __host__ __device__ static auto
MakeBDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, MakeBDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */) const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */)
{ {
...@@ -779,8 +779,9 @@ struct TransformConvFwdToGemm ...@@ -779,8 +779,9 @@ struct TransformConvFwdToGemm
is_same_v<BLayout, tensor_layout::convolution::KYXGC> || is_same_v<BLayout, tensor_layout::convolution::KYXGC> ||
is_same_v<BLayout, tensor_layout::convolution::KZYXGC>, is_same_v<BLayout, tensor_layout::convolution::KZYXGC>,
bool>::type = false> bool>::type = false>
static auto MakeBDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, __host__ __device__ static auto
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides) MakeBDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{ {
const index_t K = b_g_k_c_xs_lengths[1]; const index_t K = b_g_k_c_xs_lengths[1];
const index_t C = b_g_k_c_xs_lengths[2]; const index_t C = b_g_k_c_xs_lengths[2];
...@@ -809,7 +810,7 @@ struct TransformConvFwdToGemm ...@@ -809,7 +810,7 @@ struct TransformConvFwdToGemm
is_same_v<CLayout, tensor_layout::convolution::GNHWK> || is_same_v<CLayout, tensor_layout::convolution::GNHWK> ||
is_same_v<CLayout, tensor_layout::convolution::GNDHWK>, is_same_v<CLayout, tensor_layout::convolution::GNDHWK>,
bool>::type = false> bool>::type = false>
static auto __host__ __device__ static auto
MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths, MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */) const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */)
{ {
...@@ -834,8 +835,9 @@ struct TransformConvFwdToGemm ...@@ -834,8 +835,9 @@ struct TransformConvFwdToGemm
is_same_v<CLayout, tensor_layout::convolution::NHWGK> || is_same_v<CLayout, tensor_layout::convolution::NHWGK> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGK>, is_same_v<CLayout, tensor_layout::convolution::NDHWGK>,
bool>::type = false> bool>::type = false>
static auto MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths, __host__ __device__ static auto
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides) MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
{ {
const index_t N = c_g_n_k_wos_lengths[1]; const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2]; const index_t K = c_g_n_k_wos_lengths[2];
...@@ -858,7 +860,7 @@ struct TransformConvFwdToGemm ...@@ -858,7 +860,7 @@ struct TransformConvFwdToGemm
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::GK> || typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::GK> ||
is_same_v<CLayout, tensor_layout::convolution::G_K>, is_same_v<CLayout, tensor_layout::convolution::G_K>,
bool>::type = false> bool>::type = false>
static auto __host__ __device__ static auto
MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths, MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */) const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */)
{ {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/static_tile_distribution_encoding_helper.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
namespace ck {
namespace tile_program {
namespace block {
// Problem Description for BlockGemmARegBSmemCRegV1
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
index_t kBlockSize_,
typename BlockGemmShape_>
struct BlockGemmARegBSmemCRegV1Problem
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
static constexpr index_t kBlockSize = kBlockSize_;
};
// A is block distributed tensor
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem, typename Policy = BlockGemmARegBSmemCRegV1DefaultPolicy>
struct BlockGemmARegBSmemCRegV1
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
// C += A * B
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
__device__ void operator()(CBlockTensor& c_block_tensor,
const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.GetLengths()[Number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.GetWindowLengths()[Number<0>{}];
constexpr index_t KPerBlock = ABlockTensorTmp{}.GetLengths()[Number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template At<0>())>;
constexpr index_t MWarp = config.template At<1>();
constexpr index_t NWarp = config.template At<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto a_block_outer_dstr_encoding = StaticTileDistributionEncoding<
Sequence<NWarp>,
Tuple<Sequence<MIterPerWarp, MWarp>, Sequence<KIterPerWarp>>,
Tuple<Sequence<1, 0>>,
Tuple<Sequence<1, 0>>,
Sequence<1, 2>,
Sequence<0, 0>>{};
constexpr auto b_block_outer_dstr_encoding = StaticTileDistributionEncoding<
Sequence<MWarp>,
Tuple<Sequence<NIterPerWarp, NWarp>, Sequence<KIterPerWarp>>,
Tuple<Sequence<0, 1>>,
Tuple<Sequence<0, 1>>,
Sequence<1, 2>,
Sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding = StaticTileDistributionEncoding<
Sequence<>,
Tuple<Sequence<MIterPerWarp, MWarp>, Sequence<NIterPerWarp, NWarp>>,
Tuple<Sequence<1, 2>>,
Tuple<Sequence<1, 1>>,
Sequence<1, 2>,
Sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
constexpr auto b_block_dstr = make_static_tile_distribution(b_block_dstr_encode);
static_assert(is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::GetTileDistribution()
.GetStaticTileDistributionEncoding())>>,
"wrong!");
#if 0
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent distribution
static_assert(
is_same_v<remove_cvref_t<decltype(a_block_dstr_encode)>,
remove_cvref_t<decltype(
ABlockTensorTmp::GetBlockDistribution().GetStaticTensorDistributionEncoding())>>,
"wrong!");
#endif
// construct A-block-tensor from A-Block-tensor-tmp
auto a_block_tensor =
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr);
a_block_tensor.GetThreadBuffer() = a_block_tensor_tmp.GetThreadBuffer();
// construct B-block-window from B-block-distribution
auto b_block_window = make_tile_window(b_block_window_tmp.GetBottomTensorView(),
b_block_window_tmp.GetWindowLengths(),
b_block_window_tmp.GetWindowOrigin(),
b_block_dstr);
using AWarpDstr = typename WG::AWarpDstr;
using BWarpDstr = typename WG::BWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using BWarpTensor = typename WG::BWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.GetYs2DDescriptor().GetLengths());
constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.GetYs2DDescriptor().GetLengths());
constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.GetYs2DDescriptor().GetLengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.GetThreadBuffer() = a_block_tensor.GetSlicedThreadData(
merge_sequences(Sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(Sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
BWarpTensor b_warp_tensor;
b_warp_tensor.GetThreadBuffer() =
detail::load_sliced_thread_data_from_tile_window(
b_block_window,
MultiIndex<2 + BWarpDstr::NDimY>{nIter, kIter, 0},
merge_sequences(Sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.GetThreadBuffer() = c_block_tensor.GetSlicedThreadData(
merge_sequences(Sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(Sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.SetSlicedThreadData(
merge_sequences(Sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(Sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.GetThreadBuffer());
});
});
});
}
// C = A * B
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
__device__ auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>>,
"wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.GetLengths()[Number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.GetWindowLengths()[Number<0>{}];
constexpr index_t KPerBlock = ABlockTensorTmp{}.GetLengths()[Number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template At<0>())>;
constexpr index_t MWarp = config.template At<1>();
constexpr index_t NWarp = config.template At<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto a_block_outer_dstr_encoding = StaticTileDistributionEncoding<
Sequence<NWarp>,
Tuple<Sequence<MIterPerWarp, MWarp>, Sequence<KIterPerWarp>>,
Tuple<Sequence<1, 0>>,
Tuple<Sequence<1, 0>>,
Sequence<1, 2>,
Sequence<0, 0>>{};
constexpr auto b_block_outer_dstr_encoding = StaticTileDistributionEncoding<
Sequence<MWarp>,
Tuple<Sequence<NIterPerWarp, NWarp>, Sequence<KIterPerWarp>>,
Tuple<Sequence<0, 1>>,
Tuple<Sequence<0, 1>>,
Sequence<1, 2>,
Sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding = StaticTileDistributionEncoding<
Sequence<>,
Tuple<Sequence<MIterPerWarp, MWarp>, Sequence<NIterPerWarp, NWarp>>,
Tuple<Sequence<1, 2>>,
Tuple<Sequence<1, 1>>,
Sequence<1, 2>,
Sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
constexpr auto b_block_dstr = make_static_tile_distribution(b_block_dstr_encode);
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
#if 0
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent distribution
static_assert(
is_same_v<remove_cvref_t<decltype(a_block_dstr_encode)>,
remove_cvref_t<decltype(
ABlockTensorTmp::GetBlockDistribution().GetStaticTensorDistributionEncoding())>>,
"wrong!");
#endif
// construct A-block-tensor from A-Block-tensor-tmp
auto a_block_tensor =
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr);
a_block_tensor.GetThreadBuffer() = a_block_tensor_tmp.GetThreadBuffer();
// construct B-block-window from B-block-distribution
auto b_block_window = make_tile_window(b_block_window_tmp.GetBottomTensorView(),
b_block_window_tmp.GetWindowLengths(),
b_block_window_tmp.GetWindowOrigin(),
b_block_dstr);
// Construct C-Block-Tensor
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
using AWarpDstr = typename WG::AWarpDstr;
using BWarpDstr = typename WG::BWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using BWarpTensor = typename WG::BWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.GetYs2DDescriptor().GetLengths());
constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.GetYs2DDescriptor().GetLengths());
constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.GetYs2DDescriptor().GetLengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.GetThreadBuffer() = a_block_tensor.GetSlicedThreadData(
merge_sequences(Sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(Sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
BWarpTensor b_warp_tensor;
b_warp_tensor.GetThreadBuffer() =
detail::load_sliced_thread_data_from_tile_window(
b_block_window,
MultiIndex<2 + BWarpDstr::NDimY>{nIter, kIter, 0},
merge_sequences(Sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.GetThreadBuffer() = c_block_tensor.GetSlicedThreadData(
merge_sequences(Sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(Sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.SetSlicedThreadData(
merge_sequences(Sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(Sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.GetThreadBuffer());
});
});
});
return c_block_tensor;
}
// FIXME: remove: dummy host function for tile programming
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
__host__ void operator()(CBlockTensor&, const ABlockTensorTmp&, const BBlockWindowTmp&) const
{
}
// FIXME: remove: dummy host function for tile programming
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
__host__ auto operator()(const ABlockTensorTmp&, const BBlockWindowTmp&) const
{
static_assert(is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>>,
"wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.GetLengths()[Number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.GetWindowLengths()[Number<0>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN, "wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template At<0>())>;
constexpr index_t MWarp = config.template At<1>();
constexpr index_t NWarp = config.template At<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr auto c_block_outer_dstr_encoding = StaticTileDistributionEncoding<
Sequence<>,
Tuple<Sequence<MIterPerWarp, MWarp>, Sequence<NIterPerWarp, NWarp>>,
Tuple<Sequence<1, 2>>,
Tuple<Sequence<1, 1>>,
Sequence<1, 2>,
Sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
#if 0
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent distribution
static_assert(
is_same_v<remove_cvref_t<decltype(a_block_dstr_encode)>,
remove_cvref_t<decltype(
ABlockTensorTmp::GetBlockDistribution().GetStaticTensorDistributionEncoding())>>,
"wrong!");
#endif
// Construct C-Block-Tensor
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
};
} // namespace block
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
namespace ck {
namespace tile_program {
namespace block {
// Default policy for BlockGemmARegBSmemCRegV1
// Default policy class should not be templated, put template on member functions instead
struct BlockGemmARegBSmemCRegV1DefaultPolicy
{
template <typename Problem>
__host__ __device__ static constexpr auto GetWarpGemmMWarpNWarp()
{
using namespace ck::tile_program::warp;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
static_assert(kBlockSize % get_warp_size() == 0, "wrong!");
constexpr index_t NumWarp = kBlockSize / get_warp_size();
// FIXME
if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 &&
kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0)
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
}
else
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
}
}
};
} // namespace block
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/static_tile_distribution_encoding_helper.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/load_tile.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
namespace ck {
namespace tile_program {
namespace block {
// Problem Description for BlockGemmASmemBSmemCRegV1
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
index_t kBlockSize_,
typename BlockGemmShape_>
struct BlockGemmASmemBSmemCRegV1Problem
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
static constexpr index_t kBlockSize = kBlockSize_;
};
// A is block window on shared memory
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem, typename Policy = BlockGemmASmemBSmemCRegV1DefaultPolicy>
struct BlockGemmASmemBSmemCRegV1
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
// C += A * B
template <typename CBlockTensor, typename ABlockWindowTmp, typename BBlockWindowTmp>
__device__ void operator()(CBlockTensor& c_block_tensor,
const ABlockWindowTmp& a_block_window_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
is_same_v<BDataType, typename BBlockWindowTmp::DataType> &&
is_same_v<CDataType, typename CBlockTensor::DataType>,
"wrong!");
constexpr index_t MPerBlock = ABlockWindowTmp{}.GetWindowLengths()[Number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.GetWindowLengths()[Number<0>{}];
constexpr index_t KPerBlock = ABlockWindowTmp{}.GetWindowLengths()[Number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template At<0>())>;
constexpr index_t MWarp = config.template At<1>();
constexpr index_t NWarp = config.template At<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto a_block_outer_dstr_encoding = StaticTileDistributionEncoding<
Sequence<NWarp>,
Tuple<Sequence<MIterPerWarp, MWarp>, Sequence<KIterPerWarp>>,
Tuple<Sequence<1, 0>>,
Tuple<Sequence<1, 0>>,
Sequence<1, 2>,
Sequence<0, 0>>{};
constexpr auto b_block_outer_dstr_encoding = StaticTileDistributionEncoding<
Sequence<MWarp>,
Tuple<Sequence<NIterPerWarp, NWarp>, Sequence<KIterPerWarp>>,
Tuple<Sequence<0, 1>>,
Tuple<Sequence<0, 1>>,
Sequence<1, 2>,
Sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding = StaticTileDistributionEncoding<
Sequence<>,
Tuple<Sequence<MIterPerWarp, MWarp>, Sequence<NIterPerWarp, NWarp>>,
Tuple<Sequence<1, 2>>,
Tuple<Sequence<1, 1>>,
Sequence<1, 2>,
Sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
constexpr auto b_block_dstr = make_static_tile_distribution(b_block_dstr_encode);
static_assert(is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::GetTileDistribution()
.GetStaticTileDistributionEncoding())>>,
"wrong!");
// construct A/B-block-window from A/B-block-distribution
auto a_block_window = make_tile_window(a_block_window_tmp.GetBottomTensorView(),
a_block_window_tmp.GetWindowLengths(),
a_block_window_tmp.GetWindowOrigin(),
a_block_dstr);
auto b_block_window = make_tile_window(b_block_window_tmp.GetBottomTensorView(),
b_block_window_tmp.GetWindowLengths(),
b_block_window_tmp.GetWindowOrigin(),
b_block_dstr);
using AWarpDstr = typename WG::AWarpDstr;
using BWarpDstr = typename WG::BWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using BWarpTensor = typename WG::BWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.GetYs2DDescriptor().GetLengths());
constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.GetYs2DDescriptor().GetLengths());
constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.GetYs2DDescriptor().GetLengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window
AWarpTensor a_warp_tensor;
a_warp_tensor.GetThreadBuffer() = detail::load_sliced_thread_data_from_tile_window(
a_block_window,
MultiIndex<2 + AWarpDstr::NDimY>{mIter, kIter, 0},
merge_sequences(Sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
BWarpTensor b_warp_tensor;
b_warp_tensor.GetThreadBuffer() =
detail::load_sliced_thread_data_from_tile_window(
b_block_window,
MultiIndex<2 + BWarpDstr::NDimY>{nIter, kIter, 0},
merge_sequences(Sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.GetThreadBuffer() = c_block_tensor.GetSlicedThreadData(
merge_sequences(Sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(Sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.SetSlicedThreadData(
merge_sequences(Sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(Sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.GetThreadBuffer());
});
});
});
}
// C = A * B
template <typename ABlockWindowTmp, typename BBlockWindowTmp>
__device__ auto operator()(const ABlockWindowTmp& a_block_window_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
is_same_v<BDataType, typename BBlockWindowTmp::DataType>,
"wrong!");
constexpr index_t MPerBlock = ABlockWindowTmp{}.GetWindowLengths()[Number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.GetWindowLengths()[Number<0>{}];
constexpr index_t KPerBlock = ABlockWindowTmp{}.GetWindowLengths()[Number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template At<0>())>;
constexpr index_t MWarp = config.template At<1>();
constexpr index_t NWarp = config.template At<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto a_block_outer_dstr_encoding = StaticTileDistributionEncoding<
Sequence<NWarp>,
Tuple<Sequence<MIterPerWarp, MWarp>, Sequence<KIterPerWarp>>,
Tuple<Sequence<1, 0>>,
Tuple<Sequence<1, 0>>,
Sequence<1, 2>,
Sequence<0, 0>>{};
constexpr auto b_block_outer_dstr_encoding = StaticTileDistributionEncoding<
Sequence<MWarp>,
Tuple<Sequence<NIterPerWarp, NWarp>, Sequence<KIterPerWarp>>,
Tuple<Sequence<0, 1>>,
Tuple<Sequence<0, 1>>,
Sequence<1, 2>,
Sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding = StaticTileDistributionEncoding<
Sequence<>,
Tuple<Sequence<MIterPerWarp, MWarp>, Sequence<NIterPerWarp, NWarp>>,
Tuple<Sequence<1, 2>>,
Tuple<Sequence<1, 1>>,
Sequence<1, 2>,
Sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
constexpr auto b_block_dstr = make_static_tile_distribution(b_block_dstr_encode);
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
// construct A/B-block-window from A/B-block-distribution
auto a_block_window = make_tile_window(a_block_window_tmp.GetBottomTensorView(),
a_block_window_tmp.GetWindowLengths(),
a_block_window_tmp.GetWindowOrigin(),
a_block_dstr);
auto b_block_window = make_tile_window(b_block_window_tmp.GetBottomTensorView(),
b_block_window_tmp.GetWindowLengths(),
b_block_window_tmp.GetWindowOrigin(),
b_block_dstr);
static_assert(is_same_v<CDataType, typename WG::CDataType>, "wrong!");
// Construct C-Block-Tensor
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
using AWarpDstr = typename WG::AWarpDstr;
using BWarpDstr = typename WG::BWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using BWarpTensor = typename WG::BWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.GetYs2DDescriptor().GetLengths());
constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.GetYs2DDescriptor().GetLengths());
constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.GetYs2DDescriptor().GetLengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window
AWarpTensor a_warp_tensor;
a_warp_tensor.GetThreadBuffer() = detail::load_sliced_thread_data_from_tile_window(
a_block_window,
MultiIndex<2 + AWarpDstr::NDimY>{mIter, kIter, 0},
merge_sequences(Sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
BWarpTensor b_warp_tensor;
b_warp_tensor.GetThreadBuffer() =
detail::load_sliced_thread_data_from_tile_window(
b_block_window,
MultiIndex<2 + BWarpDstr::NDimY>{nIter, kIter, 0},
merge_sequences(Sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
// warp GEMM
if constexpr(KIterPerWarp == 0)
{
// c = a * b
c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensor);
}
else
{
// c += a * b
c_warp_tensor.GetThreadBuffer() = c_block_tensor.GetSlicedThreadData(
merge_sequences(Sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(Sequence<1, 1>{}, c_warp_y_lengths));
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
}
// write C warp tensor into C block tensor
c_block_tensor.SetSlicedThreadData(
merge_sequences(Sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(Sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.GetThreadBuffer());
});
});
});
return c_block_tensor;
}
// FIXME: remove: dummy host function for tile programming
template <typename CBlockTensor, typename ABlockWindowTmp, typename BBlockWindowTmp>
__host__ void operator()(CBlockTensor&, const ABlockWindowTmp&, const BBlockWindowTmp&) const
{
}
// FIXME: remove: dummy host function for tile programming
template <typename ABlockWindowTmp, typename BBlockWindowTmp>
__host__ auto operator()(const ABlockWindowTmp&, const BBlockWindowTmp&) const
{
static_assert(is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
is_same_v<BDataType, typename BBlockWindowTmp::DataType>,
"wrong!");
constexpr index_t MPerBlock = ABlockWindowTmp{}.GetWindowLengths()[Number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.GetWindowLengths()[Number<0>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN, "wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template At<0>())>;
constexpr index_t MWarp = config.template At<1>();
constexpr index_t NWarp = config.template At<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr auto c_block_outer_dstr_encoding = StaticTileDistributionEncoding<
Sequence<>,
Tuple<Sequence<MIterPerWarp, MWarp>, Sequence<NIterPerWarp, NWarp>>,
Tuple<Sequence<1, 2>>,
Tuple<Sequence<1, 1>>,
Sequence<1, 2>,
Sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
static_assert(is_same_v<CDataType, typename WG::CDataType>, "wrong!");
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
};
} // namespace block
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
namespace ck {
namespace tile_program {
namespace block {
// Default policy for BlockGemmASmemBSmemCRegV1
// Default policy class should not be templated, put template on member functions instead
struct BlockGemmASmemBSmemCRegV1DefaultPolicy
{
template <typename Problem>
__host__ __device__ static constexpr auto GetWarpGemmMWarpNWarp()
{
using namespace ck::tile_program::warp;
#if 0
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
static_assert(kBlockSize % get_warp_size() == 0, "wrong!");
constexpr index_t NumWarp = kBlockSize / get_warp_size();
if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 &&
kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0)
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2);
}
else
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2);
}
#else
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
#endif
}
};
} // namespace block
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tile_program/tile/static_distributed_tensor.hpp"
#include "ck/tile_program/tile/static_tile_distribution_encoding_helper.hpp"
#include "ck/tile_program/tile/distributed_tile_sweep.hpp"
namespace ck {
namespace tile_program {
namespace block {
// synchronize reduce result (cross lane reduction and broadcast on replicated dimension)
template <typename AccDistributedTensor_, typename ReduceFunc>
__device__ void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
const ReduceFunc& reduce_func)
{
using Dstr = typename AccDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode;
using DstrEncodeDetail = typename DstrEncode::Detail;
constexpr index_t NDimP = Dstr::GetNumOfDimensionP();
constexpr index_t NDimR = Dstr::GetNumOfDimensionR();
constexpr index_t idim_p_lane = NDimP - 1;
const auto ps_idx = make_array<index_t>(get_block_id(), get_lane_id());
const auto rs_idx = acc_tensor.GetTileDistribution().CalculateRsIndexFromPsIndex(ps_idx);
constexpr index_t thread_buf_size = AccDistributedTensor_::GetThreadBufferSize();
// loop over thread data
static_for<0, thread_buf_size, 1>{}([&](auto i) {
auto v_local = acc_tensor.GetThreadBuffer()[i];
// cross-lane reduce for replication
// only reduce on R dimension correspond to lane
// (lane id maps to this R dimension)
static_for<0, NDimR, 1>{}([&](auto idim_r) {
// FIXME: nasty to use does_p_own_r_
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
{
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
constexpr index_t lid_over_rid_derivative =
DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
static_assert(math::is_power_of_two_integer(r_length),
"wrong! only support power of 2 reduction");
constexpr index_t nstage = math::integer_log2_floor(r_length);
// reduction sweep forward
static_for<0, nstage, 1>{}([&](auto istage) {
constexpr index_t lid_delta =
lid_over_rid_derivative * (1 << (nstage - istage - 1));
// pull data from remote lane
const auto v_remote = warp_shuffle_down(v_local, lid_delta);
// reduce
v_local = reduce_func(v_local, v_remote);
});
}
});
// cross-lane broadcast for replication
// only broadcast on R dimension correspond to lane
// (lane id maps to this R dimension)
static_for<0, NDimR, 1>{}([&](auto idim_r) {
// FIXME: nasty to use does_p_own_r_
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
{
const index_t r_id = rs_idx[idim_r];
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
constexpr index_t lid_over_rid_derivative =
DstrEncodeDetail::ps_over_rs_derivative_[NDimP - 1][idim_r];
static_assert(math::is_power_of_two_integer(r_length),
"wrong! only support power of 2 reduction");
constexpr index_t nstage = math::integer_log2_floor(r_length);
// broadcast sweep backward
static_for<0, nstage, 1>{}([&](auto istage) {
// do I hold reduced data?
const bool do_i_hold_reduced_data = r_id < (1 << istage);
constexpr index_t lid_delta = lid_over_rid_derivative * (1 << istage);
// pull data from remote lane
const auto v_remote = warp_shuffle_up(v_local, lid_delta);
// decide whether to update local data with remote data
v_local = do_i_hold_reduced_data ? v_local : v_remote;
});
}
});
acc_tensor.GetThreadBuffer()(i) = v_local;
});
}
// FIXME: this is for 2D to 1D reduce only, need to support n-D
template <typename AccDistributedTensor_,
typename InDistributedTensor_,
index_t... InReduceDims,
typename ReduceFunc>
__device__ void block_tile_reduce(AccDistributedTensor_& acc_tensor,
const InDistributedTensor_& in_tensor,
Sequence<InReduceDims...>,
const ReduceFunc& reduce_func)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
#if 0
constexpr auto in_reduce_dims = Sequence<InReduceDims...>{};
constexpr index_t ndim_in = InDistributedTensor_::GetNumOfDimension();
constexpr index_t ndim_in_reduce = in_reduce_dims.Size();
constexpr index_t ndim_in_free = ndim_in - ndim_in_reduce;
constexpr auto in_free_dims_arr = [&] {
Array<bool, ndim_free> is_free_dims{true};
for(index_t i = 0; i < ndim_reduce; i++)
{
is_free_dims(in_reduce_dims[i]) = false;
}
Array<index_t, ndim_free> in_free_dims{-1};
index_t cnt = 0;
for(index_t i = 0; i < ndim_in; i++)
{
if(is_free_dims[i])
{
in_free_dims(cnt) = i;
cnt++
}
}
return is_free_dims;
}();
constexpr auto in_free_dims = TO_SEQUENCE(is_free_dims_arr, ndim_in_free);
#else
constexpr auto spans = InDistributedTensor_::GetDistributedSpans();
// in-thread reduction
// FIXME: hard coded to be 2D to 1D reduction
sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) {
constexpr auto acc_dstr_idx = make_tuple(dstr_idx_i0);
auto acc = acc_tensor.GetElementFromTileDistributedIndices(acc_dstr_idx);
// FIXME
sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) {
constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1);
const auto in = in_tensor.GetElementFromTileDistributedIndices(in_dstr_idx);
acc = reduce_func(acc, in);
});
acc_tensor.SetElementFromTileDistributedIndices(acc_dstr_idx, acc);
});
#endif
}
template <typename AccDataType_,
typename InDistributedTensor_,
index_t... InReduceDims,
typename ReduceFunc,
typename InDataType_>
__host__ __device__ auto block_tile_reduce(const InDistributedTensor_& in_tensor,
Sequence<InReduceDims...> in_reduce_dims,
const ReduceFunc& reduce_func,
const InDataType_& reduce_init)
{
using InDataType = typename InDistributedTensor_::DataType;
using AccDataType = remove_cvref_t<AccDataType_>;
static_assert(is_same_v<InDataType, remove_cvref_t<InDataType_>>, "wrong!");
// declare acc_tensor
constexpr auto acc_dstr = make_static_tile_distribution(
ck::tile_program::detail::make_reduce_tile_distribution_encoding(
InDistributedTensor_::GetTileDistribution().GetStaticTileDistributionEncoding(),
Sequence<InReduceDims...>{}));
auto acc_tensor = make_static_distributed_tensor<AccDataType>(acc_dstr);
// init acc_tensor
tile_elementwise_inout([&](auto& acc) { acc = type_convert<AccDataType>(reduce_init); },
acc_tensor);
// warp reduce
block_tile_reduce(acc_tensor, in_tensor, in_reduce_dims, reduce_func);
return acc_tensor;
}
// FIXME: dummy host function for tile program
template <typename AccDistributedTensor_,
typename InDistributedTensor_,
index_t... InReduceDims,
typename ReduceFunc>
__host__ void block_tile_reduce(AccDistributedTensor_&,
const InDistributedTensor_&,
Sequence<InReduceDims...>,
const ReduceFunc&)
{
}
// FIXME: dummy host function for tile program
template <typename AccDistributedTensor_, typename ReduceFunc>
__host__ void block_tile_reduce_sync(AccDistributedTensor_&, const ReduceFunc&)
{
}
} // namespace block
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
namespace ck {
namespace tile_program {
namespace block {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem, typename Policy = BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy>
struct BlockGemmPipelineAGmemBGmemCRegV1
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
__host__ __device__ static constexpr ck::index_t GetStaticLdsSize()
{
return ck::math::integer_divide_ceil(
sizeof(ADataType) *
Policy::template MakeALdsBlockDescriptor<Problem>().GetElementSpaceSize(),
16) *
16 +
sizeof(BDataType) *
Policy::template MakeBLdsBlockDescriptor<Problem>().GetElementSpaceSize();
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
__host__ __device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
static_assert(
is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kMPerBlock == ADramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] &&
kNPerBlock == BDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] &&
kKPerBlock == ADramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}],
"wrong!");
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<AddressSpaceEnum::Lds>(p_a_lds, a_lds_block_desc);
constexpr index_t a_lds_block_space_size_aligned =
math::integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.GetElementSpaceSize(),
16) *
16;
// B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<AddressSpaceEnum::Lds>(p_b_lds, b_lds_block_desc);
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.GetBottomTensorView(),
make_tuple(Number<kMPerBlock>{}, Number<kKPerBlock>{}),
a_dram_block_window_tmp.GetWindowOrigin(),
Policy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store
auto a_copy_lds_window =
make_tile_window(a_lds_block,
make_tuple(Number<kMPerBlock>{}, Number<kKPerBlock>{}),
{0, 0},
a_copy_dram_window.GetTileDistribution());
// B DRAM tile window for load
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.GetBottomTensorView(),
make_tuple(Number<kNPerBlock>{}, Number<kKPerBlock>{}),
b_dram_block_window_tmp.GetWindowOrigin(),
Policy::template MakeBDramTileDistribution<Problem>());
// B LDS tile window for store
auto b_copy_lds_window =
make_tile_window(b_lds_block,
make_tuple(Number<kNPerBlock>{}, Number<kKPerBlock>{}),
{0, 0},
b_copy_dram_window.GetTileDistribution());
// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window(
a_lds_block, make_tuple(Number<kMPerBlock>{}, Number<kKPerBlock>{}), {0, 0});
// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block, make_tuple(Number<kNPerBlock>{}, Number<kKPerBlock>{}), {0, 0});
// Block GEMM
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
// Acc register tile
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
// prefetch
// global read 0
auto a_block_tile = load_tile(a_copy_dram_window);
auto b_block_tile = load_tile(b_copy_dram_window);
{
// move to 1
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
// Initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window, a_block_tile_tmp);
// LDS write 0
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
store_tile(b_copy_lds_window, b_block_tile_tmp);
}
index_t iCounter = num_loop - 1;
do
{
// global read i + 1
a_block_tile = load_tile(a_copy_dram_window);
b_block_tile = load_tile(b_copy_dram_window);
ProgramServer::block_sync_lds();
// GEMM i
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
ProgramServer::block_sync_lds();
// move to i + 2
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
// LDS write i + 1
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window, a_block_tile_tmp);
// LDS write i + 1
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
store_tile(b_copy_lds_window, b_block_tile_tmp);
iCounter--;
} while(iCounter > 0);
// tail
{
ProgramServer::block_sync_lds();
// GEMM num_loop - 1
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
return c_block_tile;
}
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
__host__ __device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
return operator()(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
num_loop,
p_smem);
}
};
} // namespace block
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1.hpp"
namespace ck {
namespace tile_program {
namespace block {
// Default policy for BlockGemmPipelineAGmemBGmemCRegV1
// Default policy class should not be templated, put template on member functions instead
struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
#if 0
// 2d
template <typename Problem>
__host__ __device__ static constexpr auto MakeALdsBlockDescriptor()
{
using namespace ck;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto a_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kMPerBlock, kKPerBlock), Number<32>{});
return a_lds_block_desc;
}
// 2d
template <typename Problem>
__host__ __device__ static constexpr auto MakeBLdsBlockDescriptor()
{
using namespace ck;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto b_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), Number<32>{});
return b_lds_block_desc;
}
#elif 1
// 3d + padding
template <typename Problem>
__host__ __device__ static constexpr auto MakeALdsBlockDescriptor()
{
using namespace ck;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(Number<kKPerBlock / 8>{}, Number<kMPerBlock>{}, Number<8>{}),
make_tuple(Number<(kMPerBlock + 1) * 8>{}, Number<8>{}, Number<1>{}),
Number<8>{},
Number<1>{});
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_pass_through_transform(kMPerBlock),
make_merge_transform(make_tuple(kKPerBlock / 8, 8))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return a_lds_block_desc;
}
// 3d + padding
template <typename Problem>
__host__ __device__ static constexpr auto MakeBLdsBlockDescriptor()
{
using namespace ck;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(Number<kKPerBlock / 8>{}, Number<kNPerBlock>{}, Number<8>{}),
make_tuple(Number<(kNPerBlock + 1) * 8>{}, Number<8>{}, Number<1>{}),
Number<8>{},
Number<1>{});
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(make_pass_through_transform(kNPerBlock),
make_merge_transform(make_tuple(kKPerBlock / 8, 8))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return b_lds_block_desc;
}
#elif 1
// fake XOR
template <typename Problem>
__host__ __device__ static constexpr auto MakeALdsBlockDescriptor()
{
using namespace ck;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto a_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
make_tuple(kMPerBlock / 2, 2, kKPerBlock), Number<kKPerBlock>{});
constexpr index_t kK1 = 16 / sizeof(ADataType);
constexpr auto a_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
a_lds_block_desc_d1_d2_d3,
make_tuple(make_xor_transform(make_tuple(kMPerBlock / 2, kKPerBlock), kK1),
make_pass_through_transform(2)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor(
a_lds_block_desc_d4_d5_d6,
make_tuple(make_merge_transform(make_tuple(kMPerBlock / 2, 2)),
make_pass_through_transform(kKPerBlock)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return a_lds_block_desc_m_k;
}
// fake XOR
template <typename Problem>
__host__ __device__ static constexpr auto MakeBLdsBlockDescriptor()
{
using namespace ck;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto b_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
make_tuple(kNPerBlock / 2, 2, kKPerBlock), Number<kKPerBlock>{});
constexpr index_t kK1 = 16 / sizeof(BDataType);
constexpr auto b_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
b_lds_block_desc_d1_d2_d3,
make_tuple(make_xor_transform(make_tuple(kNPerBlock / 2, kKPerBlock), kK1),
make_pass_through_transform(2)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
b_lds_block_desc_d4_d5_d6,
make_tuple(make_merge_transform(make_tuple(kNPerBlock / 2, 2)),
make_pass_through_transform(kKPerBlock)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return b_lds_block_desc_n_k;
}
#endif
template <typename Problem>
__host__ __device__ static constexpr auto MakeADramTileDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = 16 / sizeof(ADataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
#if 1 // coalesce reading for each blocks
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution(
StaticTileDistributionEncoding<Sequence<1>,
Tuple<Sequence<M0, M1, M2>, Sequence<K0, K1>>,
Tuple<Sequence<1>, Sequence<1, 2>>,
Tuple<Sequence<1>, Sequence<2, 0>>,
Sequence<1, 2>,
Sequence<0, 1>>{});
#else // coalesce reading for each warps
constexpr index_t M0 = kBlockSize / get_warp_size();
constexpr index_t M1 = kMPerBlock / (M2 * M0);
return make_static_tile_distribution(
StaticTileDistributionEncoding<Sequence<1>,
Tuple<Sequence<M0, M1, M2>, Sequence<K0, K1>>,
Tuple<Sequence<1>, Sequence<1, 2>>,
Tuple<Sequence<0>, Sequence<2, 0>>,
Sequence<1, 2>,
Sequence<1, 1>>{});
#endif
}
template <typename Problem>
__host__ __device__ static constexpr auto MakeBDramTileDistribution()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = 16 / sizeof(BDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
#if 1 // coalesce reading for each blocks
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
StaticTileDistributionEncoding<Sequence<1>,
Tuple<Sequence<N0, N1, N2>, Sequence<K0, K1>>,
Tuple<Sequence<1>, Sequence<1, 2>>,
Tuple<Sequence<1>, Sequence<2, 0>>,
Sequence<1, 2>,
Sequence<0, 1>>{});
#else // coalesce reading for each warps
constexpr index_t N0 = kBlockSize / get_warp_size();
constexpr index_t N1 = kNPerBlock / (N2 * N0);
return make_static_tile_distribution(
StaticTileDistributionEncoding<Sequence<1>,
Tuple<Sequence<N0, N1, N2>, Sequence<K0, K1>>,
Tuple<Sequence<1>, Sequence<1, 2>>,
Tuple<Sequence<0>, Sequence<2, 0>>,
Sequence<1, 2>,
Sequence<1, 1>>{});
#endif
}
template <typename Problem>
__host__ __device__ static constexpr auto GetBlockGemm()
{
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1DefaultPolicy;
return BlockGemmASmemBSmemCRegV1<Problem, BlockGemmPolicy>{};
}
};
} // namespace block
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/load_tile.hpp"
#include "ck/tile_program/tile/store_tile.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
namespace ck {
namespace tile_program {
namespace block {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem, typename Policy = BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy>
struct BlockGemmPipelineAGmemBGmemCRegV2
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
__host__ __device__ static constexpr ck::index_t GetStaticLdsSize()
{
return ck::math::integer_divide_ceil(
sizeof(ADataType) *
Policy::template MakeALdsBlockDescriptor<Problem>().GetElementSpaceSize(),
16) *
16 +
sizeof(BDataType) *
Policy::template MakeBLdsBlockDescriptor<Problem>().GetElementSpaceSize();
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
__host__ __device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
static_assert(
is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kMPerBlock == ADramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] &&
kNPerBlock == BDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] &&
kKPerBlock == ADramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}],
"wrong!");
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<AddressSpaceEnum::Lds>(p_a_lds, a_lds_block_desc);
constexpr index_t a_lds_block_space_size_aligned =
math::integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.GetElementSpaceSize(),
16) *
16;
// B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<AddressSpaceEnum::Lds>(p_b_lds, b_lds_block_desc);
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.GetBottomTensorView(),
make_tuple(Number<kMPerBlock>{}, Number<kKPerBlock>{}),
a_dram_block_window_tmp.GetWindowOrigin(),
Policy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store
auto a_copy_lds_window =
make_tile_window(a_lds_block,
make_tuple(Number<kMPerBlock>{}, Number<kKPerBlock>{}),
{0, 0},
a_copy_dram_window.GetTileDistribution());
// B DRAM tile window for load
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.GetBottomTensorView(),
make_tuple(Number<kNPerBlock>{}, Number<kKPerBlock>{}),
b_dram_block_window_tmp.GetWindowOrigin(),
Policy::template MakeBDramTileDistribution<Problem>());
// B LDS tile window for store
auto b_copy_lds_window =
make_tile_window(b_lds_block,
make_tuple(Number<kNPerBlock>{}, Number<kKPerBlock>{}),
{0, 0},
b_copy_dram_window.GetTileDistribution());
// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window(
a_lds_block, make_tuple(Number<kMPerBlock>{}, Number<kKPerBlock>{}), {0, 0});
// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block, make_tuple(Number<kNPerBlock>{}, Number<kKPerBlock>{}), {0, 0});
// Block GEMM
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
// Acc register tile
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
// prefetch
// global read 0
auto a_block_tile = load_tile(a_copy_dram_window);
auto b_block_tile = load_tile(b_copy_dram_window);
{
// move to 1
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
// Initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window, a_block_tile_tmp);
// global read 1
a_block_tile = load_tile(a_copy_dram_window);
// LDS write 0
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
store_tile(b_copy_lds_window, b_block_tile_tmp);
// global read 1
b_block_tile = load_tile(b_copy_dram_window);
}
index_t iCounter = num_loop - 2;
do
{
ProgramServer::block_sync_lds();
// GEMM i
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
ProgramServer::block_sync_lds();
// move to i + 2
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
// LDS write i + 1
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window, a_block_tile_tmp);
// global read i + 2
a_block_tile = load_tile(a_copy_dram_window);
// LDS write i + 1
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
store_tile(b_copy_lds_window, b_block_tile_tmp);
// global read i + 2
b_block_tile = load_tile(b_copy_dram_window);
iCounter--;
} while(iCounter > 0);
// tail
{
ProgramServer::block_sync_lds();
// GEMM num_loop - 2
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
ProgramServer::block_sync_lds();
// LDS write num_loop - 1
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window, a_block_tile_tmp);
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
store_tile(b_copy_lds_window, b_block_tile_tmp);
ProgramServer::block_sync_lds();
// GEMM num_loop - 1
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
return c_block_tile;
}
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
__host__ __device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
return operator()(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
num_loop,
p_smem);
}
};
} // namespace block
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp"
namespace ck {
namespace tile_program {
namespace block {
// Default policy for BlockGemmPipelineAGmemBGmemCRegV2
// Default policy class should not be templated, put template on member functions instead
// NOTE: policy should be binded to its corresponding operation. It's just a coincidence that
// BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as
// BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
using BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy =
BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy;
} // namespace block
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/type.hpp"
namespace ck {
namespace tile_program {
namespace block {
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
index_t kBlockSize_,
typename BlockGemmShape_>
struct BlockGemmPipelineProblem
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
static constexpr index_t kBlockSize = kBlockSize_;
};
} // namespace block
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tile_program {
namespace grid {
template <typename Problem, typename Policy>
struct GridGemm
{
using ADataType = typename Problem::ADataType;
using BDataType = typename Problem::BDataType;
using CDataType = typename Problem::CDataType;
using AElementFunction = typename Problem::AElementFunction;
using BElementFunction = typename Problem::BElementFunction;
using CElementFunction = typename Problem::CElementFunction;
static constexpr auto kMPerBlock = Policy::kMPerBlock;
static constexpr auto kNPerBlock = Policy::kNPerBlock;
static constexpr auto kKPerBlock = Policy::kKPerBlock;
using BlockGemmPipeline = typename Policy::template BlockGemmPipeline<Problem>;
template <typename AGridTensorView, typename BGridTensorView, typename CGridTensorView>
__host__ __device__ void operator()(ProgramServer& ps,
const AGridTensorView& a_grid,
const BGridTensorView& b_grid,
CGridTensorView& c_grid,
const AElementFunction& a_element_func,
const BElementFunction& b_element_func,
const CElementFunction& c_element_func) const
{
using namespace ck;
using namespace ck::tile_program;
using namespace ck::tile_program::block;
const auto M = a_grid.desc_.GetLength(Number<0>{});
const auto N = c_grid.desc_.GetLength(Number<1>{});
const auto K = a_grid.desc_.GetLength(Number<1>{});
// divide problem
const auto id_block = ps.get_block_id();
const auto num_tile_m = M / kMPerBlock;
const auto num_tile_n = N / kNPerBlock;
const auto block2tile = ps(Policy::MakeBlock2TileMap(num_tile_m, num_tile_n));
const auto id_tile = block2tile(id_block);
const auto iM = ps.read_first_lane(id_tile.template At<0>() * kMPerBlock);
const auto iN = ps.read_first_lane(id_tile.template At<1>() * kNPerBlock);
// A block window
auto a_block_window = make_tile_window(
a_grid, make_tuple(Number<kMPerBlock>{}, Number<kKPerBlock>{}), {iM, 0});
// B block window
auto b_block_window = make_tile_window(
b_grid, make_tuple(Number<kNPerBlock>{}, Number<kKPerBlock>{}), {iN, 0});
// Block GEMM pipeline
constexpr auto block_gemm_pipeline = BlockGemmPipeline{};
__shared__ char p_smem_char[block_gemm_pipeline.GetStaticLdsSize()];
const auto acc_block_tile = block_gemm_pipeline(a_block_window,
a_element_func,
b_block_window,
b_element_func,
K / kKPerBlock,
p_smem_char);
// cast to CDataType and apply CElementFunction
const auto c_block_tile = tile_elementwise_in(
[&](const auto& acc) { return c_element_func(type_convert<CDataType>(acc)); },
acc_block_tile);
// store C
auto c_window = make_tile_window(
c_grid, make_tuple(Number<kMPerBlock>{}, Number<kNPerBlock>{}), {iM, iN});
store_tile(c_window, c_block_tile);
}
};
} // namespace grid
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <type_traits>
#include <utility>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/utility/multi_index.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
namespace ck {
namespace detail {
template <typename Descriptor>
class DescToBlock2TileMapAdaptor
{
static_assert(std::is_same_v<
MultiIndex<2>,
remove_cvref_t<decltype(std::declval<const Descriptor&>().CalculateBottomIndex(
std::declval<MultiIndex<1>>()))>>);
Descriptor descriptor_;
public:
explicit constexpr DescToBlock2TileMapAdaptor(Descriptor descriptor)
: descriptor_(std::move(descriptor))
{
}
__host__ __device__ MultiIndex<2> operator()(index_t block_id) const
{
return descriptor_.CalculateBottomIndex(make_multi_index(block_id));
}
};
template <typename Descriptor>
__host__ __device__ static auto make_desc_to_block2tile_map_adaptor(Descriptor&& descriptor)
{
return DescToBlock2TileMapAdaptor<remove_cvref_t<Descriptor>>{
std::forward<Descriptor>(descriptor)};
}
} // namespace detail
namespace tile_program {
namespace grid {
struct Block2TileMapNFast
{
__host__ __device__ static constexpr auto MakeBlock2TileMap(index_t NumTilesM,
index_t NumTilesN)
{
return ck::detail::make_desc_to_block2tile_map_adaptor(
make_cluster_descriptor(make_tuple(NumTilesM, NumTilesN)));
}
};
struct Block2TileMapMFast
{
__host__ __device__ static constexpr auto MakeBlock2TileMap(index_t NumTilesM,
index_t NumTilesN)
{
const auto unmerge = make_merge_transform(make_tuple(NumTilesN, NumTilesM));
return [unmerge](index_t block_id) {
MultiIndex<2> unmerged;
unmerge.CalculateLowerIndex(unmerged, make_multi_index(block_id));
return make_multi_index(unmerged.At<1>(), unmerged.At<0>());
};
}
};
/// NOTICE: This map will be compiled into considerable amount of instructions.
/// Use with caution or replace it with more efficient implementation.
template <index_t MaxCols = 8>
struct Block2TileMapNAdapt
{
__host__ __device__ static constexpr auto MakeBlock2TileMap(index_t NumTilesM,
index_t NumTilesN)
{
return [=](index_t block_id) {
index_t idx_M0 = block_id % NumTilesM;
index_t idx_N0 = block_id / NumTilesM;
const auto LastCols =
(idx_N0 < NumTilesN - NumTilesN % MaxCols) ? MaxCols : NumTilesN % MaxCols;
index_t idx_N00 = idx_N0 / MaxCols;
index_t idx_N01 = idx_N0 % MaxCols;
index_t idx_M0_N01_local = idx_M0 + idx_N01 * NumTilesM;
return make_multi_index(idx_M0_N01_local / LastCols,
idx_M0_N01_local % LastCols + idx_N00 * MaxCols);
};
}
};
/// NOTICE: This map will be compiled into considerable amount of instructions.
/// Use with caution or replace it with more efficient implementation.
template <index_t MaxRows = 8>
struct Block2TileMapMAdapt
{
__host__ __device__ static constexpr auto MakeBlock2TileMap(index_t NumTilesM,
index_t NumTilesN)
{
return [=](index_t block_id) {
index_t idx_N0 = block_id % NumTilesN;
index_t idx_M0 = block_id / NumTilesN;
const auto LastRows =
(idx_M0 < NumTilesM - NumTilesM % MaxRows) ? MaxRows : NumTilesM % MaxRows;
index_t idx_M00 = idx_M0 / MaxRows;
index_t idx_M01 = idx_M0 % MaxRows;
index_t idx_N0_M01_local = idx_N0 + idx_M01 * NumTilesN;
return make_multi_index(idx_N0_M01_local % LastRows + idx_M00 * MaxRows,
idx_N0_M01_local / LastRows);
};
}
};
using DefaultBlock2TileMap = Block2TileMapMFast;
namespace detail {
template <typename TupleOfBaseTypes>
struct InheritFromBaseTypes;
template <typename... BaseTypes>
struct InheritFromBaseTypes<Tuple<BaseTypes...>> : remove_cvref_t<BaseTypes>...
{
};
} // namespace detail
template <index_t kBlockSize_,
index_t kMPerBlock_,
index_t kNPerBlock_,
index_t kKPerBlock_,
template <typename /* BlockGemmPipelineProblem */, typename /* BlockGemmPipelinePolicy */>
class BlockGemmPipeline_,
typename TupleOfExtraPolicies>
struct GridGemmPolicy : detail::InheritFromBaseTypes<TupleOfExtraPolicies>
{
static constexpr auto kBlockSize = kBlockSize_;
static constexpr auto kMPerBlock = kMPerBlock_;
static constexpr auto kNPerBlock = kNPerBlock_;
static constexpr auto kKPerBlock = kKPerBlock_;
template <typename GridGemmProblem>
using BlockGemmPipelineProblem =
block::BlockGemmPipelineProblem<typename GridGemmProblem::ADataType,
typename GridGemmProblem::BDataType,
typename GridGemmProblem::AccDataType,
kBlockSize,
TileGemmShape<kMPerBlock, kNPerBlock, kKPerBlock>>;
template <typename GridGemmProblem>
using BlockGemmPipeline =
BlockGemmPipeline_<BlockGemmPipelineProblem<GridGemmProblem>, GridGemmPolicy>;
};
} // namespace grid
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tile_program {
namespace grid {
template <typename ADataType_,
typename BDataType_,
typename AccDataType_,
typename CDataType_,
typename AElementFunction_,
typename BElementFunction_,
typename CElementFunction_>
struct GridGemmProblem
{
using ADataType = ADataType_;
using BDataType = BDataType_;
using AccDataType = AccDataType_;
using CDataType = CDataType_;
using AElementFunction = AElementFunction_;
using BElementFunction = BElementFunction_;
using CElementFunction = CElementFunction_;
};
} // namespace grid
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
namespace ck {
namespace tile_program {
// sweep over a span of a distribted tile and apply lambda function F
template <typename TileDistributedSpan_, // TileDistributedSpan<...>
typename F // signature: F(TileDistributedIndex<...>)
>
__host__ __device__ void sweep_tile_span(TileDistributedSpan_, const F& f)
{
using DstrSpan = remove_cvref_t<TileDistributedSpan_>;
static_ford<typename DstrSpan::Impl>{}([&](auto dstr_idx_impl) {
constexpr auto dstr_idx = detail::make_tile_distributed_index(dstr_idx_impl);
f(dstr_idx);
});
}
} // namespace tile_program
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment