Unverified Commit b7abe77a authored by carlushuang's avatar carlushuang Committed by GitHub
Browse files

fix compile error for fmha_fwd example (#21)

parent 7ccf0bb5
......@@ -12,6 +12,7 @@
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1.hpp"
......@@ -229,7 +230,7 @@ struct BlockFmhaPipelineQKVSDefaultPolicy
__host__ __device__ static constexpr auto GetQKBlockGemm()
{
using BlockGemmProblem =
BlockGemmASmemBSmemCRegV1Problem<typename Problem::QDataType,
BlockGemmPipelineProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::kBlockSize,
......@@ -245,7 +246,7 @@ struct BlockFmhaPipelineQKVSDefaultPolicy
__host__ __device__ static constexpr auto GetKVBlockGemm()
{
using BlockGemmProblem =
BlockGemmARegBSmemCRegV1Problem<typename Problem::PDataType,
BlockGemmPipelineProblem<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::kBlockSize,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment