"...composable_kernel_rocm.git" did not exist on "5d37d7bff4e631c3b94112c31a52f209ca39dfe2"
Commit bf445c31 authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Review: Change names from FloatX to XDataType

parent 0ff1d1f8
...@@ -20,8 +20,8 @@ namespace ck { ...@@ -20,8 +20,8 @@ namespace ck {
* `MPerBlock / (MRepeat * MPerDpp) * NPerBlock / (NRepeat * NPerDpp)` waves. * `MPerBlock / (MRepeat * MPerDpp) * NPerBlock / (NRepeat * NPerDpp)` waves.
*/ */
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename ABDataType,
typename FloatAcc, typename AccDataType,
typename AK0MK1BlockDesc, typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc, typename BK0NK1BlockDesc,
index_t MPerDpp, index_t MPerDpp,
...@@ -50,7 +50,7 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2 ...@@ -50,7 +50,7 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto dpp_gemm = DppGemm<FloatAB, MPerDpp, NPerDpp, KPack>{}; static constexpr auto dpp_gemm = DppGemm<ABDataType, MPerDpp, NPerDpp, KPack>{};
static constexpr index_t KPerThread = KPerBlock / dpp_gemm.K0PerDpp; static constexpr index_t KPerThread = KPerBlock / dpp_gemm.K0PerDpp;
...@@ -58,7 +58,7 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2 ...@@ -58,7 +58,7 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerDpp); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerDpp);
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc, AccDataType,
MRepeat * NRepeat, MRepeat * NRepeat,
dpp_gemm.GetRegSizePerDpp(), dpp_gemm.GetRegSizePerDpp(),
true> true>
...@@ -260,9 +260,9 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2 ...@@ -260,9 +260,9 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
{ {
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ABDataType>(
a_thread_desc_.GetElementSpaceSize()); a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ABDataType>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
...@@ -284,17 +284,18 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2 ...@@ -284,17 +284,18 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
b_thread_buf); b_thread_buf);
static_for<0, KPerThread, KPack>{}([&](auto k) { static_for<0, KPerThread, KPack>{}([&](auto k) {
vector_type<FloatAB, KPack> a_thread_vec; vector_type<ABDataType, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec; vector_type<ABDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) { static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf a_thread_vec.template AsType<ABDataType>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}]; [Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf b_thread_vec.template AsType<ABDataType>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}]; [Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
}); });
using dpp_input_type = typename vector_type<FloatAB, dpp_gemm.K1PerDpp>::type; using dpp_input_type =
typename vector_type<ABDataType, dpp_gemm.K1PerDpp>::type;
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
...@@ -320,8 +321,8 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2 ...@@ -320,8 +321,8 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, dpp_gemm.GetRegSizePerDpp())); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, dpp_gemm.GetRegSizePerDpp()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB, using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ABDataType,
FloatAB, ABDataType,
decltype(a_block_desc_m0_m1_m2_k), decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerThread>, Sequence<1, 1, 1, KPerThread>,
...@@ -330,8 +331,8 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2 ...@@ -330,8 +331,8 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
A_K1, A_K1,
A_K1>; A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB, using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<ABDataType,
FloatAB, ABDataType,
decltype(b_block_desc_n0_n1_n2_k), decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<1, 1, 1, KPerThread>, Sequence<1, 1, 1, KPerThread>,
......
...@@ -51,9 +51,9 @@ __global__ void ...@@ -51,9 +51,9 @@ __global__ void
} }
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename ABDataType,
typename FloatAcc, typename AccDataType,
typename FloatC, typename CDataType,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
...@@ -172,9 +172,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp ...@@ -172,9 +172,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
// Argument // Argument
struct Argument : public Problem, public tensor_operation::device::BaseArgument struct Argument : public Problem, public tensor_operation::device::BaseArgument
{ {
__host__ Argument(const FloatAB* p_a_grid_, __host__ Argument(const ABDataType* p_a_grid_,
const FloatAB* p_b_grid_, const ABDataType* p_b_grid_,
FloatC* p_c_grid_, CDataType* p_c_grid_,
index_t M_, index_t M_,
index_t N_, index_t N_,
index_t K_, index_t K_,
...@@ -188,9 +188,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp ...@@ -188,9 +188,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
{ {
} }
const FloatAB* p_a_grid; const ABDataType* p_a_grid;
const FloatAB* p_b_grid; const ABDataType* p_b_grid;
FloatC* p_c_grid; CDataType* p_c_grid;
}; };
using GridwiseGemmPipe = remove_cvref_t< using GridwiseGemmPipe = remove_cvref_t<
...@@ -252,7 +252,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp ...@@ -252,7 +252,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
constexpr auto b_block_space_size_aligned = constexpr auto b_block_space_size_aligned =
math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB); return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(ABDataType);
} }
__host__ static constexpr bool CheckValidity(const Problem& problem) __host__ static constexpr bool CheckValidity(const Problem& problem)
...@@ -347,8 +347,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp ...@@ -347,8 +347,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
using BlockwiseGemm = using BlockwiseGemm =
BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2<BlockSize, BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2<BlockSize,
FloatAB, ABDataType,
FloatAcc, AccDataType,
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1), decltype(b_block_desc_k0_n_k1),
MPerDpp, MPerDpp,
...@@ -430,9 +430,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp ...@@ -430,9 +430,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
typename AGridDesc_K0_M_K1, typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N> typename CGridDesc_M_N>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const ABDataType* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, CDataType* __restrict__ p_c_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
...@@ -488,8 +488,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp ...@@ -488,8 +488,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
Sequence<K0PerBlock, MPerBlock, K1>, Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, ABDataType,
FloatAB, ABDataType,
decltype(a_grid_desc_k0_m_k1), decltype(a_grid_desc_k0_m_k1),
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -518,8 +518,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp ...@@ -518,8 +518,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
Sequence<K0PerBlock, NPerBlock, K1>, Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, ABDataType,
FloatAB, ABDataType,
decltype(b_grid_desc_k0_n_k1), decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc_k0_n_k1), decltype(b_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -548,8 +548,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp ...@@ -548,8 +548,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
// register // register
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2<BlockSize, BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2<BlockSize,
FloatAB, ABDataType,
FloatAcc, AccDataType,
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1), decltype(b_block_desc_k0_n_k1),
MPerDpp, MPerDpp,
...@@ -565,10 +565,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp ...@@ -565,10 +565,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); static_cast<ABDataType*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned, static_cast<ABDataType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_k0_n_k1.GetElementSpaceSize()); b_block_desc_k0_n_k1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
...@@ -642,8 +642,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp ...@@ -642,8 +642,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
make_multi_index(n_thread_data_on_grid)); make_multi_index(n_thread_data_on_grid));
auto c_thread_copy = auto c_thread_copy =
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
FloatC, CDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2), decltype(c_thread_desc_m0_n0_m1_n1_m2_n2),
decltype(c_grid_desc_m0_n0_m1_n1_m2_n2), decltype(c_grid_desc_m0_n0_m1_n1_m2_n2),
CElementwiseOperation, CElementwiseOperation,
......
...@@ -54,18 +54,18 @@ struct dpp_type<DppInstr::dpp8_f16_32x8x2> ...@@ -54,18 +54,18 @@ struct dpp_type<DppInstr::dpp8_f16_32x8x2>
static constexpr index_t n_per_thread = 1; static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2; static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true; static constexpr bool share_a = true;
using base_type = half_t; using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class FloatA, class FloatB, class FloatC> template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{ {
dpp8::DppInstrRunner<m_per_thread, dpp8::DppInstrRunner<m_per_thread,
n_per_thread, n_per_thread,
k_per_dpp, k_per_dpp,
base_type, BaseType,
FloatA, ADataType,
FloatB, BDataType,
FloatC, CDataType,
share_a>{} share_a>{}
.Run(a, b, reg_c); .Run(a, b, reg_c);
} }
...@@ -84,18 +84,18 @@ struct dpp_type<DppInstr::dpp8_f16_8x32x2> ...@@ -84,18 +84,18 @@ struct dpp_type<DppInstr::dpp8_f16_8x32x2>
static constexpr index_t n_per_thread = 1; static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2; static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true; static constexpr bool share_a = true;
using base_type = half_t; using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class FloatA, class FloatB, class FloatC> template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{ {
dpp8::DppInstrRunner<m_per_thread, dpp8::DppInstrRunner<m_per_thread,
n_per_thread, n_per_thread,
k_per_dpp, k_per_dpp,
base_type, BaseType,
FloatA, ADataType,
FloatB, BDataType,
FloatC, CDataType,
share_a>{} share_a>{}
.Run(a, b, reg_c); .Run(a, b, reg_c);
} }
...@@ -114,27 +114,27 @@ struct dpp_type<DppInstr::dpp8_f16_16x16x2> ...@@ -114,27 +114,27 @@ struct dpp_type<DppInstr::dpp8_f16_16x16x2>
static constexpr index_t n_per_thread = 1; static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2; static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true; static constexpr bool share_a = true;
using base_type = half_t; using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class FloatA, class FloatB, class FloatC> template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{ {
dpp8::DppInstrRunner<m_per_thread, dpp8::DppInstrRunner<m_per_thread,
n_per_thread, n_per_thread,
k_per_dpp, k_per_dpp,
base_type, BaseType,
FloatA, ADataType,
FloatB, BDataType,
FloatC, CDataType,
share_a>{} share_a>{}
.Run(a, b, reg_c); .Run(a, b, reg_c);
} }
}; };
template <typename base_type, index_t MPerDpp, index_t NPerDpp> template <typename BaseType, index_t MPerDpp, index_t NPerDpp>
struct DppSelector struct DppSelector
{ {
template <typename base_type_, index_t MPerDpp_, index_t NPerDpp_> template <typename BaseType_, index_t MPerDpp_, index_t NPerDpp_>
static constexpr auto GetDpp(); static constexpr auto GetDpp();
template <> template <>
...@@ -155,7 +155,7 @@ struct DppSelector ...@@ -155,7 +155,7 @@ struct DppSelector
return DppInstr::dpp8_f16_32x8x2; return DppInstr::dpp8_f16_32x8x2;
} }
static constexpr auto selected_dpp = dpp_type<GetDpp<base_type, MPerDpp, NPerDpp>()>{}; static constexpr auto selected_dpp = dpp_type<GetDpp<BaseType, MPerDpp, NPerDpp>()>{};
__host__ __device__ constexpr DppSelector() __host__ __device__ constexpr DppSelector()
{ {
...@@ -200,7 +200,7 @@ struct DppSelector ...@@ -200,7 +200,7 @@ struct DppSelector
static constexpr index_t GetK1PerDpp() { return selected_dpp.k_per_dpp; } static constexpr index_t GetK1PerDpp() { return selected_dpp.k_per_dpp; }
}; };
template <typename base_type, index_t MPerDpp, index_t NPerDpp, index_t KPack> template <typename BaseType, index_t MPerDpp, index_t NPerDpp, index_t KPack>
struct DppGemm struct DppGemm
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -228,13 +228,14 @@ struct DppGemm ...@@ -228,13 +228,14 @@ struct DppGemm
return MPerDpp * NPerDpp / dpp_instr.wave_size; return MPerDpp * NPerDpp / dpp_instr.wave_size;
} }
template <class FloatA, class FloatB, class FloatC> template <class ADataType, class BDataType, class CDataType>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const __device__ void
Run(const ADataType& p_a_wave, const BDataType& p_b_wave, CDataType& p_c_thread) const
{ {
static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value || static_assert(is_same<BaseType, double>::value || is_same<BaseType, float>::value ||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value || is_same<BaseType, half_t>::value || is_same<BaseType, bhalf_t>::value ||
is_same<base_type, int8_t>::value || is_same<base_type, f8_t>::value, is_same<BaseType, int8_t>::value || is_same<BaseType, f8_t>::value,
"base base_type must be double, float, half, bfloat16, and int8_t!"); "base BaseType must be double, float, half, bfloat16, and int8_t!");
static_for<0, KPack / dpp_instr.k_per_dpp, 1>{}([&](auto k) { static_for<0, KPack / dpp_instr.k_per_dpp, 1>{}([&](auto k) {
dpp_instr.template run<MPerDpp, NPerDpp>(p_a_wave[k], p_b_wave[k], p_c_thread); dpp_instr.template run<MPerDpp, NPerDpp>(p_a_wave[k], p_b_wave[k], p_c_thread);
...@@ -305,7 +306,7 @@ struct DppGemm ...@@ -305,7 +306,7 @@ struct DppGemm
return CIndex{m_offset, n_offset}; return CIndex{m_offset, n_offset};
} }
static constexpr auto dpp = DppSelector<base_type, MPerDpp, NPerDpp>{}; static constexpr auto dpp = DppSelector<BaseType, MPerDpp, NPerDpp>{};
static constexpr auto dpp_instr = dpp.selected_dpp; static constexpr auto dpp_instr = dpp.selected_dpp;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment