Commit 54d3e2f1 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into ck_tile/moe

parents 199f7f71 b8addae2
...@@ -155,7 +155,12 @@ struct HostTensorDescriptor ...@@ -155,7 +155,12 @@ struct HostTensorDescriptor
return space; return space;
} }
std::size_t get_length(std::size_t dim) const { return mLens[dim]; }
const std::vector<std::size_t>& get_lengths() const { return mLens; } const std::vector<std::size_t>& get_lengths() const { return mLens; }
std::size_t get_stride(std::size_t dim) const { return mStrides[dim]; }
const std::vector<std::size_t>& get_strides() const { return mStrides; } const std::vector<std::size_t>& get_strides() const { return mStrides; }
template <typename... Is> template <typename... Is>
...@@ -325,8 +330,12 @@ struct HostTensor ...@@ -325,8 +330,12 @@ struct HostTensor
{ {
} }
std::size_t get_length(std::size_t dim) const { return mDesc.get_length(dim); }
decltype(auto) get_lengths() const { return mDesc.get_lengths(); } decltype(auto) get_lengths() const { return mDesc.get_lengths(); }
std::size_t get_stride(std::size_t dim) const { return mDesc.get_stride(dim); }
decltype(auto) get_strides() const { return mDesc.get_strides(); } decltype(auto) get_strides() const { return mDesc.get_strides(); }
std::size_t get_num_of_dimension() const { return mDesc.get_num_of_dimension(); } std::size_t get_num_of_dimension() const { return mDesc.get_num_of_dimension(); }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <cassert>
#include <thread>
namespace ck_tile {
template <typename DataType, typename ComputeDataType = float>
CK_TILE_HOST void reference_batched_rotary_position_embedding(const HostTensor<DataType>& input_bsd,
const HostTensor<DataType>& cos_sd,
const HostTensor<DataType>& sin_sd,
bool interleaved,
HostTensor<DataType>& output_bsd,
bool use_1_row_sin_cos = false)
{
assert(cos_sd.get_num_of_dimension() == 2 && sin_sd.get_num_of_dimension() == 2);
assert(cos_sd.get_length(0) == sin_sd.get_length(0) &&
cos_sd.get_length(1) == sin_sd.get_length(1));
const index_t rotary_dim = cos_sd.get_length(1) * 2;
assert(static_cast<std::size_t>(rotary_dim) <= input_bsd.get_length(2));
output_bsd.ForEach([&](auto& self, auto i) {
const index_t i_d = i[2];
if(rotary_dim <= i_d)
{
self(i) = input_bsd(i);
return;
}
assert(i_d < rotary_dim);
const index_t i_s = i[1];
const index_t i_s_cos_sin = (use_1_row_sin_cos ? 0 : i_s);
const ComputeDataType cos = type_convert<ComputeDataType>(
interleaved ? cos_sd(i_s_cos_sin, i_d / 2)
: cos_sd(i_s_cos_sin, i_d % cos_sd.get_length(1)));
const ComputeDataType sin = type_convert<ComputeDataType>(
interleaved ? sin_sd(i_s_cos_sin, i_d / 2)
: sin_sd(i_s_cos_sin, i_d % sin_sd.get_length(1)));
const ComputeDataType half_rotated_input = [&] {
const index_t i_b = i[0];
if(interleaved)
{
const bool is_even = (i_d % 2 == 0);
const index_t pos = i_d + (is_even ? 1 : -1);
const ComputeDataType sign = (is_even ? -1 : 1);
return sign * type_convert<ComputeDataType>(input_bsd(i_b, i_s, pos));
}
else
{
const index_t half_rdim = (rotary_dim / 2);
const index_t pos = (i_d + half_rdim) % rotary_dim;
const ComputeDataType sign = (pos < half_rdim ? 1 : -1);
return sign * type_convert<ComputeDataType>(input_bsd(i_b, i_s, pos));
}
}();
ComputeDataType result =
type_convert<ComputeDataType>(input_bsd(i)) * cos + half_rotated_input * sin;
self(i) = type_convert<DataType>(result);
});
}
} // namespace ck_tile
...@@ -7,7 +7,11 @@ ...@@ -7,7 +7,11 @@
#include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp" #include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp" #include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
#include "ck_tile/ops/fmha/block/page_block_navigator.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
...@@ -21,11 +25,11 @@ ...@@ -21,11 +25,11 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
......
...@@ -43,9 +43,12 @@ enum struct AlibiMode ...@@ -43,9 +43,12 @@ enum struct AlibiMode
FROM_BOTTOM_RIGHT = 2, FROM_BOTTOM_RIGHT = 2,
}; };
template <typename DataType, bool RowMajor = true> template <typename DataType, bool RowMajor = true, unsigned LogMaxSadOprndSize = 16>
struct Alibi struct Alibi
{ {
static_assert(1 <= LogMaxSadOprndSize && LogMaxSadOprndSize <= 32,
"for LogMaxSadOprndSize <= 16, we use SAD uint16_t, otherwise, use SAD uint32_t");
// RowMajor here means if pixel within the same thread are along the row, or col // RowMajor here means if pixel within the same thread are along the row, or col
// this may impact the performance of update(), while the result are the same. // this may impact the performance of update(), while the result are the same.
// e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false // e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false
...@@ -79,6 +82,19 @@ struct Alibi ...@@ -79,6 +82,19 @@ struct Alibi
mode = mode_; mode = mode_;
} }
CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) { return sad_u32(x, y, acc); }
CK_TILE_DEVICE uint32_t sad(uint32_t x, uint32_t y, uint32_t acc)
{
if constexpr(LogMaxSadOprndSize <= 16)
{
return sad_u16(
static_cast<uint16_t>(x), static_cast<uint16_t>(y), static_cast<uint16_t>(acc));
}
return sad_u32(x, y, acc);
}
CK_TILE_HOST_DEVICE void update(DataType& pixel, index_t row_idx, index_t col_idx) CK_TILE_HOST_DEVICE void update(DataType& pixel, index_t row_idx, index_t col_idx)
{ {
if constexpr(RowMajor) if constexpr(RowMajor)
...@@ -128,7 +144,7 @@ struct EmptyPositionEncoding ...@@ -128,7 +144,7 @@ struct EmptyPositionEncoding
// can convert from the FA style left/right to our generic coordinate // can convert from the FA style left/right to our generic coordinate
// if left_size < 0 && right_size = 0, it is normal causal mask // if left_size < 0 && right_size = 0, it is normal causal mask
// local is left_size >=0 or right_size >=0 // local is left_size >=0 or right_size >=0
template <typename DataType, bool RowMajor = true> template <typename DataType, bool RowMajor = true, unsigned LogMaxSadOprndSize = 16>
CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope, CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope,
index_t window_left_size, index_t window_left_size,
index_t window_right_size, index_t window_right_size,
...@@ -142,7 +158,7 @@ CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope, ...@@ -142,7 +158,7 @@ CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope,
AlibiMode alibi_mode = AlibiMode alibi_mode =
is_causal ? AlibiMode::VERTICAL is_causal ? AlibiMode::VERTICAL
: static_cast<AlibiMode>(mask_enum) /*either top-left or bottom-right*/; : static_cast<AlibiMode>(mask_enum) /*either top-left or bottom-right*/;
return Alibi<DataType, RowMajor>{slope, y_total, x_total, alibi_mode}; return Alibi<DataType, RowMajor, LogMaxSadOprndSize>{slope, y_total, x_total, alibi_mode};
} }
// https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742 // https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
namespace ck_tile {
// This class is used for codegen pattern matching
enum class RotaryEmbeddingEnum
{
NONE = 0,
INTERLEAVED = 1, // combine dimensions 0 & 1, 2 & 3, etc
HALF_ROTATED = 2, // combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1, etc
};
template <RotaryEmbeddingEnum>
struct RotaryEmbeddingEnumToStr;
template <>
struct RotaryEmbeddingEnumToStr<RotaryEmbeddingEnum::NONE>
{
static constexpr const char* name = "";
};
template <>
struct RotaryEmbeddingEnumToStr<RotaryEmbeddingEnum::INTERLEAVED>
{
static constexpr const char* name = "inter";
};
template <>
struct RotaryEmbeddingEnumToStr<RotaryEmbeddingEnum::HALF_ROTATED>
{
static constexpr const char* name = "half";
};
template <RotaryEmbeddingEnum RotaryEnum, typename ComputeDataType = float>
struct BlockRotaryEmbedding
{
template <typename DistributedTensor,
typename OtherDramBlockWindow,
typename RotaryCosDramBlockWindow,
typename RotarySinDramBlockWindow>
CK_TILE_HOST_DEVICE static void apply(DistributedTensor& tile,
OtherDramBlockWindow other_window,
RotaryCosDramBlockWindow rotary_cos_window,
RotarySinDramBlockWindow rotary_sin_window,
index_t rotary_dim,
index_t thread_end)
{
using DataType = typename remove_cvref_t<DistributedTensor>::DataType;
if constexpr(RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
{
auto rotary_cos_tile = load_tile(rotary_cos_window);
auto rotary_sin_tile = load_tile(rotary_sin_window);
if(thread_end <= rotary_dim)
{
constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size();
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
const auto left = type_convert<ComputeDataType>(tile.thread_buf_[idx]);
const auto right = type_convert<ComputeDataType>(tile.thread_buf_[idx + 1]);
const auto cos =
type_convert<ComputeDataType>(rotary_cos_tile.thread_buf_[idx / 2]);
const auto sin =
type_convert<ComputeDataType>(rotary_sin_tile.thread_buf_[idx / 2]);
tile.thread_buf_[idx] = type_convert<DataType>(left * cos - right * sin);
tile.thread_buf_[idx + 1] = type_convert<DataType>(right * cos + left * sin);
});
}
}
else if constexpr(RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
{
if(thread_end <= rotary_dim)
{
const bool is_left = (thread_end <= (rotary_dim / 2));
move_tile_window(other_window, {0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
auto other_tile = load_tile(other_window);
move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)});
auto rotary_cos_tile = load_tile(rotary_cos_window);
move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)});
auto rotary_sin_tile = load_tile(rotary_sin_window);
constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size();
static_for<0, thread_buffer_size, 1>{}([&](auto idx) {
const auto curr = type_convert<ComputeDataType>(tile.thread_buf_[idx]);
const auto other = type_convert<ComputeDataType>(other_tile.thread_buf_[idx]);
const auto cos =
type_convert<ComputeDataType>(rotary_cos_tile.thread_buf_[idx]);
const auto sin =
type_convert<ComputeDataType>(rotary_sin_tile.thread_buf_[idx]);
tile.thread_buf_[idx] =
type_convert<DataType>(curr * cos + other * (is_left ? -sin : sin));
});
}
}
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
namespace ck_tile {
// assume that we have only 1 page-block/tensor view
template <typename TensorView>
struct TrivialPageBlockNavigator
{
using DataType = typename TensorView::DataType;
using WindowOrigin = multi_index<2>;
CK_TILE_HOST_DEVICE constexpr TrivialPageBlockNavigator(const TensorView& tensor_view_)
: tensor_view(tensor_view_)
{
}
template <typename WindowLengths>
CK_TILE_HOST_DEVICE constexpr auto make_tile_window(const WindowLengths& window_lengths,
const WindowOrigin& window_origin) const
{
return make_tuple(/*block_index=*/0,
ck_tile::make_tile_window(tensor_view, window_lengths, window_origin));
}
template <typename WindowLengths, typename TileDistribution>
CK_TILE_HOST_DEVICE constexpr auto
make_tile_window(const WindowLengths& window_lengths,
const WindowOrigin& window_origin,
const TileDistribution& tile_distribution) const
{
return make_tuple(
/*block_index=*/0,
ck_tile::make_tile_window(
tensor_view, window_lengths, window_origin, tile_distribution));
}
template <typename TileWindow>
CK_TILE_HOST_DEVICE static index_t
move_tile_window(index_t /*block_index*/,
TileWindow& tile_window,
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step)
{
ck_tile::move_tile_window(tile_window, step);
return /*block_index=*/0;
}
CK_TILE_HOST_DEVICE static constexpr WindowOrigin
to_local_window_origin(const WindowOrigin& global_window_origin)
{
return global_window_origin;
}
CK_TILE_HOST_DEVICE static constexpr WindowOrigin
to_global_window_origin(index_t /*block_index*/, const WindowOrigin& local_window_origin)
{
return local_window_origin;
}
private:
TensorView tensor_view;
};
// default page-block navigator, assume that tensor view size is same as page-block size or smaller
// if tile window on last page-block
template <typename DataType_, index_t VirtualDim, typename TensorView>
struct PageBlockNavigator
{
using DataType = DataType_;
static_assert(std::is_same_v<DataType, typename TensorView::DataType>);
static_assert(VirtualDim == 0 || VirtualDim == 1, "only support 2d tile window");
using WindowOrigin = multi_index<2>;
CK_TILE_HOST_DEVICE constexpr PageBlockNavigator(copy_const_t<DataType, void>* physical_blocks_,
long_index_t block_stride_,
long_index_t fixed_offset_,
const int32_t* physical_block_indices_,
index_t num_blocks_,
index_t page_block_size_,
const TensorView& complete_view_,
const TensorView& last_view_)
: physical_blocks(reinterpret_cast<DataType*>(physical_blocks_)),
block_stride(block_stride_),
fixed_offset(fixed_offset_),
physical_block_indices(physical_block_indices_),
num_blocks(num_blocks_),
page_block_size(page_block_size_),
complete_view(complete_view_),
last_view(last_view_)
{
}
template <typename WindowLengths>
CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths& window_lengths,
const WindowOrigin& window_origin) const
{
const index_t block_index = get_block_index(window_origin);
const WindowOrigin local_window_origin = to_local_window_origin(window_origin);
auto new_tile_window =
ck_tile::make_tile_window(is_last_block(block_index) ? last_view : complete_view,
window_lengths,
local_window_origin);
new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index));
return make_tuple(block_index, new_tile_window);
}
template <typename WindowLengths, typename TileDistribution>
CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths& window_lengths,
const WindowOrigin& window_origin,
const TileDistribution& tile_distribution) const
{
const index_t block_index = get_block_index(window_origin);
const WindowOrigin local_window_origin = to_local_window_origin(window_origin);
auto new_tile_window =
ck_tile::make_tile_window(is_last_block(block_index) ? last_view : complete_view,
window_lengths,
local_window_origin,
tile_distribution);
new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index));
return make_tuple(block_index, new_tile_window);
}
template <typename TileWindow>
CK_TILE_HOST_DEVICE index_t
move_tile_window(index_t block_index,
TileWindow& tile_window,
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step) const
{
ck_tile::move_tile_window(tile_window, step);
const WindowOrigin global_window_origin =
to_global_window_origin(block_index, tile_window.get_window_origin());
const WindowOrigin local_window_origin = to_local_window_origin(global_window_origin);
const index_t new_block_index = get_block_index(global_window_origin);
/// TODO: only update necessary attributes
tile_window.bottom_tensor_view_.desc_ =
(is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor();
tile_window.set_window_origin(local_window_origin);
tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index));
return new_block_index;
}
CK_TILE_HOST_DEVICE bool is_last_block(index_t block_index) const
{
return block_index == num_blocks - 1;
}
template <typename TileWindow>
CK_TILE_HOST_DEVICE bool is_cross_block(index_t block_index,
const TileWindow& tile_window) const
{
const index_t origin = tile_window.get_window_origin().at(number<VirtualDim>{});
const index_t length = tile_window.get_window_lengths().at(number<VirtualDim>{});
return (block_index < num_blocks - 1) && (page_block_size < origin + length);
}
template <typename TileWindow>
CK_TILE_HOST_DEVICE void
move_to_block(index_t block_index, TileWindow& tile_window, index_t new_block_index) const
{
const multi_index<2> step = [&]() {
const index_t origin_diff = (block_index - new_block_index) * page_block_size;
if constexpr(VirtualDim == 0)
{
return make_multi_index(origin_diff, 0);
}
else
{
return make_multi_index(0, origin_diff);
}
}();
/// TODO: only update necessary attributes
tile_window.bottom_tensor_view_.desc_ =
(is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor();
tile_window.set_window_origin(tile_window.get_window_origin() + step);
tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index));
}
CK_TILE_HOST_DEVICE WindowOrigin
to_local_window_origin(const WindowOrigin& global_window_origin) const
{
if constexpr(VirtualDim == 0)
{
const index_t length = global_window_origin.at(number<0>{});
const index_t num_complete_blocks = integer_divide_floor(length, page_block_size);
return make_multi_index(length - page_block_size * num_complete_blocks,
global_window_origin.at(number<1>{}));
}
else
{
const index_t length = global_window_origin.at(number<1>{});
const index_t num_complete_blocks = integer_divide_floor(length, page_block_size);
return make_multi_index(global_window_origin.at(number<0>{}),
length - page_block_size * num_complete_blocks);
}
}
CK_TILE_HOST_DEVICE WindowOrigin
to_global_window_origin(index_t block_index, const WindowOrigin& local_window_origin) const
{
if constexpr(VirtualDim == 0)
{
return make_multi_index(block_index * page_block_size +
local_window_origin.at(number<0>{}),
local_window_origin.at(number<1>{}));
}
else
{
return make_multi_index(local_window_origin.at(number<0>{}),
block_index * page_block_size +
local_window_origin.at(number<1>{}));
}
}
private:
CK_TILE_HOST_DEVICE
DataType* get_block_ptr(index_t block_index) const
{
return physical_blocks + physical_block_indices[block_index] * block_stride + fixed_offset;
}
CK_TILE_HOST_DEVICE int32_t get_block_index(const WindowOrigin& global_window_origin) const
{
return integer_divide_floor(global_window_origin.at(number<VirtualDim>{}), page_block_size);
}
DataType* physical_blocks;
long_index_t block_stride;
long_index_t fixed_offset;
const int32_t* physical_block_indices;
index_t num_blocks;
index_t page_block_size;
TensorView complete_view;
TensorView last_view;
};
template <typename TensorView>
CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView& tensor_view)
{
return TrivialPageBlockNavigator<TensorView>(tensor_view);
}
template <typename DataType, index_t VirtualDim, typename TensorView>
CK_TILE_HOST_DEVICE auto make_page_block_navigator(copy_const_t<DataType, void>* physical_blocks,
long_index_t block_stride,
long_index_t fixed_offset,
const int32_t* physical_block_indices,
index_t num_blocks,
index_t page_block_size,
const TensorView& complete_view,
const TensorView& last_view)
{
return PageBlockNavigator<DataType, VirtualDim, TensorView>(physical_blocks,
block_stride,
fixed_offset,
physical_block_indices,
num_blocks,
page_block_size,
complete_view,
last_view);
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename TilePartitioner_, typename FmhaPipeline_>
struct FmhaFwdAppendKVKernel
{
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>;
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0);
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>;
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
static constexpr bool kApplyRoPE = FmhaPipeline::RotaryEnum != RotaryEmbeddingEnum::NONE;
static constexpr bool kIsPagedKV = FmhaPipeline::kIsPagedKV;
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
// clang-format off
template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
// clang-format on
__host__ static std::string GetName()
{
// sync with generate.py
// clang-format off
#define _SS_ std::string
#define _TS_ std::to_string
auto pn = [&] () {
std::string n;
if (kPadSeqLenQ) n += "s";
if (kPadSeqLenK) n += "sk";
if (kPadHeadDimQ) n += "d";
if (kPadHeadDimV) n += "dv";
return n.empty() ? n : std::string("p") + n; }();
return
_SS_("fmha_fwd_appendkv_d") + _TS_(FmhaPipeline::kK0) + "_" + _SS_(t2s<QDataType>::name) + "_"
"b" + _TS_(FmhaPipeline::kM0) + "x" + _TS_(FmhaPipeline::kN0) + "x" + _TS_(FmhaPipeline::kK0) + "x" +
_TS_(FmhaPipeline::kN1) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn)
+ (!kApplyRoPE ? _SS_("") : (_SS_("_") + RotaryEmbeddingEnumToStr<FmhaPipeline::RotaryEnum>::name))
+ (kIsPagedKV ? "_pagedkv" : "" );
#undef _SS_
#undef _TS_
// clang-format on
}
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
// arg
struct EmptyKargs
{
};
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct BasicKargs
{
void* q_ptr;
void* k_ptr;
const void* knew_ptr;
void* v_ptr;
const void* vnew_ptr;
const int32_t* seqlen_k_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t seqlen_knew;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t num_head_q;
// for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
// if this param is larger than 1, indicate MQA/GQA case
ck_tile::index_t nhead_ratio_qk;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_knew;
ck_tile::index_t stride_v;
ck_tile::index_t stride_vnew;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_knew;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_vnew;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_knew;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_vnew;
};
struct RoPEKargs
{
const void* rotary_cos_ptr;
const void* rotary_sin_ptr;
ck_tile::index_t rotary_dim;
bool has_mask;
};
struct PageBlockTableKargs
{
const int32_t* block_table_ptr;
ck_tile::index_t batch_stride_block_table;
ck_tile::index_t page_block_size;
};
struct CacheBatchIdxKargs
{
const int32_t* cache_batch_idx;
};
struct Kargs : BasicKargs,
std::conditional_t<kApplyRoPE, RoPEKargs, EmptyKargs<0>>,
std::conditional_t<kIsPagedKV, PageBlockTableKargs, CacheBatchIdxKargs>
{
};
__host__ static constexpr Kargs MakeKargs(void* q_ptr,
void* k_ptr,
const void* knew_ptr,
void* v_ptr,
const void* vnew_ptr,
ck_tile::index_t seqlen_q,
const void* seqlen_k_ptr,
ck_tile::index_t seqlen_knew,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
const void* rotary_cos_ptr,
const void* rotary_sin_ptr,
ck_tile::index_t rotary_dim,
bool has_mask,
const void* block_table_ptr,
ck_tile::index_t batch_stride_block_table,
ck_tile::index_t page_block_size,
const void* cache_batch_idx,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_knew,
ck_tile::index_t stride_v,
ck_tile::index_t stride_vnew,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_knew,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_vnew,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_knew,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_vnew)
{
Kargs kargs{
{q_ptr,
k_ptr,
knew_ptr,
v_ptr,
vnew_ptr,
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
seqlen_q,
-1, // seqlen_k will be updated by content of seqlen_k_ptr
seqlen_knew,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
stride_q,
stride_k,
stride_knew,
stride_v,
stride_vnew,
nhead_stride_q,
nhead_stride_k,
nhead_stride_knew,
nhead_stride_v,
nhead_stride_vnew,
batch_stride_q,
batch_stride_k,
batch_stride_knew,
batch_stride_v,
batch_stride_vnew}, // args for common karg
{}, // placeholder for rope
{} // placeholder for paged-block table or cache_batch_idx
};
if constexpr(kApplyRoPE)
{
kargs.rotary_cos_ptr = rotary_cos_ptr;
kargs.rotary_sin_ptr = rotary_sin_ptr;
kargs.rotary_dim = rotary_dim;
kargs.has_mask = has_mask;
}
if constexpr(kIsPagedKV)
{
kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
kargs.batch_stride_block_table = batch_stride_block_table;
kargs.page_block_size = page_block_size;
}
else
{
kargs.cache_batch_idx = reinterpret_cast<const int32_t*>(cache_batch_idx);
}
return kargs;
}
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_knew)
{
return TilePartitioner::GridSize(batch_size, nhead, seqlen_q, seqlen_knew);
}
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
// divide problem
const auto [i_tile, i_nhead, i_batch] = TilePartitioner{}();
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kM0);
const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kN0);
const index_t i_cache_batch = [&, i_batch_ = i_batch] {
if constexpr(kIsPagedKV)
{
return i_batch_;
}
else
{
return (kargs.cache_batch_idx != nullptr ? kargs.cache_batch_idx[i_batch_]
: i_batch_);
}
}();
const long_index_t batch_offset_q =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
const long_index_t batch_offset_k =
static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
const long_index_t batch_offset_knew =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_knew;
const long_index_t batch_offset_v =
static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
const long_index_t batch_offset_vnew =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_vnew;
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
// for simplicity, batch stride we just modify the pointer
QDataType* q_ptr = reinterpret_cast<QDataType*>(kargs.q_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
batch_offset_q;
KDataType* k_ptr =
reinterpret_cast<KDataType*>(kargs.k_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
batch_offset_k;
const KDataType* knew_ptr =
reinterpret_cast<const KDataType*>(kargs.knew_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_knew +
batch_offset_knew;
VDataType* v_ptr =
reinterpret_cast<VDataType*>(kargs.v_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
batch_offset_v;
const VDataType* vnew_ptr =
reinterpret_cast<const VDataType*>(kargs.vnew_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_vnew +
batch_offset_vnew;
// Q/K/V DRAM and DRAM window
const auto q_dram = [&]() {
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
q_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_q, 1),
number<FmhaPipeline::kAlignmentQ>{},
number<1>{});
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
const auto make_k_dram = [&](KDataType* data, index_t height) {
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
data, // will update this pointer if using paged-kvcache
make_tuple(height, kargs.hdim_q),
make_tuple(kargs.stride_k, 1),
number<FmhaPipeline::kAlignmentK>{},
number<1>{});
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{});
};
const auto k_dram = [&]() {
if constexpr(kIsPagedKV)
{
return make_k_dram(nullptr, kargs.page_block_size);
}
else
{
return make_k_dram(k_ptr, kargs.seqlen_k + kargs.seqlen_knew);
}
}();
const auto knew_dram = [&]() {
const auto knew_dram_naive = make_naive_tensor_view<address_space_enum::global>(
knew_ptr,
make_tuple(kargs.seqlen_knew, kargs.hdim_q),
make_tuple(kargs.stride_knew, 1),
number<FmhaPipeline::kAlignmentK>{},
number<1>{});
return pad_tensor_view(
knew_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{});
}();
const auto make_v_dram = [&](VDataType* data, index_t length) {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
data, // will update this pointer if using paged-kvcache
make_tuple(length, kargs.hdim_v),
make_tuple(kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
const auto v_dram_transposed =
transform_tensor_view(v_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(length)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return pad_tensor_view(
v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{});
}
else
{
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
data, // will update this pointer if using paged-kvcache
make_tuple(kargs.hdim_v, length),
make_tuple(kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
return pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{});
}
};
const auto v_dram = [&]() {
if constexpr(kIsPagedKV)
{
return make_v_dram(nullptr, kargs.page_block_size);
}
else
{
return make_v_dram(v_ptr, kargs.seqlen_k + kargs.seqlen_knew);
}
}();
const auto vnew_dram = [&]() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
const auto vnew_dram_naive = make_naive_tensor_view<address_space_enum::global>(
vnew_ptr,
make_tuple(kargs.seqlen_knew, kargs.hdim_v),
make_tuple(kargs.stride_vnew, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
const auto vnew_dram_transposed = transform_tensor_view(
vnew_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.seqlen_knew)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return pad_tensor_view(
vnew_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{});
}
else
{
const auto vnew_dram_naive = make_naive_tensor_view<address_space_enum::global>(
vnew_ptr,
make_tuple(kargs.hdim_v, kargs.seqlen_knew),
make_tuple(kargs.stride_vnew, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
return pad_tensor_view(
vnew_dram_naive,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{});
}
}();
constexpr auto q_rotary_cos_sin_dram_window_lengths =
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0 / 2>{});
const auto q_rotary_cos_dram_window = [&]() {
if constexpr(kApplyRoPE)
{
const auto rotary_cos_dram_native =
make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const QDataType*>(kargs.rotary_cos_ptr) +
kargs.seqlen_k * (kargs.rotary_dim / 2),
make_tuple(kargs.seqlen_q, kargs.rotary_dim / 2),
make_tuple(kargs.has_mask * (kargs.rotary_dim / 2), 1),
number<8>{},
number<1>{});
const auto rotary_cos_dram = [&]() {
return pad_tensor_view(rotary_cos_dram_native,
q_rotary_cos_sin_dram_window_lengths,
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
return make_tile_window(
rotary_cos_dram, q_rotary_cos_sin_dram_window_lengths, {i_m0, 0});
}
else
{
return make_null_tile_window(q_rotary_cos_sin_dram_window_lengths);
}
}();
const auto q_rotary_sin_dram_window = [&]() {
if constexpr(kApplyRoPE)
{
const auto rotary_sin_dram_native =
make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const QDataType*>(kargs.rotary_sin_ptr) +
kargs.seqlen_k * (kargs.rotary_dim / 2),
make_tuple(kargs.seqlen_q, kargs.rotary_dim / 2),
make_tuple(kargs.has_mask * (kargs.rotary_dim / 2), 1),
number<8>{},
number<1>{});
const auto rotary_sin_dram = [&]() {
return pad_tensor_view(rotary_sin_dram_native,
q_rotary_cos_sin_dram_window_lengths,
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
return make_tile_window(
rotary_sin_dram, q_rotary_cos_sin_dram_window_lengths, {i_m0, 0});
}
else
{
return make_null_tile_window(q_rotary_cos_sin_dram_window_lengths);
}
}();
constexpr auto knew_rotary_cos_sin_dram_window_lengths =
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0 / 2>{});
const auto knew_rotary_cos_dram_window = [&]() {
if constexpr(kApplyRoPE)
{
const auto rotary_cos_dram_native =
make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const KDataType*>(kargs.rotary_cos_ptr) +
kargs.seqlen_k * (kargs.rotary_dim / 2),
make_tuple(kargs.seqlen_knew, kargs.rotary_dim / 2),
make_tuple(kargs.rotary_dim / 2, 1),
number<8>{},
number<1>{});
const auto rotary_cos_dram = [&]() {
return pad_tensor_view(rotary_cos_dram_native,
knew_rotary_cos_sin_dram_window_lengths,
sequence<kPadSeqLenK, kPadHeadDimQ>{});
}();
return make_tile_window(
rotary_cos_dram, knew_rotary_cos_sin_dram_window_lengths, {i_n0, 0});
}
else
{
return make_null_tile_window(knew_rotary_cos_sin_dram_window_lengths);
}
}();
const auto knew_rotary_sin_dram_window = [&]() {
if constexpr(kApplyRoPE)
{
const auto rotary_sin_dram_native =
make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const KDataType*>(kargs.rotary_sin_ptr) +
kargs.seqlen_k * (kargs.rotary_dim / 2),
make_tuple(kargs.seqlen_knew, kargs.rotary_dim / 2),
make_tuple(kargs.rotary_dim / 2, 1),
number<8>{},
number<1>{});
const auto rotary_sin_dram = [&]() {
return pad_tensor_view(rotary_sin_dram_native,
knew_rotary_cos_sin_dram_window_lengths,
sequence<kPadSeqLenK, kPadHeadDimQ>{});
}();
return make_tile_window(
rotary_sin_dram, knew_rotary_cos_sin_dram_window_lengths, {i_n0, 0});
}
else
{
return make_null_tile_window(knew_rotary_cos_sin_dram_window_lengths);
}
}();
auto k_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
if constexpr(kIsPagedKV)
{
const auto* block_indices =
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
i_batch_ * kargs.batch_stride_block_table;
const index_t num_blocks =
integer_divide_ceil(kargs.seqlen_k + kargs.seqlen_knew, kargs.page_block_size);
const long_index_t fixed_offset =
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
kargs.nhead_stride_k;
return make_page_block_navigator<KDataType, 0>(
kargs.k_ptr,
kargs.batch_stride_k,
fixed_offset,
block_indices,
num_blocks,
kargs.page_block_size,
k_dram,
make_k_dram(nullptr,
(kargs.seqlen_k + kargs.seqlen_knew) -
(num_blocks - 1) * kargs.page_block_size));
}
else
{
return make_page_block_navigator(k_dram);
}
}();
auto v_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
if constexpr(kIsPagedKV)
{
const auto* block_indices =
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
i_batch_ * kargs.batch_stride_block_table;
const index_t num_blocks =
integer_divide_ceil(kargs.seqlen_k + kargs.seqlen_knew, kargs.page_block_size);
const long_index_t fixed_offset =
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
kargs.nhead_stride_v;
return make_page_block_navigator<VDataType, 1>(
kargs.v_ptr,
kargs.batch_stride_v,
fixed_offset,
block_indices,
num_blocks,
kargs.page_block_size,
v_dram,
make_v_dram(nullptr,
(kargs.seqlen_k + kargs.seqlen_knew) -
(num_blocks - 1) * kargs.page_block_size));
}
else
{
return make_page_block_navigator(v_dram);
}
}();
auto q_dram_window =
make_tile_window(q_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
{i_m0, 0});
const bool skip_append_kv = kargs.seqlen_knew <= i_n0;
// window origin = (0, 0) if no work to do for current block
auto [i_page_block_k, k_dram_window] = k_page_block_navigator.make_tile_window(
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
{!skip_append_kv * (kargs.seqlen_k + i_n0), 0});
auto knew_dram_window =
make_tile_window(knew_dram,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
{i_n0, 0});
// window origin = (0, 0) if no work to do for current block
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
{0, !skip_append_kv * (kargs.seqlen_k + i_n0)});
auto vnew_dram_window =
make_tile_window(vnew_dram,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
{0, i_n0});
if constexpr(kApplyRoPE)
{
FmhaPipeline{}(q_dram_window,
k_dram_window,
i_page_block_k,
k_page_block_navigator,
knew_dram_window,
v_dram_window,
i_page_block_v,
v_page_block_navigator,
vnew_dram_window,
q_rotary_cos_dram_window,
q_rotary_sin_dram_window,
knew_rotary_cos_dram_window,
knew_rotary_sin_dram_window,
kargs.rotary_dim,
kargs.seqlen_q <= i_m0,
skip_append_kv);
}
else
{
FmhaPipeline{}(q_dram_window,
k_dram_window,
i_page_block_k,
k_page_block_navigator,
knew_dram_window,
v_dram_window,
i_page_block_v,
v_page_block_navigator,
vnew_dram_window,
q_rotary_cos_dram_window,
q_rotary_sin_dram_window,
knew_rotary_cos_dram_window,
knew_rotary_sin_dram_window,
0, // rotary_dim not used
kargs.seqlen_q <= i_m0,
skip_append_kv);
}
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <index_t kM0_, index_t kN0_, index_t kK0_, index_t kN1_>
struct FmhaFwdAppendKVTilePartitioner
{
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN0 = kN0_;
static constexpr ck_tile::index_t kK0 = kK0_;
static constexpr ck_tile::index_t kN1 = kN1_;
static_assert(kK0 == kN1);
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_knew)
{
// TODO: this may need tuning
return dim3(std::max(ck_tile::integer_divide_ceil(seqlen_q, kM0),
ck_tile::integer_divide_ceil(seqlen_knew, kN0)),
nhead,
batch_size);
}
CK_TILE_DEVICE auto operator()()
{
const index_t i_tile = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_tile, i_nhead, i_batch);
}
};
} // namespace ck_tile
...@@ -32,8 +32,6 @@ struct FmhaFwdSplitKVKernel ...@@ -32,8 +32,6 @@ struct FmhaFwdSplitKVKernel
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>; using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>; using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>;
using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>; using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>;
using RandValOutputDataType =
ck_tile::remove_cvref_t<typename FmhaPipeline::RandValOutputDataType>;
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>; using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>; using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
using OaccDataType = remove_cvref_t<typename FmhaPipeline::OaccDataType>; using OaccDataType = remove_cvref_t<typename FmhaPipeline::OaccDataType>;
...@@ -46,8 +44,10 @@ struct FmhaFwdSplitKVKernel ...@@ -46,8 +44,10 @@ struct FmhaFwdSplitKVKernel
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
static_assert(!kIsGroupMode || (kIsGroupMode && !kIsPagedKV),
"paged-kvcache only supported by batch mode kernels");
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>; using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking; static constexpr bool kHasMask = FmhaMask::IsMasking;
...@@ -85,8 +85,8 @@ struct FmhaFwdSplitKVKernel ...@@ -85,8 +85,8 @@ struct FmhaFwdSplitKVKernel
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kDoFp8StaticQuant ? "_squant" : "") + (kIsPagedKV ? "_pagedkv" : "" );
#undef _SS_ #undef _SS_
#undef _TS_ #undef _TS_
// clang-format on // clang-format on
...@@ -110,7 +110,6 @@ struct FmhaFwdSplitKVKernel ...@@ -110,7 +110,6 @@ struct FmhaFwdSplitKVKernel
void* o_acc_ptr; void* o_acc_ptr;
ck_tile::index_t batch; ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k; ck_tile::index_t seqlen_k;
...@@ -136,6 +135,7 @@ struct FmhaFwdSplitKVKernel ...@@ -136,6 +135,7 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_lse_acc; ck_tile::index_t nhead_stride_lse_acc;
ck_tile::index_t nhead_stride_o_acc; ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc; ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc; ck_tile::index_t split_stride_lse_acc;
...@@ -173,32 +173,16 @@ struct FmhaFwdSplitKVKernel ...@@ -173,32 +173,16 @@ struct FmhaFwdSplitKVKernel
float scale_p; float scale_p;
}; };
struct CommonDropoutKargs struct PageBlockTableKargs
{ {
void init_dropout(const float p_drop, const int32_t* block_table_ptr;
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) ck_tile::index_t batch_stride_block_table;
{ ck_tile::index_t page_block_size;
float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
rp_undrop = 1.0 / p_undrop;
drop_seed = std::get<0>(drop_seed_offset);
drop_offset = std::get<1>(drop_seed_offset);
}
float rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
bool is_store_randval = false;
uint64_t drop_seed = 1;
uint64_t drop_offset = 0;
void* rand_val_ptr = nullptr;
ck_tile::index_t stride_randval = 0;
ck_tile::index_t nhead_stride_randval = 0;
}; };
struct BatchModeDropoutKargs : CommonDropoutKargs
struct CacheBatchIdxKargs
{ {
ck_tile::index_t batch_stride_randval = 0; const int32_t* cache_batch_idx;
}; };
struct BatchModeKargs struct BatchModeKargs
...@@ -210,12 +194,13 @@ struct FmhaFwdSplitKVKernel ...@@ -210,12 +194,13 @@ struct FmhaFwdSplitKVKernel
EmptyKargs<0>>>, EmptyKargs<0>>>,
std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>, std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>, std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
std::conditional_t<kHasDropout, BatchModeDropoutKargs, EmptyKargs<3>> std::conditional_t<kIsPagedKV, PageBlockTableKargs, CacheBatchIdxKargs>
{ {
const int32_t* seqlen_k_ptr;
ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_lse_acc;
}; };
struct GroupModeKargs struct GroupModeKargs
...@@ -226,12 +211,14 @@ struct FmhaFwdSplitKVKernel ...@@ -226,12 +211,14 @@ struct FmhaFwdSplitKVKernel
AlibiKargs, AlibiKargs,
EmptyKargs<0>>>, EmptyKargs<0>>>,
std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>, std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>, std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>
std::conditional_t<kHasDropout, CommonDropoutKargs, EmptyKargs<3>>
{ {
const int32_t* seqstart_q_ptr; const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr; const int32_t* seqstart_k_ptr;
const int32_t* seqlen_k_ptr; const int32_t* seqlen_k_ptr;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
}; };
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>; using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
...@@ -242,48 +229,45 @@ struct FmhaFwdSplitKVKernel ...@@ -242,48 +229,45 @@ struct FmhaFwdSplitKVKernel
const void* k_ptr, const void* k_ptr,
const void* v_ptr, const void* v_ptr,
const void* bias_ptr, const void* bias_ptr,
void* rand_val_ptr,
void* lse_acc_ptr, void* lse_acc_ptr,
void* o_acc_ptr, void* o_acc_ptr,
ck_tile::index_t batch, ck_tile::index_t batch,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k, ck_tile::index_t seqlen_k, // only used if 'seqlen_k_ptr' is not specified
const void* seqlen_k_ptr, // only used for (paged-) kvcache
ck_tile::index_t hdim_q, ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v, ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q, ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk, ck_tile::index_t nhead_ratio_qk,
ck_tile::index_t num_splits, ck_tile::index_t num_splits,
const void* block_table_ptr,
ck_tile::index_t batch_stride_block_table,
ck_tile::index_t page_block_size,
const void* cache_batch_idx,
float scale_s, float scale_s,
float scale_p, float scale_p,
ck_tile::index_t stride_q, ck_tile::index_t stride_q,
ck_tile::index_t stride_k, ck_tile::index_t stride_k,
ck_tile::index_t stride_v, ck_tile::index_t stride_v,
ck_tile::index_t stride_bias, ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o_acc, ck_tile::index_t stride_o_acc,
ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_lse_acc,
ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse_acc, ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc, ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc, ck_tile::index_t split_stride_o_acc,
ck_tile::index_t window_size_left, ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type)
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
k_ptr, k_ptr,
...@@ -291,7 +275,6 @@ struct FmhaFwdSplitKVKernel ...@@ -291,7 +275,6 @@ struct FmhaFwdSplitKVKernel
lse_acc_ptr, lse_acc_ptr,
o_acc_ptr, o_acc_ptr,
batch, batch,
max_seqlen_q,
seqlen_q, seqlen_q,
seqlen_k, seqlen_k,
hdim_q, hdim_q,
...@@ -313,17 +296,18 @@ struct FmhaFwdSplitKVKernel ...@@ -313,17 +296,18 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v, nhead_stride_v,
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc, batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
{}, // placeholder for bias {}, // placeholder for bias
{}, // placeholder for mask {}, // placeholder for mask
{}, // placeholder for fp8_static_quant args {}, // placeholder for fp8_static_quant args
{}, // placeholder for dropout {}, // placeholder for paged-block table or cache_batch_idx
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
batch_stride_q, batch_stride_q,
batch_stride_k, batch_stride_k,
batch_stride_v, batch_stride_v};
batch_stride_lse_acc};
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
...@@ -347,14 +331,15 @@ struct FmhaFwdSplitKVKernel ...@@ -347,14 +331,15 @@ struct FmhaFwdSplitKVKernel
{ {
kargs.scale_p = scale_p; kargs.scale_p = scale_p;
} }
if constexpr(kHasDropout) if constexpr(kIsPagedKV)
{
kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
kargs.batch_stride_block_table = batch_stride_block_table;
kargs.page_block_size = page_block_size;
}
else
{ {
kargs.init_dropout(p_drop, drop_seed_offset); kargs.cache_batch_idx = reinterpret_cast<const int32_t*>(cache_batch_idx);
kargs.rand_val_ptr = rand_val_ptr;
kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval;
kargs.batch_stride_randval = batch_stride_randval;
kargs.is_store_randval = s_randval;
} }
return kargs; return kargs;
...@@ -366,11 +351,9 @@ struct FmhaFwdSplitKVKernel ...@@ -366,11 +351,9 @@ struct FmhaFwdSplitKVKernel
const void* k_ptr, const void* k_ptr,
const void* v_ptr, const void* v_ptr,
const void* bias_ptr, const void* bias_ptr,
void* rand_val_ptr,
void* lse_acc_ptr, void* lse_acc_ptr,
void* o_acc_ptr, void* o_acc_ptr,
ck_tile::index_t batch, ck_tile::index_t batch,
ck_tile::index_t max_seqlen_q,
const void* seqstart_q_ptr, const void* seqstart_q_ptr,
const void* seqstart_k_ptr, const void* seqstart_k_ptr,
const void* seqlen_k_ptr, const void* seqlen_k_ptr,
...@@ -385,24 +368,22 @@ struct FmhaFwdSplitKVKernel ...@@ -385,24 +368,22 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t stride_k, ck_tile::index_t stride_k,
ck_tile::index_t stride_v, ck_tile::index_t stride_v,
ck_tile::index_t stride_bias, ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o_acc, ck_tile::index_t stride_o_acc,
ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_lse_acc,
ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc, ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc, ck_tile::index_t split_stride_o_acc,
ck_tile::index_t window_size_left, ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type)
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
k_ptr, k_ptr,
...@@ -410,9 +391,8 @@ struct FmhaFwdSplitKVKernel ...@@ -410,9 +391,8 @@ struct FmhaFwdSplitKVKernel
lse_acc_ptr, lse_acc_ptr,
o_acc_ptr, o_acc_ptr,
batch, batch,
max_seqlen_q, -1, // seqlen_q will be updated by another pointer
-1, // seqlen will be updated by another pointer -1, // seqlen_k will be updated by another pointer
-1, //
hdim_q, hdim_q,
hdim_v, hdim_v,
num_head_q, num_head_q,
...@@ -432,16 +412,18 @@ struct FmhaFwdSplitKVKernel ...@@ -432,16 +412,18 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v, nhead_stride_v,
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc, batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
{}, // placeholder for bias {}, // placeholder for bias
{}, // placeholder for mask {}, // placeholder for mask
{}, // placeholder for fp8_static_quant args {}, // placeholder for fp8_static_quant args
{}, // placeholder for dropout
reinterpret_cast<const int32_t*>(seqstart_q_ptr), reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr), reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)}; reinterpret_cast<const int32_t*>(seqlen_k_ptr),
batch_stride_k,
batch_stride_v};
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
...@@ -464,14 +446,6 @@ struct FmhaFwdSplitKVKernel ...@@ -464,14 +446,6 @@ struct FmhaFwdSplitKVKernel
{ {
kargs.scale_p = scale_p; kargs.scale_p = scale_p;
} }
if constexpr(kHasDropout)
{
kargs.init_dropout(p_drop, drop_seed_offset);
kargs.rand_val_ptr = rand_val_ptr;
kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval;
kargs.is_store_randval = s_randval;
}
return kargs; return kargs;
} }
...@@ -508,7 +482,6 @@ struct FmhaFwdSplitKVKernel ...@@ -508,7 +482,6 @@ struct FmhaFwdSplitKVKernel
long_index_t batch_offset_k = 0; long_index_t batch_offset_k = 0;
long_index_t batch_offset_v = 0; long_index_t batch_offset_v = 0;
long_index_t batch_offset_bias = 0; long_index_t batch_offset_bias = 0;
long_index_t batch_offset_randval = 0;
long_index_t batch_offset_lse_acc = 0; long_index_t batch_offset_lse_acc = 0;
const long_index_t batch_offset_o_acc = const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc; static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
...@@ -534,14 +507,9 @@ struct FmhaFwdSplitKVKernel ...@@ -534,14 +507,9 @@ struct FmhaFwdSplitKVKernel
{ {
batch_offset_bias = query_start * kargs.stride_bias + key_start; batch_offset_bias = query_start * kargs.stride_bias + key_start;
} }
if constexpr(kHasDropout)
{
batch_offset_randval = query_start * kargs.stride_randval;
}
// get real # queries & # keys under group mode // get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
// # of required blocks is different in each groups, terminate unnecessary blocks // # of required blocks is different in each groups, terminate unnecessary blocks
// earlier // earlier
...@@ -556,24 +524,36 @@ struct FmhaFwdSplitKVKernel ...@@ -556,24 +524,36 @@ struct FmhaFwdSplitKVKernel
} }
else else
{ {
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch];
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
} }
} }
else else
{ {
const index_t i_cache_batch = [&, i_batch_ = i_batch] {
if constexpr(kIsPagedKV)
{
return i_batch_;
}
else
{
return (kargs.cache_batch_idx != nullptr ? kargs.cache_batch_idx[i_batch_]
: i_batch_);
}
}();
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q; batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k; batch_offset_k = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v; batch_offset_v = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc; batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias; batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
} }
if constexpr(kHasDropout)
if(kargs.seqlen_k_ptr != nullptr)
{ {
batch_offset_randval = kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
} }
} }
...@@ -589,6 +569,7 @@ struct FmhaFwdSplitKVKernel ...@@ -589,6 +569,7 @@ struct FmhaFwdSplitKVKernel
reinterpret_cast<const VDataType*>(kargs.v_ptr) + reinterpret_cast<const VDataType*>(kargs.v_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
batch_offset_v; batch_offset_v;
OaccDataType* o_acc_ptr = reinterpret_cast<OaccDataType*>(kargs.o_acc_ptr) + OaccDataType* o_acc_ptr = reinterpret_cast<OaccDataType*>(kargs.o_acc_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o_acc + static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o_acc +
batch_offset_o_acc + i_split * kargs.split_stride_o_acc; batch_offset_o_acc + i_split * kargs.split_stride_o_acc;
...@@ -616,10 +597,11 @@ struct FmhaFwdSplitKVKernel ...@@ -616,10 +597,11 @@ struct FmhaFwdSplitKVKernel
sequence<kPadSeqLenQ, kPadHeadDimQ>{}); sequence<kPadSeqLenQ, kPadHeadDimQ>{});
} }
}(); }();
const auto k_dram = [&]() {
const auto make_k_dram = [&](const KDataType* data, index_t height) {
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr, data, // will update this pointer if using paged-kvcache
make_tuple(kargs.seqlen_k, kargs.hdim_q), make_tuple(height, kargs.hdim_q),
make_tuple(kargs.stride_k, 1), make_tuple(kargs.stride_k, 1),
number<FmhaPipeline::kAlignmentK>{}, number<FmhaPipeline::kAlignmentK>{},
number<1>{}); number<1>{});
...@@ -628,13 +610,24 @@ struct FmhaFwdSplitKVKernel ...@@ -628,13 +610,24 @@ struct FmhaFwdSplitKVKernel
k_dram_naive, k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}), make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{}); sequence<kPadSeqLenK, kPadHeadDimQ>{});
};
const auto k_dram = [&]() {
if constexpr(kIsPagedKV)
{
return make_k_dram(nullptr, kargs.page_block_size);
}
else
{
return make_k_dram(k_ptr, kargs.seqlen_k);
}
}(); }();
const auto v_dram = [&]() {
const auto make_v_dram = [&](const VDataType* data, index_t length) {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr, data, // will update this pointer if using paged-kvcache
make_tuple(kargs.seqlen_k, kargs.hdim_v), make_tuple(length, kargs.hdim_v),
make_tuple(kargs.stride_v, 1), make_tuple(kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{}, number<FmhaPipeline::kAlignmentV>{},
number<1>{}); number<1>{});
...@@ -642,7 +635,7 @@ struct FmhaFwdSplitKVKernel ...@@ -642,7 +635,7 @@ struct FmhaFwdSplitKVKernel
const auto v_dram_transposed = const auto v_dram_transposed =
transform_tensor_view(v_dram_naive, transform_tensor_view(v_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v), make_tuple(make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.seqlen_k)), make_pass_through_transform(length)),
make_tuple(sequence<1>{}, sequence<0>{}), make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
...@@ -654,8 +647,8 @@ struct FmhaFwdSplitKVKernel ...@@ -654,8 +647,8 @@ struct FmhaFwdSplitKVKernel
else else
{ {
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr, data, // will update this pointer if using paged-kvcache
make_tuple(kargs.hdim_v, kargs.seqlen_k), make_tuple(kargs.hdim_v, length),
make_tuple(kargs.stride_v, 1), make_tuple(kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{}, number<FmhaPipeline::kAlignmentV>{},
number<1>{}); number<1>{});
...@@ -665,6 +658,76 @@ struct FmhaFwdSplitKVKernel ...@@ -665,6 +658,76 @@ struct FmhaFwdSplitKVKernel
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}), make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{}); sequence<kPadHeadDimV, kPadSeqLenK>{});
} }
};
const auto v_dram = [&]() {
if constexpr(kIsPagedKV)
{
return make_v_dram(nullptr, kargs.page_block_size);
}
else
{
return make_v_dram(v_ptr, kargs.seqlen_k);
}
}();
auto k_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
if constexpr(kIsPagedKV)
{
const auto* block_indices =
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
i_batch_ * kargs.batch_stride_block_table;
const index_t num_blocks =
integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size);
const long_index_t fixed_offset =
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
kargs.nhead_stride_k;
return make_page_block_navigator<const KDataType, 0>(
kargs.k_ptr,
kargs.batch_stride_k,
fixed_offset,
block_indices,
num_blocks,
kargs.page_block_size,
k_dram,
make_k_dram(nullptr,
kargs.seqlen_k - (num_blocks - 1) * kargs.page_block_size));
}
else
{
return make_page_block_navigator(k_dram);
}
}();
auto v_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
if constexpr(kIsPagedKV)
{
const auto* block_indices =
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
i_batch_ * kargs.batch_stride_block_table;
const index_t num_blocks =
integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size);
const long_index_t fixed_offset =
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
kargs.nhead_stride_v;
return make_page_block_navigator<const VDataType, 1>(
kargs.v_ptr,
kargs.batch_stride_v,
fixed_offset,
block_indices,
num_blocks,
kargs.page_block_size,
v_dram,
make_v_dram(nullptr,
kargs.seqlen_k - (num_blocks - 1) * kargs.page_block_size));
}
else
{
return make_page_block_navigator(v_dram);
}
}(); }();
auto q_dram_window = make_tile_window( auto q_dram_window = make_tile_window(
...@@ -678,13 +741,11 @@ struct FmhaFwdSplitKVKernel ...@@ -678,13 +741,11 @@ struct FmhaFwdSplitKVKernel
}(), }(),
{i_m0, 0}); {i_m0, 0});
auto k_dram_window = make_tile_window( auto k_dram_window_lengths =
k_dram, make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}), {0, 0}); make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{});
auto v_dram_window_lengths =
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{});
auto v_dram_window =
make_tile_window(v_dram,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
{i_n1, 0});
/// FIXME: Before C++20, capturing structured binding variables are not supported. Remove /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove
/// following copy capture of the 'i_nhead' if in C++20 /// following copy capture of the 'i_nhead' if in C++20
const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
...@@ -741,62 +802,6 @@ struct FmhaFwdSplitKVKernel ...@@ -741,62 +802,6 @@ struct FmhaFwdSplitKVKernel
return make_tile_window(lse_acc_dram, lse_acc_dram_window_lengths, {i_m0}); return make_tile_window(lse_acc_dram, lse_acc_dram_window_lengths, {i_m0});
}(); }();
// dropout
float rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
uint64_t drop_seed = 0;
uint64_t drop_offset = 0;
bool is_store_randval = false;
if constexpr(kHasDropout)
{
rp_undrop = kargs.rp_undrop;
p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t;
drop_seed = kargs.drop_seed;
drop_offset = kargs.drop_offset;
is_store_randval = kargs.is_store_randval;
}
BlockDropout dropout(i_batch,
i_nhead,
kargs.num_head_q,
drop_seed,
drop_offset,
rp_undrop,
p_undrop_in_uint8_t,
is_store_randval);
auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto randval_dram_window_lengths =
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN0>{});
if constexpr(kHasDropout)
{
RandValOutputDataType* rand_val_ptr =
reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_randval +
batch_offset_randval;
const auto randval_dram = [&]() {
const auto randval_dram_naive =
make_naive_tensor_view<address_space_enum::global>(
rand_val_ptr,
make_tuple(kargs.seqlen_q, kargs.seqlen_k),
make_tuple(kargs.stride_randval, 1),
number<1>{},
number<1>{});
return pad_tensor_view(randval_dram_naive,
randval_dram_window_lengths,
sequence<kPadSeqLenQ, kPadSeqLenK>{});
}();
return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0});
}
else
{
return make_null_tile_window(randval_dram_window_lengths);
}
}();
FmhaMask mask = [&]() { FmhaMask mask = [&]() {
if constexpr(kHasMask) if constexpr(kHasMask)
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>( return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
...@@ -823,16 +828,16 @@ struct FmhaFwdSplitKVKernel ...@@ -823,16 +828,16 @@ struct FmhaFwdSplitKVKernel
#endif #endif
if constexpr(kHasMask) if constexpr(kHasMask)
{ {
return make_alibi_from_lr_mask<SaccDataType, true>(slope, return make_alibi_from_lr_mask<SaccDataType, true, 32>(slope,
kargs.window_size_left, kargs.window_size_left,
kargs.window_size_right, kargs.window_size_right,
kargs.seqlen_q, kargs.seqlen_q,
kargs.seqlen_k, kargs.seqlen_k,
kargs.mask_type); kargs.mask_type);
} }
else else
{ {
return Alibi<SaccDataType, true>{ return Alibi<SaccDataType, true, 32>{
slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
} }
} }
...@@ -847,13 +852,14 @@ struct FmhaFwdSplitKVKernel ...@@ -847,13 +852,14 @@ struct FmhaFwdSplitKVKernel
{ {
return FmhaPipeline{}(q_dram_window, return FmhaPipeline{}(q_dram_window,
identity{}, // q_element_func identity{}, // q_element_func
k_dram_window, k_dram_window_lengths,
k_page_block_navigator,
identity{}, // k_element_func identity{}, // k_element_func
v_dram_window, v_dram_window_lengths,
v_page_block_navigator,
identity{}, // v_element_func identity{}, // v_element_func
bias_dram_window, bias_dram_window,
identity{}, // bias_element_func identity{}, // bias_element_func
randval_dram_window,
lse_acc_dram_window, lse_acc_dram_window,
identity{}, // lse_element_func identity{}, // lse_element_func
identity{}, // s_acc_element_func identity{}, // s_acc_element_func
...@@ -864,24 +870,23 @@ struct FmhaFwdSplitKVKernel ...@@ -864,24 +870,23 @@ struct FmhaFwdSplitKVKernel
mask, mask,
position_encoding, position_encoding,
kargs.scale_s, kargs.scale_s,
smem_ptr, smem_ptr);
dropout);
} }
else else
{ {
return FmhaPipeline{}(q_dram_window, return FmhaPipeline{}(q_dram_window,
k_dram_window, k_dram_window_lengths,
v_dram_window, k_page_block_navigator,
v_dram_window_lengths,
v_page_block_navigator,
bias_dram_window, bias_dram_window,
randval_dram_window,
lse_acc_dram_window, lse_acc_dram_window,
kargs.num_splits, kargs.num_splits,
i_split_, i_split_,
mask, mask,
position_encoding, position_encoding,
kargs.scale_s, kargs.scale_s,
smem_ptr, smem_ptr);
dropout);
} }
}(); }();
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
namespace ck_tile {
template <typename Problem_, typename Policy_ = BlockFmhaFwdAppendKVPipelineDefaultPolicy>
struct BlockFmhaFwdAppendKVPipeline
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using QDataType = typename Problem::QDataType;
using KDataType = typename Problem::KDataType;
using VDataType = typename Problem::VDataType;
using VLayout = typename Problem::VLayout;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = Problem::kM0;
static constexpr index_t kN0 = Problem::kN0;
static constexpr index_t kK0 = Problem::kK0;
static constexpr index_t kN1 = Problem::kN1;
static constexpr auto RotaryEnum = Problem::RotaryEnum;
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
if constexpr(kK0 <= 32)
{
return 2;
}
else if constexpr(kK0 <= 64)
{
return 3;
}
else if constexpr(kK0 <= 128)
{
return 2;
}
else if constexpr(kK0 <= 256)
{
return 1;
}
}
}();
template <typename QDramBlockWindow,
typename KDramBlockWindow,
typename KPageBlockNavigator,
typename KnewDramBlockWindow,
typename VDramBlockWindow,
typename VPageBlockNavigator,
typename VnewDramBlockWindow,
typename QElementFunction,
typename KnewElementFunction,
typename VnewElementFunction,
typename QRotaryCosDramBlockWindow,
typename QRotarySinDramBlockWindow,
typename KnewRotaryCosDramBlockWindow,
typename KnewRotarySinDramBlockWindow>
CK_TILE_HOST_DEVICE auto
operator()(QDramBlockWindow& q_dram_block_window, // M0*K0 tile
const QElementFunction& q_element_func,
KDramBlockWindow& k_dram_block_window, // N0*K0 tile
index_t i_page_block_k,
const KPageBlockNavigator& k_page_block_navigator,
const KnewDramBlockWindow& knew_dram_block_window, // N0*K0 tile
const KnewElementFunction& knew_element_func,
VDramBlockWindow& v_dram_block_window, // N1*N0 tile
index_t i_page_block_v,
const VPageBlockNavigator& v_page_block_navigator,
const VnewDramBlockWindow& vnew_dram_block_window, // N1*N0 tile
const VnewElementFunction& vnew_element_func,
const QRotaryCosDramBlockWindow q_rotary_cos_dram_block_window,
const QRotarySinDramBlockWindow q_rotary_sin_dram_block_window,
const KnewRotaryCosDramBlockWindow knew_rotary_cos_dram_block_window,
const KnewRotarySinDramBlockWindow knew_rotary_sin_dram_block_window,
index_t rotary_dim,
bool skip_rotate_q,
bool skip_rotate_append_kv) const
{
if(!skip_rotate_append_kv)
{
// append Knew to K
auto knew_window = make_tile_window(
knew_dram_block_window, Policy::template MakeKnewDramTileDistribution<Problem>());
auto knew_tile = [&]() {
auto knew = load_tile(knew_window);
return tile_elementwise_in(knew_element_func, knew);
}();
// optionally apply rotary embedding to Knew
if constexpr(RotaryEnum != RotaryEmbeddingEnum::NONE)
{
auto rotary_cos_window =
make_tile_window(knew_rotary_cos_dram_block_window,
Policy::template MakeRotaryCosSinTileDistribution<
Problem,
/*IsRotaryCosSinForQ=*/false>());
auto rotary_sin_window =
make_tile_window(knew_rotary_sin_dram_block_window,
Policy::template MakeRotaryCosSinTileDistribution<
Problem,
/*IsRotaryCosSinForQ=*/false>());
// We assume that each thread owns contiguous elements on head dimention. And we
// will use the distribution to enable/disable threads in order to override partial
// knew_tile content
auto [thread_start, thread_end] =
Policy::template GetKnewThreadRangeAlongK<Problem>();
ignore = thread_start;
BlockRotaryEmbedding<RotaryEnum>::apply(knew_tile,
knew_window,
rotary_cos_window,
rotary_sin_window,
rotary_dim,
thread_end);
}
store_tile(k_dram_block_window, knew_tile);
// write tile to another block if nesscary
if constexpr(kIsPagedKV)
{
if(k_page_block_navigator.is_cross_block(i_page_block_k, k_dram_block_window))
{
k_page_block_navigator.move_to_block(
i_page_block_k, k_dram_block_window, i_page_block_k + 1);
store_tile(k_dram_block_window, knew_tile);
}
}
// append Vnew to V
auto vnew_window = make_tile_window(
vnew_dram_block_window, Policy::template MakeVnewDramTileDistribution<Problem>());
auto vnew_tile = [&]() {
auto vnew = load_tile(vnew_window);
return tile_elementwise_in(vnew_element_func, vnew);
}();
store_tile(v_dram_block_window, vnew_tile);
// write tile to another block if nesscary
if constexpr(kIsPagedKV)
{
if(v_page_block_navigator.is_cross_block(i_page_block_v, v_dram_block_window))
{
v_page_block_navigator.move_to_block(
i_page_block_v, v_dram_block_window, i_page_block_v + 1);
store_tile(v_dram_block_window, vnew_tile);
}
}
}
if(!skip_rotate_q)
{
// optionally apply rotary embedding to Q
if constexpr(RotaryEnum != RotaryEmbeddingEnum::NONE)
{
auto q_window = make_tile_window(
q_dram_block_window, Policy::template MakeQDramTileDistribution<Problem>());
auto q_tile = [&]() {
auto q = load_tile(q_window);
return tile_elementwise_in(q_element_func, q);
}();
auto rotary_cos_window =
make_tile_window(q_rotary_cos_dram_block_window,
Policy::template MakeRotaryCosSinTileDistribution<
Problem,
/*IsRotaryCosSinForQ=*/true>());
auto rotary_sin_window =
make_tile_window(q_rotary_sin_dram_block_window,
Policy::template MakeRotaryCosSinTileDistribution<
Problem,
/*IsRotaryCosSinForQ=*/true>());
// We assume that each thread owns contiguous elements on head dimention. And we
// will use the distribution to enable/disable threads in order to override partial
// q_tile content
auto [thread_start, thread_end] = Policy::template GetQThreadRangeAlongK<Problem>();
ignore = thread_start;
BlockRotaryEmbedding<RotaryEnum>::apply(
q_tile, q_window, rotary_cos_window, rotary_sin_window, rotary_dim, thread_end);
store_tile(q_dram_block_window, q_tile);
}
}
}
template <typename QDramBlockWindow,
typename KDramBlockWindow,
typename KPageBlockNavigator,
typename KnewDramBlockWindow,
typename VDramBlockWindow,
typename VPageBlockNavigator,
typename VnewDramBlockWindow,
typename QRotaryCosDramBlockWindow,
typename QRotarySinDramBlockWindow,
typename KnewRotaryCosDramBlockWindow,
typename KnewRotarySinDramBlockWindow>
CK_TILE_HOST_DEVICE auto
operator()(QDramBlockWindow& q_dram_block_window,
KDramBlockWindow& k_dram_block_window,
index_t i_page_block_k,
const KPageBlockNavigator& k_page_block_navigator,
const KnewDramBlockWindow& knew_dram_block_window,
VDramBlockWindow& v_dram_block_window,
index_t i_page_block_v,
const VPageBlockNavigator& v_page_block_navigator,
const VnewDramBlockWindow& vnew_dram_block_window,
const QRotaryCosDramBlockWindow& q_rotary_cos_dram_block_window,
const QRotarySinDramBlockWindow& q_rotary_sin_dram_block_window,
const KnewRotaryCosDramBlockWindow& knew_rotary_cos_dram_block_window,
const KnewRotarySinDramBlockWindow& knew_rotary_sin_dram_block_window,
index_t rotary_dim,
bool skip_rotate_q,
bool skip_rotate_append_kv) const
{
return operator()(q_dram_block_window,
identity{},
k_dram_block_window,
i_page_block_k,
k_page_block_navigator,
knew_dram_block_window,
identity{},
v_dram_block_window,
i_page_block_v,
v_page_block_navigator,
vnew_dram_block_window,
identity{},
q_rotary_cos_dram_block_window,
q_rotary_sin_dram_block_window,
knew_rotary_cos_dram_block_window,
knew_rotary_sin_dram_block_window,
rotary_dim,
skip_rotate_q,
skip_rotate_append_kv);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
// This pipeline is qkv all located in LDS
struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
return 16 / sizeof(QDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK()
{
using KDataType = remove_cvref_t<typename Problem::KDataType>;
return 16 / sizeof(KDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
{
using VLayout = remove_cvref_t<typename Problem::VLayout>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::kN0;
constexpr index_t kKPerBlock = Problem::kN1;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
// TODO: not correct!
if constexpr(total_pixels > 4)
return 4;
else
return 2;
}
else
{
return 16 / sizeof(VDataType);
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQNumElemsPerRead()
{
using DataType = typename Problem::QDataType;
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
{
/// NOTICE: we might need to lower down this to support smaller rotary_dim
return 16 / sizeof(DataType);
}
else
{
return 16 / sizeof(DataType);
}
}
template <typename Problem>
CK_TILE_DEVICE static auto GetQThreadRangeAlongK()
{
static_assert(Problem::RotaryEnum != RotaryEmbeddingEnum::NONE);
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
{
constexpr index_t KPerThread = GetQNumElemsPerRead<Problem>();
static_assert(Problem::kK0 % KPerThread == 0);
constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread;
index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread;
return make_tuple(start_pos, start_pos + KPerThread);
}
else
{
constexpr index_t KPerThread = GetQNumElemsPerRead<Problem>();
static_assert(Problem::kK0 % KPerThread == 0);
constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread;
index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread;
return make_tuple(start_pos, start_pos + KPerThread);
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kKPerBlock = Problem::kK0;
constexpr index_t KPerThread = GetQNumElemsPerRead<Problem>();
constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreadPerBlock;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (NumWarps * MThreadPerWarp);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
sequence<KThreadPerBlock, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKnewNumElemsPerRead()
{
using DataType = typename Problem::KDataType;
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
{
/// NOTICE: we might need to lower down this to support smaller rotary_dim
return 16 / sizeof(DataType);
}
else
{
return 16 / sizeof(DataType);
}
}
template <typename Problem>
CK_TILE_DEVICE static auto GetKnewThreadRangeAlongK()
{
static_assert(Problem::RotaryEnum != RotaryEmbeddingEnum::NONE);
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
{
constexpr index_t KPerThread = GetKnewNumElemsPerRead<Problem>();
constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread;
index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread;
return make_tuple(start_pos, start_pos + KPerThread);
}
else
{
constexpr index_t KPerThread = GetKnewNumElemsPerRead<Problem>();
constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread;
index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread;
return make_tuple(start_pos, start_pos + KPerThread);
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKnewDramTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::kN0;
constexpr index_t kKPerBlock = Problem::kK0;
constexpr index_t KPerThread = GetKnewNumElemsPerRead<Problem>();
constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t NPerThread = kNPerBlock / (NumWarps * NThreadPerWarp);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
sequence<KThreadPerBlock, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
{
// TODO: this is for 3d layout
using VDataType = remove_cvref_t<typename Problem::VDataType>;
return 16 / sizeof(VDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVnewDramTileDistribution()
{
using VLayout = remove_cvref_t<typename Problem::VLayout>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::kN1;
constexpr index_t kKPerBlock = Problem::kN0;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t NPerThread = 16 / sizeof(VDataType);
constexpr index_t NThreadPerBlock = kNPerBlock / NPerThread;
constexpr index_t KThreadPerWarp = get_warp_size() / NThreadPerBlock;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t KPerThread = kKPerBlock / (NumWarps * KThreadPerWarp);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NThreadPerBlock, NPerThread>,
sequence<KPerThread, NumWarps, KThreadPerWarp>>,
tuple<sequence<2>, sequence<1, 2>>,
tuple<sequence<1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 0>>{});
}
else
{
constexpr index_t KPerThread = 16 / sizeof(VDataType);
constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t NPerThread = kNPerBlock / (NumWarps * NThreadPerWarp);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
sequence<KThreadPerBlock, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
}
template <typename Problem, bool IsRotaryCosSinForQ>
CK_TILE_HOST_DEVICE static constexpr auto GetRotaryCosSinTileSize()
{
constexpr index_t height = (IsRotaryCosSinForQ ? Problem::kM0 : Problem::kN0);
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
{
return make_tuple(number<height>{}, number<Problem::kK0>{});
}
else
{
return make_tuple(number<height>{}, number<Problem::kK0 / 2>{});
}
}
template <typename Problem, bool IsRotaryCosSinForQ>
CK_TILE_HOST_DEVICE static constexpr auto MakeRotaryCosSinTileDistribution()
{
using DataType = std::conditional_t<IsRotaryCosSinForQ,
typename Problem::QDataType,
typename Problem::KDataType>;
constexpr auto TileSize = GetRotaryCosSinTileSize<Problem, IsRotaryCosSinForQ>();
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = TileSize[number<0>{}];
constexpr index_t kKPerBlock = TileSize[number<1>{}];
constexpr index_t KPerThread = []() {
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
{
/// NOTICE: we might need to lower down this to support smaller rotary_dim
return 16 / sizeof(DataType);
}
else
{
return 8 / sizeof(DataType);
}
}();
constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t NPerThread = kNPerBlock / (NumWarps * NThreadPerWarp);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
sequence<KThreadPerBlock, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
};
} // namespace ck_tile
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -15,19 +14,18 @@ namespace ck_tile { ...@@ -15,19 +14,18 @@ namespace ck_tile {
template <typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy> template <typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
struct BlockFmhaFwdSplitKVPipelineQRKSVS struct BlockFmhaFwdSplitKVPipelineQRKSVS
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>; using Policy = remove_cvref_t<Policy_>;
using QDataType = remove_cvref_t<typename Problem::QDataType>; using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>; using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>; using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>; using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>; using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>; using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>; using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>; using PDataType = remove_cvref_t<typename Problem::PDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>; using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>; using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>; using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
...@@ -49,8 +47,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -49,8 +47,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = true; // always store LSE (acc) static constexpr bool kStoreLSE = true; // always store LSE (acc)
static constexpr bool kHasDropout = false; // ignore this flag static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
// last dimension vector length used to create tensor view(and decide buffer_load vector length) // last dimension vector length used to create tensor view(and decide buffer_load vector length)
...@@ -106,10 +104,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -106,10 +104,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
} }
template <typename QDramBlockWindowTmp, template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp, typename KDramBlockWindowLengths,
typename VDramBlockWindowTmp, typename KPageBlockNavigator,
typename VDramBlockWindowLengths,
typename VPageBlockNavigator,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp, typename LSEaccDramBlockWindowTmp,
typename QElementFunction, typename QElementFunction,
typename KElementFunction, typename KElementFunction,
...@@ -123,13 +122,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -123,13 +122,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func, const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
const KPageBlockNavigator& k_page_block_navigator,
const KElementFunction& k_element_func, const KElementFunction& k_element_func,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
const VPageBlockNavigator& v_page_block_navigator,
const VElementFunction& v_element_func, const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func, const BiasElementFunction& bias_element_func,
RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile
const LSEaccElementFunction& lse_acc_element_func, const LSEaccElementFunction& lse_acc_element_func,
const SAccElementFunction& s_acc_element_func, const SAccElementFunction& s_acc_element_func,
...@@ -140,20 +140,19 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -140,20 +140,19 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, void* smem_ptr) const
BlockDropout& dropout) const
{ {
static_assert( static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> && std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> && std::is_same_v<KDataType, remove_cvref_t<typename KPageBlockNavigator::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>, std::is_same_v<VDataType, remove_cvref_t<typename VPageBlockNavigator::DataType>>,
"wrong!"); "wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN0 == KDramBlockWindowLengths{}[number<0>{}] &&
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kK0 == KDramBlockWindowLengths{}[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN1 == VDramBlockWindowLengths{}[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kK1 == VDramBlockWindowLengths{}[number<1>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!"); "wrong!");
...@@ -213,12 +212,12 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -213,12 +212,12 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX( const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split); q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit if masked and no work to do. // check early exit if masked and no work to do.
if constexpr(FmhaMask::IsMasking || kHasUnevenSplits) if constexpr(FmhaMask::IsMasking || kHasUnevenSplits)
{ {
if(num_total_loop <= 0) const index_t original_num_total_loop =
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
if(original_num_total_loop <= 0)
{ {
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
...@@ -237,26 +236,34 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -237,26 +236,34 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
} }
} }
auto k_dram_block_window = // make sure the first tile is completely located in page-block
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), const index_t adjusted_seqlen_k_start = [&, seqlen_k_start_ = seqlen_k_start] {
k_dram_block_window_tmp.get_window_lengths(), if constexpr(kIsPagedKV)
{seqlen_k_start, 0}); {
return kN0 * integer_divide_floor(seqlen_k_start_, kN0);
}
else
{
return seqlen_k_start_;
}
}();
const index_t num_total_loop =
integer_divide_ceil(seqlen_k_end - adjusted_seqlen_k_start, kN0);
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
k_dram_block_window_lengths, {adjusted_seqlen_k_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window = make_tile_window( auto bias_dram_window = make_tile_window(
bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(), bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N {bias_origin.at(number<0>{}), adjusted_seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0)>( auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
randval_dram_block_window_tmp, seqlen_k_start); v_dram_block_window_lengths,
{0, adjusted_seqlen_k_start}, // TODO: hdim split?
auto v_dram_window = Policy::template MakeVDramTileDistribution<Problem>());
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
auto q_tile = tile_elementwise_in(q_element_func, q); auto q_tile = tile_elementwise_in(q_element_func, q);
...@@ -271,14 +278,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -271,14 +278,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
{ {
// STAGE 1, QK gemm // STAGE 1, QK gemm
auto k_dram_window = make_tile_window( auto k_dram_window = make_tile_window(
k_dram_block_window.get_bottom_tensor_view(), k_dram_block_window,
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load // load
auto k_block_tile = load_tile(k_dram_window); auto k_block_tile = load_tile(k_dram_window);
{ {
// moving k_dram_window is an in-page-block operation, so there is
// no need to invoke k_page_block_navigator.move_tile_window() here.
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
clear_tile(s_acc); // initialize C clear_tile(s_acc); // initialize C
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
...@@ -355,7 +362,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -355,7 +362,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
} }
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
const auto k_origin = k_dram_block_window.get_window_origin(); const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin());
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc); s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
...@@ -381,22 +389,32 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -381,22 +389,32 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
} }
move_tile_window(bias_dram_window, {0, kN0}); move_tile_window(bias_dram_window, {0, kN0});
/// TODO: only check in last iteration without increasing code size /// TODO: only check in first/last iteration without increasing code size
if constexpr(kHasUnevenSplits) if constexpr(kHasUnevenSplits)
{ {
const auto k_origin = k_dram_block_window.get_window_origin(); const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin());
set_tile_if(s_acc, set_tile_if(s_acc,
-numeric<SMPLComputeDataType>::infinity(), -numeric<SMPLComputeDataType>::infinity(),
[&, seqlen_k_end_ = seqlen_k_end](auto tile_idx) { [&, seqlen_k_start_ = seqlen_k_start, seqlen_k_end_ = seqlen_k_end](
auto tile_idx) {
const auto col = const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return seqlen_k_end_ <= col; if constexpr(kIsPagedKV)
{
return col < seqlen_k_start_ || seqlen_k_end_ <= col;
}
else
{
return seqlen_k_end_ <= col;
}
}); });
} }
if constexpr(kPadSeqLenK || FmhaMask::IsMasking) if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{ {
const auto k_origin = k_dram_block_window.get_window_origin(); const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin());
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}), k_origin.at(number<0>{}),
number<kM0>{}, number<kM0>{},
...@@ -501,12 +519,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -501,12 +519,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}); });
}); });
if constexpr(kHasDropout)
{
dropout.Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window);
}
block_sync_lds(); block_sync_lds();
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
...@@ -522,7 +534,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -522,7 +534,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
store_tile(v_lds_window, store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
} }
move_tile_window(v_dram_window, {0, kK1}); i_page_block_v =
v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {0, kK1});
const auto p = const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute)); cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
...@@ -530,8 +543,10 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -530,8 +543,10 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
// STAGE 3, KV gemm // STAGE 3, KV gemm
if constexpr(k1_loops > 1) if constexpr(k1_loops > 1)
{ {
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { static_for<0, k1_loops - 1, 1>{}([&,
const auto v = load_tile(v_dram_window); // load next v &i_page_block_v_ = i_page_block_v,
&v_dram_window_ = v_dram_window](auto i_k1) {
const auto v = load_tile(v_dram_window_); // load next v
block_sync_lds(); block_sync_lds();
gemm_1(o_acc, gemm_1(o_acc,
get_slice_tile( get_slice_tile(
...@@ -552,11 +567,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -552,11 +567,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
store_tile(v_lds_window, store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v)); // store next v tile_elementwise_in(v_element_func, v)); // store next v
} }
move_tile_window(v_dram_window, {0, kK1}); i_page_block_v_ = v_page_block_navigator.move_tile_window(
i_page_block_v_, v_dram_window_, {0, kK1});
}); });
} }
// move K tile windows // move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0}); i_page_block_k = k_page_block_navigator.move_tile_window(
i_page_block_k, k_dram_block_window, {kN0, 0});
// tail // tail
{ {
block_sync_lds(); block_sync_lds();
...@@ -618,36 +635,38 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -618,36 +635,38 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
} }
template <typename QDramBlockWindowTmp, template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp, typename KDramBlockWindowLengths,
typename VDramBlockWindowTmp, typename KPageBlockNavigator,
typename VDramBlockWindowLengths,
typename VPageBlockNavigator,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp, typename LSEaccDramBlockWindowTmp,
typename PositionEncoding> typename PositionEncoding>
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const KPageBlockNavigator& k_page_block_navigator,
const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
const VPageBlockNavigator& v_page_block_navigator,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile
index_t num_splits, index_t num_splits,
index_t i_split, index_t i_split,
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, void* smem_ptr) const
BlockDropout& dropout) const
{ {
return operator()(q_dram_block_window_tmp, return operator()(q_dram_block_window_tmp,
identity{}, identity{},
k_dram_block_window_tmp, k_dram_block_window_lengths,
k_page_block_navigator,
identity{}, identity{},
v_dram_block_window_tmp, v_dram_block_window_lengths,
v_page_block_navigator,
identity{}, identity{},
bias_dram_block_window_tmp, bias_dram_block_window_tmp,
identity{}, identity{},
randval_dram_block_window_tmp,
lse_acc_dram_block_window_tmp, lse_acc_dram_block_window_tmp,
identity{}, identity{},
identity{}, identity{},
...@@ -658,8 +677,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -658,8 +677,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
mask, mask,
position_encoding, position_encoding,
scale_s, scale_s,
smem_ptr, smem_ptr);
dropout);
} }
}; };
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
template <typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSAsyncDefaultPolicy>
struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true &&
Problem::kPadHeadDimV == true);
static constexpr bool kPadSeqLenQ = true;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x)
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = true; // always store LSE (acc)
static constexpr bool kHasDropout = false; // ignore this flag
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK = Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentO = Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
#if CK_TILE_FMHA_FWD_FAST_EXP2
static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;
#endif
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
if constexpr(kK0BlockLength <= 32)
{
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS &&
FmhaMask::IsMasking)
return 1;
else
return 2;
}
else if constexpr(kK0BlockLength <= 64)
{
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 2;
else
return 3;
}
else if constexpr(kK0BlockLength <= 128)
{
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1;
else
return 2;
}
else if constexpr(kK0BlockLength <= 256)
{
return 1;
}
}
}();
static constexpr const char* name = "qr_async";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp,
typename QElementFunction,
typename KElementFunction,
typename VElementFunction,
typename BiasElementFunction,
typename LSEaccElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const KElementFunction& /*k_element_func*/,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile
const LSEaccElementFunction& lse_acc_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
index_t num_splits,
index_t i_split,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
BlockDropout& dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
constexpr auto LdsSeq = Policy::template GetLdsBufferSequence<Problem>();
// K tile in LDS
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
auto k_lds_store = generate_tuple(
[&](auto i_buf) {
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf)),
Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf).get_lengths(),
{0, 0, 0});
},
number<Policy::NumPrefetchK>{});
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
auto k_lds_load = generate_tuple(
[&](auto i_buf) {
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor<Problem>(i_buf)),
Policy::template MakeKLdsLoadBlockDescriptor<Problem>(i_buf).get_lengths(),
{0, 0});
},
number<Policy::NumPrefetchK>{});
#else
auto k_lds_Load_view = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor<Problem>());
auto k_lds_load =
make_tile_window(k_lds_Load_view,
Policy::template MakeKLdsLoadBlockDescriptor<Problem>().get_lengths(),
{0, 0});
#endif
// V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(smem_ptr),
Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
auto q_dram_window = make_tile_window(
q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem, decltype(gemm_0)>());
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
auto q = decltype(load_tile(q_dram_window)){};
set_tile(q, number<0>{}); // use per-dword clear to avoid scratch
load_tile_raw(q, q_dram_window);
__builtin_amdgcn_sched_barrier(0);
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
auto s_acc = SaccBlockTileType{};
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
// infer Sacc, S, P, M, L, Oacc type
using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_acc));
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
// init Oacc, M, L
auto o_acc = OaccBlockTileType{};
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
__builtin_amdgcn_sched_barrier(0);
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit if masked and no work to do.
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
{
if(num_total_loop <= 0)
{
if constexpr(kStoreLSE)
{
auto lse_acc =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
store_tile(lse_acc_dram_window_tmp,
tile_elementwise_in(lse_acc_element_func, lse_acc));
}
buffer_load_fence_raw(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it
return o_acc;
}
__builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check
}
auto k_dram_block_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0});
auto k_dram_window = make_tile_window(
k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window = make_tile_window(
bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start);
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
// prefetch K tile
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0);
buffer_load_fence_raw(k_dram_window.get_num_access(), q.get_thread_buffer());
(void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
index_t i_total_loops = 0;
constexpr index_t k0_loops = kK0BlockLength / kK0;
constexpr index_t k1_loops = kN0 / kK1;
static_assert(1 <= k0_loops);
static_assert(1 <= k1_loops);
// main loop
do
{
// STAGE 1, QK gemm
clear_tile(s_acc); // initialize C
if constexpr(k0_loops > 1)
{
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}),
k_dram_window);
if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0});
async_load_fence_raw(k_dram_window.get_num_access());
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
gemm_0(s_acc,
get_slice_tile(
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
k_lds_load[number<LdsSeq.at(number<i_k0>{})>{}]);
#else
get_slice_tile(k_lds_load,
sequence<(LdsSeq.at(number<i_k0>{})) * kN0, 0>{},
sequence<(LdsSeq.at(number<i_k0>{}) + 1) * kN0, kK0>{}));
#endif
});
}
// TODO: this to fix a bug when loop smaller than 2,
// the following fence/barrier will be scheduled inside 1st loop
if constexpr(k0_loops <= 2)
__builtin_amdgcn_sched_barrier(0);
async_load_fence_raw();
__builtin_amdgcn_s_barrier();
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
auto v_buf = load_tile(v_dram_window, bool_constant<false>{});
__builtin_amdgcn_sched_barrier(0);
{ // tail
gemm_0(s_acc,
get_slice_tile(
q, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kM0, k0_loops * kK0>{}),
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
k_lds_load[number<LdsSeq.at(number<k0_loops - 1>{})>{}]);
#else
get_slice_tile(
k_lds_load,
sequence<(LdsSeq.at(number<k0_loops - 1>{})) * kN0, 0>{},
sequence<(LdsSeq.at(number<k0_loops - 1>{}) + 1) * kN0, kK0>{}));
#endif
}
__builtin_amdgcn_sched_barrier(1);
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x += type_convert<SaccDataType>(bias_element_func(y));
#else
x += log2e_v<SaccDataType> *
type_convert<SaccDataType>(bias_element_func(y));
#endif
},
s_acc,
bias_tile);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= scale_s;
position_encoding.update(s_acc(i_j_idx), row, col);
});
});
}
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
#endif
}
move_tile_window(bias_dram_window, {0, kN0});
/// TODO: only check in last iteration without increasing code size
if constexpr(kHasUnevenSplits)
{
const auto k_origin = k_dram_block_window.get_window_origin();
set_tile_if(s_acc,
-numeric<SMPLComputeDataType>::infinity(),
[&, seqlen_k_end_ = seqlen_k_end](auto tile_idx) {
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return seqlen_k_end_ <= col;
});
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
const auto k_origin = k_dram_block_window.get_window_origin();
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s,
sequence<1>{},
f_max,
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
s.get_tile_distribution()); // Pcompute{j}
__builtin_amdgcn_sched_barrier(0x7F);
// store & prefetch next v, after the max reduction
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf);
auto v_lds_window_tmp =
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(
v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
}
else
{
auto v_lds_window_tmp =
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_buf)); // store the prefetch
}
if constexpr(k1_loops > 1)
{
move_tile_window(
v_dram_window,
{0, kK1}); // will have scratch if move this right after load_tile(v_dram)...
v_buf = load_tile(v_dram_window, bool_constant<false>{}); // load next v_buf
}
__builtin_amdgcn_sched_barrier(0);
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration. alibi does not have this problem
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}
else
{
return raw_m;
}
};
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_max = scale_s * get_validated_m(m[i_idx]);
#endif
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
}
#else
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
#endif
});
});
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
// l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
}
}();
#else
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
#endif
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc(i_j_idx) *= tmp;
});
});
if constexpr(kHasDropout)
{
auto randval_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
dropout.Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
randval_ptr,
seqlen_k_start + i_total_loops * kN0,
p_compute,
randval_dram_window);
}
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
// STAGE 3, KV gemm
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1)
{
v_buf = load_tile(v_dram_window, bool_constant<false>{}); // load next v_buf
}
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) * kN1, kK1>{}));
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf);
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func,
v_shuffle_tmp)); // store the prefetch
}
else
{
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_buf)); // store next v_buf
}
if constexpr(i_k1 < k1_loops - 1)
move_tile_window(v_dram_window, {0, kK1});
});
}
i_total_loops++;
if(i_total_loops < num_total_loop)
{
// move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0});
k_dram_window =
make_tile_window(k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>());
if constexpr(k1_loops >= 2 &&
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
__builtin_amdgcn_s_barrier();
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
}
// tail
{
block_sync_lds();
gemm_1(
o_acc,
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1, kK1>{}));
}
} while(i_total_loops < num_total_loop);
// store lse acc
if constexpr(kStoreLSE)
{
auto lse_acc = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans();
sweep_tile_span(lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
lse_acc(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]);
}
else
{
lse_acc(i_idx) = m_[i_idx] * scale_s * R_LOG2E + log(l_[i_idx]);
}
#else
lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]);
#endif
});
store_tile(lse_acc_dram_window_tmp, tile_elementwise_in(lse_acc_element_func, lse_acc));
}
// finally, O
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = [&]() {
if constexpr(FmhaMask::IsMasking)
{
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
}
else
return 1 / l[i_idx];
}();
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
});
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile
index_t num_splits,
index_t i_split,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
BlockDropout& dropout) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
identity{},
v_dram_block_window_tmp,
identity{},
bias_dram_block_window_tmp,
identity{},
randval_dram_block_window_tmp,
lse_acc_dram_block_window_tmp,
identity{},
identity{},
identity{},
identity{},
num_splits,
i_split,
mask,
position_encoding,
scale_s,
smem_ptr,
dropout);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
namespace ck_tile {
// This pipeline is qkv all located in LDS
using BlockFmhaFwdSplitKVPipelineQRKSVSAsyncDefaultPolicy =
BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ true,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 3,
/* NumPrefetchV = */ 3>;
} // namespace ck_tile
...@@ -54,38 +54,50 @@ struct BlockFmhaPipelineProblem ...@@ -54,38 +54,50 @@ struct BlockFmhaPipelineProblem
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
}; };
template <typename QDataType, template <typename QDataType_,
typename KDataType, typename KDataType_,
typename VDataType, typename VDataType_,
typename SaccDataType, typename SaccDataType_,
typename SMPLComputeDataType, typename SMPLComputeDataType_,
typename BiasDataType, typename BiasDataType_,
typename RandValOutputDataType, typename LSEDataType_,
typename LSEDataType, typename PDataType_,
typename PDataType, typename OaccDataType_,
typename OaccDataType, typename ODataType_,
typename ODataType, typename BlockFmhaShape_,
typename BlockFmhaShape, bool kIsGroupMode_,
bool kIsGroupMode, typename FmhaMask_,
typename FmhaMask, typename Traits_>
typename Traits> struct BlockFmhaFwdSplitKVPipelineProblem
struct BlockFmhaFwdSplitKVPipelineProblem : BlockFmhaPipelineProblem<QDataType,
KDataType,
VDataType,
SaccDataType,
SMPLComputeDataType,
BiasDataType,
RandValOutputDataType,
LSEDataType,
PDataType,
OaccDataType,
ODataType,
BlockFmhaShape,
kIsGroupMode,
FmhaMask,
Traits>
{ {
static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits; using QDataType = remove_cvref_t<QDataType_>;
using KDataType = remove_cvref_t<KDataType_>;
using VDataType = remove_cvref_t<VDataType_>;
using SaccDataType = remove_cvref_t<SaccDataType_>;
using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>;
using BiasDataType = remove_cvref_t<BiasDataType_>;
using LSEDataType = remove_cvref_t<LSEDataType_>;
using PDataType = remove_cvref_t<PDataType_>;
using OaccDataType = remove_cvref_t<OaccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
using FmhaMask = remove_cvref_t<FmhaMask_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
}; };
template <typename LSEDataType_, template <typename LSEDataType_,
...@@ -119,4 +131,44 @@ struct BlockFmhaSplitKVCombinePipelineProblem ...@@ -119,4 +131,44 @@ struct BlockFmhaSplitKVCombinePipelineProblem
static constexpr index_t kMaxSplits = Traits::kMaxSplits; static constexpr index_t kMaxSplits = Traits::kMaxSplits;
}; };
template <typename QDataType_,
typename KDataType_,
typename VDataType_,
index_t kM0_,
index_t kN0_,
index_t kK0_,
index_t kN1_,
bool kIsVLayoutRowMajor_,
RotaryEmbeddingEnum RotaryEnum_,
bool kIsPagedKV_,
typename Traits_>
struct BlockFmhaFwdAppendKVPipelineProblem
{
using QDataType = remove_cvref_t<QDataType_>;
using KDataType = remove_cvref_t<KDataType_>;
using VDataType = remove_cvref_t<VDataType_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = 256;
static constexpr index_t kM0 = kM0_;
static constexpr index_t kN0 = kN0_;
static constexpr index_t kK0 = kK0_;
static constexpr index_t kN1 = kN1_;
using VLayout = std::conditional_t<kIsVLayoutRowMajor_,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor>;
static constexpr auto RotaryEnum = RotaryEnum_;
static constexpr bool kIsPagedKV = kIsPagedKV_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
} // namespace ck_tile } // namespace ck_tile
...@@ -707,16 +707,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -707,16 +707,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
{ {
if constexpr(AsyncCopyK) if constexpr(AsyncCopyK)
{ {
return GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>(); return GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>(0);
} }
else else
{ {
return ck_tile::max(GetSmemSizeKV<Problem>(), GetSmemSizeDropout<Problem>()); return ck_tile::max(GetSmemSizeKV<Problem>(), GetSmemSizeDropout<Problem>(0));
} }
} }
// this method is only available when Problem::kHasDropout is present
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeDropout() CK_TILE_HOST_DEVICE static constexpr std::
enable_if_t<std::is_convertible_v<decltype(Problem::kHasDropout), bool>, ck_tile::index_t>
GetSmemSizeDropout(int)
{ {
if constexpr(Problem::kHasDropout) if constexpr(Problem::kHasDropout)
{ {
...@@ -736,6 +739,13 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -736,6 +739,13 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
} }
} }
// fallback version if Problem::kHasDropout is not exist
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeDropout(...)
{
return 0;
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution()
{ {
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -32,30 +33,31 @@ struct TileFmhaTraits ...@@ -32,30 +33,31 @@ struct TileFmhaTraits
static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr index_t kBlockPerCu = kBlockPerCu_;
}; };
template <bool kPadSeqLenQ /* padding for seqlen_q */, template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadSeqLenK /* padding for seqlen_k */, bool kPadSeqLenK_ /* padding for seqlen_k */,
bool kPadHeadDimQ /* paddding for hdim_q */, bool kPadHeadDimQ_ /* paddding for hdim_q */,
bool kPadHeadDimV /* paddding for hdim_v */, bool kPadHeadDimV_ /* paddding for hdim_v */,
BlockAttentionBiasEnum BiasEnum, BlockAttentionBiasEnum BiasEnum_,
bool kHasBiasGrad, bool kHasBiasGrad_,
bool kStoreLSE, bool kStoreLSE_,
bool kHasDropout, bool kDoFp8StaticQuant_,
bool kDoFp8StaticQuant, bool kIsPagedKV_,
bool kHasUnevenSplits_ = true, bool kHasUnevenSplits_,
index_t kBlockPerCu = -1 /* overwrite occupancy if not -1 */> index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileFmhaFwdSplitKVTraits : TileFmhaTraits<kPadSeqLenQ, struct TileFmhaFwdSplitKVTraits
kPadSeqLenK,
kPadHeadDimQ,
kPadHeadDimV,
BiasEnum,
kHasBiasGrad,
kStoreLSE,
kHasDropout,
kDoFp8StaticQuant,
kBlockPerCu>
{ {
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr bool kStoreLSE = kStoreLSE_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kIsPagedKV = kIsPagedKV_;
// determine if some split (length) is not divisible by tile size // determine if some split (length) is not divisible by tile size
static constexpr bool kHasUnevenSplits = kHasUnevenSplits_; static constexpr bool kHasUnevenSplits = kHasUnevenSplits_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
}; };
template <bool kPadSeqLenQ_ /* padding for seqlen_q */, template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
...@@ -76,6 +78,20 @@ struct TileFmhaFwdSplitKVCombineTraits ...@@ -76,6 +78,20 @@ struct TileFmhaFwdSplitKVCombineTraits
static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr index_t kBlockPerCu = kBlockPerCu_;
}; };
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadSeqLenK_ /* padding for seqlen_k */,
bool kPadHeadDimQ_ /* paddding for hdim_q */,
bool kPadHeadDimV_ /* paddding for hdim_v */,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileFmhaFwdAppendKVTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
template <bool kPadSeqLenQ_ /* padding for seqlen_q */, template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadHeadDimV_ /* paddding for hdim_v */, bool kPadHeadDimV_ /* paddding for hdim_v */,
index_t kBlockPerCu_ = 2 /* hint to occupancy */> index_t kBlockPerCu_ = 2 /* hint to occupancy */>
......
...@@ -184,6 +184,43 @@ using device_grouped_conv_fwd_xdl_outelementop_bf8_f8_instances = std::tuple< ...@@ -184,6 +184,43 @@ using device_grouped_conv_fwd_xdl_outelementop_bf8_f8_instances = std::tuple<
// clang-format on // clang-format on
>; >;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec,
typename OutElementOp>
using device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| Compute|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| TypeA| TypeB|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#ifdef CK_ENABLE_FP8
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<>, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, F8>,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<>, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<>, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<>, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<>, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<>, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<>, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<>, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<>, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<>, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<>, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<>, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<>, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<>, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<>, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, F8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F8, F8, F32, F32, Tuple<>, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, F8>
#endif
// clang-format on
>;
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -8,9 +8,7 @@ ...@@ -8,9 +8,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck { namespace ck {
...@@ -177,6 +175,88 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -177,6 +175,88 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
} }
}; };
using CombConvScale = ck::tensor_operation::element_wise::ScaleScalePass;
#ifdef CK_ENABLE_FP8
void add_device_grouped_conv3d_fwd_xdl_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
F8,
F8,
ck::Tuple<>,
F32,
PassThrough,
PassThrough,
CombConvScale,
F8,
F8>>>& instances);
#endif
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename DLayouts,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename DDataTypes,
typename OutDataType,
typename AComputeType,
typename BComputeType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
InLayout,
WeiLayout,
DLayouts,
OutLayout,
InDataType,
WeiDataType,
DDataTypes,
OutDataType,
PassThrough,
PassThrough,
CombConvScale,
AComputeType,
BComputeType>>
{
using DeviceOp = DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
InLayout,
WeiLayout,
DLayouts,
OutLayout,
InDataType,
WeiDataType,
DDataTypes,
OutDataType,
PassThrough,
PassThrough,
CombConvScale,
AComputeType,
BComputeType>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP8
if constexpr(is_same_v<InDataType, f8_t> && is_same_v<WeiDataType, f8_t> &&
is_same_v<OutDataType, F32> && is_same_v<AComputeType, f8_t> &&
is_same_v<BComputeType, f8_t>)
{
add_device_grouped_conv3d_fwd_xdl_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances(
op_ptrs);
}
#endif
}
return op_ptrs;
}
};
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck { namespace ck {
...@@ -99,6 +99,88 @@ struct DeviceOperationInstanceFactory< ...@@ -99,6 +99,88 @@ struct DeviceOperationInstanceFactory<
} }
}; };
using CombConvScaleRelu = ck::tensor_operation::element_wise::ScaleScaleRelu;
#ifdef CK_ENABLE_FP8
void add_device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
F8,
F8,
ck::Tuple<>,
F32,
PassThrough,
PassThrough,
CombConvScaleRelu,
F8,
F8>>>& instances);
#endif
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename DLayouts,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename DDataTypes,
typename OutDataType,
typename AComputeType,
typename BComputeType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
InLayout,
WeiLayout,
DLayouts,
OutLayout,
InDataType,
WeiDataType,
DDataTypes,
OutDataType,
PassThrough,
PassThrough,
CombConvScaleRelu,
AComputeType,
BComputeType>>
{
using DeviceOp = DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
InLayout,
WeiLayout,
DLayouts,
OutLayout,
InDataType,
WeiDataType,
DDataTypes,
OutDataType,
PassThrough,
PassThrough,
CombConvScaleRelu,
AComputeType,
BComputeType>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP8
if constexpr(is_same_v<InDataType, f8_t> && is_same_v<WeiDataType, f8_t> &&
is_same_v<OutDataType, F32> && is_same_v<AComputeType, f8_t> &&
is_same_v<BComputeType, f8_t>)
{
add_device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances(
op_ptrs);
}
#endif
}
return op_ptrs;
}
};
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment