Unverified Commit b7abe77a authored by carlushuang's avatar carlushuang Committed by GitHub
Browse files

fix compile error for fmha_fwd example (#21)

parent 7ccf0bb5
......@@ -12,6 +12,7 @@
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1.hpp"
......@@ -229,13 +230,13 @@ struct BlockFmhaPipelineQKVSDefaultPolicy
__host__ __device__ static constexpr auto GetQKBlockGemm()
{
using BlockGemmProblem =
BlockGemmASmemBSmemCRegV1Problem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>>;
BlockGemmPipelineProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>>;
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1DefaultPolicy;
return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
......@@ -245,13 +246,13 @@ struct BlockFmhaPipelineQKVSDefaultPolicy
__host__ __device__ static constexpr auto GetKVBlockGemm()
{
using BlockGemmProblem =
BlockGemmARegBSmemCRegV1Problem<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>>;
BlockGemmPipelineProblem<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>>;
using BlockGemmPolicy = BlockGemmARegBSmemCRegV1DefaultPolicy;
return BlockGemmARegBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment