Commit bf445c31 authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Review: Change names from FloatX to XDataType

parent 0ff1d1f8
......@@ -20,8 +20,8 @@ namespace ck {
* `MPerBlock / (MRepeat * MPerDpp) * NPerBlock / (NRepeat * NPerDpp)` waves.
*/
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename ABDataType,
typename AccDataType,
typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc,
index_t MPerDpp,
......@@ -50,7 +50,7 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto dpp_gemm = DppGemm<FloatAB, MPerDpp, NPerDpp, KPack>{};
static constexpr auto dpp_gemm = DppGemm<ABDataType, MPerDpp, NPerDpp, KPack>{};
static constexpr index_t KPerThread = KPerBlock / dpp_gemm.K0PerDpp;
......@@ -58,7 +58,7 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerDpp);
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
AccDataType,
MRepeat * NRepeat,
dpp_gemm.GetRegSizePerDpp(),
true>
......@@ -260,9 +260,9 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ABDataType>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ABDataType>(
b_thread_desc_.GetElementSpaceSize());
static_for<0, MRepeat, 1>{}([&](auto m0) {
......@@ -284,17 +284,18 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
b_thread_buf);
static_for<0, KPerThread, KPack>{}([&](auto k) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
vector_type<ABDataType, KPack> a_thread_vec;
vector_type<ABDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
a_thread_vec.template AsType<ABDataType>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
b_thread_vec.template AsType<ABDataType>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
});
using dpp_input_type = typename vector_type<FloatAB, dpp_gemm.K1PerDpp>::type;
using dpp_input_type =
typename vector_type<ABDataType, dpp_gemm.K1PerDpp>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
......@@ -320,8 +321,8 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, dpp_gemm.GetRegSizePerDpp()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ABDataType,
ABDataType,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerThread>,
......@@ -330,8 +331,8 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<ABDataType,
ABDataType,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPerThread>,
......
......@@ -51,9 +51,9 @@ __global__ void
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
typename ABDataType,
typename AccDataType,
typename CDataType,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename ALayout,
typename BLayout,
......@@ -172,9 +172,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
// Argument
struct Argument : public Problem, public tensor_operation::device::BaseArgument
{
__host__ Argument(const FloatAB* p_a_grid_,
const FloatAB* p_b_grid_,
FloatC* p_c_grid_,
__host__ Argument(const ABDataType* p_a_grid_,
const ABDataType* p_b_grid_,
CDataType* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
......@@ -188,9 +188,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
{
}
const FloatAB* p_a_grid;
const FloatAB* p_b_grid;
FloatC* p_c_grid;
const ABDataType* p_a_grid;
const ABDataType* p_b_grid;
CDataType* p_c_grid;
};
using GridwiseGemmPipe = remove_cvref_t<
......@@ -252,7 +252,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
constexpr auto b_block_space_size_aligned =
math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB);
return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(ABDataType);
}
__host__ static constexpr bool CheckValidity(const Problem& problem)
......@@ -347,8 +347,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
using BlockwiseGemm =
BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2<BlockSize,
FloatAB,
FloatAcc,
ABDataType,
AccDataType,
decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
MPerDpp,
......@@ -430,9 +430,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
__device__ static void Run(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
CDataType* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
......@@ -488,8 +488,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
ABDataType,
ABDataType,
decltype(a_grid_desc_k0_m_k1),
decltype(a_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder,
......@@ -518,8 +518,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
ABDataType,
ABDataType,
decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder,
......@@ -548,8 +548,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
// register
auto blockwise_gemm =
BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2<BlockSize,
FloatAB,
FloatAcc,
ABDataType,
AccDataType,
decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
MPerDpp,
......@@ -565,10 +565,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
static_cast<ABDataType*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
static_cast<ABDataType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_k0_n_k1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
......@@ -642,8 +642,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
make_multi_index(n_thread_data_on_grid));
auto c_thread_copy =
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
FloatC,
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2),
decltype(c_grid_desc_m0_n0_m1_n1_m2_n2),
CElementwiseOperation,
......
......@@ -54,18 +54,18 @@ struct dpp_type<DppInstr::dpp8_f16_32x8x2>
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using base_type = half_t;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppInstrRunner<m_per_thread,
n_per_thread,
k_per_dpp,
base_type,
FloatA,
FloatB,
FloatC,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
......@@ -84,18 +84,18 @@ struct dpp_type<DppInstr::dpp8_f16_8x32x2>
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using base_type = half_t;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppInstrRunner<m_per_thread,
n_per_thread,
k_per_dpp,
base_type,
FloatA,
FloatB,
FloatC,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
......@@ -114,27 +114,27 @@ struct dpp_type<DppInstr::dpp8_f16_16x16x2>
static constexpr index_t n_per_thread = 1;
static constexpr index_t k_per_dpp = 2;
static constexpr bool share_a = true;
using base_type = half_t;
using BaseType = half_t;
template <index_t MPerDpp, index_t NPerDpp, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
{
dpp8::DppInstrRunner<m_per_thread,
n_per_thread,
k_per_dpp,
base_type,
FloatA,
FloatB,
FloatC,
BaseType,
ADataType,
BDataType,
CDataType,
share_a>{}
.Run(a, b, reg_c);
}
};
template <typename base_type, index_t MPerDpp, index_t NPerDpp>
template <typename BaseType, index_t MPerDpp, index_t NPerDpp>
struct DppSelector
{
template <typename base_type_, index_t MPerDpp_, index_t NPerDpp_>
template <typename BaseType_, index_t MPerDpp_, index_t NPerDpp_>
static constexpr auto GetDpp();
template <>
......@@ -155,7 +155,7 @@ struct DppSelector
return DppInstr::dpp8_f16_32x8x2;
}
static constexpr auto selected_dpp = dpp_type<GetDpp<base_type, MPerDpp, NPerDpp>()>{};
static constexpr auto selected_dpp = dpp_type<GetDpp<BaseType, MPerDpp, NPerDpp>()>{};
__host__ __device__ constexpr DppSelector()
{
......@@ -200,7 +200,7 @@ struct DppSelector
static constexpr index_t GetK1PerDpp() { return selected_dpp.k_per_dpp; }
};
template <typename base_type, index_t MPerDpp, index_t NPerDpp, index_t KPack>
template <typename BaseType, index_t MPerDpp, index_t NPerDpp, index_t KPack>
struct DppGemm
{
static constexpr auto I0 = Number<0>{};
......@@ -228,13 +228,14 @@ struct DppGemm
return MPerDpp * NPerDpp / dpp_instr.wave_size;
}
template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
template <class ADataType, class BDataType, class CDataType>
__device__ void
Run(const ADataType& p_a_wave, const BDataType& p_b_wave, CDataType& p_c_thread) const
{
static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value ||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
is_same<base_type, int8_t>::value || is_same<base_type, f8_t>::value,
"base base_type must be double, float, half, bfloat16, and int8_t!");
static_assert(is_same<BaseType, double>::value || is_same<BaseType, float>::value ||
is_same<BaseType, half_t>::value || is_same<BaseType, bhalf_t>::value ||
is_same<BaseType, int8_t>::value || is_same<BaseType, f8_t>::value,
"base BaseType must be double, float, half, bfloat16, and int8_t!");
static_for<0, KPack / dpp_instr.k_per_dpp, 1>{}([&](auto k) {
dpp_instr.template run<MPerDpp, NPerDpp>(p_a_wave[k], p_b_wave[k], p_c_thread);
......@@ -305,7 +306,7 @@ struct DppGemm
return CIndex{m_offset, n_offset};
}
static constexpr auto dpp = DppSelector<base_type, MPerDpp, NPerDpp>{};
static constexpr auto dpp = DppSelector<BaseType, MPerDpp, NPerDpp>{};
static constexpr auto dpp_instr = dpp.selected_dpp;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment