Unverified Commit d71189ff authored by Rostyslav Geyyer's avatar Rostyslav Geyyer Committed by GitHub
Browse files

Merge branch 'develop' into lwpck-1815

parents f84e2020 73b67f29
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
namespace ck_tile {
template <typename Problem_, typename Policy_ = BlockFmhaFwdAppendKVPipelineDefaultPolicy>
struct BlockFmhaFwdAppendKVPipeline
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using QDataType = typename Problem::QDataType;
using KDataType = typename Problem::KDataType;
using VDataType = typename Problem::VDataType;
using VLayout = typename Problem::VLayout;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = Problem::kM0;
static constexpr index_t kN0 = Problem::kN0;
static constexpr index_t kK0 = Problem::kK0;
static constexpr index_t kN1 = Problem::kN1;
static constexpr auto RotaryEnum = Problem::RotaryEnum;
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
if constexpr(kK0 <= 32)
{
return 2;
}
else if constexpr(kK0 <= 64)
{
return 3;
}
else if constexpr(kK0 <= 128)
{
return 2;
}
else if constexpr(kK0 <= 256)
{
return 1;
}
}
}();
template <typename QDramBlockWindow,
typename KDramBlockWindow,
typename KPageBlockNavigator,
typename KnewDramBlockWindow,
typename VDramBlockWindow,
typename VPageBlockNavigator,
typename VnewDramBlockWindow,
typename QElementFunction,
typename KnewElementFunction,
typename VnewElementFunction,
typename QRotaryCosDramBlockWindow,
typename QRotarySinDramBlockWindow,
typename KnewRotaryCosDramBlockWindow,
typename KnewRotarySinDramBlockWindow>
CK_TILE_HOST_DEVICE auto
operator()(QDramBlockWindow& q_dram_block_window, // M0*K0 tile
const QElementFunction& q_element_func,
KDramBlockWindow& k_dram_block_window, // N0*K0 tile
index_t i_page_block_k,
const KPageBlockNavigator& k_page_block_navigator,
const KnewDramBlockWindow& knew_dram_block_window, // N0*K0 tile
const KnewElementFunction& knew_element_func,
VDramBlockWindow& v_dram_block_window, // N1*N0 tile
index_t i_page_block_v,
const VPageBlockNavigator& v_page_block_navigator,
const VnewDramBlockWindow& vnew_dram_block_window, // N1*N0 tile
const VnewElementFunction& vnew_element_func,
const QRotaryCosDramBlockWindow q_rotary_cos_dram_block_window,
const QRotarySinDramBlockWindow q_rotary_sin_dram_block_window,
const KnewRotaryCosDramBlockWindow knew_rotary_cos_dram_block_window,
const KnewRotarySinDramBlockWindow knew_rotary_sin_dram_block_window,
index_t rotary_dim,
bool skip_rotate_q,
bool skip_rotate_append_kv) const
{
if(!skip_rotate_append_kv)
{
// append Knew to K
auto knew_window = make_tile_window(
knew_dram_block_window, Policy::template MakeKnewDramTileDistribution<Problem>());
auto knew_tile = [&]() {
auto knew = load_tile(knew_window);
return tile_elementwise_in(knew_element_func, knew);
}();
// optionally apply rotary embedding to Knew
if constexpr(RotaryEnum != RotaryEmbeddingEnum::NONE)
{
auto rotary_cos_window =
make_tile_window(knew_rotary_cos_dram_block_window,
Policy::template MakeRotaryCosSinTileDistribution<
Problem,
/*IsRotaryCosSinForQ=*/false>());
auto rotary_sin_window =
make_tile_window(knew_rotary_sin_dram_block_window,
Policy::template MakeRotaryCosSinTileDistribution<
Problem,
/*IsRotaryCosSinForQ=*/false>());
// We assume that each thread owns contiguous elements on head dimention. And we
// will use the distribution to enable/disable threads in order to override partial
// knew_tile content
auto [thread_start, thread_end] =
Policy::template GetKnewThreadRangeAlongK<Problem>();
ignore = thread_start;
BlockRotaryEmbedding<RotaryEnum>::apply(knew_tile,
knew_window,
rotary_cos_window,
rotary_sin_window,
rotary_dim,
thread_end);
}
store_tile(k_dram_block_window, knew_tile);
// write tile to another block if nesscary
if constexpr(kIsPagedKV)
{
if(k_page_block_navigator.is_cross_block(i_page_block_k, k_dram_block_window))
{
k_page_block_navigator.move_to_block(
i_page_block_k, k_dram_block_window, i_page_block_k + 1);
store_tile(k_dram_block_window, knew_tile);
}
}
// append Vnew to V
auto vnew_window = make_tile_window(
vnew_dram_block_window, Policy::template MakeVnewDramTileDistribution<Problem>());
auto vnew_tile = [&]() {
auto vnew = load_tile(vnew_window);
return tile_elementwise_in(vnew_element_func, vnew);
}();
store_tile(v_dram_block_window, vnew_tile);
// write tile to another block if nesscary
if constexpr(kIsPagedKV)
{
if(v_page_block_navigator.is_cross_block(i_page_block_v, v_dram_block_window))
{
v_page_block_navigator.move_to_block(
i_page_block_v, v_dram_block_window, i_page_block_v + 1);
store_tile(v_dram_block_window, vnew_tile);
}
}
}
if(!skip_rotate_q)
{
// optionally apply rotary embedding to Q
if constexpr(RotaryEnum != RotaryEmbeddingEnum::NONE)
{
auto q_window = make_tile_window(
q_dram_block_window, Policy::template MakeQDramTileDistribution<Problem>());
auto q_tile = [&]() {
auto q = load_tile(q_window);
return tile_elementwise_in(q_element_func, q);
}();
auto rotary_cos_window =
make_tile_window(q_rotary_cos_dram_block_window,
Policy::template MakeRotaryCosSinTileDistribution<
Problem,
/*IsRotaryCosSinForQ=*/true>());
auto rotary_sin_window =
make_tile_window(q_rotary_sin_dram_block_window,
Policy::template MakeRotaryCosSinTileDistribution<
Problem,
/*IsRotaryCosSinForQ=*/true>());
// We assume that each thread owns contiguous elements on head dimention. And we
// will use the distribution to enable/disable threads in order to override partial
// q_tile content
auto [thread_start, thread_end] = Policy::template GetQThreadRangeAlongK<Problem>();
ignore = thread_start;
BlockRotaryEmbedding<RotaryEnum>::apply(
q_tile, q_window, rotary_cos_window, rotary_sin_window, rotary_dim, thread_end);
store_tile(q_dram_block_window, q_tile);
}
}
}
template <typename QDramBlockWindow,
typename KDramBlockWindow,
typename KPageBlockNavigator,
typename KnewDramBlockWindow,
typename VDramBlockWindow,
typename VPageBlockNavigator,
typename VnewDramBlockWindow,
typename QRotaryCosDramBlockWindow,
typename QRotarySinDramBlockWindow,
typename KnewRotaryCosDramBlockWindow,
typename KnewRotarySinDramBlockWindow>
CK_TILE_HOST_DEVICE auto
operator()(QDramBlockWindow& q_dram_block_window,
KDramBlockWindow& k_dram_block_window,
index_t i_page_block_k,
const KPageBlockNavigator& k_page_block_navigator,
const KnewDramBlockWindow& knew_dram_block_window,
VDramBlockWindow& v_dram_block_window,
index_t i_page_block_v,
const VPageBlockNavigator& v_page_block_navigator,
const VnewDramBlockWindow& vnew_dram_block_window,
const QRotaryCosDramBlockWindow& q_rotary_cos_dram_block_window,
const QRotarySinDramBlockWindow& q_rotary_sin_dram_block_window,
const KnewRotaryCosDramBlockWindow& knew_rotary_cos_dram_block_window,
const KnewRotarySinDramBlockWindow& knew_rotary_sin_dram_block_window,
index_t rotary_dim,
bool skip_rotate_q,
bool skip_rotate_append_kv) const
{
return operator()(q_dram_block_window,
identity{},
k_dram_block_window,
i_page_block_k,
k_page_block_navigator,
knew_dram_block_window,
identity{},
v_dram_block_window,
i_page_block_v,
v_page_block_navigator,
vnew_dram_block_window,
identity{},
q_rotary_cos_dram_block_window,
q_rotary_sin_dram_block_window,
knew_rotary_cos_dram_block_window,
knew_rotary_sin_dram_block_window,
rotary_dim,
skip_rotate_q,
skip_rotate_append_kv);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
// This pipeline is qkv all located in LDS
struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
return 16 / sizeof(QDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK()
{
using KDataType = remove_cvref_t<typename Problem::KDataType>;
return 16 / sizeof(KDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
{
using VLayout = remove_cvref_t<typename Problem::VLayout>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::kN0;
constexpr index_t kKPerBlock = Problem::kN1;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
// TODO: not correct!
if constexpr(total_pixels > 4)
return 4;
else
return 2;
}
else
{
return 16 / sizeof(VDataType);
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQNumElemsPerRead()
{
using DataType = typename Problem::QDataType;
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
{
/// NOTICE: we might need to lower down this to support smaller rotary_dim
return 16 / sizeof(DataType);
}
else
{
return 16 / sizeof(DataType);
}
}
template <typename Problem>
CK_TILE_DEVICE static auto GetQThreadRangeAlongK()
{
static_assert(Problem::RotaryEnum != RotaryEmbeddingEnum::NONE);
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
{
constexpr index_t KPerThread = GetQNumElemsPerRead<Problem>();
static_assert(Problem::kK0 % KPerThread == 0);
constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread;
index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread;
return make_tuple(start_pos, start_pos + KPerThread);
}
else
{
constexpr index_t KPerThread = GetQNumElemsPerRead<Problem>();
static_assert(Problem::kK0 % KPerThread == 0);
constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread;
index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread;
return make_tuple(start_pos, start_pos + KPerThread);
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kKPerBlock = Problem::kK0;
constexpr index_t KPerThread = GetQNumElemsPerRead<Problem>();
constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreadPerBlock;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (NumWarps * MThreadPerWarp);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
sequence<KThreadPerBlock, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKnewNumElemsPerRead()
{
using DataType = typename Problem::KDataType;
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
{
/// NOTICE: we might need to lower down this to support smaller rotary_dim
return 16 / sizeof(DataType);
}
else
{
return 16 / sizeof(DataType);
}
}
template <typename Problem>
CK_TILE_DEVICE static auto GetKnewThreadRangeAlongK()
{
static_assert(Problem::RotaryEnum != RotaryEmbeddingEnum::NONE);
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
{
constexpr index_t KPerThread = GetKnewNumElemsPerRead<Problem>();
constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread;
index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread;
return make_tuple(start_pos, start_pos + KPerThread);
}
else
{
constexpr index_t KPerThread = GetKnewNumElemsPerRead<Problem>();
constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread;
index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread;
return make_tuple(start_pos, start_pos + KPerThread);
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKnewDramTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::kN0;
constexpr index_t kKPerBlock = Problem::kK0;
constexpr index_t KPerThread = GetKnewNumElemsPerRead<Problem>();
constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t NPerThread = kNPerBlock / (NumWarps * NThreadPerWarp);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
sequence<KThreadPerBlock, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
{
// TODO: this is for 3d layout
using VDataType = remove_cvref_t<typename Problem::VDataType>;
return 16 / sizeof(VDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVnewDramTileDistribution()
{
using VLayout = remove_cvref_t<typename Problem::VLayout>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::kN1;
constexpr index_t kKPerBlock = Problem::kN0;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t NPerThread = 16 / sizeof(VDataType);
constexpr index_t NThreadPerBlock = kNPerBlock / NPerThread;
constexpr index_t KThreadPerWarp = get_warp_size() / NThreadPerBlock;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t KPerThread = kKPerBlock / (NumWarps * KThreadPerWarp);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NThreadPerBlock, NPerThread>,
sequence<KPerThread, NumWarps, KThreadPerWarp>>,
tuple<sequence<2>, sequence<1, 2>>,
tuple<sequence<1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 0>>{});
}
else
{
constexpr index_t KPerThread = 16 / sizeof(VDataType);
constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t NPerThread = kNPerBlock / (NumWarps * NThreadPerWarp);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
sequence<KThreadPerBlock, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
}
template <typename Problem, bool IsRotaryCosSinForQ>
CK_TILE_HOST_DEVICE static constexpr auto GetRotaryCosSinTileSize()
{
constexpr index_t height = (IsRotaryCosSinForQ ? Problem::kM0 : Problem::kN0);
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
{
return make_tuple(number<height>{}, number<Problem::kK0>{});
}
else
{
return make_tuple(number<height>{}, number<Problem::kK0 / 2>{});
}
}
template <typename Problem, bool IsRotaryCosSinForQ>
CK_TILE_HOST_DEVICE static constexpr auto MakeRotaryCosSinTileDistribution()
{
using DataType = std::conditional_t<IsRotaryCosSinForQ,
typename Problem::QDataType,
typename Problem::KDataType>;
constexpr auto TileSize = GetRotaryCosSinTileSize<Problem, IsRotaryCosSinForQ>();
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = TileSize[number<0>{}];
constexpr index_t kKPerBlock = TileSize[number<1>{}];
constexpr index_t KPerThread = []() {
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
{
/// NOTICE: we might need to lower down this to support smaller rotary_dim
return 16 / sizeof(DataType);
}
else
{
return 8 / sizeof(DataType);
}
}();
constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t NPerThread = kNPerBlock / (NumWarps * NThreadPerWarp);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
sequence<KThreadPerBlock, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
};
} // namespace ck_tile
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -15,19 +14,18 @@ namespace ck_tile { ...@@ -15,19 +14,18 @@ namespace ck_tile {
template <typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy> template <typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
struct BlockFmhaFwdSplitKVPipelineQRKSVS struct BlockFmhaFwdSplitKVPipelineQRKSVS
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>; using Policy = remove_cvref_t<Policy_>;
using QDataType = remove_cvref_t<typename Problem::QDataType>; using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>; using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>; using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>; using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>; using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>; using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>; using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>; using PDataType = remove_cvref_t<typename Problem::PDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>; using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>; using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>; using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
...@@ -49,8 +47,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -49,8 +47,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = true; // always store LSE (acc) static constexpr bool kStoreLSE = true; // always store LSE (acc)
static constexpr bool kHasDropout = false; // ignore this flag static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
// last dimension vector length used to create tensor view(and decide buffer_load vector length) // last dimension vector length used to create tensor view(and decide buffer_load vector length)
...@@ -106,10 +104,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -106,10 +104,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
} }
template <typename QDramBlockWindowTmp, template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp, typename KDramBlockWindowLengths,
typename VDramBlockWindowTmp, typename KPageBlockNavigator,
typename VDramBlockWindowLengths,
typename VPageBlockNavigator,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp, typename LSEaccDramBlockWindowTmp,
typename QElementFunction, typename QElementFunction,
typename KElementFunction, typename KElementFunction,
...@@ -123,13 +122,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -123,13 +122,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func, const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
const KPageBlockNavigator& k_page_block_navigator,
const KElementFunction& k_element_func, const KElementFunction& k_element_func,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
const VPageBlockNavigator& v_page_block_navigator,
const VElementFunction& v_element_func, const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func, const BiasElementFunction& bias_element_func,
RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile
const LSEaccElementFunction& lse_acc_element_func, const LSEaccElementFunction& lse_acc_element_func,
const SAccElementFunction& s_acc_element_func, const SAccElementFunction& s_acc_element_func,
...@@ -140,20 +140,19 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -140,20 +140,19 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, void* smem_ptr) const
BlockDropout& dropout) const
{ {
static_assert( static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> && std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> && std::is_same_v<KDataType, remove_cvref_t<typename KPageBlockNavigator::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>, std::is_same_v<VDataType, remove_cvref_t<typename VPageBlockNavigator::DataType>>,
"wrong!"); "wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN0 == KDramBlockWindowLengths{}[number<0>{}] &&
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kK0 == KDramBlockWindowLengths{}[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN1 == VDramBlockWindowLengths{}[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kK1 == VDramBlockWindowLengths{}[number<1>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!"); "wrong!");
...@@ -213,12 +212,12 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -213,12 +212,12 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX( const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split); q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit if masked and no work to do. // check early exit if masked and no work to do.
if constexpr(FmhaMask::IsMasking || kHasUnevenSplits) if constexpr(FmhaMask::IsMasking || kHasUnevenSplits)
{ {
if(num_total_loop <= 0) const index_t original_num_total_loop =
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
if(original_num_total_loop <= 0)
{ {
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
...@@ -237,26 +236,34 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -237,26 +236,34 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
} }
} }
auto k_dram_block_window = // make sure the first tile is completely located in page-block
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), const index_t adjusted_seqlen_k_start = [&, seqlen_k_start_ = seqlen_k_start] {
k_dram_block_window_tmp.get_window_lengths(), if constexpr(kIsPagedKV)
{seqlen_k_start, 0}); {
return kN0 * integer_divide_floor(seqlen_k_start_, kN0);
}
else
{
return seqlen_k_start_;
}
}();
const index_t num_total_loop =
integer_divide_ceil(seqlen_k_end - adjusted_seqlen_k_start, kN0);
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
k_dram_block_window_lengths, {adjusted_seqlen_k_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window = make_tile_window( auto bias_dram_window = make_tile_window(
bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(), bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N {bias_origin.at(number<0>{}), adjusted_seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0)>( auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
randval_dram_block_window_tmp, seqlen_k_start); v_dram_block_window_lengths,
{0, adjusted_seqlen_k_start}, // TODO: hdim split?
auto v_dram_window = Policy::template MakeVDramTileDistribution<Problem>());
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
auto q_tile = tile_elementwise_in(q_element_func, q); auto q_tile = tile_elementwise_in(q_element_func, q);
...@@ -271,14 +278,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -271,14 +278,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
{ {
// STAGE 1, QK gemm // STAGE 1, QK gemm
auto k_dram_window = make_tile_window( auto k_dram_window = make_tile_window(
k_dram_block_window.get_bottom_tensor_view(), k_dram_block_window,
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load // load
auto k_block_tile = load_tile(k_dram_window); auto k_block_tile = load_tile(k_dram_window);
{ {
// moving k_dram_window is an in-page-block operation, so there is
// no need to invoke k_page_block_navigator.move_tile_window() here.
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
clear_tile(s_acc); // initialize C clear_tile(s_acc); // initialize C
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
...@@ -355,7 +362,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -355,7 +362,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
} }
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
const auto k_origin = k_dram_block_window.get_window_origin(); const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin());
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc); s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
...@@ -381,22 +389,32 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -381,22 +389,32 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
} }
move_tile_window(bias_dram_window, {0, kN0}); move_tile_window(bias_dram_window, {0, kN0});
/// TODO: only check in last iteration without increasing code size /// TODO: only check in first/last iteration without increasing code size
if constexpr(kHasUnevenSplits) if constexpr(kHasUnevenSplits)
{ {
const auto k_origin = k_dram_block_window.get_window_origin(); const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin());
set_tile_if(s_acc, set_tile_if(s_acc,
-numeric<SMPLComputeDataType>::infinity(), -numeric<SMPLComputeDataType>::infinity(),
[&, seqlen_k_end_ = seqlen_k_end](auto tile_idx) { [&, seqlen_k_start_ = seqlen_k_start, seqlen_k_end_ = seqlen_k_end](
auto tile_idx) {
const auto col = const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return seqlen_k_end_ <= col; if constexpr(kIsPagedKV)
{
return col < seqlen_k_start_ || seqlen_k_end_ <= col;
}
else
{
return seqlen_k_end_ <= col;
}
}); });
} }
if constexpr(kPadSeqLenK || FmhaMask::IsMasking) if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{ {
const auto k_origin = k_dram_block_window.get_window_origin(); const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin());
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}), k_origin.at(number<0>{}),
number<kM0>{}, number<kM0>{},
...@@ -501,12 +519,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -501,12 +519,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}); });
}); });
if constexpr(kHasDropout)
{
dropout.Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window);
}
block_sync_lds(); block_sync_lds();
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
...@@ -522,7 +534,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -522,7 +534,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
store_tile(v_lds_window, store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
} }
move_tile_window(v_dram_window, {0, kK1}); i_page_block_v =
v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {0, kK1});
const auto p = const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute)); cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
...@@ -530,8 +543,10 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -530,8 +543,10 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
// STAGE 3, KV gemm // STAGE 3, KV gemm
if constexpr(k1_loops > 1) if constexpr(k1_loops > 1)
{ {
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { static_for<0, k1_loops - 1, 1>{}([&,
const auto v = load_tile(v_dram_window); // load next v &i_page_block_v_ = i_page_block_v,
&v_dram_window_ = v_dram_window](auto i_k1) {
const auto v = load_tile(v_dram_window_); // load next v
block_sync_lds(); block_sync_lds();
gemm_1(o_acc, gemm_1(o_acc,
get_slice_tile( get_slice_tile(
...@@ -552,11 +567,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -552,11 +567,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
store_tile(v_lds_window, store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v)); // store next v tile_elementwise_in(v_element_func, v)); // store next v
} }
move_tile_window(v_dram_window, {0, kK1}); i_page_block_v_ = v_page_block_navigator.move_tile_window(
i_page_block_v_, v_dram_window_, {0, kK1});
}); });
} }
// move K tile windows // move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0}); i_page_block_k = k_page_block_navigator.move_tile_window(
i_page_block_k, k_dram_block_window, {kN0, 0});
// tail // tail
{ {
block_sync_lds(); block_sync_lds();
...@@ -618,36 +635,38 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -618,36 +635,38 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
} }
template <typename QDramBlockWindowTmp, template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp, typename KDramBlockWindowLengths,
typename VDramBlockWindowTmp, typename KPageBlockNavigator,
typename VDramBlockWindowLengths,
typename VPageBlockNavigator,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp, typename LSEaccDramBlockWindowTmp,
typename PositionEncoding> typename PositionEncoding>
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const KPageBlockNavigator& k_page_block_navigator,
const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
const VPageBlockNavigator& v_page_block_navigator,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile
index_t num_splits, index_t num_splits,
index_t i_split, index_t i_split,
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, void* smem_ptr) const
BlockDropout& dropout) const
{ {
return operator()(q_dram_block_window_tmp, return operator()(q_dram_block_window_tmp,
identity{}, identity{},
k_dram_block_window_tmp, k_dram_block_window_lengths,
k_page_block_navigator,
identity{}, identity{},
v_dram_block_window_tmp, v_dram_block_window_lengths,
v_page_block_navigator,
identity{}, identity{},
bias_dram_block_window_tmp, bias_dram_block_window_tmp,
identity{}, identity{},
randval_dram_block_window_tmp,
lse_acc_dram_block_window_tmp, lse_acc_dram_block_window_tmp,
identity{}, identity{},
identity{}, identity{},
...@@ -658,8 +677,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -658,8 +677,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
mask, mask,
position_encoding, position_encoding,
scale_s, scale_s,
smem_ptr, smem_ptr);
dropout);
} }
}; };
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
template <typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSAsyncDefaultPolicy>
struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true &&
Problem::kPadHeadDimV == true);
static constexpr bool kPadSeqLenQ = true;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x)
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = true; // always store LSE (acc)
static constexpr bool kHasDropout = false; // ignore this flag
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK = Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentO = Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
#if CK_TILE_FMHA_FWD_FAST_EXP2
static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;
#endif
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
if constexpr(kK0BlockLength <= 32)
{
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS &&
FmhaMask::IsMasking)
return 1;
else
return 2;
}
else if constexpr(kK0BlockLength <= 64)
{
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 2;
else
return 3;
}
else if constexpr(kK0BlockLength <= 128)
{
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1;
else
return 2;
}
else if constexpr(kK0BlockLength <= 256)
{
return 1;
}
}
}();
static constexpr const char* name = "qr_async";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp,
typename QElementFunction,
typename KElementFunction,
typename VElementFunction,
typename BiasElementFunction,
typename LSEaccElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const KElementFunction& /*k_element_func*/,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile
const LSEaccElementFunction& lse_acc_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
index_t num_splits,
index_t i_split,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
BlockDropout& dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
constexpr auto LdsSeq = Policy::template GetLdsBufferSequence<Problem>();
// K tile in LDS
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
auto k_lds_store = generate_tuple(
[&](auto i_buf) {
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf)),
Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf).get_lengths(),
{0, 0, 0});
},
number<Policy::NumPrefetchK>{});
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
auto k_lds_load = generate_tuple(
[&](auto i_buf) {
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor<Problem>(i_buf)),
Policy::template MakeKLdsLoadBlockDescriptor<Problem>(i_buf).get_lengths(),
{0, 0});
},
number<Policy::NumPrefetchK>{});
#else
auto k_lds_Load_view = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor<Problem>());
auto k_lds_load =
make_tile_window(k_lds_Load_view,
Policy::template MakeKLdsLoadBlockDescriptor<Problem>().get_lengths(),
{0, 0});
#endif
// V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(smem_ptr),
Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
auto q_dram_window = make_tile_window(
q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem, decltype(gemm_0)>());
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
auto q = decltype(load_tile(q_dram_window)){};
set_tile(q, number<0>{}); // use per-dword clear to avoid scratch
load_tile_raw(q, q_dram_window);
__builtin_amdgcn_sched_barrier(0);
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
auto s_acc = SaccBlockTileType{};
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
// infer Sacc, S, P, M, L, Oacc type
using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_acc));
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
// init Oacc, M, L
auto o_acc = OaccBlockTileType{};
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
__builtin_amdgcn_sched_barrier(0);
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit if masked and no work to do.
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
{
if(num_total_loop <= 0)
{
if constexpr(kStoreLSE)
{
auto lse_acc =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
store_tile(lse_acc_dram_window_tmp,
tile_elementwise_in(lse_acc_element_func, lse_acc));
}
buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it
return o_acc;
}
__builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check
}
auto k_dram_block_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0});
auto k_dram_window = make_tile_window(
k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window = make_tile_window(
bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start);
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
// prefetch K tile
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0);
buffer_load_fence(k_dram_window.get_num_access(), q.get_thread_buffer());
(void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
index_t i_total_loops = 0;
constexpr index_t k0_loops = kK0BlockLength / kK0;
constexpr index_t k1_loops = kN0 / kK1;
static_assert(1 <= k0_loops);
static_assert(1 <= k1_loops);
// main loop
do
{
// STAGE 1, QK gemm
clear_tile(s_acc); // initialize C
if constexpr(k0_loops > 1)
{
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}),
k_dram_window);
if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0});
async_load_fence(k_dram_window.get_num_access());
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
gemm_0(s_acc,
get_slice_tile(
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
k_lds_load[number<LdsSeq.at(number<i_k0>{})>{}]);
#else
get_slice_tile(k_lds_load,
sequence<(LdsSeq.at(number<i_k0>{})) * kN0, 0>{},
sequence<(LdsSeq.at(number<i_k0>{}) + 1) * kN0, kK0>{}));
#endif
});
}
// TODO: this to fix a bug when loop smaller than 2,
// the following fence/barrier will be scheduled inside 1st loop
if constexpr(k0_loops <= 2)
__builtin_amdgcn_sched_barrier(0);
async_load_fence();
__builtin_amdgcn_s_barrier();
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
auto v_buf = load_tile(v_dram_window, bool_constant<false>{});
__builtin_amdgcn_sched_barrier(0);
{ // tail
gemm_0(s_acc,
get_slice_tile(
q, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kM0, k0_loops * kK0>{}),
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
k_lds_load[number<LdsSeq.at(number<k0_loops - 1>{})>{}]);
#else
get_slice_tile(
k_lds_load,
sequence<(LdsSeq.at(number<k0_loops - 1>{})) * kN0, 0>{},
sequence<(LdsSeq.at(number<k0_loops - 1>{}) + 1) * kN0, kK0>{}));
#endif
}
__builtin_amdgcn_sched_barrier(1);
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x += type_convert<SaccDataType>(bias_element_func(y));
#else
x += log2e_v<SaccDataType> *
type_convert<SaccDataType>(bias_element_func(y));
#endif
},
s_acc,
bias_tile);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= scale_s;
position_encoding.update(s_acc(i_j_idx), row, col);
});
});
}
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
#endif
}
move_tile_window(bias_dram_window, {0, kN0});
/// TODO: only check in last iteration without increasing code size
if constexpr(kHasUnevenSplits)
{
const auto k_origin = k_dram_block_window.get_window_origin();
set_tile_if(s_acc,
-numeric<SMPLComputeDataType>::infinity(),
[&, seqlen_k_end_ = seqlen_k_end](auto tile_idx) {
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return seqlen_k_end_ <= col;
});
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
const auto k_origin = k_dram_block_window.get_window_origin();
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s,
sequence<1>{},
f_max,
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
s.get_tile_distribution()); // Pcompute{j}
__builtin_amdgcn_sched_barrier(0x7F);
// store & prefetch next v, after the max reduction
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf);
auto v_lds_window_tmp =
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(
v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
}
else
{
auto v_lds_window_tmp =
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_buf)); // store the prefetch
}
if constexpr(k1_loops > 1)
{
move_tile_window(
v_dram_window,
{0, kK1}); // will have scratch if move this right after load_tile(v_dram)...
v_buf = load_tile(v_dram_window, bool_constant<false>{}); // load next v_buf
}
__builtin_amdgcn_sched_barrier(0);
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration. alibi does not have this problem
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}
else
{
return raw_m;
}
};
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_max = scale_s * get_validated_m(m[i_idx]);
#endif
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
}
#else
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
#endif
});
});
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
// l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
}
}();
#else
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
#endif
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc(i_j_idx) *= tmp;
});
});
if constexpr(kHasDropout)
{
auto randval_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
dropout.Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
randval_ptr,
seqlen_k_start + i_total_loops * kN0,
p_compute,
randval_dram_window);
}
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
// STAGE 3, KV gemm
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1)
{
v_buf = load_tile(v_dram_window, bool_constant<false>{}); // load next v_buf
}
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) * kN1, kK1>{}));
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf);
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func,
v_shuffle_tmp)); // store the prefetch
}
else
{
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_buf)); // store next v_buf
}
if constexpr(i_k1 < k1_loops - 1)
move_tile_window(v_dram_window, {0, kK1});
});
}
i_total_loops++;
if(i_total_loops < num_total_loop)
{
// move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0});
k_dram_window =
make_tile_window(k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>());
if constexpr(k1_loops >= 2 &&
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
__builtin_amdgcn_s_barrier();
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
}
// tail
{
block_sync_lds();
gemm_1(
o_acc,
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1, kK1>{}));
}
} while(i_total_loops < num_total_loop);
// store lse acc
if constexpr(kStoreLSE)
{
auto lse_acc = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans();
sweep_tile_span(lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
lse_acc(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]);
}
else
{
lse_acc(i_idx) = m_[i_idx] * scale_s * R_LOG2E + log(l_[i_idx]);
}
#else
lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]);
#endif
});
store_tile(lse_acc_dram_window_tmp, tile_elementwise_in(lse_acc_element_func, lse_acc));
}
// finally, O
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = [&]() {
if constexpr(FmhaMask::IsMasking)
{
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
}
else
return 1 / l[i_idx];
}();
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
});
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile
index_t num_splits,
index_t i_split,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
BlockDropout& dropout) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
identity{},
v_dram_block_window_tmp,
identity{},
bias_dram_block_window_tmp,
identity{},
randval_dram_block_window_tmp,
lse_acc_dram_block_window_tmp,
identity{},
identity{},
identity{},
identity{},
num_splits,
i_split,
mask,
position_encoding,
scale_s,
smem_ptr,
dropout);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
namespace ck_tile {
// This pipeline is qkv all located in LDS
using BlockFmhaFwdSplitKVPipelineQRKSVSAsyncDefaultPolicy =
BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ true,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 3,
/* NumPrefetchV = */ 3>;
} // namespace ck_tile
...@@ -54,38 +54,50 @@ struct BlockFmhaPipelineProblem ...@@ -54,38 +54,50 @@ struct BlockFmhaPipelineProblem
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
}; };
template <typename QDataType, template <typename QDataType_,
typename KDataType, typename KDataType_,
typename VDataType, typename VDataType_,
typename SaccDataType, typename SaccDataType_,
typename SMPLComputeDataType, typename SMPLComputeDataType_,
typename BiasDataType, typename BiasDataType_,
typename RandValOutputDataType, typename LSEDataType_,
typename LSEDataType, typename PDataType_,
typename PDataType, typename OaccDataType_,
typename OaccDataType, typename ODataType_,
typename ODataType, typename BlockFmhaShape_,
typename BlockFmhaShape, bool kIsGroupMode_,
bool kIsGroupMode, typename FmhaMask_,
typename FmhaMask, typename Traits_>
typename Traits> struct BlockFmhaFwdSplitKVPipelineProblem
struct BlockFmhaFwdSplitKVPipelineProblem : BlockFmhaPipelineProblem<QDataType,
KDataType,
VDataType,
SaccDataType,
SMPLComputeDataType,
BiasDataType,
RandValOutputDataType,
LSEDataType,
PDataType,
OaccDataType,
ODataType,
BlockFmhaShape,
kIsGroupMode,
FmhaMask,
Traits>
{ {
static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits; using QDataType = remove_cvref_t<QDataType_>;
using KDataType = remove_cvref_t<KDataType_>;
using VDataType = remove_cvref_t<VDataType_>;
using SaccDataType = remove_cvref_t<SaccDataType_>;
using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>;
using BiasDataType = remove_cvref_t<BiasDataType_>;
using LSEDataType = remove_cvref_t<LSEDataType_>;
using PDataType = remove_cvref_t<PDataType_>;
using OaccDataType = remove_cvref_t<OaccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
using FmhaMask = remove_cvref_t<FmhaMask_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
}; };
template <typename LSEDataType_, template <typename LSEDataType_,
...@@ -119,4 +131,44 @@ struct BlockFmhaSplitKVCombinePipelineProblem ...@@ -119,4 +131,44 @@ struct BlockFmhaSplitKVCombinePipelineProblem
static constexpr index_t kMaxSplits = Traits::kMaxSplits; static constexpr index_t kMaxSplits = Traits::kMaxSplits;
}; };
template <typename QDataType_,
typename KDataType_,
typename VDataType_,
index_t kM0_,
index_t kN0_,
index_t kK0_,
index_t kN1_,
bool kIsVLayoutRowMajor_,
RotaryEmbeddingEnum RotaryEnum_,
bool kIsPagedKV_,
typename Traits_>
struct BlockFmhaFwdAppendKVPipelineProblem
{
using QDataType = remove_cvref_t<QDataType_>;
using KDataType = remove_cvref_t<KDataType_>;
using VDataType = remove_cvref_t<VDataType_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = 256;
static constexpr index_t kM0 = kM0_;
static constexpr index_t kN0 = kN0_;
static constexpr index_t kK0 = kK0_;
static constexpr index_t kN1 = kN1_;
using VLayout = std::conditional_t<kIsVLayoutRowMajor_,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor>;
static constexpr auto RotaryEnum = RotaryEnum_;
static constexpr bool kIsPagedKV = kIsPagedKV_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
} // namespace ck_tile } // namespace ck_tile
...@@ -707,16 +707,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -707,16 +707,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
{ {
if constexpr(AsyncCopyK) if constexpr(AsyncCopyK)
{ {
return GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>(); return GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>(0);
} }
else else
{ {
return ck_tile::max(GetSmemSizeKV<Problem>(), GetSmemSizeDropout<Problem>()); return ck_tile::max(GetSmemSizeKV<Problem>(), GetSmemSizeDropout<Problem>(0));
} }
} }
// this method is only available when Problem::kHasDropout is present
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeDropout() CK_TILE_HOST_DEVICE static constexpr std::
enable_if_t<std::is_convertible_v<decltype(Problem::kHasDropout), bool>, ck_tile::index_t>
GetSmemSizeDropout(int)
{ {
if constexpr(Problem::kHasDropout) if constexpr(Problem::kHasDropout)
{ {
...@@ -736,6 +739,13 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -736,6 +739,13 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
} }
} }
// fallback version if Problem::kHasDropout is not exist
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeDropout(...)
{
return 0;
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution()
{ {
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -32,30 +33,31 @@ struct TileFmhaTraits ...@@ -32,30 +33,31 @@ struct TileFmhaTraits
static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr index_t kBlockPerCu = kBlockPerCu_;
}; };
template <bool kPadSeqLenQ /* padding for seqlen_q */, template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadSeqLenK /* padding for seqlen_k */, bool kPadSeqLenK_ /* padding for seqlen_k */,
bool kPadHeadDimQ /* paddding for hdim_q */, bool kPadHeadDimQ_ /* paddding for hdim_q */,
bool kPadHeadDimV /* paddding for hdim_v */, bool kPadHeadDimV_ /* paddding for hdim_v */,
BlockAttentionBiasEnum BiasEnum, BlockAttentionBiasEnum BiasEnum_,
bool kHasBiasGrad, bool kHasBiasGrad_,
bool kStoreLSE, bool kStoreLSE_,
bool kHasDropout, bool kDoFp8StaticQuant_,
bool kDoFp8StaticQuant, bool kIsPagedKV_,
bool kHasUnevenSplits_ = true, bool kHasUnevenSplits_,
index_t kBlockPerCu = -1 /* overwrite occupancy if not -1 */> index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileFmhaFwdSplitKVTraits : TileFmhaTraits<kPadSeqLenQ, struct TileFmhaFwdSplitKVTraits
kPadSeqLenK,
kPadHeadDimQ,
kPadHeadDimV,
BiasEnum,
kHasBiasGrad,
kStoreLSE,
kHasDropout,
kDoFp8StaticQuant,
kBlockPerCu>
{ {
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr bool kStoreLSE = kStoreLSE_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kIsPagedKV = kIsPagedKV_;
// determine if some split (length) is not divisible by tile size // determine if some split (length) is not divisible by tile size
static constexpr bool kHasUnevenSplits = kHasUnevenSplits_; static constexpr bool kHasUnevenSplits = kHasUnevenSplits_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
}; };
template <bool kPadSeqLenQ_ /* padding for seqlen_q */, template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
...@@ -76,6 +78,20 @@ struct TileFmhaFwdSplitKVCombineTraits ...@@ -76,6 +78,20 @@ struct TileFmhaFwdSplitKVCombineTraits
static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr index_t kBlockPerCu = kBlockPerCu_;
}; };
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadSeqLenK_ /* padding for seqlen_k */,
bool kPadHeadDimQ_ /* paddding for hdim_q */,
bool kPadHeadDimV_ /* paddding for hdim_v */,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileFmhaFwdAppendKVTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
template <bool kPadSeqLenQ_ /* padding for seqlen_q */, template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadHeadDimV_ /* paddding for hdim_v */, bool kPadHeadDimV_ /* paddding for hdim_v */,
index_t kBlockPerCu_ = 2 /* hint to occupancy */> index_t kBlockPerCu_ = 2 /* hint to occupancy */>
......
...@@ -74,6 +74,10 @@ using GNWK = ck::tensor_layout::convolution::GNWK; ...@@ -74,6 +74,10 @@ using GNWK = ck::tensor_layout::convolution::GNWK;
using GNHWK = ck::tensor_layout::convolution::GNHWK; using GNHWK = ck::tensor_layout::convolution::GNHWK;
using GNDHWK = ck::tensor_layout::convolution::GNDHWK; using GNDHWK = ck::tensor_layout::convolution::GNDHWK;
using NGKW = ck::tensor_layout::convolution::NGKW;
using NGKHW = ck::tensor_layout::convolution::NGKHW;
using NGKDHW = ck::tensor_layout::convolution::NGKDHW;
// //
using NWGC = ck::tensor_layout::convolution::NWGC; using NWGC = ck::tensor_layout::convolution::NWGC;
using NHWGC = ck::tensor_layout::convolution::NHWGC; using NHWGC = ck::tensor_layout::convolution::NHWGC;
...@@ -87,6 +91,10 @@ using NWGK = ck::tensor_layout::convolution::NWGK; ...@@ -87,6 +91,10 @@ using NWGK = ck::tensor_layout::convolution::NWGK;
using NHWGK = ck::tensor_layout::convolution::NHWGK; using NHWGK = ck::tensor_layout::convolution::NHWGK;
using NDHWGK = ck::tensor_layout::convolution::NDHWGK; using NDHWGK = ck::tensor_layout::convolution::NDHWGK;
using NGCW = ck::tensor_layout::convolution::NGCW;
using NGCHW = ck::tensor_layout::convolution::NGCHW;
using NGCDHW = ck::tensor_layout::convolution::NGCDHW;
// //
using G_K = ck::tensor_layout::convolution::G_K; using G_K = ck::tensor_layout::convolution::G_K;
using GK_Tuple = ck::Tuple<G_K>; using GK_Tuple = ck::Tuple<G_K>;
......
...@@ -56,6 +56,46 @@ using device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances = std ...@@ -56,6 +56,46 @@ using device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances = std
// clang-format on // clang-format on
>; >;
// NGCHW requires transpose, we use vector loads and stores params for them
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionBackwardWeightSpecialization ConvSpec,
BlockGemmPipelineScheduler Scheduler,
BlockGemmPipelineVersion PipelineVersion>
using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances = std::tuple<
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1, F16, F16, 1, 1>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, F16, F16, 2, 2>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F16, F16, 4, 4>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 8, F16, F16, 8, 8>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, F16, F16, 2, 2>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F16, F16, 4, 4>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, F16, F16, 8, 8>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, F16, F16, 1, 2>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F16, F16, 1, 4>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 8, F16, F16, 1, 8>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F16, F16, 1, 4>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, F16, F16, 1, 8>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, F16, F16, 2, 1>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F16, F16, 4, 1>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 8, F16, F16, 8 ,1>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F16, F16, 4, 1>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, F16, F16, 8, 1>
// clang-format on
>;
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Empty_Tuple = ck::Tuple<>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto ConvFwd3x3 = ConvolutionForwardSpecialization::Filter3x3;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_merged_groups_bf16_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ACompute| BCompute| BlockGemm| NumGroups|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Type| Type| Pipeline| ToMerge|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | Scheduler| |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsLayout, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsLayout, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 16>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsLayout, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 32>
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_merged_groups_f16_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsLayout, F16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsLayout, F16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 16>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsLayout, F16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32>
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_merged_groups_f32_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsLayout, F32, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsLayout, F32, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 16>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsLayout, F32, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 32>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -367,6 +367,21 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -367,6 +367,21 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
op_ptrs); op_ptrs);
} }
#endif
}
if constexpr(is_same_v<InLayout, NGCHW> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, NGKHW>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
is_same_v<ComputeTypeB, half_t>)
{
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instances(
op_ptrs);
}
#endif #endif
} }
} }
...@@ -447,6 +462,21 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -447,6 +462,21 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances(
op_ptrs); op_ptrs);
} }
#endif
}
if constexpr(is_same_v<InLayout, NGCDHW> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NGKDHW>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
is_same_v<ComputeTypeB, half_t>)
{
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instances(
op_ptrs);
}
#endif #endif
} }
} }
......
...@@ -137,6 +137,29 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pi ...@@ -137,6 +137,29 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pi
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW,
GKYXC,
NGKHW,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW,
GKYXC,
NGKHW,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
...@@ -240,6 +263,29 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16 ...@@ -240,6 +263,29 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKZYXC,
NGKDHW,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKZYXC,
NGKDHW,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#ifdef CK_USE_XDL #ifdef CK_USE_XDL
#include "grouped_convolution_forward_xdl.inc" #include "grouped_convolution_forward_xdl.inc"
#include "grouped_convolution_forward_xdl_large_tensor.inc" #include "grouped_convolution_forward_xdl_large_tensor.inc"
#include "grouped_convolution_forward_xdl_merged_groups.inc"
#include "grouped_convolution_forward_comp_xdl.inc" #include "grouped_convolution_forward_comp_xdl.inc"
#include "grouped_convolution_forward_mem_inter_xdl.inc" #include "grouped_convolution_forward_mem_inter_xdl.inc"
#include "grouped_convolution_forward_mem_intra_xdl.inc" #include "grouped_convolution_forward_mem_intra_xdl.inc"
...@@ -202,6 +203,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -202,6 +203,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances( add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs); op_ptrs);
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances(
op_ptrs); op_ptrs);
...@@ -217,6 +220,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -217,6 +220,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances( add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs); op_ptrs);
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances( add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances(
op_ptrs); op_ptrs);
...@@ -234,6 +239,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -234,6 +239,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
op_ptrs); op_ptrs);
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances( add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
op_ptrs); op_ptrs);
...@@ -293,6 +300,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -293,6 +300,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs);
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances( add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs); op_ptrs);
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances(
op_ptrs); op_ptrs);
...@@ -349,6 +358,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -349,6 +358,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs);
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances( add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs); op_ptrs);
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances(op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances( add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances(
op_ptrs); op_ptrs);
...@@ -366,6 +377,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -366,6 +377,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs);
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
op_ptrs); op_ptrs);
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
op_ptrs); op_ptrs);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -46,6 +46,21 @@ std::vector<std::size_t> get_layout_transpose_gnchw_to_old() ...@@ -46,6 +46,21 @@ std::vector<std::size_t> get_layout_transpose_gnchw_to_old()
{ {
return {0, 1, 2, 3}; return {0, 1, 2, 3};
} }
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NGCW> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NGKW>)
{
return {1, 0, 2, 3};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NGCHW> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NGKHW>)
{
return {1, 0, 2, 3, 4};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NGCDHW> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NGKDHW>)
{
return {1, 0, 2, 3, 4, 5};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNCHW> || else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNCHW> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GKCYX> || ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GKCYX> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNKHW>) ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNKHW>)
...@@ -132,6 +147,18 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck::utils::conv::ConvPa ...@@ -132,6 +147,18 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck::utils::conv::ConvPa
param.input_spatial_lengths_.begin() + param.num_dim_spatial_); param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
} }
// separate from legacy code above // separate from legacy code above
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NGCW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NGCHW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NGCDHW>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.end(),
param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCW> || else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCHW> || ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCHW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCDHW>) ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCDHW>)
...@@ -314,6 +341,19 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck::utils::conv::ConvP ...@@ -314,6 +341,19 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck::utils::conv::ConvP
param.output_spatial_lengths_.begin(), param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_); param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
} }
// separate from legacy code above
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NGKW> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NGKHW> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NGKDHW>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_)};
physical_lengths.insert(physical_lengths.end(),
param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNWK> || else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNWK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNHWK> || ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNHWK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNDHWK>) ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNDHWK>)
......
...@@ -8,6 +8,8 @@ set(GROUPED_CONV2D_BWD_WEIGHT ...@@ -8,6 +8,8 @@ set(GROUPED_CONV2D_BWD_WEIGHT
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instance.cpp
) )
if(DL_KERNELS) if(DL_KERNELS)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW,
GKYXC,
NGKHW,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances<
2,
NGCHW,
GKYXC,
NGKHW,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v2>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW,
GKYXC,
NGKHW,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances<
2,
NGCHW,
GKYXC,
NGKHW,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v5>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -14,6 +14,11 @@ add_instance_library(device_grouped_conv2d_fwd_instance ...@@ -14,6 +14,11 @@ add_instance_library(device_grouped_conv2d_fwd_instance
xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instance.cpp xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instance.cpp
# merged groups
# NHWGC, GKYXC, NHWGK
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp
#mem #mem
# NHWGC, GKYXC, NHWGK # NHWGC, GKYXC, NHWGK
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment