Commit 036c5234 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge remote-tracking branch 'origin/develop' into aosewski/ggemm_multi_d2

parents 22995e9a 7843a8a7
......@@ -3,7 +3,9 @@
#pragma once
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
namespace ck_tile {
// This class is used for codegen pattern matching
enum class BlockAttentionBiasEnum
{
NO_BIAS = 0,
ELEMENTWISE_BIAS = 1, // attention bias, each elements add to the result of Q*K(after scale)
ALIBI = 2, // bias computed with position encoding, applied after scale
};
template <BlockAttentionBiasEnum>
struct BlockAttentionBiasEnumToStr;
template <>
struct BlockAttentionBiasEnumToStr<BlockAttentionBiasEnum::NO_BIAS>
{
static constexpr const char* name = "";
};
template <>
struct BlockAttentionBiasEnumToStr<BlockAttentionBiasEnum::ELEMENTWISE_BIAS>
{
static constexpr const char* name = "bias";
};
template <>
struct BlockAttentionBiasEnumToStr<BlockAttentionBiasEnum::ALIBI>
{
static constexpr const char* name = "alibi";
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include <cmath>
#include <vector>
namespace ck_tile {
enum struct PositionEncodingEnum
{
NO = 0,
ALIBI = 1,
};
/*
VERTICAL:
[0] 1 2 3 4 5
[0] 1 2 3 4 5
[0] 1 2 3 4 5
[0] 1 2 3 4 5
TOP_LEFT:
[0] 1 2 3 4 5
1 [0] 1 2 3 4
2 1 [0] 1 2 3
3 2 1 [0] 1 2
FROM_BOTTOM_RIGHT:
2 1 [0] 1 2 3
3 2 1 [0] 1 2
4 3 2 1 [0] 1
5 4 3 2 1 [0]
*/
enum struct AlibiMode
{
VERTICAL = 0,
FROM_TOP_LEFT = 1, // keep sync with mask enum
FROM_BOTTOM_RIGHT = 2,
};
template <typename DataType, bool RowMajor = true>
struct Alibi
{
// RowMajor here means if pixel within the same thread are along the row, or col
// this may impact the performance of update(), while the result are the same.
// e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false
CK_TILE_HOST_DEVICE Alibi(DataType slope_,
index_t y_total_,
index_t x_total_,
AlibiMode mode_ = AlibiMode::VERTICAL)
{
slope = mode_ == AlibiMode::VERTICAL ? slope_ : -slope;
shift_left_up = [&]() {
if(RowMajor)
{
return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(y_total_ - x_total_, 0) : 0;
}
else
{
return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(x_total_ - y_total_, 0) : 0;
}
}();
shift_right_down = [&]() {
if(RowMajor)
{
return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(x_total_ - y_total_, 0) : 0;
}
else
{
return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(y_total_ - x_total_, 0) : 0;
}
}();
mode = mode_;
}
CK_TILE_HOST_DEVICE void update(DataType& pixel, index_t row_idx, index_t col_idx)
{
if constexpr(RowMajor)
{
// at least 3 instructions per row
index_t current_zero_point =
mode == AlibiMode::VERTICAL ? shift_right_down : row_idx + shift_right_down;
// for every threads, most of the pixels are along the row, below operation should be
// the main hot spot.
auto position = type_convert<DataType>(sad(bit_cast<uint32_t>(current_zero_point),
bit_cast<uint32_t>(col_idx + shift_left_up),
0));
pixel += slope * position;
}
else
{
// at least 3 instructions per col;
index_t current_zero_point = mode == AlibiMode::VERTICAL
? row_idx + col_idx + shift_right_down
: col_idx + shift_right_down;
// for every threads, most of the pixels are along the col, below operation should be
// the main hot spot.
auto position = type_convert<DataType>(sad(bit_cast<uint32_t>(current_zero_point),
bit_cast<uint32_t>(row_idx + shift_left_up),
0));
pixel += slope * position;
}
}
DataType slope; // float?
index_t shift_left_up; // always possitive
index_t shift_right_down; // always possitive
AlibiMode mode;
};
template <typename DataType>
struct EmptyPositionEncoding
{
CK_TILE_HOST_DEVICE void update(DataType& /*pixel*/, index_t /*row_idx*/, index_t /*col_idx*/)
{
}
};
//
// can convert from the FA style left/right to our generic coordinate
// if left_size < 0 && right_size = 0, it is normal causal mask
// local is left_size >=0 or right_size >=0
template <typename DataType, bool RowMajor = true>
CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope,
index_t window_left_size,
index_t window_right_size,
index_t y_total,
index_t x_total,
GenericAttentionMaskEnum mask_enum)
{
// assume mask_enum will never be NO_MASK, since if we do not have mask, it's
// totally OK to use constexpr
bool is_causal = window_left_size < 0 && window_right_size == 0;
AlibiMode alibi_mode =
is_causal ? AlibiMode::VERTICAL
: static_cast<AlibiMode>(mask_enum) /*either top-left or bottom-right*/;
return Alibi<DataType, RowMajor>{slope, y_total, x_total, alibi_mode};
}
// https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
// Do we need a device version?
template <typename DataType>
CK_TILE_HOST std::vector<DataType> get_alibi_slopes(ck_tile::index_t nheads)
{
auto get_slopes_power_of_2 = [](ck_tile::index_t n) {
float start = std::powf(
static_cast<float>(2),
-std::powf(static_cast<float>(2), -static_cast<float>((integer_log2_floor(n) - 3))));
std::vector<DataType> rtn;
for(auto i = 0; i < n; i++)
{
rtn.push_back(static_cast<DataType>(start * std::powf(start, i)));
}
return rtn;
};
if(is_power_of_two_integer(nheads))
{
// power of 2 calculation
return get_slopes_power_of_2(nheads);
}
else
{
ck_tile::index_t closest_power_of_2 = 1 << integer_log2_floor(nheads);
auto v0 = get_slopes_power_of_2(closest_power_of_2);
auto v1 = get_slopes_power_of_2(closest_power_of_2 * 2);
auto v1_sliced = [&](auto vec, ck_tile::index_t rem) {
std::vector<DataType> sliced;
for(ck_tile::index_t i = 0; i < static_cast<ck_tile::index_t>(vec.size()); i++)
{
if(i % 2 == 0)
sliced.push_back(vec[i]);
}
std::vector<DataType> sliced_2(sliced.begin(), sliced.begin() + rem);
return sliced_2;
}(v1, nheads - closest_power_of_2);
v0.insert(v0.end(), v1_sliced.begin(), v1_sliced.end());
return v0;
}
}
} // namespace ck_tile
......@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string>
#include <type_traits>
......@@ -33,6 +34,7 @@ struct FmhaFwdKernel
using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>;
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
......@@ -41,7 +43,7 @@ struct FmhaFwdKernel
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr bool kHasBias = FmhaPipeline::kHasBias;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
......@@ -81,7 +83,8 @@ struct FmhaFwdKernel
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(kHasBias ? "_bias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_
#undef _TS_
// clang-format on
......@@ -136,6 +139,13 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_bias = 0;
};
struct FmhaFwdAlibiKargs
{
// alibi is batch*nhead*1, no matter in batch/group mode, they are the same
const void* alibi_slope_ptr;
ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
};
struct FmhaFwdMaskKargs
{
// ck_tile::index_t window_size_left, window_size_right;
......@@ -162,7 +172,11 @@ struct FmhaFwdKernel
struct FmhaFwdBatchModeKargs
: FmhaFwdCommonKargs,
std::conditional_t<kHasBias, FmhaFwdBatchModeBiasKargs, FmhaFwdEmptyKargs<0>>,
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
FmhaFwdBatchModeBiasKargs,
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
FmhaFwdAlibiKargs,
FmhaFwdEmptyKargs<0>>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdBatchModeLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>
......@@ -175,7 +189,11 @@ struct FmhaFwdKernel
struct FmhaFwdGroupModeKargs
: FmhaFwdCommonKargs,
std::conditional_t<kHasBias, FmhaFwdCommonBiasKargs, FmhaFwdEmptyKargs<0>>,
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
FmhaFwdCommonBiasKargs,
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
FmhaFwdAlibiKargs,
FmhaFwdEmptyKargs<0>>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>
......@@ -255,13 +273,18 @@ struct FmhaFwdKernel
batch_stride_v,
batch_stride_o};
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
kargs.bias_ptr = bias_ptr;
kargs.stride_bias = stride_bias;
kargs.nhead_stride_bias = nhead_stride_bias;
kargs.batch_stride_bias = batch_stride_bias;
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
kargs.alibi_slope_ptr = bias_ptr;
kargs.alibi_slope_stride = stride_bias;
}
if constexpr(kHasMask)
{
kargs.window_size_left = window_size_left;
......@@ -345,12 +368,17 @@ struct FmhaFwdKernel
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
kargs.bias_ptr = bias_ptr;
kargs.stride_bias = stride_bias;
kargs.nhead_stride_bias = nhead_stride_bias;
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
kargs.alibi_slope_ptr = bias_ptr;
kargs.alibi_slope_stride = stride_bias;
}
if constexpr(kHasMask)
{
kargs.window_size_left = window_size_left;
......@@ -421,14 +449,10 @@ struct FmhaFwdKernel
{
batch_offset_v = key_start;
}
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = query_start * kargs.stride_bias + key_start;
}
else
{
batch_offset_bias = key_start;
}
if constexpr(kStoreLSE)
{
batch_offset_lse = query_start;
......@@ -461,7 +485,7 @@ struct FmhaFwdKernel
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
}
......@@ -585,7 +609,7 @@ struct FmhaFwdKernel
const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto bias_dram_window_lengths =
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN0>{});
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
const BiasDataType* bias_ptr =
reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
......@@ -654,6 +678,39 @@ struct FmhaFwdKernel
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
}();
// WA i_batch capture structure binding before c++20
auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
// data loading, shared by entire wg
// TODO: how to use s_read?
SaccDataType slope =
*(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
i_batch_ * kargs.alibi_slope_stride + i_nhead_);
#if CK_TILE_FMHA_FWD_FAST_EXP2
slope *= ck_tile::log2e_v<>;
#endif
if constexpr(kHasMask)
{
return make_alibi_from_lr_mask<SaccDataType, true>(slope,
kargs.window_size_left,
kargs.window_size_right,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type);
}
else
{
return Alibi<SaccDataType, true>{
slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::VERTICAL};
}
}
else
{
return EmptyPositionEncoding<SaccDataType>{};
}
}();
auto o_acc_tile = [&]() {
if constexpr(kDoFp8StaticQuant)
{
......@@ -672,6 +729,7 @@ struct FmhaFwdKernel
scales{kargs.scale_p}, // p_compute_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
mask,
position_encoding,
kargs.scale_s,
smem_ptr);
}
......@@ -683,6 +741,7 @@ struct FmhaFwdKernel
bias_dram_window,
lse_dram_window,
mask,
position_encoding,
kargs.scale_s,
smem_ptr);
}
......
......@@ -13,4 +13,23 @@ enum class BlockFmhaPipelineEnum
QSKSVS,
};
template <BlockFmhaPipelineEnum>
struct BlockFmhaPipelineEnumToStr;
template <>
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS>
{
static constexpr const char* name = "qr";
};
template <>
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS_ASYNC>
{
static constexpr const char* name = "qr_async";
};
template <>
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QSKSVS>
{
static constexpr const char* name = "qs";
};
} // namespace ck_tile
......@@ -45,7 +45,7 @@ struct BlockFmhaPipelineProblem
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr bool kHasBias = Traits::kHasBias;
static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
......
......@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
......@@ -46,7 +47,7 @@ struct BlockFmhaPipelineQRKSVS
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kHasBias = Problem::kHasBias;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
......@@ -82,7 +83,7 @@ struct BlockFmhaPipelineQRKSVS
}
else if constexpr(kK0BlockLength <= 128)
{
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1;
else
return 2;
......@@ -113,7 +114,8 @@ struct BlockFmhaPipelineQRKSVS
typename LSEElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction>
typename OAccElementFunction,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
......@@ -129,6 +131,7 @@ struct BlockFmhaPipelineQRKSVS
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr) const
{
......@@ -270,13 +273,13 @@ struct BlockFmhaPipelineQRKSVS
k_block_tile = load_tile(k_dram_window);
}
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
......@@ -322,7 +325,7 @@ struct BlockFmhaPipelineQRKSVS
}
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
......@@ -338,6 +341,25 @@ struct BlockFmhaPipelineQRKSVS
s_acc,
bias_tile);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= scale_s;
position_encoding.update(s_acc(i_j_idx), row, col);
});
});
}
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
......@@ -382,7 +404,8 @@ struct BlockFmhaPipelineQRKSVS
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
if constexpr(kHasBias || FmhaMask::IsMasking)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
......@@ -403,7 +426,8 @@ struct BlockFmhaPipelineQRKSVS
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
}
......@@ -427,7 +451,8 @@ struct BlockFmhaPipelineQRKSVS
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() {
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
......@@ -519,7 +544,8 @@ struct BlockFmhaPipelineQRKSVS
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
}
......@@ -563,7 +589,8 @@ struct BlockFmhaPipelineQRKSVS
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename LSEDramBlockWindowTmp>
typename LSEDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
......@@ -571,6 +598,7 @@ struct BlockFmhaPipelineQRKSVS
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr) const
{
......@@ -588,6 +616,7 @@ struct BlockFmhaPipelineQRKSVS
identity{},
identity{},
mask,
position_encoding,
scale_s,
smem_ptr);
}
......
......@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
......@@ -51,7 +52,7 @@ struct BlockFmhaPipelineQRKSVSAsync
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x)
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
static constexpr bool kHasBias = Problem::kHasBias;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
......@@ -79,21 +80,22 @@ struct BlockFmhaPipelineQRKSVSAsync
{
if constexpr(kK0BlockLength <= 32)
{
if constexpr(kPadSeqLenK && kHasBias && FmhaMask::IsMasking)
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS &&
FmhaMask::IsMasking)
return 1;
else
return 2;
}
else if constexpr(kK0BlockLength <= 64)
{
if constexpr(kPadSeqLenK && kHasBias)
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 2;
else
return 3;
}
else if constexpr(kK0BlockLength <= 128)
{
if constexpr(kPadSeqLenK && kHasBias)
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1;
else
return 2;
......@@ -124,7 +126,8 @@ struct BlockFmhaPipelineQRKSVSAsync
typename LSEElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction>
typename OAccElementFunction,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
......@@ -140,6 +143,7 @@ struct BlockFmhaPipelineQRKSVSAsync
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr) const
{
......@@ -247,8 +251,8 @@ struct BlockFmhaPipelineQRKSVSAsync
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit if masked and no work to do.
if constexpr(FmhaMask::IsMasking)
// check early exit
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
{
if(num_total_loop <= 0)
{
......@@ -367,7 +371,7 @@ struct BlockFmhaPipelineQRKSVSAsync
__builtin_amdgcn_sched_barrier(1);
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
......@@ -383,6 +387,25 @@ struct BlockFmhaPipelineQRKSVSAsync
s_acc,
bias_tile);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= scale_s;
position_encoding.update(s_acc(i_j_idx), row, col);
});
});
}
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
......@@ -463,8 +486,9 @@ struct BlockFmhaPipelineQRKSVSAsync
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
if constexpr(kHasBias || FmhaMask::IsMasking)
/// consideration. alibi does not have this problem
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
......@@ -485,7 +509,8 @@ struct BlockFmhaPipelineQRKSVSAsync
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
}
......@@ -509,7 +534,8 @@ struct BlockFmhaPipelineQRKSVSAsync
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() {
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
......@@ -617,7 +643,8 @@ struct BlockFmhaPipelineQRKSVSAsync
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]);
}
......@@ -661,7 +688,8 @@ struct BlockFmhaPipelineQRKSVSAsync
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename LSEDramBlockWindowTmp>
typename LSEDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
......@@ -669,6 +697,7 @@ struct BlockFmhaPipelineQRKSVSAsync
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr) const
{
......@@ -686,6 +715,7 @@ struct BlockFmhaPipelineQRKSVSAsync
identity{},
identity{},
mask,
position_encoding,
scale_s,
smem_ptr);
}
......
......@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
......@@ -46,7 +47,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kHasBias = Problem::kHasBias;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
......@@ -82,7 +83,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
}
else if constexpr(kK0BlockLength <= 128)
{
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1;
else
return 2;
......@@ -105,7 +106,8 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename LSEDramBlockWindowTmp>
typename LSEDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
......@@ -113,6 +115,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& /*lse_dram_window_tmp*/, // not supported
FmhaMask mask,
PositionEncoding /*position_encoding*/,
float scale_s,
float descale_qk,
float descale_sv,
......@@ -249,13 +252,13 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
k_block_tile = load_tile(k_dram_window);
}
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
......@@ -300,7 +303,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
}
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
tile_elementwise_inout(
[&](auto& x, const auto& y) {
......@@ -356,7 +359,8 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
if constexpr(kHasBias || FmhaMask::IsMasking)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
......@@ -377,7 +381,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
}
......@@ -401,7 +405,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() {
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
......
......@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
namespace ck_tile {
......@@ -45,7 +46,7 @@ struct BlockFmhaPipelineQSKSVS
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kHasBias = Problem::kHasBias;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr index_t kBlockPerCu = []() {
......@@ -63,7 +64,7 @@ struct BlockFmhaPipelineQSKSVS
}
else if constexpr(kK0BlockLength <= 128)
{
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1;
else
return 2;
......@@ -99,7 +100,8 @@ struct BlockFmhaPipelineQSKSVS
typename LSEElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction>
typename OAccElementFunction,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
......@@ -115,6 +117,7 @@ struct BlockFmhaPipelineQSKSVS
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr) const
{
......@@ -265,13 +268,13 @@ struct BlockFmhaPipelineQSKSVS
k_block_tile = load_tile(k_dram_window);
}
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
......@@ -313,7 +316,7 @@ struct BlockFmhaPipelineQSKSVS
}
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
......@@ -329,6 +332,25 @@ struct BlockFmhaPipelineQSKSVS
s_acc,
bias_tile);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= scale_s;
position_encoding.update(s_acc(i_j_idx), row, col);
});
});
}
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
......@@ -373,7 +395,8 @@ struct BlockFmhaPipelineQSKSVS
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
if constexpr(kHasBias || FmhaMask::IsMasking)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
......@@ -394,7 +417,8 @@ struct BlockFmhaPipelineQSKSVS
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
}
......@@ -418,7 +442,8 @@ struct BlockFmhaPipelineQSKSVS
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() {
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
......@@ -510,7 +535,8 @@ struct BlockFmhaPipelineQSKSVS
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
}
......@@ -554,7 +580,8 @@ struct BlockFmhaPipelineQSKSVS
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename LSEDramBlockWindowTmp>
typename LSEDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
......@@ -562,6 +589,7 @@ struct BlockFmhaPipelineQSKSVS
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr) const
{
......@@ -579,6 +607,7 @@ struct BlockFmhaPipelineQSKSVS
identity{},
identity{},
mask,
position_encoding,
scale_s,
smem_ptr);
}
......
......@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
namespace ck_tile {
......@@ -11,7 +12,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadSeqLenK_ /* padding for seqlen_k */,
bool kPadHeadDimQ_ /* paddding for hdim_q */,
bool kPadHeadDimV_ /* paddding for hdim_v */,
bool kHasBias_,
BlockAttentionBiasEnum BiasEnum_,
bool kStoreLSE_,
bool kDoFp8StaticQuant_,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
......@@ -21,7 +22,7 @@ struct TileFmhaTraits
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
static constexpr bool kHasBias = kHasBias_;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kStoreLSE = kStoreLSE_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
......
......@@ -27,14 +27,16 @@ using Empty_Tuple = ck::Tuple<>;
using BF16_Tuple = ck::Tuple<BF16>;
using F16_Tuple = ck::Tuple<F16>;
using F16_F16_Tuple = ck::Tuple<F16, F16>;
using F16_Tuple = ck::Tuple<F16>;
using F16_F16_Tuple = ck::Tuple<F16, F16>;
using BF16_BF16_Tuple = ck::Tuple<BF16, BF16>;
using F64_Tuple = ck::Tuple<F64>;
using F32_Tuple = ck::Tuple<F32>;
using I32_Tuple = ck::Tuple<I32>;
using I32_F32_Tuple = ck::Tuple<I32, F32>;
using I8_Tuple = ck::Tuple<I8>;
using BF16_Tuple = ck::Tuple<BF16>;
using F32_F32_Tuple = ck::Tuple<F32, F32>;
......@@ -91,23 +93,26 @@ using GK_Tuple = ck::Tuple<G_K>;
using GK_GK_Tuple = ck::Tuple<G_K, G_K>;
// pointwise functor
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Relu = ck::tensor_operation::element_wise::Relu;
using TanH = ck::tensor_operation::element_wise::TanH;
using Scale = ck::tensor_operation::element_wise::Scale;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
using AddSilu = ck::tensor_operation::element_wise::AddSilu;
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
using Gelu = ck::tensor_operation::element_wise::Gelu;
using Swish = ck::tensor_operation::element_wise::Swish;
using Add = ck::tensor_operation::element_wise::Add;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Relu = ck::tensor_operation::element_wise::Relu;
using TanH = ck::tensor_operation::element_wise::TanH;
using Scale = ck::tensor_operation::element_wise::Scale;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
using AddSilu = ck::tensor_operation::element_wise::AddSilu;
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
using Gelu = ck::tensor_operation::element_wise::Gelu;
using Swish = ck::tensor_operation::element_wise::Swish;
using Add = ck::tensor_operation::element_wise::Add;
using Multiply = ck::tensor_operation::element_wise::Multiply;
template <typename Activation>
using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>;
......
......@@ -17,7 +17,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
using Scales = ck::tensor_operation::element_wise::Scales;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
......@@ -33,7 +33,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances(
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Multiply,
AddFastGelu>>>& instances);
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(
......@@ -46,7 +46,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Multiply,
Add>>>& instances);
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(
......@@ -59,7 +59,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(
ck::Tuple<>,
BF16,
PassThrough,
Scales,
Multiply,
FastGelu>>>& instances);
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances(
......@@ -72,7 +72,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances(
ck::Tuple<>,
BF16,
PassThrough,
Scales,
Multiply,
PassThrough>>>& instances);
// RCR
......@@ -86,7 +86,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_v1_instances(
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Multiply,
AddFastGelu>>>& instances);
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances(
......@@ -99,7 +99,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances(
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Multiply,
Add>>>& instances);
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances(
......@@ -112,7 +112,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances(
ck::Tuple<>,
BF16,
PassThrough,
Scales,
Multiply,
FastGelu>>>& instances);
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances(
......@@ -125,7 +125,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances(
ck::Tuple<>,
BF16,
PassThrough,
Scales,
Multiply,
PassThrough>>>& instances);
// CRR
......@@ -139,7 +139,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_v1_instances(
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Multiply,
AddFastGelu>>>& instances);
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_bias_v1_instances(
......@@ -152,7 +152,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_bias_v1_instances(
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Multiply,
Add>>>& instances);
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_v1_instances(
......@@ -165,7 +165,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_v1_instances(
ck::Tuple<>,
BF16,
PassThrough,
Scales,
Multiply,
FastGelu>>>& instances);
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_v1_instances(
......@@ -178,8 +178,62 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_v1_instances(
ck::Tuple<>,
BF16,
PassThrough,
Scales,
Multiply,
PassThrough>>>& instances);
// Multiply
void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAddFastGelu>>>& instances);
void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAdd>>>& instances);
void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyFastGelu>>>& instances);
void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>>& instances);
#endif
// GEMM + Add + Gelu
......@@ -201,7 +255,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
AddFastGelu>>
{
using DeviceOp = DeviceGemmMultipleABD<AsLayout,
......@@ -213,7 +267,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
AddFastGelu>;
static auto GetInstances()
......@@ -271,7 +325,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
Add>>
{
using DeviceOp = DeviceGemmMultipleABD<AsLayout,
......@@ -283,7 +337,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
Add>;
static auto GetInstances()
......@@ -341,7 +395,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
FastGelu>>
{
using DeviceOp = DeviceGemmMultipleABD<AsLayout,
......@@ -353,7 +407,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
FastGelu>;
static auto GetInstances()
......@@ -411,7 +465,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
PassThrough>>
{
using DeviceOp = DeviceGemmMultipleABD<AsLayout,
......@@ -423,7 +477,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
PassThrough>;
static auto GetInstances()
......@@ -462,6 +516,234 @@ struct DeviceOperationInstanceFactory<
}
};
// Multiply
// GEMM + Add + Gelu
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
PassThrough,
MultiplyAddFastGelu>>
{
using DeviceOp = DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
PassThrough,
MultiplyAddFastGelu>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
is_same_v<BsDataType, ck::Tuple<I8>> &&
is_same_v<DsDataType, ck::Tuple<BF16, BF16>> && is_same_v<EDataType, BF16>)
{
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Row>> &&
is_same_v<DsLayout, ck::Tuple<Row, Row>> && is_same_v<ELayout, Row>)
{
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances(
op_ptrs);
}
}
#endif
return op_ptrs;
}
};
// GEMM + Add
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
PassThrough,
MultiplyAdd>>
{
using DeviceOp = DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
PassThrough,
MultiplyAdd>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
is_same_v<BsDataType, ck::Tuple<I8>> &&
is_same_v<DsDataType, ck::Tuple<BF16, BF16>> && is_same_v<EDataType, BF16>)
{
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Row>> &&
is_same_v<DsLayout, ck::Tuple<Row, Row>> && is_same_v<ELayout, Row>)
{
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(
op_ptrs);
}
}
#endif
return op_ptrs;
}
};
// GEMM + Gelu
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
PassThrough,
MultiplyFastGelu>>
{
using DeviceOp = DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
PassThrough,
MultiplyFastGelu>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
is_same_v<BsDataType, ck::Tuple<I8>> &&
is_same_v<DsDataType, ck::Tuple<BF16>> && is_same_v<EDataType, BF16>)
{
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Row>> &&
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
{
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(
op_ptrs);
}
}
#endif
return op_ptrs;
}
};
// GEMM
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
PassThrough,
Multiply>>
{
using DeviceOp = DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
PassThrough,
Multiply>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
is_same_v<BsDataType, ck::Tuple<I8>> &&
is_same_v<DsDataType, ck::Tuple<BF16>> && is_same_v<EDataType, BF16>)
{
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Row>> &&
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
{
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
......
......@@ -315,6 +315,107 @@ void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instanc
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
template <typename ADataType,
typename BDataType,
......@@ -494,6 +595,64 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> &&
is_same_v<CDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_instances(
op_ptrs);
}
}
#endif
return op_ptrs;
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using namespace ck::tensor_layout::convolution;
using F16 = ck::half_t;
using F32 = float;
using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvBwdWeightDefault =
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 =
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionBackwardWeightSpecialization ConvSpec>
using device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances = std::tuple<
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<1, 4, 8, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 8, 1, 8>, 1>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -86,6 +86,7 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_bilinear_instances = std:
//#########################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
// generic instance
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F16, F16, F16, F32, Tuple<F16>, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1>,
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F16, F16, F16, F32, Tuple<F16>, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2>,
// instance for small conv.K
// for fp16 conv.K and conv.C must be divisible by 2
......
......@@ -352,6 +352,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
......@@ -419,6 +421,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
......
......@@ -113,6 +113,18 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
......@@ -192,6 +204,18 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
......
......@@ -17,7 +17,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
using Scales = ck::tensor_operation::element_wise::Scales;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
......@@ -32,7 +32,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_g
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Multiply,
AddFastGelu>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances(
......@@ -45,7 +45,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_i
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Multiply,
Add>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances(
......@@ -58,7 +58,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_i
ck::Tuple<>,
BF16,
PassThrough,
Scales,
Multiply,
FastGelu>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances(
......@@ -71,7 +71,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instan
ck::Tuple<>,
BF16,
PassThrough,
Scales,
Multiply,
PassThrough>>>& instances);
// RCR
......@@ -85,7 +85,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_g
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Multiply,
AddFastGelu>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances(
......@@ -98,7 +98,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_i
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Multiply,
Add>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances(
......@@ -111,7 +111,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_i
ck::Tuple<>,
BF16,
PassThrough,
Scales,
Multiply,
FastGelu>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances(
......@@ -124,7 +124,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instan
ck::Tuple<>,
BF16,
PassThrough,
Scales,
Multiply,
PassThrough>>>& instances);
// CRR
......@@ -138,7 +138,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_g
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Multiply,
AddFastGelu>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances(
......@@ -151,7 +151,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_i
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Multiply,
Add>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances(
......@@ -164,7 +164,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_i
ck::Tuple<>,
BF16,
PassThrough,
Scales,
Multiply,
FastGelu>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances(
......@@ -177,7 +177,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instan
ck::Tuple<>,
BF16,
PassThrough,
Scales,
Multiply,
PassThrough>>>& instances);
// GEMM + Add + Gelu
......@@ -199,7 +199,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
AddFastGelu>>
{
using DeviceOp = DeviceGroupedGemmMultiABDFixedNK<AsLayout,
......@@ -211,7 +211,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
AddFastGelu>;
static auto GetInstances()
......@@ -270,7 +270,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
Add>>
{
using DeviceOp = DeviceGroupedGemmMultiABDFixedNK<AsLayout,
......@@ -282,7 +282,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
Add>;
static auto GetInstances()
......@@ -341,7 +341,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
FastGelu>>
{
using DeviceOp = DeviceGroupedGemmMultiABDFixedNK<AsLayout,
......@@ -353,7 +353,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
FastGelu>;
static auto GetInstances()
......@@ -412,7 +412,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
PassThrough>>
{
using DeviceOp = DeviceGroupedGemmMultiABDFixedNK<AsLayout,
......@@ -424,7 +424,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
PassThrough>;
static auto GetInstances()
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP16
// fp16_output
void add_device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
template <typename ALayout,
typename BLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGroupedGemmTileLoop<ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
PassThrough,
PassThrough,
PassThrough>>
{
using DeviceOp = DeviceGroupedGemmTileLoop<ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
PassThrough,
PassThrough,
PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
// fp16_output
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
}
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment