"...composable_kernel.git" did not exist on "986182fc63e4352dbdc75a2cb2488cefa4a2e976"
Commit a60cf0d0 authored by Adam Osewski's avatar Adam Osewski
Browse files

Use AccDataType for Output of MFMA instruction.

parent b045fad5
...@@ -18,7 +18,7 @@ struct GemmBasicTypeConfig<ck_tile::half_t> ...@@ -18,7 +18,7 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
using ADataType = ck_tile::half_t; using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t; using BDataType = ck_tile::half_t;
using AccDataType = float; using AccDataType = float;
using CDataType = float; using CDataType = ck_tile::half_t;
// ToDo: Add more bias config to support different categories of GEMM. // ToDo: Add more bias config to support different categories of GEMM.
}; };
......
...@@ -91,6 +91,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -91,6 +91,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<
ck_tile::UniversalGemmPipelineProblem<ADataType, ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType, BDataType,
AccDataType,
CDataType, CDataType,
GemmShape, GemmShape,
ALayout, ALayout,
......
...@@ -18,7 +18,7 @@ struct BlockGemmAsBsCr ...@@ -18,7 +18,7 @@ struct BlockGemmAsBsCr
using Policy = remove_cvref_t<Policy_>; using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>; using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
...@@ -31,7 +31,7 @@ struct BlockGemmAsBsCr ...@@ -31,7 +31,7 @@ struct BlockGemmAsBsCr
{ {
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> && static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType> && std::is_same_v<BDataType, typename BBlockWindowTmp::DataType> &&
std::is_same_v<CDataType, typename CBlockTensor::DataType>, std::is_same_v<AccDataType, typename CBlockTensor::DataType>,
"wrong!"); "wrong!");
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
...@@ -195,7 +195,7 @@ struct BlockGemmAsBsCr ...@@ -195,7 +195,7 @@ struct BlockGemmAsBsCr
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr); auto c_block_tensor = make_static_distributed_tensor<AccDataType>(c_block_dstr);
return c_block_tensor; return c_block_tensor;
} }
......
...@@ -17,7 +17,7 @@ struct BlockGemmAsBsCrDefaultPolicy ...@@ -17,7 +17,7 @@ struct BlockGemmAsBsCrDefaultPolicy
{ {
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> && if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> && std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>) std::is_same_v<typename Problem::AccDataType, float>)
{ {
#if 0 #if 0
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
...@@ -45,7 +45,7 @@ struct BlockGemmAsBsCrDefaultPolicy ...@@ -45,7 +45,7 @@ struct BlockGemmAsBsCrDefaultPolicy
} }
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> && else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> && std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>) std::is_same_v<typename Problem::AccDataType, float>)
{ {
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1); return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
} }
......
...@@ -42,6 +42,7 @@ struct BlockGemmPipelineProblem ...@@ -42,6 +42,7 @@ struct BlockGemmPipelineProblem
template <typename ADataType_, template <typename ADataType_,
typename BDataType_, typename BDataType_,
typename AccDataType_,
typename CDataType_, typename CDataType_,
typename BlockGemmShape_, typename BlockGemmShape_,
typename ALayout_, typename ALayout_,
...@@ -57,6 +58,7 @@ struct UniversalGemmPipelineProblem ...@@ -57,6 +58,7 @@ struct UniversalGemmPipelineProblem
{ {
using ADataType = remove_cvref_t<ADataType_>; using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>; using BDataType = remove_cvref_t<BDataType_>;
using AccDataType = remove_cvref_t<AccDataType_>;
using CDataType = remove_cvref_t<CDataType_>; using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>; using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment