Unverified Commit b7abe77a authored by carlushuang's avatar carlushuang Committed by GitHub
Browse files

fix compile error for fmha_fwd example (#21)

parent 7ccf0bb5
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "ck/tile_program/tile/tile_elementwise.hpp" #include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp" #include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp" #include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp" #include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1.hpp" #include "ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1.hpp"
...@@ -229,13 +230,13 @@ struct BlockFmhaPipelineQKVSDefaultPolicy ...@@ -229,13 +230,13 @@ struct BlockFmhaPipelineQKVSDefaultPolicy
__host__ __device__ static constexpr auto GetQKBlockGemm() __host__ __device__ static constexpr auto GetQKBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem =
BlockGemmASmemBSmemCRegV1Problem<typename Problem::QDataType, BlockGemmPipelineProblem<typename Problem::QDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::SaccDataType, typename Problem::SaccDataType,
Problem::kBlockSize, Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0, TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0, Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>>; Problem::BlockFmhaShape::kK0>>;
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1DefaultPolicy; using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1DefaultPolicy;
return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
...@@ -245,13 +246,13 @@ struct BlockFmhaPipelineQKVSDefaultPolicy ...@@ -245,13 +246,13 @@ struct BlockFmhaPipelineQKVSDefaultPolicy
__host__ __device__ static constexpr auto GetKVBlockGemm() __host__ __device__ static constexpr auto GetKVBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem =
BlockGemmARegBSmemCRegV1Problem<typename Problem::PDataType, BlockGemmPipelineProblem<typename Problem::PDataType,
typename Problem::VDataType, typename Problem::VDataType,
typename Problem::OaccDataType, typename Problem::OaccDataType,
Problem::kBlockSize, Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0, TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN1, Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>>; Problem::BlockFmhaShape::kK1>>;
using BlockGemmPolicy = BlockGemmARegBSmemCRegV1DefaultPolicy; using BlockGemmPolicy = BlockGemmARegBSmemCRegV1DefaultPolicy;
return BlockGemmARegBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; return BlockGemmARegBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment