"test/vscode:/vscode.git/clone" did not exist on "05834520e5fea5e42754103b6fb2b675533fcd4b"
Commit 1d784873 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents d25889b1 851c3ed1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include <cmath>
#include <vector>
namespace ck_tile {
enum struct PositionEncodingEnum
{
NO = 0,
ALIBI = 1,
};
/*
VERTICAL:
[0] 1 2 3 4 5
[0] 1 2 3 4 5
[0] 1 2 3 4 5
[0] 1 2 3 4 5
TOP_LEFT:
[0] 1 2 3 4 5
1 [0] 1 2 3 4
2 1 [0] 1 2 3
3 2 1 [0] 1 2
FROM_BOTTOM_RIGHT:
2 1 [0] 1 2 3
3 2 1 [0] 1 2
4 3 2 1 [0] 1
5 4 3 2 1 [0]
*/
enum struct AlibiMode
{
VERTICAL = 0,
FROM_TOP_LEFT = 1, // keep sync with mask enum
FROM_BOTTOM_RIGHT = 2,
};
template <typename DataType, bool RowMajor = true>
struct Alibi
{
// RowMajor here means if pixel within the same thread are along the row, or col
// this may impact the performance of update(), while the result are the same.
// e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false
CK_TILE_HOST_DEVICE Alibi(DataType slope_,
index_t y_total_,
index_t x_total_,
AlibiMode mode_ = AlibiMode::VERTICAL)
{
slope = mode_ == AlibiMode::VERTICAL ? slope_ : -slope;
shift_left_up = [&]() {
if(RowMajor)
{
return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(y_total_ - x_total_, 0) : 0;
}
else
{
return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(x_total_ - y_total_, 0) : 0;
}
}();
shift_right_down = [&]() {
if(RowMajor)
{
return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(x_total_ - y_total_, 0) : 0;
}
else
{
return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(y_total_ - x_total_, 0) : 0;
}
}();
mode = mode_;
}
CK_TILE_HOST_DEVICE void update(DataType& pixel, index_t row_idx, index_t col_idx)
{
if constexpr(RowMajor)
{
// at least 3 instructions per row
index_t current_zero_point =
mode == AlibiMode::VERTICAL ? shift_right_down : row_idx + shift_right_down;
// for every threads, most of the pixels are along the row, below operation should be
// the main hot spot.
auto position = type_convert<DataType>(sad(bit_cast<uint32_t>(current_zero_point),
bit_cast<uint32_t>(col_idx + shift_left_up),
0));
pixel += slope * position;
}
else
{
// at least 3 instructions per col;
index_t current_zero_point = mode == AlibiMode::VERTICAL
? row_idx + col_idx + shift_right_down
: col_idx + shift_right_down;
// for every threads, most of the pixels are along the col, below operation should be
// the main hot spot.
auto position = type_convert<DataType>(sad(bit_cast<uint32_t>(current_zero_point),
bit_cast<uint32_t>(row_idx + shift_left_up),
0));
pixel += slope * position;
}
}
DataType slope; // float?
index_t shift_left_up; // always possitive
index_t shift_right_down; // always possitive
AlibiMode mode;
};
template <typename DataType>
struct EmptyPositionEncoding
{
CK_TILE_HOST_DEVICE void update(DataType& /*pixel*/, index_t /*row_idx*/, index_t /*col_idx*/)
{
}
};
//
// can convert from the FA style left/right to our generic coordinate
// if left_size < 0 && right_size = 0, it is normal causal mask
// local is left_size >=0 or right_size >=0
template <typename DataType, bool RowMajor = true>
CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope,
index_t window_left_size,
index_t window_right_size,
index_t y_total,
index_t x_total,
GenericAttentionMaskEnum mask_enum)
{
// assume mask_enum will never be NO_MASK, since if we do not have mask, it's
// totally OK to use constexpr
bool is_causal = window_left_size < 0 && window_right_size == 0;
AlibiMode alibi_mode =
is_causal ? AlibiMode::VERTICAL
: static_cast<AlibiMode>(mask_enum) /*either top-left or bottom-right*/;
return Alibi<DataType, RowMajor>{slope, y_total, x_total, alibi_mode};
}
// https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
// Do we need a device version?
template <typename DataType>
CK_TILE_HOST std::vector<DataType> get_alibi_slopes(ck_tile::index_t nheads)
{
auto get_slopes_power_of_2 = [](ck_tile::index_t n) {
float start = std::powf(
static_cast<float>(2),
-std::powf(static_cast<float>(2), -static_cast<float>((integer_log2_floor(n) - 3))));
std::vector<DataType> rtn;
for(auto i = 0; i < n; i++)
{
rtn.push_back(static_cast<DataType>(start * std::powf(start, i)));
}
return rtn;
};
if(is_power_of_two_integer(nheads))
{
// power of 2 calculation
return get_slopes_power_of_2(nheads);
}
else
{
ck_tile::index_t closest_power_of_2 = 1 << integer_log2_floor(nheads);
auto v0 = get_slopes_power_of_2(closest_power_of_2);
auto v1 = get_slopes_power_of_2(closest_power_of_2 * 2);
auto v1_sliced = [&](auto vec, ck_tile::index_t rem) {
std::vector<DataType> sliced;
for(ck_tile::index_t i = 0; i < static_cast<ck_tile::index_t>(vec.size()); i++)
{
if(i % 2 == 0)
sliced.push_back(vec[i]);
}
std::vector<DataType> sliced_2(sliced.begin(), sliced.begin() + rem);
return sliced_2;
}(v1, nheads - closest_power_of_2);
v0.insert(v0.end(), v1_sliced.begin(), v1_sliced.end());
return v0;
}
}
} // namespace ck_tile
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string> #include <string>
#include <type_traits> #include <type_traits>
...@@ -33,6 +34,7 @@ struct FmhaFwdKernel ...@@ -33,6 +34,7 @@ struct FmhaFwdKernel
using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>; using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>;
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>; using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>; using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>; using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
...@@ -41,7 +43,7 @@ struct FmhaFwdKernel ...@@ -41,7 +43,7 @@ struct FmhaFwdKernel
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr bool kHasBias = FmhaPipeline::kHasBias; static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>; using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
...@@ -81,7 +83,8 @@ struct FmhaFwdKernel ...@@ -81,7 +83,8 @@ struct FmhaFwdKernel
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(kHasBias ? "_bias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_ #undef _SS_
#undef _TS_ #undef _TS_
// clang-format on // clang-format on
...@@ -136,6 +139,13 @@ struct FmhaFwdKernel ...@@ -136,6 +139,13 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_bias = 0; ck_tile::index_t batch_stride_bias = 0;
}; };
struct FmhaFwdAlibiKargs
{
// alibi is batch*nhead*1, no matter in batch/group mode, they are the same
const void* alibi_slope_ptr;
ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
};
struct FmhaFwdMaskKargs struct FmhaFwdMaskKargs
{ {
// ck_tile::index_t window_size_left, window_size_right; // ck_tile::index_t window_size_left, window_size_right;
...@@ -162,7 +172,11 @@ struct FmhaFwdKernel ...@@ -162,7 +172,11 @@ struct FmhaFwdKernel
struct FmhaFwdBatchModeKargs struct FmhaFwdBatchModeKargs
: FmhaFwdCommonKargs, : FmhaFwdCommonKargs,
std::conditional_t<kHasBias, FmhaFwdBatchModeBiasKargs, FmhaFwdEmptyKargs<0>>, std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
FmhaFwdBatchModeBiasKargs,
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
FmhaFwdAlibiKargs,
FmhaFwdEmptyKargs<0>>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>, std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdBatchModeLSEKargs, FmhaFwdEmptyKargs<2>>, std::conditional_t<kStoreLSE, FmhaFwdBatchModeLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>> std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>
...@@ -175,7 +189,11 @@ struct FmhaFwdKernel ...@@ -175,7 +189,11 @@ struct FmhaFwdKernel
struct FmhaFwdGroupModeKargs struct FmhaFwdGroupModeKargs
: FmhaFwdCommonKargs, : FmhaFwdCommonKargs,
std::conditional_t<kHasBias, FmhaFwdCommonBiasKargs, FmhaFwdEmptyKargs<0>>, std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
FmhaFwdCommonBiasKargs,
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
FmhaFwdAlibiKargs,
FmhaFwdEmptyKargs<0>>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>, std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>, std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>> std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>
...@@ -255,13 +273,18 @@ struct FmhaFwdKernel ...@@ -255,13 +273,18 @@ struct FmhaFwdKernel
batch_stride_v, batch_stride_v,
batch_stride_o}; batch_stride_o};
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
kargs.bias_ptr = bias_ptr; kargs.bias_ptr = bias_ptr;
kargs.stride_bias = stride_bias; kargs.stride_bias = stride_bias;
kargs.nhead_stride_bias = nhead_stride_bias; kargs.nhead_stride_bias = nhead_stride_bias;
kargs.batch_stride_bias = batch_stride_bias; kargs.batch_stride_bias = batch_stride_bias;
} }
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
kargs.alibi_slope_ptr = bias_ptr;
kargs.alibi_slope_stride = stride_bias;
}
if constexpr(kHasMask) if constexpr(kHasMask)
{ {
kargs.window_size_left = window_size_left; kargs.window_size_left = window_size_left;
...@@ -345,12 +368,17 @@ struct FmhaFwdKernel ...@@ -345,12 +368,17 @@ struct FmhaFwdKernel
reinterpret_cast<const int32_t*>(seqstart_k_ptr), reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)}; reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
kargs.bias_ptr = bias_ptr; kargs.bias_ptr = bias_ptr;
kargs.stride_bias = stride_bias; kargs.stride_bias = stride_bias;
kargs.nhead_stride_bias = nhead_stride_bias; kargs.nhead_stride_bias = nhead_stride_bias;
} }
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
kargs.alibi_slope_ptr = bias_ptr;
kargs.alibi_slope_stride = stride_bias;
}
if constexpr(kHasMask) if constexpr(kHasMask)
{ {
kargs.window_size_left = window_size_left; kargs.window_size_left = window_size_left;
...@@ -421,14 +449,10 @@ struct FmhaFwdKernel ...@@ -421,14 +449,10 @@ struct FmhaFwdKernel
{ {
batch_offset_v = key_start; batch_offset_v = key_start;
} }
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
batch_offset_bias = query_start * kargs.stride_bias + key_start; batch_offset_bias = query_start * kargs.stride_bias + key_start;
} }
else
{
batch_offset_bias = key_start;
}
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
batch_offset_lse = query_start; batch_offset_lse = query_start;
...@@ -461,7 +485,7 @@ struct FmhaFwdKernel ...@@ -461,7 +485,7 @@ struct FmhaFwdKernel
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q; batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k; batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v; batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias; batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
} }
...@@ -585,7 +609,7 @@ struct FmhaFwdKernel ...@@ -585,7 +609,7 @@ struct FmhaFwdKernel
const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto bias_dram_window_lengths = constexpr auto bias_dram_window_lengths =
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN0>{}); make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN0>{});
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
const BiasDataType* bias_ptr = const BiasDataType* bias_ptr =
reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) + reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
...@@ -654,6 +678,39 @@ struct FmhaFwdKernel ...@@ -654,6 +678,39 @@ struct FmhaFwdKernel
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
}(); }();
// WA i_batch capture structure binding before c++20
auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
// data loading, shared by entire wg
// TODO: how to use s_read?
SaccDataType slope =
*(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
i_batch_ * kargs.alibi_slope_stride + i_nhead_);
#if CK_TILE_FMHA_FWD_FAST_EXP2
slope *= ck_tile::log2e_v<>;
#endif
if constexpr(kHasMask)
{
return make_alibi_from_lr_mask<SaccDataType, true>(slope,
kargs.window_size_left,
kargs.window_size_right,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type);
}
else
{
return Alibi<SaccDataType, true>{
slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::VERTICAL};
}
}
else
{
return EmptyPositionEncoding<SaccDataType>{};
}
}();
auto o_acc_tile = [&]() { auto o_acc_tile = [&]() {
if constexpr(kDoFp8StaticQuant) if constexpr(kDoFp8StaticQuant)
{ {
...@@ -672,6 +729,7 @@ struct FmhaFwdKernel ...@@ -672,6 +729,7 @@ struct FmhaFwdKernel
scales{kargs.scale_p}, // p_compute_element_func scales{kargs.scale_p}, // p_compute_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
mask, mask,
position_encoding,
kargs.scale_s, kargs.scale_s,
smem_ptr); smem_ptr);
} }
...@@ -683,6 +741,7 @@ struct FmhaFwdKernel ...@@ -683,6 +741,7 @@ struct FmhaFwdKernel
bias_dram_window, bias_dram_window,
lse_dram_window, lse_dram_window,
mask, mask,
position_encoding,
kargs.scale_s, kargs.scale_s,
smem_ptr); smem_ptr);
} }
......
...@@ -13,4 +13,23 @@ enum class BlockFmhaPipelineEnum ...@@ -13,4 +13,23 @@ enum class BlockFmhaPipelineEnum
QSKSVS, QSKSVS,
}; };
template <BlockFmhaPipelineEnum>
struct BlockFmhaPipelineEnumToStr;
template <>
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS>
{
static constexpr const char* name = "qr";
};
template <>
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS_ASYNC>
{
static constexpr const char* name = "qr_async";
};
template <>
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QSKSVS>
{
static constexpr const char* name = "qs";
};
} // namespace ck_tile } // namespace ck_tile
...@@ -45,7 +45,7 @@ struct BlockFmhaPipelineProblem ...@@ -45,7 +45,7 @@ struct BlockFmhaPipelineProblem
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr bool kHasBias = Traits::kHasBias; static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kStoreLSE = Traits::kStoreLSE; static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp"
...@@ -46,7 +47,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -46,7 +47,7 @@ struct BlockFmhaPipelineQRKSVS
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kHasBias = Problem::kHasBias; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kStoreLSE = Problem::kStoreLSE;
// last dimension vector length used to create tensor view(and decide buffer_load vector length) // last dimension vector length used to create tensor view(and decide buffer_load vector length)
...@@ -82,7 +83,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -82,7 +83,7 @@ struct BlockFmhaPipelineQRKSVS
} }
else if constexpr(kK0BlockLength <= 128) else if constexpr(kK0BlockLength <= 128)
{ {
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1; return 1;
else else
return 2; return 2;
...@@ -113,7 +114,8 @@ struct BlockFmhaPipelineQRKSVS ...@@ -113,7 +114,8 @@ struct BlockFmhaPipelineQRKSVS
typename LSEElementFunction, typename LSEElementFunction,
typename SAccElementFunction, typename SAccElementFunction,
typename PComputeElementFunction, typename PComputeElementFunction,
typename OAccElementFunction> typename OAccElementFunction,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func, const QElementFunction& q_element_func,
...@@ -129,6 +131,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -129,6 +131,7 @@ struct BlockFmhaPipelineQRKSVS
const PComputeElementFunction& p_compute_element_func, const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func, const OAccElementFunction& o_acc_element_func,
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr) const void* smem_ptr) const
{ {
...@@ -270,13 +273,13 @@ struct BlockFmhaPipelineQRKSVS ...@@ -270,13 +273,13 @@ struct BlockFmhaPipelineQRKSVS
k_block_tile = load_tile(k_dram_window); k_block_tile = load_tile(k_dram_window);
} }
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
__builtin_amdgcn_sched_barrier( __builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads 0); // prevent from messing up the order of global loads
} }
const auto bias_tile = load_tile(bias_dram_window); // load bias tile const auto bias_tile = load_tile(bias_dram_window); // load bias tile
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
__builtin_amdgcn_sched_barrier( __builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads 0); // prevent from messing up the order of global loads
...@@ -322,7 +325,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -322,7 +325,7 @@ struct BlockFmhaPipelineQRKSVS
} }
// STAGE 2, scale_s, add bias, mask, softmax // STAGE 2, scale_s, add bias, mask, softmax
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
s_acc = tile_elementwise_in(s_acc_element_func, s_acc); s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
...@@ -338,6 +341,25 @@ struct BlockFmhaPipelineQRKSVS ...@@ -338,6 +341,25 @@ struct BlockFmhaPipelineQRKSVS
s_acc, s_acc,
bias_tile); bias_tile);
} }
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= scale_s;
position_encoding.update(s_acc(i_j_idx), row, col);
});
});
}
else else
{ {
s_acc = tile_elementwise_in(s_acc_element_func, s_acc); s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
...@@ -382,7 +404,8 @@ struct BlockFmhaPipelineQRKSVS ...@@ -382,7 +404,8 @@ struct BlockFmhaPipelineQRKSVS
static const auto get_validated_m = [](SMPLComputeDataType raw_m) { static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need /// NOTICE: bias might be materialized mask including -inf values, need
/// consideration /// consideration
if constexpr(kHasBias || FmhaMask::IsMasking) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{ {
return raw_m == -numeric<SMPLComputeDataType>::infinity() return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f) ? type_convert<SMPLComputeDataType>(0.f)
...@@ -403,7 +426,8 @@ struct BlockFmhaPipelineQRKSVS ...@@ -403,7 +426,8 @@ struct BlockFmhaPipelineQRKSVS
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2 #if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
} }
...@@ -427,7 +451,8 @@ struct BlockFmhaPipelineQRKSVS ...@@ -427,7 +451,8 @@ struct BlockFmhaPipelineQRKSVS
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2 #if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() { const auto tmp = [&]() {
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
} }
...@@ -519,7 +544,8 @@ struct BlockFmhaPipelineQRKSVS ...@@ -519,7 +544,8 @@ struct BlockFmhaPipelineQRKSVS
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2 #if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
} }
...@@ -563,7 +589,8 @@ struct BlockFmhaPipelineQRKSVS ...@@ -563,7 +589,8 @@ struct BlockFmhaPipelineQRKSVS
typename KDramBlockWindowTmp, typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp, typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename LSEDramBlockWindowTmp> typename LSEDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
...@@ -571,6 +598,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -571,6 +598,7 @@ struct BlockFmhaPipelineQRKSVS
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr) const void* smem_ptr) const
{ {
...@@ -588,6 +616,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -588,6 +616,7 @@ struct BlockFmhaPipelineQRKSVS
identity{}, identity{},
identity{}, identity{},
mask, mask,
position_encoding,
scale_s, scale_s,
smem_ptr); smem_ptr);
} }
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp"
...@@ -51,7 +52,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -51,7 +52,7 @@ struct BlockFmhaPipelineQRKSVSAsync
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x)
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
static constexpr bool kHasBias = Problem::kHasBias; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kStoreLSE = Problem::kStoreLSE;
// last dimension vector length used to create tensor view(and decide buffer_load vector length) // last dimension vector length used to create tensor view(and decide buffer_load vector length)
...@@ -79,21 +80,22 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -79,21 +80,22 @@ struct BlockFmhaPipelineQRKSVSAsync
{ {
if constexpr(kK0BlockLength <= 32) if constexpr(kK0BlockLength <= 32)
{ {
if constexpr(kPadSeqLenK && kHasBias && FmhaMask::IsMasking) if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS &&
FmhaMask::IsMasking)
return 1; return 1;
else else
return 2; return 2;
} }
else if constexpr(kK0BlockLength <= 64) else if constexpr(kK0BlockLength <= 64)
{ {
if constexpr(kPadSeqLenK && kHasBias) if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 2; return 2;
else else
return 3; return 3;
} }
else if constexpr(kK0BlockLength <= 128) else if constexpr(kK0BlockLength <= 128)
{ {
if constexpr(kPadSeqLenK && kHasBias) if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1; return 1;
else else
return 2; return 2;
...@@ -124,7 +126,8 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -124,7 +126,8 @@ struct BlockFmhaPipelineQRKSVSAsync
typename LSEElementFunction, typename LSEElementFunction,
typename SAccElementFunction, typename SAccElementFunction,
typename PComputeElementFunction, typename PComputeElementFunction,
typename OAccElementFunction> typename OAccElementFunction,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func, const QElementFunction& q_element_func,
...@@ -140,6 +143,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -140,6 +143,7 @@ struct BlockFmhaPipelineQRKSVSAsync
const PComputeElementFunction& p_compute_element_func, const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func, const OAccElementFunction& o_acc_element_func,
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr) const void* smem_ptr) const
{ {
...@@ -247,8 +251,8 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -247,8 +251,8 @@ struct BlockFmhaPipelineQRKSVSAsync
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit if masked and no work to do. // check early exit
if constexpr(FmhaMask::IsMasking) if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
{ {
if(num_total_loop <= 0) if(num_total_loop <= 0)
{ {
...@@ -367,7 +371,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -367,7 +371,7 @@ struct BlockFmhaPipelineQRKSVSAsync
__builtin_amdgcn_sched_barrier(1); __builtin_amdgcn_sched_barrier(1);
// STAGE 2, scale_s, add bias, mask, softmax // STAGE 2, scale_s, add bias, mask, softmax
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
s_acc = tile_elementwise_in(s_acc_element_func, s_acc); s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
...@@ -383,6 +387,25 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -383,6 +387,25 @@ struct BlockFmhaPipelineQRKSVSAsync
s_acc, s_acc,
bias_tile); bias_tile);
} }
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= scale_s;
position_encoding.update(s_acc(i_j_idx), row, col);
});
});
}
else else
{ {
s_acc = tile_elementwise_in(s_acc_element_func, s_acc); s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
...@@ -463,8 +486,9 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -463,8 +486,9 @@ struct BlockFmhaPipelineQRKSVSAsync
static const auto get_validated_m = [](SMPLComputeDataType raw_m) { static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need /// NOTICE: bias might be materialized mask including -inf values, need
/// consideration /// consideration. alibi does not have this problem
if constexpr(kHasBias || FmhaMask::IsMasking) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{ {
return raw_m == -numeric<SMPLComputeDataType>::infinity() return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f) ? type_convert<SMPLComputeDataType>(0.f)
...@@ -485,7 +509,8 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -485,7 +509,8 @@ struct BlockFmhaPipelineQRKSVSAsync
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2 #if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
} }
...@@ -509,7 +534,8 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -509,7 +534,8 @@ struct BlockFmhaPipelineQRKSVSAsync
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2 #if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() { const auto tmp = [&]() {
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
} }
...@@ -617,7 +643,8 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -617,7 +643,8 @@ struct BlockFmhaPipelineQRKSVSAsync
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2 #if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]); lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]);
} }
...@@ -661,7 +688,8 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -661,7 +688,8 @@ struct BlockFmhaPipelineQRKSVSAsync
typename KDramBlockWindowTmp, typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp, typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename LSEDramBlockWindowTmp> typename LSEDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
...@@ -669,6 +697,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -669,6 +697,7 @@ struct BlockFmhaPipelineQRKSVSAsync
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr) const void* smem_ptr) const
{ {
...@@ -686,6 +715,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -686,6 +715,7 @@ struct BlockFmhaPipelineQRKSVSAsync
identity{}, identity{},
identity{}, identity{},
mask, mask,
position_encoding,
scale_s, scale_s,
smem_ptr); smem_ptr);
} }
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp"
...@@ -46,7 +47,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 ...@@ -46,7 +47,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kHasBias = Problem::kHasBias; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kStoreLSE = Problem::kStoreLSE;
// last dimension vector length used to create tensor view(and decide buffer_load vector length) // last dimension vector length used to create tensor view(and decide buffer_load vector length)
...@@ -82,7 +83,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 ...@@ -82,7 +83,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
} }
else if constexpr(kK0BlockLength <= 128) else if constexpr(kK0BlockLength <= 128)
{ {
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1; return 1;
else else
return 2; return 2;
...@@ -105,7 +106,8 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 ...@@ -105,7 +106,8 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
typename KDramBlockWindowTmp, typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp, typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename LSEDramBlockWindowTmp> typename LSEDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
...@@ -113,6 +115,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 ...@@ -113,6 +115,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& /*lse_dram_window_tmp*/, // not supported LSEDramBlockWindowTmp& /*lse_dram_window_tmp*/, // not supported
FmhaMask mask, FmhaMask mask,
PositionEncoding /*position_encoding*/,
float scale_s, float scale_s,
float descale_qk, float descale_qk,
float descale_sv, float descale_sv,
...@@ -249,13 +252,13 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 ...@@ -249,13 +252,13 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
k_block_tile = load_tile(k_dram_window); k_block_tile = load_tile(k_dram_window);
} }
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
__builtin_amdgcn_sched_barrier( __builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads 0); // prevent from messing up the order of global loads
} }
const auto bias_tile = load_tile(bias_dram_window); // load bias tile const auto bias_tile = load_tile(bias_dram_window); // load bias tile
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
__builtin_amdgcn_sched_barrier( __builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads 0); // prevent from messing up the order of global loads
...@@ -300,7 +303,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 ...@@ -300,7 +303,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
} }
// STAGE 2, scale_s, add bias, mask, softmax // STAGE 2, scale_s, add bias, mask, softmax
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
tile_elementwise_inout( tile_elementwise_inout(
[&](auto& x, const auto& y) { [&](auto& x, const auto& y) {
...@@ -356,7 +359,8 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 ...@@ -356,7 +359,8 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
static const auto get_validated_m = [](SMPLComputeDataType raw_m) { static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need /// NOTICE: bias might be materialized mask including -inf values, need
/// consideration /// consideration
if constexpr(kHasBias || FmhaMask::IsMasking) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{ {
return raw_m == -numeric<SMPLComputeDataType>::infinity() return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f) ? type_convert<SMPLComputeDataType>(0.f)
...@@ -377,7 +381,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 ...@@ -377,7 +381,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2 #if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
} }
...@@ -401,7 +405,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 ...@@ -401,7 +405,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2 #if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() { const auto tmp = [&]() {
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
} }
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -45,7 +46,7 @@ struct BlockFmhaPipelineQSKSVS ...@@ -45,7 +46,7 @@ struct BlockFmhaPipelineQSKSVS
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kHasBias = Problem::kHasBias; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr index_t kBlockPerCu = []() { static constexpr index_t kBlockPerCu = []() {
...@@ -63,7 +64,7 @@ struct BlockFmhaPipelineQSKSVS ...@@ -63,7 +64,7 @@ struct BlockFmhaPipelineQSKSVS
} }
else if constexpr(kK0BlockLength <= 128) else if constexpr(kK0BlockLength <= 128)
{ {
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1; return 1;
else else
return 2; return 2;
...@@ -99,7 +100,8 @@ struct BlockFmhaPipelineQSKSVS ...@@ -99,7 +100,8 @@ struct BlockFmhaPipelineQSKSVS
typename LSEElementFunction, typename LSEElementFunction,
typename SAccElementFunction, typename SAccElementFunction,
typename PComputeElementFunction, typename PComputeElementFunction,
typename OAccElementFunction> typename OAccElementFunction,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func, const QElementFunction& q_element_func,
...@@ -115,6 +117,7 @@ struct BlockFmhaPipelineQSKSVS ...@@ -115,6 +117,7 @@ struct BlockFmhaPipelineQSKSVS
const PComputeElementFunction& p_compute_element_func, const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func, const OAccElementFunction& o_acc_element_func,
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr) const void* smem_ptr) const
{ {
...@@ -265,13 +268,13 @@ struct BlockFmhaPipelineQSKSVS ...@@ -265,13 +268,13 @@ struct BlockFmhaPipelineQSKSVS
k_block_tile = load_tile(k_dram_window); k_block_tile = load_tile(k_dram_window);
} }
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
__builtin_amdgcn_sched_barrier( __builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads 0); // prevent from messing up the order of global loads
} }
const auto bias_tile = load_tile(bias_dram_window); // load bias tile const auto bias_tile = load_tile(bias_dram_window); // load bias tile
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
__builtin_amdgcn_sched_barrier( __builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads 0); // prevent from messing up the order of global loads
...@@ -313,7 +316,7 @@ struct BlockFmhaPipelineQSKSVS ...@@ -313,7 +316,7 @@ struct BlockFmhaPipelineQSKSVS
} }
// STAGE 2, scale_s, add bias, mask, softmax // STAGE 2, scale_s, add bias, mask, softmax
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
s_acc = tile_elementwise_in(s_acc_element_func, s_acc); s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
...@@ -329,6 +332,25 @@ struct BlockFmhaPipelineQSKSVS ...@@ -329,6 +332,25 @@ struct BlockFmhaPipelineQSKSVS
s_acc, s_acc,
bias_tile); bias_tile);
} }
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= scale_s;
position_encoding.update(s_acc(i_j_idx), row, col);
});
});
}
else else
{ {
s_acc = tile_elementwise_in(s_acc_element_func, s_acc); s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
...@@ -373,7 +395,8 @@ struct BlockFmhaPipelineQSKSVS ...@@ -373,7 +395,8 @@ struct BlockFmhaPipelineQSKSVS
static const auto get_validated_m = [](SMPLComputeDataType raw_m) { static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need /// NOTICE: bias might be materialized mask including -inf values, need
/// consideration /// consideration
if constexpr(kHasBias || FmhaMask::IsMasking) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{ {
return raw_m == -numeric<SMPLComputeDataType>::infinity() return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f) ? type_convert<SMPLComputeDataType>(0.f)
...@@ -394,7 +417,8 @@ struct BlockFmhaPipelineQSKSVS ...@@ -394,7 +417,8 @@ struct BlockFmhaPipelineQSKSVS
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2 #if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
} }
...@@ -418,7 +442,8 @@ struct BlockFmhaPipelineQSKSVS ...@@ -418,7 +442,8 @@ struct BlockFmhaPipelineQSKSVS
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2 #if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() { const auto tmp = [&]() {
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
} }
...@@ -510,7 +535,8 @@ struct BlockFmhaPipelineQSKSVS ...@@ -510,7 +535,8 @@ struct BlockFmhaPipelineQSKSVS
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2 #if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
} }
...@@ -554,7 +580,8 @@ struct BlockFmhaPipelineQSKSVS ...@@ -554,7 +580,8 @@ struct BlockFmhaPipelineQSKSVS
typename KDramBlockWindowTmp, typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp, typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename LSEDramBlockWindowTmp> typename LSEDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
...@@ -562,6 +589,7 @@ struct BlockFmhaPipelineQSKSVS ...@@ -562,6 +589,7 @@ struct BlockFmhaPipelineQSKSVS
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr) const void* smem_ptr) const
{ {
...@@ -579,6 +607,7 @@ struct BlockFmhaPipelineQSKSVS ...@@ -579,6 +607,7 @@ struct BlockFmhaPipelineQSKSVS
identity{}, identity{},
identity{}, identity{},
mask, mask,
position_encoding,
scale_s, scale_s,
smem_ptr); smem_ptr);
} }
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -11,7 +12,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */, ...@@ -11,7 +12,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadSeqLenK_ /* padding for seqlen_k */, bool kPadSeqLenK_ /* padding for seqlen_k */,
bool kPadHeadDimQ_ /* paddding for hdim_q */, bool kPadHeadDimQ_ /* paddding for hdim_q */,
bool kPadHeadDimV_ /* paddding for hdim_v */, bool kPadHeadDimV_ /* paddding for hdim_v */,
bool kHasBias_, BlockAttentionBiasEnum BiasEnum_,
bool kStoreLSE_, bool kStoreLSE_,
bool kDoFp8StaticQuant_, bool kDoFp8StaticQuant_,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */> index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
...@@ -21,7 +22,7 @@ struct TileFmhaTraits ...@@ -21,7 +22,7 @@ struct TileFmhaTraits
static constexpr bool kPadSeqLenK = kPadSeqLenK_; static constexpr bool kPadSeqLenK = kPadSeqLenK_;
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
static constexpr bool kPadHeadDimV = kPadHeadDimV_; static constexpr bool kPadHeadDimV = kPadHeadDimV_;
static constexpr bool kHasBias = kHasBias_; static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kStoreLSE = kStoreLSE_; static constexpr bool kStoreLSE = kStoreLSE_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr index_t kBlockPerCu = kBlockPerCu_;
......
...@@ -26,6 +26,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_ ...@@ -26,6 +26,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_
BF8, BF8,
F8>>>& instances) F8>>>& instances)
{ {
#if CK_BUILD_DEPRECATED
#pragma message "These instances are getting deprecated"
// 1. Default // 1. Default
add_device_operation_instances( add_device_operation_instances(
instances, instances,
...@@ -44,6 +46,10 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_ ...@@ -44,6 +46,10 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_
Empty_Tuple, Empty_Tuple,
NDHWGC, NDHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{}); ConvBwdDataFilter1x1Stride1Pad0>{});
#else
#pragma message "These instances were deprecated"
std::ignore = instances;
#endif
} }
} // namespace instance } // namespace instance
......
...@@ -23,6 +23,8 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_ ...@@ -23,6 +23,8 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_
BF8, BF8,
F8>>>& instances) F8>>>& instances)
{ {
#if CK_BUILD_DEPRECATED
#pragma message "These instances are getting deprecated"
// 1. Default // 1. Default
add_device_operation_instances( add_device_operation_instances(
instances, instances,
...@@ -41,6 +43,10 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_ ...@@ -41,6 +43,10 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_
GKZYXC, GKZYXC,
NDHWGK, NDHWGK,
ConvBwdWeightFilter1x1Stride1Pad0>{}); ConvBwdWeightFilter1x1Stride1Pad0>{});
#else
#pragma message "These instances were deprecated"
std::ignore = instances;
#endif
} }
} // namespace instance } // namespace instance
......
...@@ -24,6 +24,8 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instance ...@@ -24,6 +24,8 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instance
PassThrough, PassThrough,
F8>>>& instances) F8>>>& instances)
{ {
#if CK_BUILD_DEPRECATED
#pragma message "These instances are getting deprecated"
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv_fwd_xdl_f16_comp_f8_instances<3, device_grouped_conv_fwd_xdl_f16_comp_f8_instances<3,
...@@ -48,6 +50,10 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instance ...@@ -48,6 +50,10 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instance
Empty_Tuple, Empty_Tuple,
NDHWGK, NDHWGK,
ConvFwd1x1S1P0>{}); ConvFwd1x1S1P0>{});
#else
#pragma message "These instances were deprecated"
std::ignore = instances;
#endif
} }
} // namespace instance } // namespace instance
......
...@@ -181,3 +181,4 @@ add_subdirectory(wrapper) ...@@ -181,3 +181,4 @@ add_subdirectory(wrapper)
if(GPU_TARGETS MATCHES "gfx11") if(GPU_TARGETS MATCHES "gfx11")
add_subdirectory(wmma_op) add_subdirectory(wmma_op)
endif() endif()
add_subdirectory(position_embedding)
add_test_executable(test_position_embedding position_embedding.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha.hpp"
#ifndef TEST_ALIBI_VERBOSE
#define TEST_ALIBI_VERBOSE 0
#endif
template <typename DataType>
struct attention_score
{
ck_tile::index_t rows, cols;
std::vector<DataType> pixels;
attention_score(ck_tile::index_t rows_,
ck_tile::index_t cols_,
DataType init_v_ = static_cast<DataType>(0))
: rows(rows_), cols(cols_), pixels(rows_ * cols_, init_v_)
{
}
auto& operator()(ck_tile::index_t i_row, ck_tile::index_t i_col)
{
return pixels[i_row * cols + i_col];
}
void print()
{
for(auto i_row = 0; i_row < rows; i_row++)
{
for(auto i_col = 0; i_col < cols; i_col++)
{
std::cout << pixels[i_row * cols + i_col] << " ";
}
std::cout << std::endl;
}
}
};
template <bool RowMajor, typename DataType>
void alibi_traverse_with_slope(attention_score<DataType>& score,
DataType slope,
ck_tile::AlibiMode mode = ck_tile::AlibiMode::VERTICAL)
{
using Alibi = ck_tile::Alibi<DataType, RowMajor>;
auto alibi = Alibi{slope, score.rows, score.cols, mode};
for(ck_tile::index_t i_row = 0; i_row < score.rows; i_row++)
{
for(ck_tile::index_t i_col = 0; i_col < score.cols; i_col++)
{
alibi.update(score(i_row, i_col), i_row, i_col);
}
}
}
std::string alibi_mode_to_str(ck_tile::AlibiMode mode)
{
if(mode == ck_tile::AlibiMode::VERTICAL)
return std::string("alibi_verti");
else if(mode == ck_tile::AlibiMode::FROM_TOP_LEFT)
return std::string("alibi_top-l");
else if(mode == ck_tile::AlibiMode::FROM_BOTTOM_RIGHT)
return std::string("alibi_bot-r");
return "";
}
template <bool RowMajor, typename DataType>
bool test_alibi_traverse_with_slope(ck_tile::index_t rows,
ck_tile::index_t cols,
DataType slope,
ck_tile::AlibiMode mode,
const std::vector<DataType>& expected)
{
attention_score<DataType> score{rows, cols};
alibi_traverse_with_slope<RowMajor, DataType>(score, slope, mode);
bool is_match = std::equal(score.pixels.begin(), score.pixels.end(), expected.begin());
#if TEST_ALIBI_VERBOSE
std::cout << "---------" << alibi_mode_to_str(mode) << ", " << rows << "x" << cols << "("
<< (RowMajor ? "row_major" : "col_major") << ")"
<< (is_match ? ", valie:y" : ", valid:n") << std::endl;
score.print();
#endif
return is_match;
}
template <typename DataType>
bool test_alibi_slope_generation(ck_tile::index_t nheads, const std::vector<DataType>& expected)
{
auto slopes = ck_tile::get_alibi_slopes<DataType>(nheads);
bool is_match = std::equal(slopes.begin(),
slopes.end(),
expected.begin(),
expected.end(),
[](const DataType& lhs, const DataType& rhs) {
constexpr float rtol = 1e-6;
auto error = std::abs(lhs - rhs);
return error < rtol * std::abs(rhs);
});
#if TEST_ALIBI_VERBOSE
std::cout << "-------------------- slopes " << nheads << ", " << (is_match ? "y" : "n")
<< std::endl;
for(ck_tile::index_t i = 0; i < nheads; i++)
{
std::cout << slopes[i] << " ";
}
std::cout << std::endl;
#endif
return is_match;
}
int main()
{
using dtype = int32_t;
dtype slope = static_cast<dtype>(1);
bool rtn = true;
// clang-format off
rtn &= test_alibi_traverse_with_slope<true, dtype>(4, 6, slope, ck_tile::AlibiMode::VERTICAL, {0, 1, 2, 3, 4, 5,
0, 1, 2, 3, 4, 5,
0, 1, 2, 3, 4, 5,
0, 1, 2, 3, 4, 5});
rtn &= test_alibi_traverse_with_slope<true, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, 4, 5,
1, 0, 1, 2, 3, 4,
2, 1, 0, 1, 2, 3,
3, 2, 1, 0, 1, 2});
rtn &= test_alibi_traverse_with_slope<true, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3,
1, 0, 1, 2,
2, 1, 0, 1,
3, 2, 1, 0,
4, 3, 2, 1,
5, 4, 3, 2});
rtn &= test_alibi_traverse_with_slope<true, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2,
1, 0, 1,
2, 1, 0});
rtn &= test_alibi_traverse_with_slope<true, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 1, 0, 1, 2, 3,
3, 2, 1, 0, 1, 2,
4, 3, 2, 1, 0, 1,
5, 4, 3, 2, 1, 0});
rtn &= test_alibi_traverse_with_slope<true, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 3, 4, 5,
1, 2, 3, 4,
0, 1, 2, 3,
1, 0, 1, 2,
2, 1, 0, 1,
3, 2, 1, 0});
rtn &= test_alibi_traverse_with_slope<true, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {0, 1, 2,
1, 0, 1,
2, 1, 0});
rtn &= test_alibi_traverse_with_slope<false, dtype>(4, 6, slope, ck_tile::AlibiMode::VERTICAL, {0, 1, 2, 3, 4, 5,
0, 1, 2, 3, 4, 5,
0, 1, 2, 3, 4, 5,
0, 1, 2, 3, 4, 5});
rtn &= test_alibi_traverse_with_slope<false, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, 4, 5,
1, 0, 1, 2, 3, 4,
2, 1, 0, 1, 2, 3,
3, 2, 1, 0, 1, 2});
rtn &= test_alibi_traverse_with_slope<false, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3,
1, 0, 1, 2,
2, 1, 0, 1,
3, 2, 1, 0,
4, 3, 2, 1,
5, 4, 3, 2});
rtn &= test_alibi_traverse_with_slope<false, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2,
1, 0, 1,
2, 1, 0});
rtn &= test_alibi_traverse_with_slope<false, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 1, 0, 1, 2, 3,
3, 2, 1, 0, 1, 2,
4, 3, 2, 1, 0, 1,
5, 4, 3, 2, 1, 0});
rtn &= test_alibi_traverse_with_slope<false, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 3, 4, 5,
1, 2, 3, 4,
0, 1, 2, 3,
1, 0, 1, 2,
2, 1, 0, 1,
3, 2, 1, 0});
rtn &= test_alibi_traverse_with_slope<false, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {0, 1, 2,
1, 0, 1,
2, 1, 0});
rtn &= test_alibi_slope_generation<float>(8, {0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625, 0.0078125, 0.00390625});
rtn &= test_alibi_slope_generation<float>(16, {0.7071067811865476, 0.5, 0.35355339059327384, 0.25000000000000006, 0.17677669529663692,
0.12500000000000006, 0.08838834764831849, 0.06250000000000004, 0.044194173824159244,
0.03125000000000002, 0.022097086912079626, 0.01562500000000001, 0.011048543456039816,
0.007812500000000007, 0.005524271728019908, 0.003906250000000004});
rtn &= test_alibi_slope_generation<float>(1, {0.00390625});
rtn &= test_alibi_slope_generation<float>(5, {0.25, 0.0625, 0.015625, 0.00390625, 0.5});
rtn &= test_alibi_slope_generation<float>(6, {0.25, 0.0625, 0.015625, 0.00390625, 0.5, 0.125});
rtn &= test_alibi_slope_generation<float>(7, {0.25, 0.0625, 0.015625, 0.00390625, 0.5, 0.125, 0.03125});
rtn &= test_alibi_slope_generation<float>(9, {0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625, 0.0078125, 0.00390625, 0.7071067811865476});
// clang-format on
return rtn ? 0 : -1;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment