Commit aaf3d81d authored by Jing Zhang's avatar Jing Zhang
Browse files

clean code

parent 1d6022b1
...@@ -55,34 +55,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops ...@@ -55,34 +55,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
__device__ constexpr auto GetOutputLayout() const { return XdlopsGemm.GetOutputLayout(); } __device__ constexpr auto GetOutputLayout() const { return XdlopsGemm.GetOutputLayout(); }
#if CK_WORKAROUND_SWDEV_241664
template <index_t MRepeats_ = MRepeats, index_t NRepeats_ = NRepeats>
__device__ constexpr auto CreateOutputVecZero() const;
template <>
__device__ constexpr auto CreateOutputVecZero<2, 1>() const
{
return c_vec32_2_2_t::CreateVecZero();
}
template <>
__device__ constexpr auto CreateOutputVecZero<1, 2>() const
{
return c_vec32_2_2_t::CreateVecZero();
}
template <>
__device__ constexpr auto CreateOutputVecZero<1, 1>() const
{
return XdlopsGemm.GetOutputLayout().CreateOutputVecZero();
}
#else
__device__ constexpr auto CreateOutputVecZero() const
{
return XdlopsGemm.GetOutputLayout().CreateOutputVecZero();
}
#endif
__device__ constexpr auto GetNumBlks() const __device__ constexpr auto GetNumBlks() const
{ {
#if CK_WORKAROUND_SWDEV_241664 #if CK_WORKAROUND_SWDEV_241664
......
...@@ -210,7 +210,8 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2 ...@@ -210,7 +210,8 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
// get zero-initialized output register of vector type // get zero-initialized output register of vector type
// auto c_thread_vec = blockwise_gemm.CreateOutputVecZero(); // auto c_thread_vec = blockwise_gemm.CreateOutputVecZero();
auto c_thread_vec = float_vec128_t{}; constexpr index_t c_thread_size = MPerBlock * NPerBlock / BlockSize;
auto c_thread_vec = GetRegBuffer<AccFloat, c_thread_size>();
// preload data into LDS // preload data into LDS
{ {
...@@ -325,7 +326,7 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2 ...@@ -325,7 +326,7 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
m_thread_data_on_global % (M2 * M1) / M2, m_thread_data_on_global % (M2 * M1) / M2,
m_thread_data_on_global % M2, m_thread_data_on_global % M2,
n_thread_data_on_global)) n_thread_data_on_global))
.Run(c_thread_vec.At(Number<16>{})[Number<blk_id>{}], p_c_global); .Store(c_thread_vec.At(Number<M0 * M2>{})[Number<blk_id>{}], p_c_global);
}); });
} }
} }
......
...@@ -84,9 +84,9 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -84,9 +84,9 @@ struct ThreadwiseGenericTensorSliceCopy_v5
} }
template <typename DstData, typename SrcData> template <typename DstData, typename SrcData>
__device__ static void load_data(DstData& dst, const SrcData* p_src, index_t src_offset) __device__ static auto load_data(const SrcData* p_src, index_t src_offset)
{ {
dst = *reinterpret_cast<const DstData*>(&p_src[src_offset]); return *reinterpret_cast<const DstData*>(&p_src[src_offset]);
} }
template <typename DstData, typename SrcData> template <typename DstData, typename SrcData>
...@@ -104,9 +104,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -104,9 +104,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
template <typename SrcCoord> template <typename SrcCoord>
__device__ static auto run(const float* p_src, const SrcCoord src_coord_begin) __device__ static auto run(const float* p_src, const SrcCoord src_coord_begin)
{ {
float r; return load_data<float>(p_src, src_coord_begin.GetOffset());
load_data(r, p_src, src_coord_begin.GetOffset());
return r;
} }
}; };
...@@ -116,9 +114,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -116,9 +114,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
template <typename SrcCoord> template <typename SrcCoord>
__device__ static auto run(const float* p_src, const SrcCoord src_coord_begin) __device__ static auto run(const float* p_src, const SrcCoord src_coord_begin)
{ {
float2_t r; return load_data<float2_t>(p_src, src_coord_begin.GetOffset());
load_data(r, p_src, src_coord_begin.GetOffset());
return r;
} }
}; };
...@@ -128,9 +124,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -128,9 +124,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
template <typename SrcCoord> template <typename SrcCoord>
__device__ static auto run(const float* p_src, const SrcCoord src_coord_begin) __device__ static auto run(const float* p_src, const SrcCoord src_coord_begin)
{ {
float4_t r; return load_data<float4_t>(p_src, src_coord_begin.GetOffset());
load_data(r, p_src, src_coord_begin.GetOffset());
return r;
} }
}; };
...@@ -237,7 +231,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5 ...@@ -237,7 +231,7 @@ struct ThreadwiseGenericTensorSliceCopy_v5
} }
template <typename SrcData, typename DstData> template <typename SrcData, typename DstData>
__device__ void Run(SrcData src, DstData* p_dst) __device__ void Store(SrcData src, DstData* p_dst)
{ {
constexpr auto vector_access_dim = Number<DstVectorWriteDim>{}; constexpr auto vector_access_dim = Number<DstVectorWriteDim>{};
......
...@@ -15,13 +15,13 @@ enum struct mfma_instr ...@@ -15,13 +15,13 @@ enum struct mfma_instr
mfma_f32_4x4x1xf32, mfma_f32_4x4x1xf32,
mfma_f32_32x32x2xf32, // k reduction mfma_f32_32x32x2xf32, // k reduction
mfma_f32_16x16x4xf32, // k reduction mfma_f32_16x16x4xf32, // k reduction
// fp16 // fp16
mfma_f32_32x32x4f16, mfma_f32_32x32x4f16,
mfma_f32_16x16x4f16, mfma_f32_16x16x4f16,
mfma_f32_4x4x4f16, mfma_f32_4x4x4f16,
mfma_f32_32x32x8f16, // k reduction mfma_f32_32x32x8f16, // k reduction
mfma_f32_16x16x16f16, // k reduction mfma_f32_16x16x16f16, // k reduction
// bfp16 // bfp16
mfma_f32_32x32x2bf16, mfma_f32_32x32x2bf16,
mfma_f32_16x16x2bf16, mfma_f32_16x16x2bf16,
mfma_f32_4x4x2bf16, mfma_f32_4x4x2bf16,
...@@ -535,8 +535,7 @@ template <mfma_instr instr, ...@@ -535,8 +535,7 @@ template <mfma_instr instr,
index_t MPerXdlops_, index_t MPerXdlops_,
index_t NPerXdlops_, index_t NPerXdlops_,
index_t MRepeats_, index_t MRepeats_,
index_t NRepeats_, index_t NRepeats_>
class OutputVecType_>
struct xdlops_info struct xdlops_info
{ {
static constexpr auto mfma_type = mfma_info<instr>{}; static constexpr auto mfma_type = mfma_info<instr>{};
...@@ -552,8 +551,6 @@ struct xdlops_info ...@@ -552,8 +551,6 @@ struct xdlops_info
{ {
return (mfma_type.num_output_blks == 1) && (mfma_type.num_input_blks > 1); return (mfma_type.num_output_blks == 1) && (mfma_type.num_input_blks > 1);
} }
static constexpr auto OutputVecType = OutputVecType_{};
}; };
template <class data_type, template <class data_type,
...@@ -635,27 +632,59 @@ struct XdlopsGemm_t ...@@ -635,27 +632,59 @@ struct XdlopsGemm_t
} }
} }
} }
})
}).Else([&](auto) { .Else([&](auto) {
static_if<IsABroadcast>{}([&](auto) { static_if<IsABroadcast>{}([&](auto) {
for(index_t m_i = 0; m_i < MRepeats; ++m_i)
for(index_t m_i = 0; m_i < MRepeats; ++m_i)
{
for(index_t n_i = 0; n_i < NRepeats; ++n_i)
{ {
// ABroadcast for(index_t n_i = 0; n_i < NRepeats; ++n_i)
{
// ABroadcast
for(index_t k = 0; k < K; ++k)
{
for(index_t b = 0; b < MPerXdlops / mfma_type.m; ++b)
{
for(index_t n = 0; n < mfma_type.num_input_blks; ++n)
{
index_t a_off = k * M + b * mfma_type.m + MPerXdlops * m_i;
index_t b_off = k * N + n * mfma_type.num_threads_blk +
NPerXdlops * n_i;
index_t c_off =
n * mfma_type.num_regs_blk +
b * mfma_type.num_regs_xdlops +
(NRepeats * m_i + n_i) * GetRegSizePerXdlops();
for(index_t m = 0; m < mfma_type.num_regs_blk; ++m)
{
index_t aindex = m % mfma_type.group_size +
blk_id * mfma_type.group_size +
m / mfma_type.group_size *
(mfma_type.group_size *
mfma_type.num_input_blks);
index_t bindex = blk_td;
p_c_thread.n[m + c_off] +=
inner_product_with_conversion<float>{}(
p_a_wave[aindex + a_off],
p_b_wave[bindex + b_off]);
}
}
}
}
}
}
})
.Else([&](auto) {
// BBroadcast
for(index_t k = 0; k < K; ++k) for(index_t k = 0; k < K; ++k)
{ {
for(index_t b = 0; b < MPerXdlops / mfma_type.m; ++b) for(index_t b = 0; b < NPerXdlops / mfma_type.n; ++b)
{ {
for(index_t n = 0; n < mfma_type.num_input_blks; ++n) for(index_t n = 0; n < mfma_type.num_input_blks; ++n)
{ {
index_t a_off = k * M + b * mfma_type.m + MPerXdlops * m_i; index_t a_off = k * M + n * mfma_type.m;
index_t b_off = index_t b_off = k * N + b * mfma_type.n;
k * N + n * mfma_type.num_threads_blk + NPerXdlops * n_i; index_t c_off =
index_t c_off = n * mfma_type.num_regs_blk + n * mfma_type.num_regs_blk + b * mfma_type.num_regs_xdlops;
b * mfma_type.num_regs_xdlops +
(NRepeats * m_i + n_i) * GetRegSizePerXdlops();
for(index_t m = 0; m < mfma_type.num_regs_blk; ++m) for(index_t m = 0; m < mfma_type.num_regs_blk; ++m)
{ {
...@@ -672,37 +701,8 @@ struct XdlopsGemm_t ...@@ -672,37 +701,8 @@ struct XdlopsGemm_t
} }
} }
} }
} });
}
}).Else([&](auto) {
// BBroadcast
for(index_t k = 0; k < K; ++k)
{
for(index_t b = 0; b < NPerXdlops / mfma_type.n; ++b)
{
for(index_t n = 0; n < mfma_type.num_input_blks; ++n)
{
index_t a_off = k * M + n * mfma_type.m;
index_t b_off = k * N + b * mfma_type.n;
index_t c_off =
n * mfma_type.num_regs_blk + b * mfma_type.num_regs_xdlops;
for(index_t m = 0; m < mfma_type.num_regs_blk; ++m)
{
index_t aindex =
m % mfma_type.group_size + blk_id * mfma_type.group_size +
m / mfma_type.group_size *
(mfma_type.group_size * mfma_type.num_input_blks);
index_t bindex = blk_td;
p_c_thread.n[m + c_off] += inner_product_with_conversion<float>{}(
p_a_wave[aindex + a_off], p_b_wave[bindex + b_off]);
}
}
}
}
}); });
});
return p_c_thread; return p_c_thread;
} }
...@@ -745,13 +745,12 @@ struct XdlopsGemm_t ...@@ -745,13 +745,12 @@ struct XdlopsGemm_t
constexpr index_t BStride = K * KRepeats; constexpr index_t BStride = K * KRepeats;
static_if<!IsKReduction>{}([&](auto) { static_if<!IsKReduction>{}([&](auto) {
for(index_t m_i = 0; m_i < MRepeats; ++m_i) for(index_t m_i = 0; m_i < MRepeats; ++m_i)
for(index_t k_i = 0; k_i < K; ++k_i) for(index_t k_i = 0; k_i < K; ++k_i)
a[k_i + m_i * K] = p_a_wave[k_i * M + laneId + MPerXdlops * m_i]; a[k_i + m_i * K] = p_a_wave[k_i * M + laneId + MPerXdlops * m_i];
for(index_t n_i = 0; n_i < NRepeats; ++n_i) for(index_t n_i = 0; n_i < NRepeats; ++n_i)
for(index_t k_i = 0; k_i < K; ++k_i) for(index_t k_i = 0; k_i < K; ++k_i)
b[k_i + n_i * K] = p_b_wave[k_i * N + laneId + NPerXdlops * n_i]; b[k_i + n_i * K] = p_b_wave[k_i * N + laneId + NPerXdlops * n_i];
#if CK_WORKAROUND_SWDEV_229564 #if CK_WORKAROUND_SWDEV_229564
...@@ -765,32 +764,31 @@ struct XdlopsGemm_t ...@@ -765,32 +764,31 @@ struct XdlopsGemm_t
BStride>( BStride>(
&pa[k_i * mfma_type.k_base], &pb[k_i * mfma_type.k_base], p_c_thread); &pa[k_i * mfma_type.k_base], &pb[k_i * mfma_type.k_base], p_c_thread);
} }
})
.Else([&](auto) {
const index_t blk_id = laneId / mfma_type.num_threads_blk;
const index_t blk_td = laneId % mfma_type.num_threads_blk;
}).Else([&](auto) { // load into registers
for(index_t k_i = 0; k_i < K; k_i += mfma_type.num_input_blks)
const index_t blk_id = laneId / mfma_type.num_threads_blk; {
const index_t blk_td = laneId % mfma_type.num_threads_blk; a[k_i] = p_a_wave[(k_i + blk_id) * M + blk_td];
b[k_i] = p_b_wave[(k_i + blk_id) * N + blk_td];
// load into registers }
for(index_t k_i = 0; k_i < K; k_i += mfma_type.num_input_blks)
{
a[k_i] = p_a_wave[(k_i + blk_id) * M + blk_td];
b[k_i] = p_b_wave[(k_i + blk_id) * N + blk_td];
}
#if CK_WORKAROUND_SWDEV_229564 #if CK_WORKAROUND_SWDEV_229564
#pragma unroll #pragma unroll
#endif #endif
for(index_t k_i = 0; k_i < K; k_i += mfma_type.num_input_blks) for(index_t k_i = 0; k_i < K; k_i += mfma_type.num_input_blks)
{ {
for(index_t i = 0; i < KRepeats; ++i) for(index_t i = 0; i < KRepeats; ++i)
p_c_thread = mfma_type.template run<MPerXdlops, NPerXdlops, AStride, BStride>( p_c_thread =
&pa[(k_i * KRepeats + i) * mfma_type.k_base], mfma_type.template run<MPerXdlops, NPerXdlops, AStride, BStride>(
&pb[(k_i * KRepeats + i) * mfma_type.k_base], &pa[(k_i * KRepeats + i) * mfma_type.k_base],
p_c_thread); &pb[(k_i * KRepeats + i) * mfma_type.k_base],
} p_c_thread);
}
}); });
#endif #endif
return p_c_thread; return p_c_thread;
...@@ -837,199 +835,199 @@ struct XdlopsGemm_t ...@@ -837,199 +835,199 @@ struct XdlopsGemm_t
template <> template <>
static constexpr auto GetXdlopsInfo<float, 128, 64>() static constexpr auto GetXdlopsInfo<float, 128, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64, 2, 1, c_vec32_4_t>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64, 2, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 64, 128>() static constexpr auto GetXdlopsInfo<float, 64, 128>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64, 1, 2, c_vec32_4_t>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64, 1, 2>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 64, 64>() static constexpr auto GetXdlopsInfo<float, 64, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64, 1, 1, c_vec32_2_t>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64, 1, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 64, 32>() static constexpr auto GetXdlopsInfo<float, 64, 32>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 32, 1, 1, c_vec32_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 32, 1, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 32, 64>() static constexpr auto GetXdlopsInfo<float, 32, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 32, 64, 1, 1, c_vec32_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 32, 64, 1, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 64, 16>() static constexpr auto GetXdlopsInfo<float, 64, 16>()
{ {
return xdlops_info<mfma_instr::mfma_f32_16x16x1xf32, 64, 16, 1, 1, c_vec16_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_16x16x1xf32, 64, 16, 1, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 16, 64>() static constexpr auto GetXdlopsInfo<float, 16, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_16x16x1xf32, 16, 64, 1, 1, c_vec16_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_16x16x1xf32, 16, 64, 1, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 8, 64>() static constexpr auto GetXdlopsInfo<float, 8, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 8, 64, 1, 1, c_vec4_2_t>{}; return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 8, 64, 1, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 4, 64>() static constexpr auto GetXdlopsInfo<float, 4, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 4, 64, 1, 1, c_vec4_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 4, 64, 1, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 32, 32>() static constexpr auto GetXdlopsInfo<float, 32, 32>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x2xf32, 32, 32, 1, 1, c_vec16_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x2xf32, 32, 32, 1, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 16, 16>() static constexpr auto GetXdlopsInfo<float, 16, 16>()
{ {
return xdlops_info<mfma_instr::mfma_f32_16x16x4xf32, 16, 16, 1, 1, c_vec4_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_16x16x4xf32, 16, 16, 1, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<half_t, 128, 64>() static constexpr auto GetXdlopsInfo<half_t, 128, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64, 2, 1, c_vec32_4_t>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64, 2, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 128>() static constexpr auto GetXdlopsInfo<half_t, 64, 128>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64, 1, 2, c_vec32_4_t>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64, 1, 2>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 64>() static constexpr auto GetXdlopsInfo<half_t, 64, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64, 1, 1, c_vec32_2_t>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64, 1, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 32>() static constexpr auto GetXdlopsInfo<half_t, 64, 32>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 32, 1, 1, c_vec32_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 32, 1, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<half_t, 32, 64>() static constexpr auto GetXdlopsInfo<half_t, 32, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 32, 64, 1, 1, c_vec32_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 32, 64, 1, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 16>() static constexpr auto GetXdlopsInfo<half_t, 64, 16>()
{ {
return xdlops_info<mfma_instr::mfma_f32_16x16x4f16, 64, 16, 1, 1, c_vec16_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_16x16x4f16, 64, 16, 1, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<half_t, 16, 64>() static constexpr auto GetXdlopsInfo<half_t, 16, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_16x16x4f16, 16, 64, 1, 1, c_vec16_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_16x16x4f16, 16, 64, 1, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<half_t, 8, 64>() static constexpr auto GetXdlopsInfo<half_t, 8, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 8, 64, 1, 1, c_vec4_2_t>{}; return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 8, 64, 1, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<half_t, 4, 64>() static constexpr auto GetXdlopsInfo<half_t, 4, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 4, 64, 1, 1, c_vec4_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 4, 64, 1, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<half_t, 32, 32>() static constexpr auto GetXdlopsInfo<half_t, 32, 32>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x8f16, 32, 32, 1, 1, c_vec16_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x8f16, 32, 32, 1, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<half_t, 16, 16>() static constexpr auto GetXdlopsInfo<half_t, 16, 16>()
{ {
return xdlops_info<mfma_instr::mfma_f32_16x16x16f16, 16, 16, 1, 1, c_vec4_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_16x16x16f16, 16, 16, 1, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<ushort, 128, 64>() static constexpr auto GetXdlopsInfo<ushort, 128, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 2, 1, c_vec32_4_t>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 2, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 128>() static constexpr auto GetXdlopsInfo<ushort, 64, 128>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 1, 2, c_vec32_4_t>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 1, 2>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 64>() static constexpr auto GetXdlopsInfo<ushort, 64, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 1, 1, c_vec32_2_t>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 1, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 32>() static constexpr auto GetXdlopsInfo<ushort, 64, 32>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 32, 1, 1, c_vec32_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 32, 1, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<ushort, 32, 64>() static constexpr auto GetXdlopsInfo<ushort, 32, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 32, 64, 1, 1, c_vec32_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 32, 64, 1, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 16>() static constexpr auto GetXdlopsInfo<ushort, 64, 16>()
{ {
return xdlops_info<mfma_instr::mfma_f32_16x16x2bf16, 64, 16, 1, 1, c_vec16_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_16x16x2bf16, 64, 16, 1, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<ushort, 16, 64>() static constexpr auto GetXdlopsInfo<ushort, 16, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_16x16x2bf16, 16, 64, 1, 1, c_vec16_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_16x16x2bf16, 16, 64, 1, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<ushort, 8, 64>() static constexpr auto GetXdlopsInfo<ushort, 8, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_4x4x2bf16, 8, 64, 1, 1, c_vec4_2_t>{}; return xdlops_info<mfma_instr::mfma_f32_4x4x2bf16, 8, 64, 1, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<ushort, 4, 64>() static constexpr auto GetXdlopsInfo<ushort, 4, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_4x4x2bf16, 4, 64, 1, 1, c_vec4_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_4x4x2bf16, 4, 64, 1, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<ushort, 32, 32>() static constexpr auto GetXdlopsInfo<ushort, 32, 32>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x4bf16, 32, 32, 1, 1, c_vec16_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x4bf16, 32, 32, 1, 1>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<ushort, 16, 16>() static constexpr auto GetXdlopsInfo<ushort, 16, 16>()
{ {
return xdlops_info<mfma_instr::mfma_f32_16x16x8bf16, 16, 16, 1, 1, c_vec4_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_16x16x8bf16, 16, 16, 1, 1>{};
} }
static constexpr index_t MRepeats = GetXdlopsInfo().MRepeats; static constexpr index_t MRepeats = GetXdlopsInfo().MRepeats;
...@@ -1055,11 +1053,6 @@ struct XdlopsGemm_t ...@@ -1055,11 +1053,6 @@ struct XdlopsGemm_t
{ {
return GetNumBlksPerXdlops() * MRepeats * NRepeats; return GetNumBlksPerXdlops() * MRepeats * NRepeats;
} }
__device__ static constexpr auto CreateOutputVecZero()
{
return GetXdlopsInfo().OutputVecType.CreateVecZero();
}
}; };
__device__ static constexpr auto GetOutputLayout() { return OutputLayout{}; } __device__ static constexpr auto GetOutputLayout() { return OutputLayout{}; }
......
...@@ -95,10 +95,10 @@ struct intrin_mfma_f32_32x32x1f32<64, 64, AStride, BStride> ...@@ -95,10 +95,10 @@ struct intrin_mfma_f32_32x32x1f32<64, 64, AStride, BStride>
{ {
__device__ static float_vec64_t run(const float* reg_a, const float* reg_b, float_vec64_t reg_c) __device__ static float_vec64_t run(const float* reg_a, const float* reg_b, float_vec64_t reg_c)
{ {
reg_c.At(Number<32>{})(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32( reg_c.v32(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a[0], reg_b[0], reg_c.At(Number<32>{})[Number<0>{}], 1, 0, 0); reg_a[0], reg_b[0], reg_c.v32[Number<0>{}], 1, 0, 0);
reg_c.At(Number<32>{})(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32( reg_c.v32(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a[0], reg_b[0], reg_c.At(Number<32>{})[Number<1>{}], 1, 1, 0); reg_a[0], reg_b[0], reg_c.v32[Number<1>{}], 1, 1, 0);
return reg_c; return reg_c;
} }
}; };
......
...@@ -186,7 +186,8 @@ union float_vec32_t ...@@ -186,7 +186,8 @@ union float_vec32_t
union float_vec64_t union float_vec64_t
{ {
StaticallyIndexedArray<float, 64> s1; StaticallyIndexedArray<float, 64> s1;
StaticallyIndexedArray<float32_t, 2> s32; StaticallyIndexedArray<float_vec32_t, 2> s32;
StaticallyIndexedArray<float32_t, 2> v32;
StaticallyIndexedArray<float64_t, 1> s64; StaticallyIndexedArray<float64_t, 1> s64;
__host__ __device__ constexpr float_vec64_t() {} __host__ __device__ constexpr float_vec64_t() {}
...@@ -210,7 +211,7 @@ union float_vec128_t ...@@ -210,7 +211,7 @@ union float_vec128_t
{ {
StaticallyIndexedArray<float, 64> s1; StaticallyIndexedArray<float, 64> s1;
StaticallyIndexedArray<float_vec16_t, 8> s16; StaticallyIndexedArray<float_vec16_t, 8> s16;
StaticallyIndexedArray<float32_t, 4> s32; StaticallyIndexedArray<float_vec32_t, 4> s32;
StaticallyIndexedArray<float_vec64_t, 2> s64; StaticallyIndexedArray<float_vec64_t, 2> s64;
StaticallyIndexedArray<float128_t, 1> s128; StaticallyIndexedArray<float128_t, 1> s128;
__host__ __device__ constexpr float_vec128_t() {} __host__ __device__ constexpr float_vec128_t() {}
...@@ -264,6 +265,18 @@ constexpr auto GetRegBuffer<float, 16>() ...@@ -264,6 +265,18 @@ constexpr auto GetRegBuffer<float, 16>()
return float_vec16_t{}; return float_vec16_t{};
} }
template <>
constexpr auto GetRegBuffer<float, 64>()
{
return float_vec64_t{};
}
template <>
constexpr auto GetRegBuffer<float, 128>()
{
return float_vec128_t{};
}
struct c_vec32_4_t struct c_vec32_4_t
{ {
union VecType union VecType
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment