Unverified Commit d71189ff authored by Rostyslav Geyyer's avatar Rostyslav Geyyer Committed by GitHub
Browse files

Merge branch 'develop' into lwpck-1815

parents f84e2020 73b67f29
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
namespace ck_tile {
// This pipeline is qkv all located in LDS
using BlockFmhaFwdSplitKVPipelineQRKSVSAsyncDefaultPolicy =
BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ true,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 3,
/* NumPrefetchV = */ 3>;
} // namespace ck_tile
......@@ -54,38 +54,50 @@ struct BlockFmhaPipelineProblem
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
template <typename QDataType,
typename KDataType,
typename VDataType,
typename SaccDataType,
typename SMPLComputeDataType,
typename BiasDataType,
typename RandValOutputDataType,
typename LSEDataType,
typename PDataType,
typename OaccDataType,
typename ODataType,
typename BlockFmhaShape,
bool kIsGroupMode,
typename FmhaMask,
typename Traits>
struct BlockFmhaFwdSplitKVPipelineProblem : BlockFmhaPipelineProblem<QDataType,
KDataType,
VDataType,
SaccDataType,
SMPLComputeDataType,
BiasDataType,
RandValOutputDataType,
LSEDataType,
PDataType,
OaccDataType,
ODataType,
BlockFmhaShape,
kIsGroupMode,
FmhaMask,
Traits>
template <typename QDataType_,
typename KDataType_,
typename VDataType_,
typename SaccDataType_,
typename SMPLComputeDataType_,
typename BiasDataType_,
typename LSEDataType_,
typename PDataType_,
typename OaccDataType_,
typename ODataType_,
typename BlockFmhaShape_,
bool kIsGroupMode_,
typename FmhaMask_,
typename Traits_>
struct BlockFmhaFwdSplitKVPipelineProblem
{
using QDataType = remove_cvref_t<QDataType_>;
using KDataType = remove_cvref_t<KDataType_>;
using VDataType = remove_cvref_t<VDataType_>;
using SaccDataType = remove_cvref_t<SaccDataType_>;
using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>;
using BiasDataType = remove_cvref_t<BiasDataType_>;
using LSEDataType = remove_cvref_t<LSEDataType_>;
using PDataType = remove_cvref_t<PDataType_>;
using OaccDataType = remove_cvref_t<OaccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
using FmhaMask = remove_cvref_t<FmhaMask_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
template <typename LSEDataType_,
......@@ -119,4 +131,44 @@ struct BlockFmhaSplitKVCombinePipelineProblem
static constexpr index_t kMaxSplits = Traits::kMaxSplits;
};
template <typename QDataType_,
typename KDataType_,
typename VDataType_,
index_t kM0_,
index_t kN0_,
index_t kK0_,
index_t kN1_,
bool kIsVLayoutRowMajor_,
RotaryEmbeddingEnum RotaryEnum_,
bool kIsPagedKV_,
typename Traits_>
struct BlockFmhaFwdAppendKVPipelineProblem
{
using QDataType = remove_cvref_t<QDataType_>;
using KDataType = remove_cvref_t<KDataType_>;
using VDataType = remove_cvref_t<VDataType_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = 256;
static constexpr index_t kM0 = kM0_;
static constexpr index_t kN0 = kN0_;
static constexpr index_t kK0 = kK0_;
static constexpr index_t kN1 = kN1_;
using VLayout = std::conditional_t<kIsVLayoutRowMajor_,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor>;
static constexpr auto RotaryEnum = RotaryEnum_;
static constexpr bool kIsPagedKV = kIsPagedKV_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
} // namespace ck_tile
......@@ -74,6 +74,10 @@ using GNWK = ck::tensor_layout::convolution::GNWK;
using GNHWK = ck::tensor_layout::convolution::GNHWK;
using GNDHWK = ck::tensor_layout::convolution::GNDHWK;
using NGKW = ck::tensor_layout::convolution::NGKW;
using NGKHW = ck::tensor_layout::convolution::NGKHW;
using NGKDHW = ck::tensor_layout::convolution::NGKDHW;
//
using NWGC = ck::tensor_layout::convolution::NWGC;
using NHWGC = ck::tensor_layout::convolution::NHWGC;
......@@ -87,6 +91,10 @@ using NWGK = ck::tensor_layout::convolution::NWGK;
using NHWGK = ck::tensor_layout::convolution::NHWGK;
using NDHWGK = ck::tensor_layout::convolution::NDHWGK;
using NGCW = ck::tensor_layout::convolution::NGCW;
using NGCHW = ck::tensor_layout::convolution::NGCHW;
using NGCDHW = ck::tensor_layout::convolution::NGCDHW;
//
using G_K = ck::tensor_layout::convolution::G_K;
using GK_Tuple = ck::Tuple<G_K>;
......
......@@ -8,6 +8,8 @@ set(GROUPED_CONV2D_BWD_WEIGHT
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instance.cpp
)
if(DL_KERNELS)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment