Commit 21541511 authored by ThomasNing's avatar ThomasNing
Browse files

Solve FMHA error

parent 20034872
...@@ -26,7 +26,6 @@ struct BlockGemmARegBRegCRegV1 ...@@ -26,7 +26,6 @@ struct BlockGemmARegBRegCRegV1
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
...@@ -45,9 +44,13 @@ struct BlockGemmARegBRegCRegV1 ...@@ -45,9 +44,13 @@ struct BlockGemmARegBRegCRegV1
}; };
public: public:
using Traits = GemmTraits_<Problem_, Policy_>; using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using Traits = GemmTraits_<Problem, Policy>;
using WarpGemm = typename Traits::WarpGemm; using WarpGemm = typename Traits::WarpGemm;
using BlockGemmShape = typename Traits::BlockGemmShape;
using ADataType = remove_cvref_t<typename Traits::ADataType>; using ADataType = remove_cvref_t<typename Traits::ADataType>;
using BDataType = remove_cvref_t<typename Traits::BDataType>; using BDataType = remove_cvref_t<typename Traits::BDataType>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment