Unverified Commit fe488bf2 authored by rocking's avatar rocking Committed by GitHub
Browse files

Merge branch 'develop' into layernorm/instance_support

parents 8bed8529 0394f8a7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
template <typename ADataType,
typename AccDataType,
typename BDataType,
index_t kBlockSize,
typename BlockWarps, // num warps along seq<M, N>
typename BlockTile, // block size, seq<M, N>
typename WarpTile, // warp size, seq<M, N>
typename ThreadTile> // contiguous pixels(vector size) along seq<M, N>
struct Reduce
{
static constexpr index_t Block_M = BlockTile::at(number<0>{});
static constexpr index_t Block_N = BlockTile::at(number<1>{});
static constexpr index_t Warp_M = WarpTile::at(number<0>{});
static constexpr index_t Warp_N = WarpTile::at(number<1>{});
static constexpr index_t Thread_M = ThreadTile::at(number<0>{});
static constexpr index_t Thread_N = ThreadTile::at(number<1>{});
static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{});
static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{});
static constexpr index_t ThreadPerWarp_M = Warp_M / Thread_M;
static constexpr index_t ThreadPerWarp_N = Warp_N / Thread_N;
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
__device__ static constexpr auto MakeABlockTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<Repeat_M, WarpPerBlock_M, ThreadPerWarp_M, Thread_M>,
sequence<Repeat_N, WarpPerBlock_N, ThreadPerWarp_N, Thread_N>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<1, 1>, sequence<2, 2>>,
sequence<1, 1, 2, 2>,
sequence<0, 3, 0, 3>>{});
}
__device__ void operator()(const ADataType* p_a, BDataType* p_b, index_t M, index_t N) const
{
const auto a_m_n = make_naive_tensor_view<address_space_enum::global>(
p_a, make_tuple(M, N), make_tuple(N, 1), number<Thread_N>{}, number<1>{});
const auto iM = get_block_id() * Block_M;
// A window
auto a_block_window = make_tile_window(a_m_n,
make_tuple(number<Block_M>{}, number<Block_N>{}),
{iM, 0},
MakeABlockTileDistribution());
const auto f_reduce = [](const auto& v0, const auto& v1) { return v0 + v1; };
const ADataType reduce_init_value = 0;
constexpr auto reduce_dims = sequence<1>{};
// Acc tile
// TODO: support cross warp reduction
auto acc_block_tensor = decltype(block_tile_reduce<AccDataType>(
load_tile(a_block_window), reduce_dims, f_reduce, reduce_init_value)){};
// init Acc tile
tile_elementwise_inout(
[&](auto& acc) { acc = type_convert<AccDataType>(reduce_init_value); },
acc_block_tensor);
// loop
index_t iN = 0;
do
{
const auto a_block_tensor = load_tile(a_block_window);
// FIXME: support cross warp reduction
block_tile_reduce(acc_block_tensor, a_block_tensor, reduce_dims, f_reduce);
move_tile_window(a_block_window, {0, Block_N});
iN += Block_N;
} while(iN < N);
// FIXME: support cross warp reduction
block_tile_reduce_sync(acc_block_tensor, f_reduce);
// convert acc_block_tensor to b_block_tensor
const auto b_block_tensor = tile_elementwise_in(
[](const auto& acc) { return type_convert<BDataType>(acc); }, acc_block_tensor);
// B
const auto b_m = make_naive_tensor_view_packed<address_space_enum::global>(
p_b, make_tuple(M), number<32>{});
// B window
auto b_block_window = make_tile_window(b_m, make_tuple(number<Block_M>{}), {iM});
// store B tile
store_tile(b_block_window, b_block_tensor);
}
};
} // namespace ck_tile
......@@ -6,3 +6,4 @@ add_subdirectory(01_fmha)
add_subdirectory(02_layernorm2d)
add_subdirectory(03_gemm)
add_subdirectory(04_img2col)
add_subdirectory(05_reduce)
......@@ -52,6 +52,7 @@
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
#include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/philox_rand.hpp"
......
......@@ -59,4 +59,47 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
#endif
}
template <typename T>
CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
{
#if 0
return __shfl(v_local, src_lane);
#elif 1
if constexpr(sizeof(int32_t) > sizeof(T))
{
union packet
{
int32_t x;
T v;
};
packet p;
p.v = v_local;
packet p_remote;
p_remote.x = __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(p));
return p_remote.v;
}
else if constexpr(sizeof(int32_t) == sizeof(T))
{
const int32_t v_remote_tmp =
__builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(v_local));
return bit_cast<T>(v_remote_tmp);
}
else
{
static_assert(sizeof(T) % sizeof(int32_t) == 0, "wrong!");
constexpr index_t elm = sizeof(T) / sizeof(int32_t);
using vector_type = thread_buffer<int32_t, elm>;
auto vs = bit_cast<vector_type>(v_local);
auto vs_remote = vector_type{};
static_for<0, elm, 1>{}([&](auto i_e) {
int32_t tmp = __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(vs[i_e]));
vs_remote(i_e) = tmp;
});
return bit_cast<T>(vs_remote);
}
#endif
}
} // namespace ck_tile
......@@ -32,11 +32,13 @@
#define CK_TILE_DEVICE inline __device__
#define CK_TILE_HOST_DEVICE inline __host__ __device__
#define CK_TILE_DEVICE_EXTERN __device__
#define CK_TILE_HOST_DEVICE_EXTERN __host__ __device__
#else
#define CK_TILE_HOST inline
#define CK_TILE_DEVICE inline
#define CK_TILE_HOST_DEVICE inline
#define CK_TILE_DEVICE_EXTERN
#define CK_TILE_HOST_DEVICE_EXTERN
#endif
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
......
......@@ -1111,4 +1111,126 @@ CK_TILE_HOST_DEVICE constexpr auto generate_array(F&& f, number<N>)
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
namespace impl {
template <typename, typename, typename, index_t>
struct reverse_slice_sequence_impl;
template <index_t x,
index_t... xs,
index_t m,
index_t... ms,
index_t id,
index_t... ids,
index_t SliceSize>
struct reverse_slice_sequence_impl<sequence<x, xs...>,
sequence<m, ms...>,
sequence<id, ids...>,
SliceSize>
{
using old_scan =
reverse_slice_sequence_impl<sequence<xs...>, sequence<ms...>, sequence<ids...>, SliceSize>;
static constexpr auto slice_size = old_scan::remaining_slice_sizes::front().value;
static constexpr auto slice_length =
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
using dim_lengths =
typename sequence_merge<sequence<slice_length>, typename old_scan::dim_lengths>::type;
using dim_slices =
typename sequence_merge<sequence<x / slice_length>, typename old_scan::dim_slices>::type;
using remaining_slice_sizes = typename sequence_merge<
std::conditional_t<m, sequence<slice_size / slice_length>, sequence<slice_size>>,
typename old_scan::remaining_slice_sizes>::type;
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.front().value == 1;
static constexpr index_t _split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
static constexpr index_t _split_idx =
std::conditional_t<_split_flag, number<id>, number<0>>::value;
static constexpr index_t split_flag = _split_flag || old_scan::split_flag;
static constexpr index_t split_idx = std::
conditional_t<old_scan::split_flag, number<old_scan::split_idx>, number<_split_idx>>::value;
};
template <index_t x, index_t m, index_t id, index_t SliceSize>
struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, SliceSize>
{
static constexpr auto slice_size = SliceSize;
static constexpr auto slice_length =
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
using dim_lengths = sequence<slice_length>;
using dim_slices = sequence<x / slice_length>;
using remaining_slice_sizes =
std::conditional_t<m, sequence<slice_size / slice_length>, sequence<slice_size>>;
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.front().value == 1;
static constexpr index_t split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
static constexpr index_t split_idx =
std::conditional_t<split_flag, number<id>, number<0>>::value;
};
} // namespace impl
// clang-format off
// input a sequence(with optional mask), and the SliceSize : size per slice
// output the sequence each slice, and number of slices
//
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2
// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1
//
// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0
// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0
// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1
// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2
// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2
// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2
//
// <4, 2, 1, 4, 2> / 4 ->
// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0
//
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
// have split slices (right -> left)
// or the first index that sliced length is different from the original length
// clang-format on
template <typename Seq,
index_t SliceSize,
typename Mask = typename uniform_sequence_gen<Seq::size(), 1>::type>
constexpr auto reverse_slice_sequence(Seq,
number<SliceSize>,
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
{
static_assert(Seq::size() == Mask::size());
using sliced_type =
impl::reverse_slice_sequence_impl<Seq,
Mask,
typename arithmetic_sequence_gen<0, Seq::size(), 1>::type,
SliceSize>;
static_assert(sliced_type::remaining_slice_sizes::front().value == 1,
"can not evenly divide this sequence, please check");
return make_tuple(typename sliced_type::dim_lengths{},
typename sliced_type::dim_slices{},
number<sliced_type::split_idx>{});
}
template <typename Seq,
index_t SliceSize,
typename Mask = typename uniform_sequence_gen<Seq::size(), 1>::type>
constexpr auto slice_sequence(Seq,
number<SliceSize>,
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
{
constexpr auto r =
reverse_slice_sequence(Seq{}.reverse(), number<SliceSize>{}, Mask{}.reverse());
return make_tuple(r[number<0>{}].reverse(),
r[number<1>{}].reverse(),
number<Seq::size() - r[number<2>{}] - 1>{});
}
} // namespace ck_tile
......@@ -488,6 +488,26 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y,
f, x, y, z, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
}
namespace detail {
template <typename F, typename X, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto embed_tuples_impl(F f, const X& x, sequence<Is...>)
{
return concat_tuple(f(x.at(number<Is>{}))...);
}
} // namespace detail
// make sure F return at least a tuple
// e.g. x : tuple<X, Y>, f will return tuple<Z, W>
// this function will return
template <typename F, typename X>
CK_TILE_HOST_DEVICE constexpr auto embed_tuples(F f, const X& x)
{
return detail::embed_tuples_impl(
f, x, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
}
// By default unroll to the flatten
template <index_t Depth = 0, index_t MaxDepth = -1>
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<>& t)
......
......@@ -187,4 +187,18 @@ set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_ten
});
}
// this function used inside span loop over
template <typename YLengths, index_t XUnpacks>
CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number<XUnpacks>)
{
constexpr auto y_size = reduce_on_sequence(YLengths{}, multiplies{}, number<1>{});
constexpr auto y_packs = number<XUnpacks>{};
static_assert(y_size % y_packs == 0);
constexpr auto y_slice_size = y_size / y_packs;
constexpr auto slice_info = slice_sequence(YLengths{}, number<y_slice_size>{});
constexpr auto unpacks = slice_info[number<1>{}];
return unpacks;
}
} // namespace ck_tile
......@@ -8,6 +8,7 @@
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
......@@ -27,4 +28,281 @@ CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F& f)
});
}
// unpacked span, this version support span with unpack(multi-arg) functor
//
template <
typename TileDistributedSpan_, // tile_distributed_span<...>
typename F, // signature: F(tile_distributed_index<...>)
typename Unpacks = typename uniform_sequence_gen<TileDistributedSpan_::Impl::size(), 1>::type>
CK_TILE_DEVICE void sweep_tile_uspan(TileDistributedSpan_, const F& f, Unpacks = {})
{
using DstrSpan = remove_cvref_t<TileDistributedSpan_>;
static_uford<typename DstrSpan::Impl, Unpacks>{}(
[&](auto... dstr_idx_impl) { f(detail::make_tile_distributed_index(dstr_idx_impl)...); });
}
namespace impl {
template <typename, typename, typename>
struct sweep_tile_impl;
template <typename DistributedTensor, typename UnpacksPerXDim, index_t I, index_t... Is>
struct sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<I, Is...>>
{
CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks() const
{
constexpr auto spans = DistributedTensor::get_distributed_spans();
constexpr auto y_lengths = typename decltype(spans[number<I>{}])::Impl{};
constexpr auto x_unpacks = number<UnpacksPerXDim{}.at(number<I>{})>{};
constexpr auto y_unpacks = get_y_unpacks_from_x_unpacks(y_lengths, x_unpacks);
return y_unpacks;
}
CK_TILE_HOST_DEVICE constexpr index_t get_num_of_access() const
{
constexpr auto spans = DistributedTensor::get_distributed_spans();
constexpr auto u =
static_uford<typename decltype(spans[number<I>{}])::Impl, decltype(get_y_unpacks())>{};
return u.get_num_of_access() *
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}
.get_num_of_access();
}
template <typename F, typename SpanIdx>
CK_TILE_HOST_DEVICE constexpr void operator()(const F& f, const SpanIdx& span_idx) const
{
constexpr auto spans = DistributedTensor::get_distributed_spans();
sweep_tile_uspan(
spans[number<I>{}],
[&](auto... i_idx) {
const auto next_span_idx = embed_tuples(
[&](auto si) { return make_tuple(concat_tuple(si, make_tuple(i_idx))...); },
span_idx);
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}(
f, next_span_idx);
},
get_y_unpacks());
}
template <typename F, typename SpanIdx, index_t i_access>
CK_TILE_HOST_DEVICE constexpr void
operator()(const F& f, const SpanIdx& span_idx, number<i_access>) const
{
constexpr auto spans = DistributedTensor::get_distributed_spans();
constexpr auto u =
static_uford<typename decltype(spans[number<I>{}])::Impl, decltype(get_y_unpacks())>{};
constexpr auto access_stride =
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}
.get_num_of_access();
constexpr auto curr_i_access = number<i_access / access_stride>{};
constexpr auto next_i_access = number<i_access % access_stride>{};
u(
[&](auto... i_idx) {
const auto next_span_idx = embed_tuples(
[&](auto si) {
return make_tuple(concat_tuple(
si, make_tuple(detail::make_tile_distributed_index(i_idx)))...);
},
span_idx);
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}(
f, next_span_idx, next_i_access);
},
curr_i_access);
}
};
template <typename DistributedTensor, typename UnpacksPerXDim>
struct sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<>>
{
CK_TILE_HOST_DEVICE constexpr index_t get_num_of_access() const { return 1; }
template <typename F, typename SpanIdx>
CK_TILE_HOST_DEVICE constexpr void operator()(const F& f, const SpanIdx& span_idx) const
{
unpack(f, span_idx);
}
template <typename F, typename SpanIdx, index_t i_access>
CK_TILE_HOST_DEVICE constexpr void
operator()(const F& f, const SpanIdx& span_idx, number<i_access>) const
{
unpack(f, span_idx);
}
};
template <typename, typename, typename>
struct sweep_tile_impl_0;
// TODO: support empty tuple to remove this "entry-point" like function
template <typename DistributedTensor, typename UnpacksPerXDim, index_t I, index_t... Is>
struct sweep_tile_impl_0<DistributedTensor, UnpacksPerXDim, sequence<I, Is...>>
{
CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks() const
{
constexpr auto spans = DistributedTensor::get_distributed_spans();
constexpr auto y_lengths = typename decltype(spans[number<I>{}])::Impl{};
constexpr auto x_unpacks = number<UnpacksPerXDim{}.at(number<I>{})>{};
constexpr auto y_unpacks = get_y_unpacks_from_x_unpacks(y_lengths, x_unpacks);
return y_unpacks;
}
CK_TILE_HOST_DEVICE constexpr index_t get_num_of_access() const
{
constexpr auto spans = DistributedTensor::get_distributed_spans();
constexpr auto u =
static_uford<typename decltype(spans[number<I>{}])::Impl, decltype(get_y_unpacks())>{};
return u.get_num_of_access() *
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}
.get_num_of_access();
}
template <typename F>
CK_TILE_HOST_DEVICE constexpr void operator()(const F& f) const
{
constexpr auto spans = DistributedTensor::get_distributed_spans();
sweep_tile_uspan(
spans[number<I>{}],
[&](auto... i_idx) {
constexpr auto next_span_idx = make_tuple(make_tuple(i_idx)...);
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}(
f, next_span_idx);
},
get_y_unpacks());
}
template <typename F, index_t i_access>
CK_TILE_HOST_DEVICE constexpr void operator()(const F& f, number<i_access>) const
{
constexpr auto spans = DistributedTensor::get_distributed_spans();
constexpr auto u =
static_uford<typename decltype(spans[number<I>{}])::Impl, decltype(get_y_unpacks())>{};
constexpr auto access_stride =
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}
.get_num_of_access();
constexpr auto curr_i_access = number<i_access / access_stride>{};
constexpr auto next_i_access = number<i_access % access_stride>{};
u(
[&](auto... i_idx) {
constexpr auto next_span_idx =
make_tuple(make_tuple(detail::make_tile_distributed_index(i_idx))...);
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}(
f, next_span_idx, next_i_access);
},
curr_i_access);
}
};
} // namespace impl
/*
* Enhanced sweep-tile utility, can control unpacks along each X-dim
* the lambda function argument is the distributed-idx, which can directly
* plugged into the distributed tensor as setter/getter
*
* e.g. below function, y with the type DistributedTensor, r is row scale
*
* // sweep tile 1 by 1
* sweep_tile<DistributedTensor>([&](auto idx) {
* constexpr auto row_id = make_tuple(idx[number<0>{}]);
* y(idx) = y(idx) * r(row_id);
* });
*
* // sweep tile with 2 pixel from last dim each function call
* sweep_tile<DistributedTensor>(
* [&](auto idx_0, auto idx_1) {
* constexpr auto row_id = make_tuple(idx_0[number<0>{}]);
* y(idx_0) = y(idx_0) * r(row_id);
* y(idx_1) = y(idx_1) * r(row_id);
* },
* sequence<1, 2>{});
*
* // sweep tile with 2x2 pixel each function call
* sweep_tile<DistributedTensor>(
* [&](auto idx_00, auto idx_01, auto idx_10, auto idx_11) {
* constexpr auto row_id0 = make_tuple(idx_00[number<0>{}]);
* constexpr auto row_id1 = make_tuple(idx_10[number<0>{}]);
* y(idx_00) = y(idx_00) * r(row_id0);
* y(idx_01) = y(idx_01) * r(row_id0);
* y(idx_10) = y(idx_10) * r(row_id1);
* y(idx_11) = y(idx_11) * r(row_id1);
* },
* sequence<2, 2>{});
*
* TODO: do we need constexpr? lambda function could be non-constexpr
*/
template <typename DistributedTensor,
typename F,
typename UnpacksPerXDim =
typename uniform_sequence_gen<DistributedTensor::get_num_of_dimension(), 1>::type>
CK_TILE_HOST_DEVICE constexpr void sweep_tile(const F& f, UnpacksPerXDim = {})
{
constexpr auto spans = DistributedTensor::get_distributed_spans();
impl::sweep_tile_impl_0<DistributedTensor,
UnpacksPerXDim,
typename arithmetic_sequence_gen<0, spans.size(), 1>::type>{}(f);
}
template <typename DistributedTensor,
typename F,
typename UnpacksPerXDim =
typename uniform_sequence_gen<DistributedTensor::get_num_of_dimension(), 1>::type>
CK_TILE_HOST_DEVICE constexpr void
sweep_tile(const DistributedTensor&, const F& f, UnpacksPerXDim = {})
{
sweep_tile<DistributedTensor, F, UnpacksPerXDim>(f, UnpacksPerXDim{});
}
/*
* construct a sweep tile instance, which support issue the lambda one by one
* Note that this struct will hold the lambda functor, but will not hold the distributed tensor
* the functionality is the same as sweep_tile()
*/
template <typename DistributedTensor_,
typename F_,
typename UnpacksPerXDim_ =
typename uniform_sequence_gen<DistributedTensor_::get_num_of_dimension(), 1>::type>
struct tile_sweeper
{
using DistributedTensor = remove_cvref_t<DistributedTensor_>;
using F = remove_cvref_t<F_>;
using UnpacksPerXDim = remove_cvref_t<UnpacksPerXDim_>;
CK_TILE_HOST_DEVICE tile_sweeper(const F& f_, UnpacksPerXDim = {}) : f(f_) {}
CK_TILE_HOST_DEVICE tile_sweeper(const DistributedTensor&, const F& f_, UnpacksPerXDim = {})
: f(f_)
{
}
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_access()
{
constexpr auto spans = DistributedTensor::get_distributed_spans();
constexpr auto tmp =
impl::sweep_tile_impl_0<DistributedTensor,
UnpacksPerXDim,
typename arithmetic_sequence_gen<0, spans.size(), 1>::type>{};
return tmp.get_num_of_access();
}
CK_TILE_HOST_DEVICE void operator()() const
{
sweep_tile<DistributedTensor>(f, UnpacksPerXDim{});
}
template <index_t i_access>
CK_TILE_HOST_DEVICE void operator()(number<i_access>) const
{
constexpr auto spans = DistributedTensor::get_distributed_spans();
impl::sweep_tile_impl_0<DistributedTensor,
UnpacksPerXDim,
typename arithmetic_sequence_gen<0, spans.size(), 1>::type>{}(
f, number<i_access>{});
}
F f;
};
// partial deduction is not allowed
// template <typename T, typename F, typename U>
// CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const F&, U = {})->tile_sweeper<T, F, U>;
// deduction guide
template <typename T,
typename F,
typename U = typename uniform_sequence_gen<T::get_num_of_dimension(), 1>::type>
CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const T&, const F&, U = {})->tile_sweeper<T, F, U>;
} // namespace ck_tile
......@@ -17,6 +17,14 @@
namespace ck_tile {
namespace detail {
template <typename Distribution>
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
{
return Distribution::_get_partition_index();
}
} // namespace detail
// distributed span
template <index_t... PartialHsLengths>
struct tile_distributed_span
......@@ -83,6 +91,21 @@ struct tile_distribution
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_p() { return NDimP; }
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_r() { return NDimR; }
CK_TILE_HOST_DEVICE static auto _get_partition_index()
{
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
if constexpr(NDimP == 1)
{
return array<index_t, 1>{get_lane_id()};
}
else if constexpr(NDimP == 2)
{
return array<index_t, 2>{get_warp_id(), get_lane_id()};
}
}
CK_TILE_HOST_DEVICE static constexpr auto get_lengths()
{
#if 0
......@@ -149,6 +172,16 @@ struct tile_distribution
}
#endif
template <typename PartitionIndex = decltype(_get_partition_index())>
CK_TILE_HOST_DEVICE auto
calculate_index(const PartitionIndex& ps_idx = _get_partition_index()) const
{
const auto ps_ys_idx = container_concat(ps_idx, array<index_t, NDimY>{0});
const auto window_adaptor_thread_coord_tmp =
make_tensor_adaptor_coordinate(ps_ys_to_xs_, ps_ys_idx);
return window_adaptor_thread_coord_tmp.get_bottom_index();
}
CK_TILE_HOST_DEVICE static constexpr auto get_distributed_spans()
{
constexpr auto distributed_spans_impl = DstrEncode::detail::distributed_spans_lengthss_;
......@@ -421,6 +454,7 @@ struct tile_distribution_detail
} // namespace detail
#if 0
// this returns a constexpr tile_distribution
template <typename StaticTileDistributionEncoding_>
CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_)
......@@ -457,6 +491,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistribution
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
}
#endif
// this returns a static tile_distribution
template <typename StaticTileDistributionEncoding_>
......@@ -499,129 +534,6 @@ CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistr
//***********************************************************************************
namespace detail {
template <typename Distribution>
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
{
// only support warp-tile and block-tile
static_assert(Distribution::NDimP == 1 or Distribution::NDimP == 2, "wrong!");
if constexpr(Distribution::NDimP == 1)
{
return array<index_t, 1>{get_lane_id()};
}
else if constexpr(Distribution::NDimP == 2)
{
return array<index_t, 2>{get_warp_id(), get_lane_id()};
}
}
template <typename, typename, typename, index_t>
struct reverse_slice_sequence_impl;
template <index_t x,
index_t... xs,
index_t m,
index_t... ms,
index_t id,
index_t... ids,
index_t SliceSize>
struct reverse_slice_sequence_impl<sequence<x, xs...>,
sequence<m, ms...>,
sequence<id, ids...>,
SliceSize>
{
using old_scan =
reverse_slice_sequence_impl<sequence<xs...>, sequence<ms...>, sequence<ids...>, SliceSize>;
static constexpr auto slice_size = old_scan::remaining_slice_sizes::front().value;
static constexpr auto slice_length =
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
using dim_lengths =
typename sequence_merge<sequence<slice_length>, typename old_scan::dim_lengths>::type;
using dim_slices =
typename sequence_merge<sequence<x / slice_length>, typename old_scan::dim_slices>::type;
using remaining_slice_sizes = typename sequence_merge<
std::conditional_t<m, sequence<slice_size / slice_length>, sequence<slice_size>>,
typename old_scan::remaining_slice_sizes>::type;
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.front().value == 1;
static constexpr index_t _split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
static constexpr index_t _split_idx =
std::conditional_t<_split_flag, number<id>, number<0>>::value;
static constexpr index_t split_flag = _split_flag || old_scan::split_flag;
static constexpr index_t split_idx = std::
conditional_t<old_scan::split_flag, number<old_scan::split_idx>, number<_split_idx>>::value;
};
template <index_t x, index_t m, index_t id, index_t SliceSize>
struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, SliceSize>
{
static constexpr auto slice_size = SliceSize;
static constexpr auto slice_length =
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
using dim_lengths = sequence<slice_length>;
using dim_slices = sequence<x / slice_length>;
using remaining_slice_sizes =
std::conditional_t<m, sequence<slice_size / slice_length>, sequence<slice_size>>;
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.front().value == 1;
static constexpr index_t split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
static constexpr index_t split_idx =
std::conditional_t<split_flag, number<id>, number<0>>::value;
};
// clang-format off
// input a sequence(with optional mask), and the SliceSize : size per slice
// output the sequence each slice, and number of slices
//
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2
// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1
//
// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0
// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0
// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1
// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2
// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2
// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2
//
// <4, 2, 1, 4, 2> / 4 ->
// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0
//
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
// have split slices (right -> left)
// or the first index that sliced length is different from the original length
// clang-format on
template <typename Seq,
index_t SliceSize,
typename Mask = typename uniform_sequence_gen<Seq::size(), 1>::type>
constexpr auto reverse_slice_sequence(Seq,
number<SliceSize>,
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
{
static_assert(Seq::size() == Mask::size());
using sliced_type =
reverse_slice_sequence_impl<Seq,
Mask,
typename arithmetic_sequence_gen<0, Seq::size(), 1>::type,
SliceSize>;
static_assert(sliced_type::remaining_slice_sizes::front().value == 1,
"can not evenly divide this sequence, please check");
return make_tuple(typename sliced_type::dim_lengths{},
typename sliced_type::dim_slices{},
number<sliced_type::split_idx>{});
}
//
// slice tensor from x_dim, result in split in y_dim, not p_dim.
// We don't support slice cross p_dim (aka, slice different threads)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
// This file should not be included inside tuple.hpp!
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include <stdint.h>
#include <utility>
namespace ck_tile {
namespace detail {
// RemainLengths: sequence<...>
// Orders: sequence<...>
template <class RemainLengths, class RamainUnpacks, class Orders>
struct static_uford_impl
{
CK_TILE_HOST_DEVICE constexpr static_uford_impl()
{
static_assert(RemainLengths::size() > 0, "wrong! should not get here");
static_assert(RamainUnpacks::size() > 0, "wrong! should not get here");
}
template <class F, class CurrentUnpackIds>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentUnpackIds) const
{
constexpr index_t pack_len = RamainUnpacks::front();
static_for<0, RemainLengths::front(), pack_len>{}([=](auto I) {
constexpr auto new_pack = generate_tuple(
[&](auto idx_) {
constexpr auto i_new_pack = number<I + idx_ % pack_len>{};
constexpr auto i_pre_pack = number<idx_ / pack_len>{};
return CurrentUnpackIds{}.at(i_pre_pack).push_back(i_new_pack);
},
number<CurrentUnpackIds::size() * pack_len>{});
static_uford_impl<decltype(RemainLengths::pop_front()),
decltype(RamainUnpacks::pop_front()),
Orders>{}(f, new_pack);
});
}
};
template <class Orders>
struct static_uford_impl<sequence<>, sequence<>, Orders>
{
template <class F, class PackedId>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, PackedId) const
{
constexpr auto origin_packs = transform_tuples(
[](auto pack_) { return decltype(pack_)::reorder_old_to_new(Orders{}); }, PackedId{});
unpack(f, origin_packs);
}
};
template <class RemainLengths, class RamainUnpacks, class Orders>
struct static_uford_one_shot_impl
{
template <class F, class CurrentUnpackIds, index_t current_acc>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentUnpackIds, number<current_acc>) const
{
constexpr auto r_lens_stride =
reverse_exclusive_scan_sequence(RemainLengths{}, multiplies{}, number<1>{});
constexpr auto r_upks_stride =
reverse_exclusive_scan_sequence(RamainUnpacks{}, multiplies{}, number<1>{});
constexpr index_t current_stride = r_lens_stride.front() / r_upks_stride.front();
constexpr index_t pack_len = RamainUnpacks::front();
constexpr index_t current_idx = (current_acc / current_stride) * pack_len;
constexpr auto new_pack = generate_tuple(
[&](auto idx_) {
constexpr auto i_new_pack = number<current_idx + idx_ % pack_len>{};
constexpr auto i_pre_pack = number<idx_ / pack_len>{};
return CurrentUnpackIds{}.at(i_pre_pack).push_back(i_new_pack);
},
number<CurrentUnpackIds::size() * pack_len>{});
static_uford_one_shot_impl<decltype(RemainLengths::pop_front()),
decltype(RamainUnpacks::pop_front()),
Orders>{}(f, new_pack, number<current_acc % current_stride>{});
}
};
template <class Orders>
struct static_uford_one_shot_impl<sequence<>, sequence<>, Orders>
{
template <class F, class PackedId, index_t current_acc>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, PackedId, number<current_acc>) const
{
constexpr auto origin_packs = transform_tuples(
[](auto pack_) { return decltype(pack_)::reorder_old_to_new(Orders{}); }, PackedId{});
unpack(f, origin_packs);
}
};
} // namespace detail
// TODO: we may unify static_ford/static_uford in the future
//
// loop over nd space(sequence) with packs
// you must make sure the function passed in has same number of argument
//
// e.g.
// Lengths=seq<2, 3, 4>, Unpacks=<1, 1, 2>
// static_uford<Lengths, Unpacks>{}([&](auto i_0, auto i_1){}); // require 2 args(packs)
//
// loop #0, i_0=seq<0, 0, 0>, i_1=<0, 0, 1>
// loop #1, i_0=seq<0, 0, 2>, i_1=<0, 0, 3>
// loop #2, i_0=seq<0, 1, 0>, i_1=<0, 1, 1>
// loop #3, i_0=seq<0, 1, 2>, i_1=<0, 1, 3>
// loop #4, i_0=seq<0, 2, 0>, i_1=<0, 2, 1>
// loop #5, i_0=seq<0, 2, 2>, i_1=<0, 2, 3>
// loop #6, i_0=seq<1, 0, 0>, i_1=<1, 0, 1>
// ...
template <class Lengths,
class Unpacks = typename uniform_sequence_gen<Lengths::size(), 1>::type,
class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
struct static_uford
{
static constexpr index_t num_packs = reduce_on_sequence(Unpacks{}, multiplies{}, number<1>{});
CK_TILE_HOST_DEVICE constexpr static_uford()
{
static_assert(Lengths::size() > 0, "wrong! Lengths is empty");
static_assert(Lengths::size() == Unpacks::size(), "wrong! inconsistent size");
static_assert(Lengths::size() == Orders::size(), "wrong! inconsistent size");
static_for<0, Lengths::size(), 1>{}(
[&](auto i) { static_assert(Lengths{}.at(i) % Unpacks{}.at(i) == 0); });
}
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_access()
{
using L_ = decltype(Lengths{} / Unpacks{});
return reduce_on_sequence(L_{}, multiplies{}, number<1>{});
}
// F signature: F(sequence<...> multi_id...)
// multi_id is the unordered multi-index
template <class F>
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
{
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
constexpr auto ordered_unpacks = Unpacks::reorder_new_to_old(Orders{});
detail::static_uford_impl<decltype(ordered_lengths), decltype(ordered_unpacks), Orders>{}(
f, make_tuple(sequence<>{}));
}
// this version is friendly for issue function one by one
template <class F, index_t i_access>
CK_TILE_HOST_DEVICE constexpr void operator()(F f, number<i_access>) const
{
static_assert(i_access < get_num_of_access());
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
constexpr auto ordered_unpacks = Unpacks::reorder_new_to_old(Orders{});
detail::static_uford_one_shot_impl<decltype(ordered_lengths),
decltype(ordered_unpacks),
Orders>{}(
f, make_tuple(sequence<>{}), number<i_access>{});
}
};
} // namespace ck_tile
......@@ -21,7 +21,7 @@
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/stream_config.hpp"
......
......@@ -12,6 +12,16 @@ namespace detail {
template <index_t N>
struct log2;
template <>
struct log2<4> : std::integral_constant<index_t, 2>
{
};
template <>
struct log2<8> : std::integral_constant<index_t, 3>
{
};
template <>
struct log2<16> : std::integral_constant<index_t, 4>
{
......@@ -72,18 +82,18 @@ struct BlockFmhaFwdSplitKVCombinePipeline
{
if constexpr(kHeadDimV <= 32)
{
constexpr std::array<int, 4> occupancy{3, 3, 3, 1};
return occupancy[detail::log2<kMaxSplits>::value - 4];
constexpr std::array occupancy{3, 3, 3, 3, 3, 1};
return occupancy[detail::log2<kMaxSplits>::value - 2];
}
else if constexpr(kHeadDimV <= 128)
{
constexpr std::array<int, 4> occupancy{3, 3, 2, 1};
return occupancy[detail::log2<kMaxSplits>::value - 4];
constexpr std::array occupancy{3, 3, 3, 3, 2, 1};
return occupancy[detail::log2<kMaxSplits>::value - 2];
}
else if constexpr(kHeadDimV <= 256)
{
constexpr std::array<int, 4> occupancy{2, 2, 2, 1};
return occupancy[detail::log2<kMaxSplits>::value - 4];
constexpr std::array occupancy{2, 2, 2, 2, 2, 1};
return occupancy[detail::log2<kMaxSplits>::value - 2];
}
}
}();
......@@ -138,9 +148,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline
auto lse_accum = make_static_distributed_tensor<LSEDataType>(
Policy::template MakeLSEaccRegTileDistribution<Problem>());
// copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, max(kMaxSplits, warp_size)])
// this will extend the distributed tensor width so that each thread in wave have data to
// reduce.
// copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, kMaxSplits])
// and fill up -INF values outside the [kM0, num_splits] region.
{
constexpr auto spans = decltype(lse_accum)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
......
......@@ -10,11 +10,26 @@ namespace ck_tile {
struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
template <index_t BlockSize, index_t M, index_t N, typename DataType>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeForTile()
{
constexpr index_t PixelsPerThread = (M * N) / BlockSize;
static_assert(0 < PixelsPerThread);
constexpr index_t MaxNPerThread = 16 / sizeof(DataType);
constexpr index_t NPerThread = min(MaxNPerThread, PixelsPerThread);
return NPerThread;
}
// alignment for dram lse tile (shape=[kMaxSplits, kM0])
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentLSE()
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
return 16 / sizeof(LSEDataType);
return GetVectorSizeForTile<Problem::kBlockSize,
Problem::kMaxSplits,
Problem::kM0,
typename Problem::LSEDataType>();
}
template <typename Problem>
......@@ -47,29 +62,31 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
MakeLSEaccLdsBlockDescriptor<Problem>().get_element_space_size();
}
// shape=[kMaxSplits, kM0]
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccDramTileDistribution()
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNumWarps = Problem::kNumWarps;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t NPerThread = 16 / sizeof(LSEDataType);
constexpr index_t NThreads = kNPerBlock / NPerThread;
constexpr index_t NPerThread =
GetVectorSizeForTile<kBlockSize, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr index_t NThreads = kNPerBlock / NPerThread;
constexpr index_t MThreadsPerWarp = get_warp_size() / NThreads;
constexpr index_t TotalWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (TotalWarps * MThreadsPerWarp);
constexpr index_t MPerThread = kMPerBlock / (kNumWarps * MThreadsPerWarp);
static_assert(NThreads * NPerThread == kNPerBlock);
static_assert(MPerThread * TotalWarps * MThreadsPerWarp == kMPerBlock);
static_assert(MPerThread * kNumWarps * MThreadsPerWarp == kMPerBlock);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<MPerThread, TotalWarps, MThreadsPerWarp>,
tuple<sequence<MPerThread, kNumWarps, MThreadsPerWarp>,
sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
......@@ -77,15 +94,18 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
sequence<0, 1>>{});
}
// 3d + padding, [kMaxSplits, kM0]
// 3d + padding, shape=[kMaxSplits, kM0]
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsStoreBlockDescriptor()
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t NPack = 16 / sizeof(LSEDataType);
constexpr index_t NPack =
GetVectorSizeForTile<kBlockSize, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
......@@ -103,15 +123,18 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
return lse_acc_lds_block_desc;
}
// 3d + padding, [kM0, kMaxSplits]
// 3d + padding, shape=[kM0, kMaxSplits]
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsBlockDescriptor()
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t NPack = 16 / sizeof(LSEDataType);
constexpr index_t NPack =
GetVectorSizeForTile<kBlockSize, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
......@@ -134,26 +157,28 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = max(Problem::kMaxSplits, get_warp_size());
constexpr index_t kNPerBlock = Problem::kMaxSplits;
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t NThreads = get_warp_size();
constexpr index_t NThreads = 4;
constexpr index_t NPerThread = kNPerBlock / NThreads;
constexpr index_t MThreads = kBlockSize / NThreads;
constexpr index_t MPerThread = kMPerBlock / MThreads;
constexpr index_t MThreads = kBlockSize / NThreads;
constexpr index_t MPerThread = kMPerBlock / MThreads;
constexpr index_t MWarps = kBlockSize / get_warp_size();
constexpr index_t MThreadPerWarp = get_warp_size() / NThreads;
static_assert(NThreads * NPerThread == kNPerBlock);
static_assert(MThreads * MPerThread == kMPerBlock);
static_assert(MWarps * MThreadPerWarp * MPerThread == kMPerBlock);
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<MThreads, MPerThread>, sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<2>>,
tuple<sequence<0>, sequence<0>>,
tuple<sequence<MWarps, MThreadPerWarp, MPerThread>, sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<2, 1>>,
tuple<sequence<0>, sequence<0, 1>>,
sequence<1, 2>,
sequence<1, 1>>{});
sequence<2, 1>>{});
}
template <typename Problem>
......
......@@ -115,7 +115,8 @@ struct BlockFmhaSplitKVCombinePipelineProblem
using ODataType = remove_cvref_t<ODataType_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = 256;
static constexpr index_t kNumWarps = kM0_ / (get_warp_size() / 4);
static constexpr index_t kBlockSize = kNumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr index_t kHeadDimV = HeadDimV_;
......
......@@ -88,22 +88,33 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr auto warp_gemm = []() {
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
static_assert(WarpGemmM == 16 || WarpGemmM == 32);
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
else // WarpGemmM == 16
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
std::is_same_v<typename Problem::KDataType, bf16_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
else // WarpGemmM == 16
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
std::is_same_v<typename Problem::KDataType, fp8_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
static_assert(WarpGemmM == 32);
// TODO: hard coded here. Otherwise, it may incorrect result
constexpr index_t swizzle_factor = 4;
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
......
......@@ -23,12 +23,12 @@
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
......
......@@ -4,6 +4,9 @@
#pragma once
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp"
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
......@@ -5,37 +5,57 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/welford/thread/thread_welford.hpp"
#include "ck_tile/ops/welford/warp/warp_welford.hpp"
namespace ck_tile {
// TODO: Extract some type to wrapper class
template <typename Problem_>
struct Layernorm2dFwd
// host side args
struct Layernorm2dFwdHostArgs
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
const void* p_x;
const void* p_gamma;
const void* p_beta;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
void* p_y;
void* p_mean;
void* p_invStd;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>;
static constexpr bool kSaveMean = !std::is_same_v<MeanDataType, ck_tile::null_type>;
static constexpr bool kSaveInvStd = !std::is_same_v<InvStdDataType, ck_tile::null_type>;
float epsilon;
static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock;
static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
index_t m;
index_t n;
index_t stride; // row_stride
};
static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp;
static constexpr ck_tile::index_t kNPerThread = Problem::BlockShape::kNPerThread;
// TODO: Extract some type to wrapper class
template <typename Pipeline_>
struct Layernorm2dFwd
{
using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = typename Pipeline::Problem;
using XDataType = remove_cvref_t<typename Problem::XDataType>;
using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = remove_cvref_t<typename Problem::YDataType>;
using MeanDataType = remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = remove_cvref_t<typename Problem::InvStdDataType>;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, null_type>;
static constexpr bool kSaveMeanInvStd = Problem::kSaveMeanInvStd;
static constexpr bool kSaveMean = Problem::kSaveMeanInvStd;
static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd;
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kTwoPass = Problem::kTwoPass;
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
......@@ -52,400 +72,177 @@ struct Layernorm2dFwd
float epsilon;
ck_tile::index_t M;
ck_tile::index_t N;
index_t m;
index_t n;
index_t stride; // row_stride
};
using Hargs = Layernorm2dFwdHostArgs;
CK_TILE_HOST static constexpr Kargs MakeKargs(const void* p_x,
const void* p_gamma,
const void* p_beta,
void* p_y,
void* p_mean,
void* p_invStd,
float epsilon,
ck_tile::index_t M,
ck_tile::index_t N)
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{
return Kargs{p_x, p_gamma, p_beta, p_y, p_mean, p_invStd, epsilon, M, N};
return Kargs{hargs.p_x,
hargs.p_gamma,
hargs.p_beta,
hargs.p_y,
hargs.p_mean,
hargs.p_invStd,
hargs.epsilon,
hargs.m,
hargs.n,
hargs.stride};
}
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t M) { return M / kMPerBlock; }
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::kBlockSize; }
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::kMWarpPerBlock, S::kMThreadPerWarp, S::kMPerThread>,
sequence<S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<1, 2>,
sequence<2, 2>>{});
return (hargs.m + Block_M - 1) / Block_M;
}
CK_TILE_DEVICE static constexpr auto MakeGammaBetaBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<S::kMWarpPerBlock, S::kMThreadPerWarp>,
tuple<sequence<S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<1>,
sequence<2>>{});
}
CK_TILE_DEVICE static int GetWelfordMaxCount(int N)
{
constexpr ck_tile::index_t kNThreadPerBlock = kNPerBlock / kNPerThread;
int thread_id_n = get_thread_id() % kNThreadPerBlock;
int max_count =
__builtin_amdgcn_readfirstlane(N < kNPerBlock ? 0 : kNPerThread * (N / kNPerBlock));
int n_per_block_tail_loop =
__builtin_amdgcn_readfirstlane(N - max_count * kNThreadPerBlock);
if(n_per_block_tail_loop > 0)
{
int thread_max_n = (thread_id_n + 1) * kNPerThread;
int delta = thread_max_n - n_per_block_tail_loop;
delta = clamp(thread_max_n - n_per_block_tail_loop, 0, kNPerThread);
max_count += kNPerThread - delta;
}
return max_count;
}
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
template <typename DistributedTensor>
CK_TILE_DEVICE static auto InvSqrt(const DistributedTensor& in_dstr_tensor,
const ComputeDataType epsilon)
{
// TODO: Investigate fast inverse square root algorithm with epsilon
constexpr auto spans = DistributedTensor::get_distributed_spans();
DistributedTensor out_dstr_tensor;
// clang-format off
template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
// clang-format on
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
out_dstr_tensor(i_idx) = type_convert<ComputeDataType>(1.0f) /
ck_tile::sqrt(in_dstr_tensor[i_idx] + epsilon);
});
return out_dstr_tensor;
}
// in byte
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
template <typename XBlockWindow,
typename GammaBlockWindow,
typename BetaBlockWindow,
typename YBlockWindow,
typename MeanBlockWindow,
typename InvStdBlockWindow,
bool Cond = (kHasGamma && kHasBeta)>
CK_TILE_DEVICE std::enable_if_t<Cond>
TwoPassLayernorm2dFwd(XBlockWindow& x_block_window,
GammaBlockWindow& gamma_block_window,
BetaBlockWindow& beta_block_window,
YBlockWindow& y_block_window,
MeanBlockWindow& mean_block_window,
InvStdBlockWindow& inv_std_block_window,
ComputeDataType epsilon,
ck_tile::index_t N) const
CK_TILE_HOST static std::string GetName()
{
// TODO - Optimize tail loop to reduce move_tile_window()
index_t num_n_tile_iteration =
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, kNPerBlock));
int welford_max_count = GetWelfordMaxCount(N);
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
using XTensorType = decltype(load_tile(x_block_window));
auto mean_compute_block_tensor =
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
auto var_compute_block_tensor =
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
clear_tile(mean_compute_block_tensor);
clear_tile(var_compute_block_tensor);
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto x_block_tensor = load_tile(x_block_window);
thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor);
move_tile_window(x_block_window, {0, kNPerBlock});
}
// TODO: support cross warp Welford
WarpMergeWelford<ComputeDataType, true>{}(
mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_);
auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon);
if constexpr(kSaveMean)
store_tile(mean_block_window, cast_tile<MeanDataType>(mean_compute_block_tensor));
if constexpr(kSaveInvStd)
store_tile(inv_std_block_window,
cast_tile<InvStdDataType>(inv_std_compute_block_tensor));
// reverse read x to reuse cache
ck_tile::index_t stride_to_right_most_window =
N % kNPerBlock == 0 ? N - kNPerBlock : N - N % kNPerBlock;
move_tile_window(x_block_window, {0, -kNPerBlock});
move_tile_window(gamma_block_window, {stride_to_right_most_window});
move_tile_window(beta_block_window, {stride_to_right_most_window});
move_tile_window(y_block_window, {0, stride_to_right_most_window});
// Normalization
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto x_block_tensor = load_tile(x_block_window);
const auto gamma_block_tensor = load_tile(gamma_block_window);
const auto beta_block_tensor = load_tile(beta_block_window);
constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans();
auto y_block_tensor =
make_static_distributed_tensor<YDataType>(x_block_tensor.get_tile_distribution());
sweep_tile_span(x_spans[I1], [&](auto idx1) {
constexpr auto j_idx = make_tuple(idx1);
const auto gamma = type_convert<ComputeDataType>(gamma_block_tensor[j_idx]);
const auto beta = type_convert<ComputeDataType>(beta_block_tensor[j_idx]);
sweep_tile_span(x_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto mean = mean_compute_block_tensor[i_idx];
const auto inv_std = inv_std_compute_block_tensor[i_idx];
const auto x = type_convert<ComputeDataType>(x_block_tensor[i_j_idx]);
auto y = (x - mean) * inv_std * gamma + beta;
y_block_tensor(i_j_idx) = type_convert<YDataType>(y);
});
});
store_tile(y_block_window, y_block_tensor);
move_tile_window(x_block_window, {0, -kNPerBlock});
move_tile_window(gamma_block_window, {-kNPerBlock});
move_tile_window(beta_block_window, {-kNPerBlock});
move_tile_window(y_block_window, {0, -kNPerBlock});
}
}
template <typename XBlockWindow,
typename GammaBlockWindow,
typename BetaBlockWindow,
typename YBlockWindow,
typename MeanBlockWindow,
typename InvStdBlockWindow,
bool Cond = (kHasGamma && kHasBeta)>
CK_TILE_DEVICE std::enable_if_t<Cond>
OnePassLayernorm2dFwd(XBlockWindow& x_block_window,
GammaBlockWindow& gamma_block_window,
BetaBlockWindow& beta_block_window,
YBlockWindow& y_block_window,
MeanBlockWindow& mean_block_window,
InvStdBlockWindow& inv_std_block_window,
ComputeDataType epsilon,
ck_tile::index_t N) const
{
int welford_max_count = GetWelfordMaxCount(N);
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
using XTensorType = decltype(load_tile(x_block_window));
auto mean_compute_block_tensor =
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
auto var_compute_block_tensor =
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
clear_tile(mean_compute_block_tensor);
clear_tile(var_compute_block_tensor);
const auto x_block_tensor = load_tile(x_block_window);
thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor);
// TODO: support cross warp Welford
WarpMergeWelford<ComputeDataType, true>{}(
mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_);
auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon);
if constexpr(kSaveMean)
store_tile(mean_block_window, cast_tile<MeanDataType>(mean_compute_block_tensor));
if constexpr(kSaveInvStd)
store_tile(inv_std_block_window,
cast_tile<InvStdDataType>(inv_std_compute_block_tensor));
// normalize
const auto gamma_block_tensor = load_tile(gamma_block_window);
const auto beta_block_tensor = load_tile(beta_block_window);
constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans();
auto y_block_tensor =
make_static_distributed_tensor<YDataType>(x_block_tensor.get_tile_distribution());
sweep_tile_span(x_spans[I1], [&](auto idx1) {
constexpr auto j_idx = make_tuple(idx1);
const auto gamma = type_convert<ComputeDataType>(gamma_block_tensor[j_idx]);
const auto beta = type_convert<ComputeDataType>(beta_block_tensor[j_idx]);
sweep_tile_span(x_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto mean = mean_compute_block_tensor[i_idx];
const auto inv_std = inv_std_compute_block_tensor[i_idx];
const auto x = type_convert<ComputeDataType>(x_block_tensor[i_j_idx]);
auto y = (x - mean) * inv_std * gamma + beta;
y_block_tensor(i_j_idx) = type_convert<YDataType>(y);
});
});
store_tile(y_block_window, y_block_tensor);
// clang-format off
using S_ = typename Problem::BlockShape;
auto surfix = [&] () {
std::string n;
if (kPadN) n += "_pn";
if (kSaveMeanInvStd) n += "_mv";
if (kTwoPass) n += "_2p";
return n; }();
#define _SS_ std::string
#define _TS_ std::to_string
return _SS_("layernorm2d_fwd_") + _SS_(t2s<XDataType>::name) + "_" +
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
_SS_(Pipeline::name) + surfix;
#undef _SS_
#undef _TS_
// clang-format on
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
const auto x_m_n = [&]() {
const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
const auto iM = get_block_id() * Block_M;
const auto x_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(kargs.p_x),
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.N, 1),
number<kNPerThread>{},
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1),
number<Vector_N>{},
number<1>{});
return pad_tensor_view(x_dram_naive,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
sequence<kPadM, kPadN>{});
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will
// check the max count dynamically
const auto tmp2_ = pad_tensor_view(
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<false, false>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}();
const auto gamma_n = [&]() {
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
const auto gamma_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const GammaDataType*>(kargs.p_gamma),
make_tuple(kargs.N),
make_tuple(kargs.n),
make_tuple(1),
number<kNPerThread>{},
number<Vector_N>{},
number<1>{});
return pad_tensor_view(
gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<kPadN>{});
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<false>{});
return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
}();
const auto beta_n = [&]() {
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
const auto beta_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const BetaDataType*>(kargs.p_beta),
make_tuple(kargs.N),
make_tuple(kargs.n),
make_tuple(1),
number<kNPerThread>{},
number<Vector_N>{},
number<1>{});
return pad_tensor_view(
gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<kPadN>{});
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<false>{});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {0});
}();
const auto iM = get_block_id() * kMPerBlock;
constexpr auto xDstr = MakeXBlockTileDistribution();
auto x_block_window = make_tile_window(
x_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0}, xDstr);
const auto y_m_n = [&]() {
const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>(
auto y_window = [&]() {
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<YDataType*>(kargs.p_y),
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.N, 1),
number<kNPerThread>{},
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1),
number<Vector_N>{},
number<1>{});
return pad_tensor_view(y_dram_naive,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
sequence<kPadM, kPadN>{});
auto tmp2_ = pad_tensor_view(
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<kPadM, kPadN>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}();
auto y_block_window = make_tile_window(
y_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0});
constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution();
constexpr auto betaDstr = gammaDstr;
auto gamma_block_window =
make_tile_window(gamma_n, make_tuple(number<kNPerBlock>{}), {0}, gammaDstr);
auto beta_block_window = make_tile_window(
beta_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {0}, betaDstr);
auto mean_block_window = [&]() {
auto mean_window = [&]() {
if constexpr(kSaveMean)
{
const auto mean_m = [&]() {
const auto mean_dram_naive =
make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<MeanDataType*>(kargs.p_mean),
make_tuple(kargs.M),
make_tuple(kargs.m),
number<1>{});
return pad_tensor_view(
mean_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{});
mean_dram_naive, make_tuple(number<Block_M>{}), sequence<kPadM>{});
}();
return make_tile_window(mean_m, make_tuple(number<kMPerBlock>{}), {iM});
return make_tile_window(mean_m, make_tuple(number<Block_M>{}), {iM});
}
else
return make_null_tile_window(make_tuple(number<kMPerBlock>{}));
return make_null_tile_window(make_tuple(number<Block_M>{}));
}();
auto inv_std_block_window = [&]() {
auto inv_std_window = [&]() {
if constexpr(kSaveInvStd)
{
const auto inv_std_m = [&]() {
const auto inv_std_dram_naive =
make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<InvStdDataType*>(kargs.p_invStd),
make_tuple(kargs.M),
make_tuple(kargs.m),
number<1>{});
return pad_tensor_view(
inv_std_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{});
inv_std_dram_naive, make_tuple(number<Block_M>{}), sequence<kPadM>{});
}();
return make_tile_window(inv_std_m, make_tuple(number<kMPerBlock>{}), {iM});
return make_tile_window(inv_std_m, make_tuple(number<Block_M>{}), {iM});
}
else
return make_null_tile_window(make_tuple(number<kMPerBlock>{}));
return make_null_tile_window(make_tuple(number<Block_M>{}));
}();
if(kargs.N <= kNPerBlock)
OnePassLayernorm2dFwd(x_block_window,
gamma_block_window,
beta_block_window,
y_block_window,
mean_block_window,
inv_std_block_window,
static_cast<const ComputeDataType>(kargs.epsilon),
kargs.N);
else
TwoPassLayernorm2dFwd(x_block_window,
gamma_block_window,
beta_block_window,
y_block_window,
mean_block_window,
inv_std_block_window,
static_cast<const ComputeDataType>(kargs.epsilon),
kargs.N);
__shared__ char smem[GetSmemSize()];
Pipeline{}(x_window,
gamma_window,
beta_window,
y_window,
mean_window,
inv_std_window,
static_cast<const ComputeDataType>(kargs.epsilon),
kargs.n,
smem);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment