Commit aaf3d81d authored by Jing Zhang's avatar Jing Zhang
Browse files

clean code

parent 1d6022b1
......@@ -55,34 +55,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
__device__ constexpr auto GetOutputLayout() const { return XdlopsGemm.GetOutputLayout(); }
#if CK_WORKAROUND_SWDEV_241664
template <index_t MRepeats_ = MRepeats, index_t NRepeats_ = NRepeats>
__device__ constexpr auto CreateOutputVecZero() const;
template <>
__device__ constexpr auto CreateOutputVecZero<2, 1>() const
{
return c_vec32_2_2_t::CreateVecZero();
}
template <>
__device__ constexpr auto CreateOutputVecZero<1, 2>() const
{
return c_vec32_2_2_t::CreateVecZero();
}
template <>
__device__ constexpr auto CreateOutputVecZero<1, 1>() const
{
return XdlopsGemm.GetOutputLayout().CreateOutputVecZero();
}
#else
__device__ constexpr auto CreateOutputVecZero() const
{
return XdlopsGemm.GetOutputLayout().CreateOutputVecZero();
}
#endif
__device__ constexpr auto GetNumBlks() const
{
#if CK_WORKAROUND_SWDEV_241664
......
......@@ -210,7 +210,8 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
// get zero-initialized output register of vector type
// auto c_thread_vec = blockwise_gemm.CreateOutputVecZero();
auto c_thread_vec = float_vec128_t{};
constexpr index_t c_thread_size = MPerBlock * NPerBlock / BlockSize;
auto c_thread_vec = GetRegBuffer<AccFloat, c_thread_size>();
// preload data into LDS
{
......@@ -325,7 +326,7 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
m_thread_data_on_global % (M2 * M1) / M2,
m_thread_data_on_global % M2,
n_thread_data_on_global))
.Run(c_thread_vec.At(Number<16>{})[Number<blk_id>{}], p_c_global);
.Store(c_thread_vec.At(Number<M0 * M2>{})[Number<blk_id>{}], p_c_global);
});
}
}
......
......@@ -84,9 +84,9 @@ struct ThreadwiseGenericTensorSliceCopy_v5
}
template <typename DstData, typename SrcData>
__device__ static void load_data(DstData& dst, const SrcData* p_src, index_t src_offset)
__device__ static auto load_data(const SrcData* p_src, index_t src_offset)
{
dst = *reinterpret_cast<const DstData*>(&p_src[src_offset]);
return *reinterpret_cast<const DstData*>(&p_src[src_offset]);
}
template <typename DstData, typename SrcData>
......@@ -104,9 +104,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
template <typename SrcCoord>
__device__ static auto run(const float* p_src, const SrcCoord src_coord_begin)
{
float r;
load_data(r, p_src, src_coord_begin.GetOffset());
return r;
return load_data<float>(p_src, src_coord_begin.GetOffset());
}
};
......@@ -116,9 +114,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
template <typename SrcCoord>
__device__ static auto run(const float* p_src, const SrcCoord src_coord_begin)
{
float2_t r;
load_data(r, p_src, src_coord_begin.GetOffset());
return r;
return load_data<float2_t>(p_src, src_coord_begin.GetOffset());
}
};
......@@ -128,9 +124,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
template <typename SrcCoord>
__device__ static auto run(const float* p_src, const SrcCoord src_coord_begin)
{
float4_t r;
load_data(r, p_src, src_coord_begin.GetOffset());
return r;
return load_data<float4_t>(p_src, src_coord_begin.GetOffset());
}
};
......@@ -237,7 +231,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
}
template <typename SrcData, typename DstData>
__device__ void Run(SrcData src, DstData* p_dst)
__device__ void Store(SrcData src, DstData* p_dst)
{
constexpr auto vector_access_dim = Number<DstVectorWriteDim>{};
......
......@@ -535,8 +535,7 @@ template <mfma_instr instr,
index_t MPerXdlops_,
index_t NPerXdlops_,
index_t MRepeats_,
index_t NRepeats_,
class OutputVecType_>
index_t NRepeats_>
struct xdlops_info
{
static constexpr auto mfma_type = mfma_info<instr>{};
......@@ -552,8 +551,6 @@ struct xdlops_info
{
return (mfma_type.num_output_blks == 1) && (mfma_type.num_input_blks > 1);
}
static constexpr auto OutputVecType = OutputVecType_{};
};
template <class data_type,
......@@ -635,10 +632,9 @@ struct XdlopsGemm_t
}
}
}
}).Else([&](auto) {
})
.Else([&](auto) {
static_if<IsABroadcast>{}([&](auto) {
for(index_t m_i = 0; m_i < MRepeats; ++m_i)
{
for(index_t n_i = 0; n_i < NRepeats; ++n_i)
......@@ -651,31 +647,33 @@ struct XdlopsGemm_t
for(index_t n = 0; n < mfma_type.num_input_blks; ++n)
{
index_t a_off = k * M + b * mfma_type.m + MPerXdlops * m_i;
index_t b_off =
k * N + n * mfma_type.num_threads_blk + NPerXdlops * n_i;
index_t c_off = n * mfma_type.num_regs_blk +
index_t b_off = k * N + n * mfma_type.num_threads_blk +
NPerXdlops * n_i;
index_t c_off =
n * mfma_type.num_regs_blk +
b * mfma_type.num_regs_xdlops +
(NRepeats * m_i + n_i) * GetRegSizePerXdlops();
for(index_t m = 0; m < mfma_type.num_regs_blk; ++m)
{
index_t aindex =
m % mfma_type.group_size +
index_t aindex = m % mfma_type.group_size +
blk_id * mfma_type.group_size +
m / mfma_type.group_size *
(mfma_type.group_size * mfma_type.num_input_blks);
(mfma_type.group_size *
mfma_type.num_input_blks);
index_t bindex = blk_td;
p_c_thread.n[m + c_off] +=
inner_product_with_conversion<float>{}(
p_a_wave[aindex + a_off], p_b_wave[bindex + b_off]);
p_a_wave[aindex + a_off],
p_b_wave[bindex + b_off]);
}
}
}
}
}
}
}).Else([&](auto) {
})
.Else([&](auto) {
// BBroadcast
for(index_t k = 0; k < K; ++k)
{
......@@ -691,11 +689,13 @@ struct XdlopsGemm_t
for(index_t m = 0; m < mfma_type.num_regs_blk; ++m)
{
index_t aindex =
m % mfma_type.group_size + blk_id * mfma_type.group_size +
m % mfma_type.group_size +
blk_id * mfma_type.group_size +
m / mfma_type.group_size *
(mfma_type.group_size * mfma_type.num_input_blks);
index_t bindex = blk_td;
p_c_thread.n[m + c_off] += inner_product_with_conversion<float>{}(
p_c_thread.n[m + c_off] +=
inner_product_with_conversion<float>{}(
p_a_wave[aindex + a_off], p_b_wave[bindex + b_off]);
}
}
......@@ -745,7 +745,6 @@ struct XdlopsGemm_t
constexpr index_t BStride = K * KRepeats;
static_if<!IsKReduction>{}([&](auto) {
for(index_t m_i = 0; m_i < MRepeats; ++m_i)
for(index_t k_i = 0; k_i < K; ++k_i)
a[k_i + m_i * K] = p_a_wave[k_i * M + laneId + MPerXdlops * m_i];
......@@ -765,9 +764,8 @@ struct XdlopsGemm_t
BStride>(
&pa[k_i * mfma_type.k_base], &pb[k_i * mfma_type.k_base], p_c_thread);
}
}).Else([&](auto) {
})
.Else([&](auto) {
const index_t blk_id = laneId / mfma_type.num_threads_blk;
const index_t blk_td = laneId % mfma_type.num_threads_blk;
......@@ -784,12 +782,12 @@ struct XdlopsGemm_t
for(index_t k_i = 0; k_i < K; k_i += mfma_type.num_input_blks)
{
for(index_t i = 0; i < KRepeats; ++i)
p_c_thread = mfma_type.template run<MPerXdlops, NPerXdlops, AStride, BStride>(
p_c_thread =
mfma_type.template run<MPerXdlops, NPerXdlops, AStride, BStride>(
&pa[(k_i * KRepeats + i) * mfma_type.k_base],
&pb[(k_i * KRepeats + i) * mfma_type.k_base],
p_c_thread);
}
});
#endif
......@@ -837,199 +835,199 @@ struct XdlopsGemm_t
template <>
static constexpr auto GetXdlopsInfo<float, 128, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64, 2, 1, c_vec32_4_t>{};
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64, 2, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 64, 128>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64, 1, 2, c_vec32_4_t>{};
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64, 1, 2>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 64, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64, 1, 1, c_vec32_2_t>{};
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64, 1, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 64, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 32, 1, 1, c_vec32_1_t>{};
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 32, 1, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 32, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 32, 64, 1, 1, c_vec32_1_t>{};
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 32, 64, 1, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 64, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x1xf32, 64, 16, 1, 1, c_vec16_1_t>{};
return xdlops_info<mfma_instr::mfma_f32_16x16x1xf32, 64, 16, 1, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 16, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x1xf32, 16, 64, 1, 1, c_vec16_1_t>{};
return xdlops_info<mfma_instr::mfma_f32_16x16x1xf32, 16, 64, 1, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 8, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 8, 64, 1, 1, c_vec4_2_t>{};
return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 8, 64, 1, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 4, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 4, 64, 1, 1, c_vec4_1_t>{};
return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 4, 64, 1, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 32, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2xf32, 32, 32, 1, 1, c_vec16_1_t>{};
return xdlops_info<mfma_instr::mfma_f32_32x32x2xf32, 32, 32, 1, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 16, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x4xf32, 16, 16, 1, 1, c_vec4_1_t>{};
return xdlops_info<mfma_instr::mfma_f32_16x16x4xf32, 16, 16, 1, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 128, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64, 2, 1, c_vec32_4_t>{};
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64, 2, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 128>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64, 1, 2, c_vec32_4_t>{};
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64, 1, 2>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64, 1, 1, c_vec32_2_t>{};
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64, 1, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 32, 1, 1, c_vec32_1_t>{};
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 32, 1, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 32, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 32, 64, 1, 1, c_vec32_1_t>{};
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 32, 64, 1, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x4f16, 64, 16, 1, 1, c_vec16_1_t>{};
return xdlops_info<mfma_instr::mfma_f32_16x16x4f16, 64, 16, 1, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 16, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x4f16, 16, 64, 1, 1, c_vec16_1_t>{};
return xdlops_info<mfma_instr::mfma_f32_16x16x4f16, 16, 64, 1, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 8, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 8, 64, 1, 1, c_vec4_2_t>{};
return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 8, 64, 1, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 4, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 4, 64, 1, 1, c_vec4_1_t>{};
return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 4, 64, 1, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 32, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x8f16, 32, 32, 1, 1, c_vec16_1_t>{};
return xdlops_info<mfma_instr::mfma_f32_32x32x8f16, 32, 32, 1, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 16, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x16f16, 16, 16, 1, 1, c_vec4_1_t>{};
return xdlops_info<mfma_instr::mfma_f32_16x16x16f16, 16, 16, 1, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 128, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 2, 1, c_vec32_4_t>{};
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 2, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 128>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 1, 2, c_vec32_4_t>{};
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 1, 2>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 1, 1, c_vec32_2_t>{};
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 1, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 32, 1, 1, c_vec32_1_t>{};
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 32, 1, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 32, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 32, 64, 1, 1, c_vec32_1_t>{};
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 32, 64, 1, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x2bf16, 64, 16, 1, 1, c_vec16_1_t>{};
return xdlops_info<mfma_instr::mfma_f32_16x16x2bf16, 64, 16, 1, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 16, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x2bf16, 16, 64, 1, 1, c_vec16_1_t>{};
return xdlops_info<mfma_instr::mfma_f32_16x16x2bf16, 16, 64, 1, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 8, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x2bf16, 8, 64, 1, 1, c_vec4_2_t>{};
return xdlops_info<mfma_instr::mfma_f32_4x4x2bf16, 8, 64, 1, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 4, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x2bf16, 4, 64, 1, 1, c_vec4_1_t>{};
return xdlops_info<mfma_instr::mfma_f32_4x4x2bf16, 4, 64, 1, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 32, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4bf16, 32, 32, 1, 1, c_vec16_1_t>{};
return xdlops_info<mfma_instr::mfma_f32_32x32x4bf16, 32, 32, 1, 1>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 16, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x8bf16, 16, 16, 1, 1, c_vec4_1_t>{};
return xdlops_info<mfma_instr::mfma_f32_16x16x8bf16, 16, 16, 1, 1>{};
}
static constexpr index_t MRepeats = GetXdlopsInfo().MRepeats;
......@@ -1055,11 +1053,6 @@ struct XdlopsGemm_t
{
return GetNumBlksPerXdlops() * MRepeats * NRepeats;
}
__device__ static constexpr auto CreateOutputVecZero()
{
return GetXdlopsInfo().OutputVecType.CreateVecZero();
}
};
__device__ static constexpr auto GetOutputLayout() { return OutputLayout{}; }
......
......@@ -95,10 +95,10 @@ struct intrin_mfma_f32_32x32x1f32<64, 64, AStride, BStride>
{
__device__ static float_vec64_t run(const float* reg_a, const float* reg_b, float_vec64_t reg_c)
{
reg_c.At(Number<32>{})(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a[0], reg_b[0], reg_c.At(Number<32>{})[Number<0>{}], 1, 0, 0);
reg_c.At(Number<32>{})(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a[0], reg_b[0], reg_c.At(Number<32>{})[Number<1>{}], 1, 1, 0);
reg_c.v32(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a[0], reg_b[0], reg_c.v32[Number<0>{}], 1, 0, 0);
reg_c.v32(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a[0], reg_b[0], reg_c.v32[Number<1>{}], 1, 1, 0);
return reg_c;
}
};
......
......@@ -186,7 +186,8 @@ union float_vec32_t
union float_vec64_t
{
StaticallyIndexedArray<float, 64> s1;
StaticallyIndexedArray<float32_t, 2> s32;
StaticallyIndexedArray<float_vec32_t, 2> s32;
StaticallyIndexedArray<float32_t, 2> v32;
StaticallyIndexedArray<float64_t, 1> s64;
__host__ __device__ constexpr float_vec64_t() {}
......@@ -210,7 +211,7 @@ union float_vec128_t
{
StaticallyIndexedArray<float, 64> s1;
StaticallyIndexedArray<float_vec16_t, 8> s16;
StaticallyIndexedArray<float32_t, 4> s32;
StaticallyIndexedArray<float_vec32_t, 4> s32;
StaticallyIndexedArray<float_vec64_t, 2> s64;
StaticallyIndexedArray<float128_t, 1> s128;
__host__ __device__ constexpr float_vec128_t() {}
......@@ -264,6 +265,18 @@ constexpr auto GetRegBuffer<float, 16>()
return float_vec16_t{};
}
template <>
constexpr auto GetRegBuffer<float, 64>()
{
return float_vec64_t{};
}
template <>
constexpr auto GetRegBuffer<float, 128>()
{
return float_vec128_t{};
}
struct c_vec32_4_t
{
union VecType
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment