Unverified Commit f221c2b0 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #203 from ROCm/merge_from_public

Merge from public
parents 140d2fa6 e1cd4121
......@@ -52,6 +52,7 @@
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
#include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/philox_rand.hpp"
......
......@@ -59,4 +59,47 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
#endif
}
template <typename T>
CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
{
#if 0
return __shfl(v_local, src_lane);
#elif 1
if constexpr(sizeof(int32_t) > sizeof(T))
{
union packet
{
int32_t x;
T v;
};
packet p;
p.v = v_local;
packet p_remote;
p_remote.x = __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(p));
return p_remote.v;
}
else if constexpr(sizeof(int32_t) == sizeof(T))
{
const int32_t v_remote_tmp =
__builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(v_local));
return bit_cast<T>(v_remote_tmp);
}
else
{
static_assert(sizeof(T) % sizeof(int32_t) == 0, "wrong!");
constexpr index_t elm = sizeof(T) / sizeof(int32_t);
using vector_type = thread_buffer<int32_t, elm>;
auto vs = bit_cast<vector_type>(v_local);
auto vs_remote = vector_type{};
static_for<0, elm, 1>{}([&](auto i_e) {
int32_t tmp = __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(vs[i_e]));
vs_remote(i_e) = tmp;
});
return bit_cast<T>(vs_remote);
}
#endif
}
} // namespace ck_tile
......@@ -32,11 +32,13 @@
#define CK_TILE_DEVICE inline __device__
#define CK_TILE_HOST_DEVICE inline __host__ __device__
#define CK_TILE_DEVICE_EXTERN __device__
#define CK_TILE_HOST_DEVICE_EXTERN __host__ __device__
#else
#define CK_TILE_HOST inline
#define CK_TILE_DEVICE inline
#define CK_TILE_HOST_DEVICE inline
#define CK_TILE_DEVICE_EXTERN
#define CK_TILE_HOST_DEVICE_EXTERN
#endif
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
......@@ -157,8 +159,11 @@
#endif
#endif
// workaround for ROCm 6.2 and later
#ifndef CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133
#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133) || \
(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 3 && HIP_VERSION_PATCH >= 42131) || \
(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR > 3)
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 1
#else
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 0
......
......@@ -1111,4 +1111,126 @@ CK_TILE_HOST_DEVICE constexpr auto generate_array(F&& f, number<N>)
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
namespace impl {
template <typename, typename, typename, index_t>
struct reverse_slice_sequence_impl;
template <index_t x,
index_t... xs,
index_t m,
index_t... ms,
index_t id,
index_t... ids,
index_t SliceSize>
struct reverse_slice_sequence_impl<sequence<x, xs...>,
sequence<m, ms...>,
sequence<id, ids...>,
SliceSize>
{
using old_scan =
reverse_slice_sequence_impl<sequence<xs...>, sequence<ms...>, sequence<ids...>, SliceSize>;
static constexpr auto slice_size = old_scan::remaining_slice_sizes::front().value;
static constexpr auto slice_length =
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
using dim_lengths =
typename sequence_merge<sequence<slice_length>, typename old_scan::dim_lengths>::type;
using dim_slices =
typename sequence_merge<sequence<x / slice_length>, typename old_scan::dim_slices>::type;
using remaining_slice_sizes = typename sequence_merge<
std::conditional_t<m, sequence<slice_size / slice_length>, sequence<slice_size>>,
typename old_scan::remaining_slice_sizes>::type;
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.front().value == 1;
static constexpr index_t _split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
static constexpr index_t _split_idx =
std::conditional_t<_split_flag, number<id>, number<0>>::value;
static constexpr index_t split_flag = _split_flag || old_scan::split_flag;
static constexpr index_t split_idx = std::
conditional_t<old_scan::split_flag, number<old_scan::split_idx>, number<_split_idx>>::value;
};
template <index_t x, index_t m, index_t id, index_t SliceSize>
struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, SliceSize>
{
static constexpr auto slice_size = SliceSize;
static constexpr auto slice_length =
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
using dim_lengths = sequence<slice_length>;
using dim_slices = sequence<x / slice_length>;
using remaining_slice_sizes =
std::conditional_t<m, sequence<slice_size / slice_length>, sequence<slice_size>>;
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.front().value == 1;
static constexpr index_t split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
static constexpr index_t split_idx =
std::conditional_t<split_flag, number<id>, number<0>>::value;
};
} // namespace impl
// clang-format off
// input a sequence(with optional mask), and the SliceSize : size per slice
// output the sequence each slice, and number of slices
//
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2
// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1
//
// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0
// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0
// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1
// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2
// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2
// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2
//
// <4, 2, 1, 4, 2> / 4 ->
// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0
//
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
// have split slices (right -> left)
// or the first index that sliced length is different from the original length
// clang-format on
template <typename Seq,
index_t SliceSize,
typename Mask = typename uniform_sequence_gen<Seq::size(), 1>::type>
constexpr auto reverse_slice_sequence(Seq,
number<SliceSize>,
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
{
static_assert(Seq::size() == Mask::size());
using sliced_type =
impl::reverse_slice_sequence_impl<Seq,
Mask,
typename arithmetic_sequence_gen<0, Seq::size(), 1>::type,
SliceSize>;
static_assert(sliced_type::remaining_slice_sizes::front().value == 1,
"can not evenly divide this sequence, please check");
return make_tuple(typename sliced_type::dim_lengths{},
typename sliced_type::dim_slices{},
number<sliced_type::split_idx>{});
}
template <typename Seq,
index_t SliceSize,
typename Mask = typename uniform_sequence_gen<Seq::size(), 1>::type>
constexpr auto slice_sequence(Seq,
number<SliceSize>,
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
{
constexpr auto r =
reverse_slice_sequence(Seq{}.reverse(), number<SliceSize>{}, Mask{}.reverse());
return make_tuple(r[number<0>{}].reverse(),
r[number<1>{}].reverse(),
number<Seq::size() - r[number<2>{}] - 1>{});
}
} // namespace ck_tile
......@@ -488,6 +488,26 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y,
f, x, y, z, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
}
namespace detail {
template <typename F, typename X, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto embed_tuples_impl(F f, const X& x, sequence<Is...>)
{
return concat_tuple(f(x.at(number<Is>{}))...);
}
} // namespace detail
// make sure F return at least a tuple
// e.g. x : tuple<X, Y>, f will return tuple<Z, W>
// this function will return
template <typename F, typename X>
CK_TILE_HOST_DEVICE constexpr auto embed_tuples(F f, const X& x)
{
return detail::embed_tuples_impl(
f, x, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
}
// By default unroll to the flatten
template <index_t Depth = 0, index_t MaxDepth = -1>
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<>& t)
......
......@@ -187,4 +187,18 @@ set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_ten
});
}
// this function used inside span loop over
template <typename YLengths, index_t XUnpacks>
CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number<XUnpacks>)
{
constexpr auto y_size = reduce_on_sequence(YLengths{}, multiplies{}, number<1>{});
constexpr auto y_packs = number<XUnpacks>{};
static_assert(y_size % y_packs == 0);
constexpr auto y_slice_size = y_size / y_packs;
constexpr auto slice_info = slice_sequence(YLengths{}, number<y_slice_size>{});
constexpr auto unpacks = slice_info[number<1>{}];
return unpacks;
}
} // namespace ck_tile
......@@ -8,6 +8,7 @@
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
......@@ -27,4 +28,281 @@ CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F& f)
});
}
// unpacked span, this version support span with unpack(multi-arg) functor
//
template <
typename TileDistributedSpan_, // tile_distributed_span<...>
typename F, // signature: F(tile_distributed_index<...>)
typename Unpacks = typename uniform_sequence_gen<TileDistributedSpan_::Impl::size(), 1>::type>
CK_TILE_DEVICE void sweep_tile_uspan(TileDistributedSpan_, const F& f, Unpacks = {})
{
using DstrSpan = remove_cvref_t<TileDistributedSpan_>;
static_uford<typename DstrSpan::Impl, Unpacks>{}(
[&](auto... dstr_idx_impl) { f(detail::make_tile_distributed_index(dstr_idx_impl)...); });
}
namespace impl {
template <typename, typename, typename>
struct sweep_tile_impl;
template <typename DistributedTensor, typename UnpacksPerXDim, index_t I, index_t... Is>
struct sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<I, Is...>>
{
CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks() const
{
constexpr auto spans = DistributedTensor::get_distributed_spans();
constexpr auto y_lengths = typename decltype(spans[number<I>{}])::Impl{};
constexpr auto x_unpacks = number<UnpacksPerXDim{}.at(number<I>{})>{};
constexpr auto y_unpacks = get_y_unpacks_from_x_unpacks(y_lengths, x_unpacks);
return y_unpacks;
}
CK_TILE_HOST_DEVICE constexpr index_t get_num_of_access() const
{
constexpr auto spans = DistributedTensor::get_distributed_spans();
constexpr auto u =
static_uford<typename decltype(spans[number<I>{}])::Impl, decltype(get_y_unpacks())>{};
return u.get_num_of_access() *
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}
.get_num_of_access();
}
template <typename F, typename SpanIdx>
CK_TILE_HOST_DEVICE constexpr void operator()(const F& f, const SpanIdx& span_idx) const
{
constexpr auto spans = DistributedTensor::get_distributed_spans();
sweep_tile_uspan(
spans[number<I>{}],
[&](auto... i_idx) {
const auto next_span_idx = embed_tuples(
[&](auto si) { return make_tuple(concat_tuple(si, make_tuple(i_idx))...); },
span_idx);
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}(
f, next_span_idx);
},
get_y_unpacks());
}
template <typename F, typename SpanIdx, index_t i_access>
CK_TILE_HOST_DEVICE constexpr void
operator()(const F& f, const SpanIdx& span_idx, number<i_access>) const
{
constexpr auto spans = DistributedTensor::get_distributed_spans();
constexpr auto u =
static_uford<typename decltype(spans[number<I>{}])::Impl, decltype(get_y_unpacks())>{};
constexpr auto access_stride =
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}
.get_num_of_access();
constexpr auto curr_i_access = number<i_access / access_stride>{};
constexpr auto next_i_access = number<i_access % access_stride>{};
u(
[&](auto... i_idx) {
const auto next_span_idx = embed_tuples(
[&](auto si) {
return make_tuple(concat_tuple(
si, make_tuple(detail::make_tile_distributed_index(i_idx)))...);
},
span_idx);
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}(
f, next_span_idx, next_i_access);
},
curr_i_access);
}
};
template <typename DistributedTensor, typename UnpacksPerXDim>
struct sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<>>
{
CK_TILE_HOST_DEVICE constexpr index_t get_num_of_access() const { return 1; }
template <typename F, typename SpanIdx>
CK_TILE_HOST_DEVICE constexpr void operator()(const F& f, const SpanIdx& span_idx) const
{
unpack(f, span_idx);
}
template <typename F, typename SpanIdx, index_t i_access>
CK_TILE_HOST_DEVICE constexpr void
operator()(const F& f, const SpanIdx& span_idx, number<i_access>) const
{
unpack(f, span_idx);
}
};
template <typename, typename, typename>
struct sweep_tile_impl_0;
// TODO: support empty tuple to remove this "entry-point" like function
template <typename DistributedTensor, typename UnpacksPerXDim, index_t I, index_t... Is>
struct sweep_tile_impl_0<DistributedTensor, UnpacksPerXDim, sequence<I, Is...>>
{
CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks() const
{
constexpr auto spans = DistributedTensor::get_distributed_spans();
constexpr auto y_lengths = typename decltype(spans[number<I>{}])::Impl{};
constexpr auto x_unpacks = number<UnpacksPerXDim{}.at(number<I>{})>{};
constexpr auto y_unpacks = get_y_unpacks_from_x_unpacks(y_lengths, x_unpacks);
return y_unpacks;
}
CK_TILE_HOST_DEVICE constexpr index_t get_num_of_access() const
{
constexpr auto spans = DistributedTensor::get_distributed_spans();
constexpr auto u =
static_uford<typename decltype(spans[number<I>{}])::Impl, decltype(get_y_unpacks())>{};
return u.get_num_of_access() *
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}
.get_num_of_access();
}
template <typename F>
CK_TILE_HOST_DEVICE constexpr void operator()(const F& f) const
{
constexpr auto spans = DistributedTensor::get_distributed_spans();
sweep_tile_uspan(
spans[number<I>{}],
[&](auto... i_idx) {
constexpr auto next_span_idx = make_tuple(make_tuple(i_idx)...);
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}(
f, next_span_idx);
},
get_y_unpacks());
}
template <typename F, index_t i_access>
CK_TILE_HOST_DEVICE constexpr void operator()(const F& f, number<i_access>) const
{
constexpr auto spans = DistributedTensor::get_distributed_spans();
constexpr auto u =
static_uford<typename decltype(spans[number<I>{}])::Impl, decltype(get_y_unpacks())>{};
constexpr auto access_stride =
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}
.get_num_of_access();
constexpr auto curr_i_access = number<i_access / access_stride>{};
constexpr auto next_i_access = number<i_access % access_stride>{};
u(
[&](auto... i_idx) {
constexpr auto next_span_idx =
make_tuple(make_tuple(detail::make_tile_distributed_index(i_idx))...);
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}(
f, next_span_idx, next_i_access);
},
curr_i_access);
}
};
} // namespace impl
/*
* Enhanced sweep-tile utility, can control unpacks along each X-dim
* the lambda function argument is the distributed-idx, which can directly
* plugged into the distributed tensor as setter/getter
*
* e.g. below function, y with the type DistributedTensor, r is row scale
*
* // sweep tile 1 by 1
* sweep_tile<DistributedTensor>([&](auto idx) {
* constexpr auto row_id = make_tuple(idx[number<0>{}]);
* y(idx) = y(idx) * r(row_id);
* });
*
* // sweep tile with 2 pixel from last dim each function call
* sweep_tile<DistributedTensor>(
* [&](auto idx_0, auto idx_1) {
* constexpr auto row_id = make_tuple(idx_0[number<0>{}]);
* y(idx_0) = y(idx_0) * r(row_id);
* y(idx_1) = y(idx_1) * r(row_id);
* },
* sequence<1, 2>{});
*
* // sweep tile with 2x2 pixel each function call
* sweep_tile<DistributedTensor>(
* [&](auto idx_00, auto idx_01, auto idx_10, auto idx_11) {
* constexpr auto row_id0 = make_tuple(idx_00[number<0>{}]);
* constexpr auto row_id1 = make_tuple(idx_10[number<0>{}]);
* y(idx_00) = y(idx_00) * r(row_id0);
* y(idx_01) = y(idx_01) * r(row_id0);
* y(idx_10) = y(idx_10) * r(row_id1);
* y(idx_11) = y(idx_11) * r(row_id1);
* },
* sequence<2, 2>{});
*
* TODO: do we need constexpr? lambda function could be non-constexpr
*/
template <typename DistributedTensor,
typename F,
typename UnpacksPerXDim =
typename uniform_sequence_gen<DistributedTensor::get_num_of_dimension(), 1>::type>
CK_TILE_HOST_DEVICE constexpr void sweep_tile(const F& f, UnpacksPerXDim = {})
{
constexpr auto spans = DistributedTensor::get_distributed_spans();
impl::sweep_tile_impl_0<DistributedTensor,
UnpacksPerXDim,
typename arithmetic_sequence_gen<0, spans.size(), 1>::type>{}(f);
}
template <typename DistributedTensor,
typename F,
typename UnpacksPerXDim =
typename uniform_sequence_gen<DistributedTensor::get_num_of_dimension(), 1>::type>
CK_TILE_HOST_DEVICE constexpr void
sweep_tile(const DistributedTensor&, const F& f, UnpacksPerXDim = {})
{
sweep_tile<DistributedTensor, F, UnpacksPerXDim>(f, UnpacksPerXDim{});
}
/*
* construct a sweep tile instance, which support issue the lambda one by one
* Note that this struct will hold the lambda functor, but will not hold the distributed tensor
* the functionality is the same as sweep_tile()
*/
template <typename DistributedTensor_,
typename F_,
typename UnpacksPerXDim_ =
typename uniform_sequence_gen<DistributedTensor_::get_num_of_dimension(), 1>::type>
struct tile_sweeper
{
using DistributedTensor = remove_cvref_t<DistributedTensor_>;
using F = remove_cvref_t<F_>;
using UnpacksPerXDim = remove_cvref_t<UnpacksPerXDim_>;
CK_TILE_HOST_DEVICE tile_sweeper(const F& f_, UnpacksPerXDim = {}) : f(f_) {}
CK_TILE_HOST_DEVICE tile_sweeper(const DistributedTensor&, const F& f_, UnpacksPerXDim = {})
: f(f_)
{
}
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_access()
{
constexpr auto spans = DistributedTensor::get_distributed_spans();
constexpr auto tmp =
impl::sweep_tile_impl_0<DistributedTensor,
UnpacksPerXDim,
typename arithmetic_sequence_gen<0, spans.size(), 1>::type>{};
return tmp.get_num_of_access();
}
CK_TILE_HOST_DEVICE void operator()() const
{
sweep_tile<DistributedTensor>(f, UnpacksPerXDim{});
}
template <index_t i_access>
CK_TILE_HOST_DEVICE void operator()(number<i_access>) const
{
constexpr auto spans = DistributedTensor::get_distributed_spans();
impl::sweep_tile_impl_0<DistributedTensor,
UnpacksPerXDim,
typename arithmetic_sequence_gen<0, spans.size(), 1>::type>{}(
f, number<i_access>{});
}
F f;
};
// partial deduction is not allowed
// template <typename T, typename F, typename U>
// CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const F&, U = {})->tile_sweeper<T, F, U>;
// deduction guide
template <typename T,
typename F,
typename U = typename uniform_sequence_gen<T::get_num_of_dimension(), 1>::type>
CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const T&, const F&, U = {})->tile_sweeper<T, F, U>;
} // namespace ck_tile
......@@ -17,6 +17,14 @@
namespace ck_tile {
namespace detail {
template <typename Distribution>
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
{
return Distribution::_get_partition_index();
}
} // namespace detail
// distributed span
template <index_t... PartialHsLengths>
struct tile_distributed_span
......@@ -83,6 +91,21 @@ struct tile_distribution
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_p() { return NDimP; }
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_r() { return NDimR; }
CK_TILE_HOST_DEVICE static auto _get_partition_index()
{
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
if constexpr(NDimP == 1)
{
return array<index_t, 1>{get_lane_id()};
}
else if constexpr(NDimP == 2)
{
return array<index_t, 2>{get_warp_id(), get_lane_id()};
}
}
CK_TILE_HOST_DEVICE static constexpr auto get_lengths()
{
#if 0
......@@ -149,6 +172,16 @@ struct tile_distribution
}
#endif
template <typename PartitionIndex = decltype(_get_partition_index())>
CK_TILE_HOST_DEVICE auto
calculate_index(const PartitionIndex& ps_idx = _get_partition_index()) const
{
const auto ps_ys_idx = container_concat(ps_idx, array<index_t, NDimY>{0});
const auto window_adaptor_thread_coord_tmp =
make_tensor_adaptor_coordinate(ps_ys_to_xs_, ps_ys_idx);
return window_adaptor_thread_coord_tmp.get_bottom_index();
}
CK_TILE_HOST_DEVICE static constexpr auto get_distributed_spans()
{
constexpr auto distributed_spans_impl = DstrEncode::detail::distributed_spans_lengthss_;
......@@ -421,6 +454,7 @@ struct tile_distribution_detail
} // namespace detail
#if 0
// this returns a constexpr tile_distribution
template <typename StaticTileDistributionEncoding_>
CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_)
......@@ -457,6 +491,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistribution
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
}
#endif
// this returns a static tile_distribution
template <typename StaticTileDistributionEncoding_>
......@@ -499,129 +534,6 @@ CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistr
//***********************************************************************************
namespace detail {
template <typename Distribution>
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
{
// only support warp-tile and block-tile
static_assert(Distribution::NDimP == 1 or Distribution::NDimP == 2, "wrong!");
if constexpr(Distribution::NDimP == 1)
{
return array<index_t, 1>{get_lane_id()};
}
else if constexpr(Distribution::NDimP == 2)
{
return array<index_t, 2>{get_warp_id(), get_lane_id()};
}
}
template <typename, typename, typename, index_t>
struct reverse_slice_sequence_impl;
template <index_t x,
index_t... xs,
index_t m,
index_t... ms,
index_t id,
index_t... ids,
index_t SliceSize>
struct reverse_slice_sequence_impl<sequence<x, xs...>,
sequence<m, ms...>,
sequence<id, ids...>,
SliceSize>
{
using old_scan =
reverse_slice_sequence_impl<sequence<xs...>, sequence<ms...>, sequence<ids...>, SliceSize>;
static constexpr auto slice_size = old_scan::remaining_slice_sizes::front().value;
static constexpr auto slice_length =
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
using dim_lengths =
typename sequence_merge<sequence<slice_length>, typename old_scan::dim_lengths>::type;
using dim_slices =
typename sequence_merge<sequence<x / slice_length>, typename old_scan::dim_slices>::type;
using remaining_slice_sizes = typename sequence_merge<
std::conditional_t<m, sequence<slice_size / slice_length>, sequence<slice_size>>,
typename old_scan::remaining_slice_sizes>::type;
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.front().value == 1;
static constexpr index_t _split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
static constexpr index_t _split_idx =
std::conditional_t<_split_flag, number<id>, number<0>>::value;
static constexpr index_t split_flag = _split_flag || old_scan::split_flag;
static constexpr index_t split_idx = std::
conditional_t<old_scan::split_flag, number<old_scan::split_idx>, number<_split_idx>>::value;
};
template <index_t x, index_t m, index_t id, index_t SliceSize>
struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, SliceSize>
{
static constexpr auto slice_size = SliceSize;
static constexpr auto slice_length =
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
using dim_lengths = sequence<slice_length>;
using dim_slices = sequence<x / slice_length>;
using remaining_slice_sizes =
std::conditional_t<m, sequence<slice_size / slice_length>, sequence<slice_size>>;
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.front().value == 1;
static constexpr index_t split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
static constexpr index_t split_idx =
std::conditional_t<split_flag, number<id>, number<0>>::value;
};
// clang-format off
// input a sequence(with optional mask), and the SliceSize : size per slice
// output the sequence each slice, and number of slices
//
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2
// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1
//
// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0
// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0
// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1
// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2
// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2
// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2
//
// <4, 2, 1, 4, 2> / 4 ->
// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0
//
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
// have split slices (right -> left)
// or the first index that sliced length is different from the original length
// clang-format on
template <typename Seq,
index_t SliceSize,
typename Mask = typename uniform_sequence_gen<Seq::size(), 1>::type>
constexpr auto reverse_slice_sequence(Seq,
number<SliceSize>,
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
{
static_assert(Seq::size() == Mask::size());
using sliced_type =
reverse_slice_sequence_impl<Seq,
Mask,
typename arithmetic_sequence_gen<0, Seq::size(), 1>::type,
SliceSize>;
static_assert(sliced_type::remaining_slice_sizes::front().value == 1,
"can not evenly divide this sequence, please check");
return make_tuple(typename sliced_type::dim_lengths{},
typename sliced_type::dim_slices{},
number<sliced_type::split_idx>{});
}
//
// slice tensor from x_dim, result in split in y_dim, not p_dim.
// We don't support slice cross p_dim (aka, slice different threads)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
// This file should not be included inside tuple.hpp!
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include <stdint.h>
#include <utility>
namespace ck_tile {
namespace detail {
// RemainLengths: sequence<...>
// Orders: sequence<...>
template <class RemainLengths, class RamainUnpacks, class Orders>
struct static_uford_impl
{
CK_TILE_HOST_DEVICE constexpr static_uford_impl()
{
static_assert(RemainLengths::size() > 0, "wrong! should not get here");
static_assert(RamainUnpacks::size() > 0, "wrong! should not get here");
}
template <class F, class CurrentUnpackIds>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentUnpackIds) const
{
constexpr index_t pack_len = RamainUnpacks::front();
static_for<0, RemainLengths::front(), pack_len>{}([=](auto I) {
constexpr auto new_pack = generate_tuple(
[&](auto idx_) {
constexpr auto i_new_pack = number<I + idx_ % pack_len>{};
constexpr auto i_pre_pack = number<idx_ / pack_len>{};
return CurrentUnpackIds{}.at(i_pre_pack).push_back(i_new_pack);
},
number<CurrentUnpackIds::size() * pack_len>{});
static_uford_impl<decltype(RemainLengths::pop_front()),
decltype(RamainUnpacks::pop_front()),
Orders>{}(f, new_pack);
});
}
};
template <class Orders>
struct static_uford_impl<sequence<>, sequence<>, Orders>
{
template <class F, class PackedId>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, PackedId) const
{
constexpr auto origin_packs = transform_tuples(
[](auto pack_) { return decltype(pack_)::reorder_old_to_new(Orders{}); }, PackedId{});
unpack(f, origin_packs);
}
};
template <class RemainLengths, class RamainUnpacks, class Orders>
struct static_uford_one_shot_impl
{
template <class F, class CurrentUnpackIds, index_t current_acc>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentUnpackIds, number<current_acc>) const
{
constexpr auto r_lens_stride =
reverse_exclusive_scan_sequence(RemainLengths{}, multiplies{}, number<1>{});
constexpr auto r_upks_stride =
reverse_exclusive_scan_sequence(RamainUnpacks{}, multiplies{}, number<1>{});
constexpr index_t current_stride = r_lens_stride.front() / r_upks_stride.front();
constexpr index_t pack_len = RamainUnpacks::front();
constexpr index_t current_idx = (current_acc / current_stride) * pack_len;
constexpr auto new_pack = generate_tuple(
[&](auto idx_) {
constexpr auto i_new_pack = number<current_idx + idx_ % pack_len>{};
constexpr auto i_pre_pack = number<idx_ / pack_len>{};
return CurrentUnpackIds{}.at(i_pre_pack).push_back(i_new_pack);
},
number<CurrentUnpackIds::size() * pack_len>{});
static_uford_one_shot_impl<decltype(RemainLengths::pop_front()),
decltype(RamainUnpacks::pop_front()),
Orders>{}(f, new_pack, number<current_acc % current_stride>{});
}
};
template <class Orders>
struct static_uford_one_shot_impl<sequence<>, sequence<>, Orders>
{
template <class F, class PackedId, index_t current_acc>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, PackedId, number<current_acc>) const
{
constexpr auto origin_packs = transform_tuples(
[](auto pack_) { return decltype(pack_)::reorder_old_to_new(Orders{}); }, PackedId{});
unpack(f, origin_packs);
}
};
} // namespace detail
// TODO: we may unify static_ford/static_uford in the future
//
// loop over nd space(sequence) with packs
// you must make sure the function passed in has same number of argument
//
// e.g.
// Lengths=seq<2, 3, 4>, Unpacks=<1, 1, 2>
// static_uford<Lengths, Unpacks>{}([&](auto i_0, auto i_1){}); // require 2 args(packs)
//
// loop #0, i_0=seq<0, 0, 0>, i_1=<0, 0, 1>
// loop #1, i_0=seq<0, 0, 2>, i_1=<0, 0, 3>
// loop #2, i_0=seq<0, 1, 0>, i_1=<0, 1, 1>
// loop #3, i_0=seq<0, 1, 2>, i_1=<0, 1, 3>
// loop #4, i_0=seq<0, 2, 0>, i_1=<0, 2, 1>
// loop #5, i_0=seq<0, 2, 2>, i_1=<0, 2, 3>
// loop #6, i_0=seq<1, 0, 0>, i_1=<1, 0, 1>
// ...
template <class Lengths,
class Unpacks = typename uniform_sequence_gen<Lengths::size(), 1>::type,
class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
struct static_uford
{
static constexpr index_t num_packs = reduce_on_sequence(Unpacks{}, multiplies{}, number<1>{});
CK_TILE_HOST_DEVICE constexpr static_uford()
{
static_assert(Lengths::size() > 0, "wrong! Lengths is empty");
static_assert(Lengths::size() == Unpacks::size(), "wrong! inconsistent size");
static_assert(Lengths::size() == Orders::size(), "wrong! inconsistent size");
static_for<0, Lengths::size(), 1>{}(
[&](auto i) { static_assert(Lengths{}.at(i) % Unpacks{}.at(i) == 0); });
}
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_access()
{
using L_ = decltype(Lengths{} / Unpacks{});
return reduce_on_sequence(L_{}, multiplies{}, number<1>{});
}
// F signature: F(sequence<...> multi_id...)
// multi_id is the unordered multi-index
template <class F>
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
{
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
constexpr auto ordered_unpacks = Unpacks::reorder_new_to_old(Orders{});
detail::static_uford_impl<decltype(ordered_lengths), decltype(ordered_unpacks), Orders>{}(
f, make_tuple(sequence<>{}));
}
// this version is friendly for issue function one by one
template <class F, index_t i_access>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, number<i_access>) const
{
static_assert(i_access < get_num_of_access());
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
constexpr auto ordered_unpacks = Unpacks::reorder_new_to_old(Orders{});
detail::static_uford_one_shot_impl<decltype(ordered_lengths),
decltype(ordered_unpacks),
Orders>{}(
f, make_tuple(sequence<>{}), number<i_access>{});
}
};
} // namespace ck_tile
......@@ -21,7 +21,7 @@
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/stream_config.hpp"
......
......@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
auto k_lds_write_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
auto k_lds_read_window =
make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK0>{}),
k_lds_write_window.get_window_origin(),
Policy::template MakeKRegSliceBlockDescriptor<Problem>());
Policy::template MakeKRegBlockDescriptor<Problem>());
auto k_reg_tensor = make_static_distributed_tensor<KDataType>(
Policy::template MakeKRegBlockDescriptor<Problem>());
......@@ -204,16 +204,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
auto v_lds_write_window =
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kK2>{}), {0, 0});
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kVHeaddim>{}), {0, 0});
auto v_lds_read_window =
make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK2>{}),
v_lds_write_window.get_window_origin(),
Policy::template MakeVRegSliceBlockDescriptor<Problem>());
auto v_reg_tensor = make_static_distributed_tensor<VDataType>(
Policy::template MakeVRegBlockDescriptor<Problem>());
Policy::template MakeVRegBlockDescriptor<Problem>());
//------------------------------------------------------------------
// KT, Reg ->LDS ->Reg
......@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>());
auto shuffled_k_lds_write_window = make_tile_window(
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
auto kt_lds_read = make_tensor_view<address_space_enum::lds>(
kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
......@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
block_sync_lds();
v_reg_tensor = load_tile(v_lds_read_window);
auto v_reg_tensor = load_tile(v_lds_read_window);
block_sync_lds();
//---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS
......@@ -276,7 +273,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
auto q_lds_read_window =
make_tile_window(q_lds_window.get_bottom_tensor_view(),
......@@ -297,7 +294,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>());
auto shuffled_q_lds_write_window = make_tile_window(
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
auto qt_lds_read = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>());
......@@ -322,7 +319,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
auto do_lds_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
auto do_lds_read_window =
make_tile_window(do_lds_window.get_bottom_tensor_view(),
......@@ -341,7 +338,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>());
auto shuffled_do_lds_write_window = make_tile_window(
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
auto dot_read_lds = make_tensor_view<address_space_enum::lds>(
dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>());
......@@ -483,9 +480,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
index_t i_total_loops = 0;
index_t seqlen_q_step = seqlen_q_start;
static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0");
static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0");
static_assert(kM0 == kK1, "kM0 should equal to kK1");
static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2");
static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2");
static_assert(kM0 == kK3, "kM0 should equal to kK3");
constexpr index_t k4_loops = kN0 / kK4;
......
......@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
auto k_lds_write_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
auto k_lds_read_window =
make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK0>{}),
k_lds_write_window.get_window_origin(),
Policy::template MakeKRegSliceBlockDescriptor<Problem>());
Policy::template MakeKRegBlockDescriptor<Problem>());
auto k_reg_tensor = make_static_distributed_tensor<KDataType>(
Policy::template MakeKRegBlockDescriptor<Problem>());
......@@ -204,16 +204,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
auto v_lds_write_window =
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kK2>{}), {0, 0});
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kVHeaddim>{}), {0, 0});
auto v_lds_read_window =
make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK2>{}),
v_lds_write_window.get_window_origin(),
Policy::template MakeVRegSliceBlockDescriptor<Problem>());
auto v_reg_tensor = make_static_distributed_tensor<VDataType>(
Policy::template MakeVRegBlockDescriptor<Problem>());
Policy::template MakeVRegBlockDescriptor<Problem>());
//------------------------------------------------------------------
// KT, Reg ->LDS ->Reg
......@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>());
auto shuffled_k_lds_write_window = make_tile_window(
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
auto kt_lds_read = make_tensor_view<address_space_enum::lds>(
kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
......@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
block_sync_lds();
v_reg_tensor = load_tile(v_lds_read_window);
auto v_reg_tensor = load_tile(v_lds_read_window);
//---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS
auto q_dram_window =
......@@ -275,7 +272,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
auto q_lds_read_window =
make_tile_window(q_lds_window.get_bottom_tensor_view(),
......@@ -296,7 +293,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>());
auto shuffled_q_lds_write_window = make_tile_window(
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
auto qt_lds_read = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>());
......@@ -321,7 +318,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
auto do_lds_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
auto do_lds_read_window =
make_tile_window(do_lds_window.get_bottom_tensor_view(),
......@@ -340,7 +337,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>());
auto shuffled_do_lds_write_window = make_tile_window(
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
auto dot_read_lds = make_tensor_view<address_space_enum::lds>(
dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>());
......@@ -482,9 +479,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
index_t i_total_loops = 0;
index_t seqlen_q_step = seqlen_q_start;
static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0");
static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0");
static_assert(kM0 == kK1, "kM0 should equal to kK1");
static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2");
static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2");
static_assert(kM0 == kK3, "kM0 should equal to kK3");
constexpr index_t k4_loops = kN0 / kK4;
......
......@@ -5,9 +5,8 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
......@@ -27,20 +26,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
using GemmProblem =
GemmPipelineProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
BlockGemmProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::QDataType,
......@@ -66,20 +60,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm()
{
using GemmProblem =
GemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kVHeaddim,
Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadHeadDimV,
Problem::kPadHeadDimV,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
BlockGemmProblem<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kVHeaddim,
Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
......@@ -104,20 +93,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
{
using GemmProblem =
GemmPipelineProblem<typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK2>,
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
typename Problem::BlockFmhaShape::Gemm2WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
BlockGemmProblem<typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK2>,
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::OGradDataType,
......@@ -143,20 +127,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm()
{
using GemmProblem =
GemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK3>,
typename Problem::BlockFmhaShape::Gemm3BlockWarps,
typename Problem::BlockFmhaShape::Gemm3WarpTile>,
TileGemmTraits<Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
Problem::kPadSeqLenK,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
BlockGemmProblem<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK3>,
typename Problem::BlockFmhaShape::Gemm3BlockWarps,
typename Problem::BlockFmhaShape::Gemm3WarpTile>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
......@@ -181,20 +160,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm()
{
using GemmProblem =
GemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK4>,
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
typename Problem::BlockFmhaShape::Gemm4WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadHeadDimQ,
Problem::kPadSeqLenK,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
BlockGemmProblem<typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK4>,
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
typename Problem::BlockFmhaShape::Gemm4WarpTile>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
......@@ -222,7 +196,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using QDataType = remove_cvref_t<typename Problem::QDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kMaxVecLoad = 16 / sizeof(QDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(QDataType);
......@@ -241,7 +215,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kMaxVecLoad = 16 / sizeof(KDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(KDataType);
......@@ -260,7 +234,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kMaxVecLoad = 16 / sizeof(VDataType);
constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
......@@ -280,7 +254,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kMaxVecLoad = 16 / sizeof(OGradDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(OGradDataType);
......@@ -341,7 +315,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
......@@ -353,7 +327,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
return total_pixels / GetAlignmentK<Problem>();
......@@ -364,7 +338,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
......@@ -402,7 +376,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t K1 = GetAlignmentK<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
......@@ -425,7 +399,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t K1 = GetAlignmentV<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
......@@ -448,7 +422,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t K1 = GetAlignmentQ<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
......@@ -471,7 +445,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t K1 = GetAlignmentOGrad<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
......@@ -842,44 +816,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsWriteBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKRegSliceBlockDescriptor()
{
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto k_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode);
return k_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKRegBlockDescriptor()
{
......@@ -891,7 +833,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
......@@ -916,45 +858,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsWriteBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kVPack = GetSmemKPackV<Problem>();
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kVPack>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVRegSliceBlockDescriptor()
{
using BlockGemm = remove_cvref_t<decltype(GetOGradVBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto v_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode);
return v_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVRegBlockDescriptor()
{
......@@ -966,7 +876,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
......@@ -992,7 +902,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t K1 = GetAlignmentK<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
......@@ -1074,7 +984,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
{
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
......@@ -1118,7 +1028,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t K1 = GetAlignmentQ<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
......@@ -1281,7 +1191,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
// Hold full block data
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
......@@ -1325,7 +1235,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t K1 = GetAlignmentOGrad<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
......@@ -1885,6 +1795,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0;
static constexpr index_t kQKHeaddim = Problem::BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = Problem::BlockFmhaShape::kVHeaddim;
static constexpr index_t kK0 = Problem::BlockFmhaShape::kK0;
static constexpr index_t kK2 = Problem::BlockFmhaShape::kK2;
static constexpr index_t kK4 = Problem::BlockFmhaShape::kK4;
static constexpr index_t WarpGemmM =
......@@ -1899,14 +1811,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
// Compute
static constexpr index_t Gemm0MFMA =
kM0 * kN0 * kQKHeaddim /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
kM0 * kN0 * kK0 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm1MFMA =
kM0 * kN0 * kVHeaddim /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm2MFMA =
kN0 * kVHeaddim * kM0 /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm2MFMA =
kM0 * kN0 * kK2 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm3MFMA =
kN0 * kQKHeaddim * kM0 /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
......@@ -1929,13 +1839,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>();
static constexpr index_t SGradT_LDS_READ_P1 =
kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
static constexpr index_t Q_LDS_READ =
kM0 * kQKHeaddim / kBlockSize / GetAlignmentQ<Problem>();
static constexpr index_t Q_LDS_READ = kM0 * kK0 / kBlockSize / GetAlignmentQ<Problem>();
static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
static constexpr index_t SGradT_LDS_READ_P2 =
kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
static constexpr index_t OGrad_LDS_READ =
kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
kM0 * kK2 / kBlockSize / GetAlignmentOGrad<Problem>();
static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
// LDS Write
......
......@@ -12,6 +12,16 @@ namespace detail {
template <index_t N>
struct log2;
template <>
struct log2<4> : std::integral_constant<index_t, 2>
{
};
template <>
struct log2<8> : std::integral_constant<index_t, 3>
{
};
template <>
struct log2<16> : std::integral_constant<index_t, 4>
{
......@@ -72,18 +82,18 @@ struct BlockFmhaFwdSplitKVCombinePipeline
{
if constexpr(kHeadDimV <= 32)
{
constexpr std::array<int, 4> occupancy{3, 3, 3, 1};
return occupancy[detail::log2<kMaxSplits>::value - 4];
constexpr std::array occupancy{3, 3, 3, 3, 3, 1};
return occupancy[detail::log2<kMaxSplits>::value - 2];
}
else if constexpr(kHeadDimV <= 128)
{
constexpr std::array<int, 4> occupancy{3, 3, 2, 1};
return occupancy[detail::log2<kMaxSplits>::value - 4];
constexpr std::array occupancy{3, 3, 3, 3, 2, 1};
return occupancy[detail::log2<kMaxSplits>::value - 2];
}
else if constexpr(kHeadDimV <= 256)
{
constexpr std::array<int, 4> occupancy{2, 2, 2, 1};
return occupancy[detail::log2<kMaxSplits>::value - 4];
constexpr std::array occupancy{2, 2, 2, 2, 2, 1};
return occupancy[detail::log2<kMaxSplits>::value - 2];
}
}
}();
......@@ -138,9 +148,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline
auto lse_accum = make_static_distributed_tensor<LSEDataType>(
Policy::template MakeLSEaccRegTileDistribution<Problem>());
// copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, max(kMaxSplits, warp_size)])
// this will extend the distributed tensor width so that each thread in wave have data to
// reduce.
// copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, kMaxSplits])
// and fill up -INF values outside the [kM0, num_splits] region.
{
constexpr auto spans = decltype(lse_accum)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
......
......@@ -10,11 +10,26 @@ namespace ck_tile {
struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
template <index_t BlockSize, index_t M, index_t N, typename DataType>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeForTile()
{
constexpr index_t PixelsPerThread = (M * N) / BlockSize;
static_assert(0 < PixelsPerThread);
constexpr index_t MaxNPerThread = 16 / sizeof(DataType);
constexpr index_t NPerThread = min(MaxNPerThread, PixelsPerThread);
return NPerThread;
}
// alignment for dram lse tile (shape=[kMaxSplits, kM0])
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentLSE()
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
return 16 / sizeof(LSEDataType);
return GetVectorSizeForTile<Problem::kBlockSize,
Problem::kMaxSplits,
Problem::kM0,
typename Problem::LSEDataType>();
}
template <typename Problem>
......@@ -47,29 +62,31 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
MakeLSEaccLdsBlockDescriptor<Problem>().get_element_space_size();
}
// shape=[kMaxSplits, kM0]
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccDramTileDistribution()
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNumWarps = Problem::kNumWarps;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t NPerThread = 16 / sizeof(LSEDataType);
constexpr index_t NThreads = kNPerBlock / NPerThread;
constexpr index_t NPerThread =
GetVectorSizeForTile<kBlockSize, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr index_t NThreads = kNPerBlock / NPerThread;
constexpr index_t MThreadsPerWarp = get_warp_size() / NThreads;
constexpr index_t TotalWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (TotalWarps * MThreadsPerWarp);
constexpr index_t MPerThread = kMPerBlock / (kNumWarps * MThreadsPerWarp);
static_assert(NThreads * NPerThread == kNPerBlock);
static_assert(MPerThread * TotalWarps * MThreadsPerWarp == kMPerBlock);
static_assert(MPerThread * kNumWarps * MThreadsPerWarp == kMPerBlock);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<MPerThread, TotalWarps, MThreadsPerWarp>,
tuple<sequence<MPerThread, kNumWarps, MThreadsPerWarp>,
sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
......@@ -77,15 +94,18 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
sequence<0, 1>>{});
}
// 3d + padding, [kMaxSplits, kM0]
// 3d + padding, shape=[kMaxSplits, kM0]
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsStoreBlockDescriptor()
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t NPack = 16 / sizeof(LSEDataType);
constexpr index_t NPack =
GetVectorSizeForTile<kBlockSize, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
......@@ -103,15 +123,18 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
return lse_acc_lds_block_desc;
}
// 3d + padding, [kM0, kMaxSplits]
// 3d + padding, shape=[kM0, kMaxSplits]
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsBlockDescriptor()
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t NPack = 16 / sizeof(LSEDataType);
constexpr index_t NPack =
GetVectorSizeForTile<kBlockSize, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
......@@ -134,26 +157,28 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = max(Problem::kMaxSplits, get_warp_size());
constexpr index_t kNPerBlock = Problem::kMaxSplits;
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t NThreads = get_warp_size();
constexpr index_t NThreads = 4;
constexpr index_t NPerThread = kNPerBlock / NThreads;
constexpr index_t MThreads = kBlockSize / NThreads;
constexpr index_t MPerThread = kMPerBlock / MThreads;
constexpr index_t MThreads = kBlockSize / NThreads;
constexpr index_t MPerThread = kMPerBlock / MThreads;
constexpr index_t MWarps = kBlockSize / get_warp_size();
constexpr index_t MThreadPerWarp = get_warp_size() / NThreads;
static_assert(NThreads * NPerThread == kNPerBlock);
static_assert(MThreads * MPerThread == kMPerBlock);
static_assert(MWarps * MThreadPerWarp * MPerThread == kMPerBlock);
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<MThreads, MPerThread>, sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<2>>,
tuple<sequence<0>, sequence<0>>,
tuple<sequence<MWarps, MThreadPerWarp, MPerThread>, sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<2, 1>>,
tuple<sequence<0>, sequence<0, 1>>,
sequence<1, 2>,
sequence<1, 1>>{});
sequence<2, 1>>{});
}
template <typename Problem>
......
......@@ -115,7 +115,8 @@ struct BlockFmhaSplitKVCombinePipelineProblem
using ODataType = remove_cvref_t<ODataType_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = 256;
static constexpr index_t kNumWarps = kM0_ / (get_warp_size() / 4);
static constexpr index_t kBlockSize = kNumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr index_t kHeadDimV = HeadDimV_;
......
......@@ -5,9 +5,9 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
......@@ -77,38 +77,44 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
using GemmProblem =
GemmPipelineProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
BlockGemmProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr auto warp_gemm = []() {
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
static_assert(WarpGemmM == 16 || WarpGemmM == 32);
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
else // WarpGemmM == 16
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
std::is_same_v<typename Problem::KDataType, bf16_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
else // WarpGemmM == 16
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
std::is_same_v<typename Problem::KDataType, fp8_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
static_assert(WarpGemmM == 32);
// TODO: hard coded here. Otherwise, it may incorrect result
constexpr index_t swizzle_factor = 4;
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
......@@ -207,20 +213,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
using GemmProblem =
GemmPipelineProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
BlockGemmProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
......@@ -968,20 +969,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
{
using GemmProblem =
GemmPipelineProblem<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>,
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
BlockGemmProblem<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
auto warp_gemm = [&]() {
if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> &&
......
......@@ -28,6 +28,7 @@
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
namespace ck_tile {
// UniversalGemm Policy
template <typename LayoutA_, typename LayoutB_, typename LayoutC_>
struct UniversalGemmPipelineAgBgCrPolicy
{
using LayoutA = remove_cvref_t<LayoutA_>;
using LayoutB = remove_cvref_t<LayoutB_>;
using LayoutC = remove_cvref_t<LayoutC_>;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr bool TransposeC = true;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
Problem::BlockGemmShape::WarpTile::at(I0),
Problem::BlockGemmShape::WarpTile::at(I1),
Problem::BlockGemmShape::WarpTile::at(I2),
TransposeC>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = WarpGemm::kK;
constexpr index_t K0 = KPerBlock / K1;
if constexpr(std::is_same<tensor_layout::gemm::RowMajor, LayoutA>::value)
{
constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1
? 1
: 32 * 4 / KPerBlock / sizeof(ADataType);
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
make_tuple(K0 * number<MLdsLayer>{}, number<MPerBlock / MLdsLayer>{}, K1),
make_tuple(K1, number<KPerBlock * MLdsLayer>{}, I1));
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc,
make_tuple(make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer>{},
number<K0 * MLdsLayer>{})),
make_pass_through_transform(K1)),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto a_lds_block_desc_ak0_kMLdsLayer_m_ak1 = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(make_tuple(K0, number<MLdsLayer>{})),
make_pass_through_transform(number<MPerBlock / MLdsLayer>{}),
make_pass_through_transform(K1)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor(
a_lds_block_desc_ak0_kMLdsLayer_m_ak1,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0, K1)),
make_merge_transform_v3_division_mod(
make_tuple(number<MPerBlock / MLdsLayer>{}, number<MLdsLayer>{}))),
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return a_lds_block_desc_m_k;
}
else // ColumnMajor A
{
// kfold and mpair dimension is not always required.
// more dimension in merge_transform increase the difficulty of generating immarg offset
// for compiler.
constexpr auto M0 = get_warp_size() * Problem::BlockGemmShape::BlockWarps::at(I0);
constexpr auto M1 = MPerBlock / M0;
constexpr auto KThreadWrite = Problem::kBlockSize / M0;
constexpr auto K0PerThreadWrite = K0 / KThreadWrite;
constexpr auto KThreadRead = 64 / WarpGemm::kM;
constexpr auto K0PerThreadRead = K0 / KThreadRead;
constexpr auto kfold =
(K1 * M0 * sizeof(ADataType) > 128) ? 1 : 128 / (K1 * M0 * sizeof(ADataType));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=mpair<=kN0
constexpr auto mpair = (K1 * WarpGemm::kM * sizeof(ADataType) > 128)
? 1
: ((128 / (K1 * WarpGemm::kM * sizeof(ADataType))) > M0
? M0
: 128 / (K1 * WarpGemm::kM * sizeof(ADataType)));
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
number<K0PerThreadWrite>{},
number<KThreadReadPerm * M1>{},
number<kfold * M0 / mpair>{},
number<mpair>{},
K1));
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_xor_transform(
make_tuple(number<KThreadReadPerm * M1>{}, number<kfold * M0 / mpair>{})),
make_pass_through_transform(number<mpair>{}),
make_pass_through_transform(K1)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}));
constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<M1>{})),
make_unmerge_transform(make_tuple(number<kfold>{}, number<M0 / mpair>{})),
make_pass_through_transform(number<mpair>{}),
make_pass_through_transform(K1)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(sequence<1>{},
sequence<2>{},
sequence<0, 3>{},
sequence<4, 5>{},
sequence<6>{},
sequence<7>{}));
constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor(
a_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<KThreadReadPerm>{},
number<KThreadWrite / kfold / KThreadReadPerm>{},
number<kfold>{},
number<K0PerThreadWrite>{},
K1)),
make_merge_transform_v3_division_mod(
make_tuple(number<M0 / mpair>{}, number<mpair>{}, number<M1>{}))),
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return a_lds_block_desc_m_k;
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
Problem::BlockGemmShape::WarpTile::at(I0),
Problem::BlockGemmShape::WarpTile::at(I1),
Problem::BlockGemmShape::WarpTile::at(I2),
TransposeC>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = WarpGemm::kK;
constexpr index_t K0 = KPerBlock / K1;
if constexpr(std::is_same<tensor_layout::gemm::ColumnMajor, LayoutB>::value)
{
// NLdsLayer * K0 as logical Bank
constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1
? 1
: 32 * 4 / KPerBlock / sizeof(BDataType);
;
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
make_tuple(K0 * number<NLdsLayer>{}, number<NPerBlock / NLdsLayer>{}, K1),
make_tuple(K1, number<KPerBlock * NLdsLayer>{}, I1));
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc,
make_tuple(make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer>{},
number<K0 * NLdsLayer>{})),
make_pass_through_transform(K1)),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto b_lds_block_desc_bk0_kNLdsLayer_n_bk1 = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(make_tuple(K0, number<NLdsLayer>{})),
make_pass_through_transform(number<NPerBlock / NLdsLayer>{}),
make_pass_through_transform(K1)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
b_lds_block_desc_bk0_kNLdsLayer_n_bk1,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0, K1)),
make_merge_transform_v3_division_mod(
make_tuple(number<NPerBlock / NLdsLayer>{}, number<NLdsLayer>{}))),
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return b_lds_block_desc_n_k;
}
else // RowMajor B
{
constexpr auto N0 = get_warp_size() * Problem::BlockGemmShape::BlockWarps::at(I1);
constexpr auto N1 = NPerBlock / N0;
constexpr auto KThreadWrite = Problem::kBlockSize / N0;
constexpr auto K0PerThreadWrite = K0 / KThreadWrite;
constexpr auto KThreadRead = 64 / WarpGemm::kN;
constexpr auto K0PerThreadRead = K0 / KThreadRead;
constexpr auto kfold =
(K1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (K1 * N0 * sizeof(BDataType));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=npair<=kN0
constexpr auto npair = (K1 * WarpGemm::kN * sizeof(BDataType) > 128)
? 1
: ((128 / (K1 * WarpGemm::kN * sizeof(BDataType))) > N0
? N0
: 128 / (K1 * WarpGemm::kN * sizeof(BDataType)));
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
number<K0PerThreadWrite>{},
number<KThreadReadPerm * N1>{},
number<kfold * N0 / npair>{},
number<npair>{},
K1));
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_xor_transform(
make_tuple(number<KThreadReadPerm * N1>{}, number<kfold * N0 / npair>{})),
make_pass_through_transform(number<npair>{}),
make_pass_through_transform(K1)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}));
constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<N1>{})),
make_unmerge_transform(make_tuple(number<kfold>{}, number<N0 / npair>{})),
make_pass_through_transform(number<npair>{}),
make_pass_through_transform(K1)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(sequence<1>{},
sequence<2>{},
sequence<0, 3>{},
sequence<4, 5>{},
sequence<6>{},
sequence<7>{}));
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
b_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<KThreadReadPerm>{},
number<KThreadWrite / kfold / KThreadReadPerm>{},
number<kfold>{},
number<K0PerThreadWrite>{},
K1)),
make_merge_transform_v3_division_mod(
make_tuple(number<N0 / npair>{}, number<npair>{}, number<N1>{}))),
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return b_lds_block_desc_n_k;
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_a;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{
constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_b;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
index_t smem_size = 0;
smem_size += smem_size_a + smem_size_b;
return smem_size;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
Problem::BlockGemmShape::WarpTile::at(I0),
Problem::BlockGemmShape::WarpTile::at(I1),
Problem::BlockGemmShape::WarpTile::at(I2),
TransposeC>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = WarpGemm::kK;
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
{
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
Problem::BlockGemmShape::WarpTile::at(I0),
Problem::BlockGemmShape::WarpTile::at(I1),
Problem::BlockGemmShape::WarpTile::at(I2),
TransposeC>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = WarpGemm::kK;
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = BlockSize / get_warp_size();
static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = NPerBlock / (N2 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
AccDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
TransposeC>;
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockGemmASmemBSmemCRegV1<Problem, BlockGemmPolicy>{};
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment