Unverified Commit 35c1777d authored by Thomas Ning's avatar Thomas Ning Committed by GitHub
Browse files

decouple the calling from gemm_pipeline (#1571)

* decouple the calling from gemm_pipeline

* clang format
parent 29d384d0
...@@ -5,9 +5,8 @@ ...@@ -5,9 +5,8 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
...@@ -27,20 +26,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -27,20 +26,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using GemmProblem = using GemmProblem =
GemmPipelineProblem<typename Problem::QDataType, BlockGemmProblem<typename Problem::QDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0, Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>, Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>, typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
using WarpGemm = WarpGemmMfmaDispatcher< using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::QDataType, typename Problem::QDataType,
...@@ -66,20 +60,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -66,20 +60,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm()
{ {
using GemmProblem = using GemmProblem =
GemmPipelineProblem<typename Problem::GemmDataType, BlockGemmProblem<typename Problem::GemmDataType,
typename Problem::OGradDataType, typename Problem::OGradDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0, TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kVHeaddim, Problem::BlockFmhaShape::kVHeaddim,
Problem::BlockFmhaShape::kK1>, Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps, typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>, typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadHeadDimV,
Problem::kPadHeadDimV,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
using WarpGemm = using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
...@@ -104,20 +93,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -104,20 +93,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
{ {
using GemmProblem = using GemmProblem =
GemmPipelineProblem<typename Problem::OGradDataType, BlockGemmProblem<typename Problem::OGradDataType,
typename Problem::VDataType, typename Problem::VDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0, Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK2>, Problem::BlockFmhaShape::kK2>,
typename Problem::BlockFmhaShape::Gemm2BlockWarps, typename Problem::BlockFmhaShape::Gemm2BlockWarps,
typename Problem::BlockFmhaShape::Gemm2WarpTile>, typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
using WarpGemm = WarpGemmMfmaDispatcher< using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::OGradDataType, typename Problem::OGradDataType,
...@@ -143,20 +127,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -143,20 +127,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm()
{ {
using GemmProblem = using GemmProblem =
GemmPipelineProblem<typename Problem::GemmDataType, BlockGemmProblem<typename Problem::GemmDataType,
typename Problem::QDataType, typename Problem::QDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0, TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kQKHeaddim, Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK3>, Problem::BlockFmhaShape::kK3>,
typename Problem::BlockFmhaShape::Gemm3BlockWarps, typename Problem::BlockFmhaShape::Gemm3BlockWarps,
typename Problem::BlockFmhaShape::Gemm3WarpTile>, typename Problem::BlockFmhaShape::Gemm3WarpTile>>;
TileGemmTraits<Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
Problem::kPadSeqLenK,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
using WarpGemm = using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
...@@ -181,20 +160,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -181,20 +160,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm()
{ {
using GemmProblem = using GemmProblem =
GemmPipelineProblem<typename Problem::GemmDataType, BlockGemmProblem<typename Problem::GemmDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kQKHeaddim, Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK4>, Problem::BlockFmhaShape::kK4>,
typename Problem::BlockFmhaShape::Gemm4BlockWarps, typename Problem::BlockFmhaShape::Gemm4BlockWarps,
typename Problem::BlockFmhaShape::Gemm4WarpTile>, typename Problem::BlockFmhaShape::Gemm4WarpTile>>;
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadHeadDimQ,
Problem::kPadSeqLenK,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
using WarpGemm = using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
......
...@@ -5,9 +5,9 @@ ...@@ -5,9 +5,9 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
...@@ -77,20 +77,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -77,20 +77,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using GemmProblem = using GemmProblem =
GemmPipelineProblem<typename Problem::QDataType, BlockGemmProblem<typename Problem::QDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::SaccDataType, typename Problem::SaccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0, Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>, Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>, typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
...@@ -207,20 +202,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -207,20 +202,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using GemmProblem = using GemmProblem =
GemmPipelineProblem<typename Problem::QDataType, BlockGemmProblem<typename Problem::QDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::SaccDataType, typename Problem::SaccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0, Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>, Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>, typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
...@@ -968,20 +958,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -968,20 +958,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
{ {
using GemmProblem = using GemmProblem =
GemmPipelineProblem<typename Problem::PDataType, BlockGemmProblem<typename Problem::PDataType,
typename Problem::VDataType, typename Problem::VDataType,
typename Problem::OaccDataType, typename Problem::OaccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN1, Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>, Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps, typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>, typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
TileGemmTraits<Problem::kPadSeqLenQ,
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
auto warp_gemm = [&]() { auto warp_gemm = [&]() {
if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> && if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> &&
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment