Commit a60cf0d0 authored by Adam Osewski's avatar Adam Osewski
Browse files

Use AccDataType for Output of MFMA instruction.

parent b045fad5
......@@ -18,7 +18,7 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using AccDataType = float;
using CDataType = float;
using CDataType = ck_tile::half_t;
// ToDo: Add more bias config to support different categories of GEMM.
};
......
......@@ -91,6 +91,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
CDataType,
GemmShape,
ALayout,
......
......@@ -18,7 +18,7 @@ struct BlockGemmAsBsCr
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
......@@ -31,7 +31,7 @@ struct BlockGemmAsBsCr
{
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType> &&
std::is_same_v<CDataType, typename CBlockTensor::DataType>,
std::is_same_v<AccDataType, typename CBlockTensor::DataType>,
"wrong!");
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
......@@ -195,7 +195,7 @@ struct BlockGemmAsBsCr
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
auto c_block_tensor = make_static_distributed_tensor<AccDataType>(c_block_dstr);
return c_block_tensor;
}
......
......@@ -17,7 +17,7 @@ struct BlockGemmAsBsCrDefaultPolicy
{
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
std::is_same_v<typename Problem::AccDataType, float>)
{
#if 0
constexpr index_t kBlockSize = Problem::kBlockSize;
......@@ -45,7 +45,7 @@ struct BlockGemmAsBsCrDefaultPolicy
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
std::is_same_v<typename Problem::AccDataType, float>)
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
}
......
......@@ -42,6 +42,7 @@ struct BlockGemmPipelineProblem
template <typename ADataType_,
typename BDataType_,
typename AccDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename ALayout_,
......@@ -57,6 +58,7 @@ struct UniversalGemmPipelineProblem
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using AccDataType = remove_cvref_t<AccDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment