Commit 90ec6a19 authored by Jing Zhang's avatar Jing Zhang
Browse files

added 128x128 wavegemm

parent 1d48b521
...@@ -160,7 +160,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -160,7 +160,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
xdlops_gemm.template Run2<decltype(a_thread_desc_), xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_), decltype(b_thread_desc_),
decltype(c_thread_desc_), decltype(c_thread_desc_),
m0, m0,
...@@ -372,14 +372,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -372,14 +372,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
a_thread_buf); a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0 // C_sub_00 += transpose(A_sub_0) * B_sub_0
xdlops_gemm.template Run2<decltype(a_thread_desc_), xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_), decltype(b_thread_desc_),
decltype(c_thread_desc_), decltype(c_thread_desc_),
0, 0,
0>(a_thread_buf, b_thread_buf, c_thread_buf); 0>(a_thread_buf, b_thread_buf, c_thread_buf);
// C_sub_01 += transpose(A_sub_0) * B_sub_1 // C_sub_01 += transpose(A_sub_0) * B_sub_1
xdlops_gemm.template Run2<decltype(a_thread_desc_), xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_), decltype(b_thread_desc_),
decltype(c_thread_desc_), decltype(c_thread_desc_),
0, 0,
...@@ -395,7 +395,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -395,7 +395,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
a_thread_buf); a_thread_buf);
// C_sub_10 += transpose(A_sub_1) * B_sub_0 // C_sub_10 += transpose(A_sub_1) * B_sub_0
xdlops_gemm.template Run2<decltype(a_thread_desc_), xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_), decltype(b_thread_desc_),
decltype(c_thread_desc_), decltype(c_thread_desc_),
1, 1,
...@@ -410,7 +410,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -410,7 +410,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
b_thread_buf); b_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1 // C_sub_11 += transpose(A_sub_1) * B_sub_1
xdlops_gemm.template Run2<decltype(a_thread_desc_), xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_), decltype(b_thread_desc_),
decltype(c_thread_desc_), decltype(c_thread_desc_),
1, 1,
...@@ -433,14 +433,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -433,14 +433,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
a_thread_buf); a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0 // C_sub_00 += transpose(A_sub_0) * B_sub_0
xdlops_gemm.template Run2<decltype(a_thread_desc_), xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_), decltype(b_thread_desc_),
decltype(c_thread_desc_), decltype(c_thread_desc_),
0, 0,
0>(a_thread_buf, b_thread_buf, c_thread_buf); 0>(a_thread_buf, b_thread_buf, c_thread_buf);
// C_sub_01 += transpose(A_sub_0) * B_sub_1 // C_sub_01 += transpose(A_sub_0) * B_sub_1
xdlops_gemm.template Run2<decltype(a_thread_desc_), xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_), decltype(b_thread_desc_),
decltype(c_thread_desc_), decltype(c_thread_desc_),
0, 0,
...@@ -448,14 +448,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -448,14 +448,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
}); });
// C_sub_10 += transpose(A_sub_1) * B_sub_0 // C_sub_10 += transpose(A_sub_1) * B_sub_0
xdlops_gemm.template Run2<decltype(a_thread_desc_), xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_), decltype(b_thread_desc_),
decltype(c_thread_desc_), decltype(c_thread_desc_),
1, 1,
0>(a_thread_buf, b_thread_buf, c_thread_buf); 0>(a_thread_buf, b_thread_buf, c_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1 // C_sub_11 += transpose(A_sub_1) * B_sub_1
xdlops_gemm.template Run2<decltype(a_thread_desc_), xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_), decltype(b_thread_desc_),
decltype(c_thread_desc_), decltype(c_thread_desc_),
1, 1,
......
...@@ -518,7 +518,7 @@ template <mfma_instr instr, ...@@ -518,7 +518,7 @@ template <mfma_instr instr,
index_t NPerXdlops_, index_t NPerXdlops_,
index_t MRepeats_, index_t MRepeats_,
index_t NRepeats_, index_t NRepeats_,
class OutputVecType_> class CType_>
struct xdlops_info struct xdlops_info
{ {
static constexpr auto mfma_type = mfma_info<instr>{}; static constexpr auto mfma_type = mfma_info<instr>{};
...@@ -540,196 +540,74 @@ struct xdlops_info ...@@ -540,196 +540,74 @@ struct xdlops_info
return mfma_type.k_base * (IsKReduction() ? mfma_type.num_input_blks : 1); return mfma_type.k_base * (IsKReduction() ? mfma_type.num_input_blks : 1);
} }
static constexpr auto OutputVecType = OutputVecType_{}; static constexpr index_t GetNumCRegs() { return MPerXdlops * NPerXdlops / mfma_type.wave_size; }
static constexpr auto GetCType() { return CType_{}; }
}; };
template <class data_type, index_t MPerWave, index_t NPerWave, index_t KPerWave> template <class base_type, index_t MPerWave, index_t NPerWave, index_t KPerWave>
struct XdlopsGemm struct XdlopsGemm
{ {
struct MatrixIndex template <class base_type_ = base_type,
{
index_t row;
index_t col;
};
__device__ static constexpr index_t GetNumBlksPerXdlops()
{
return (MPerXdlops * NPerXdlops) / (mfma_type.m * mfma_type.n);
}
__host__ __device__ constexpr XdlopsGemm()
{
static_assert(NPerXdlops == 4 || NPerXdlops == 8 || NPerXdlops == 16 || NPerXdlops == 32 ||
NPerXdlops == 64,
"Only support GemmNPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 ||
MPerXdlops == 64,
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
static_assert(mfma_type.num_threads_blk == mfma_type.n, "n != num_threads_blk");
static_assert(mfma_type.num_regs_blk * mfma_type.num_input_blks == mfma_type.m,
"m != num_input_blks * num_regs_blk");
static_assert(mfma_type.num_output_blks == mfma_type.num_input_blks ||
mfma_type.num_output_blks == 1,
"incorrect num_output_blks");
static_assert(mfma_type.num_regs_blk * mfma_type.wave_size == mfma_type.m * mfma_type.n,
"num_regs_blk incorrect");
static_assert(mfma_type.k % mfma_type.k_base == 0, "k % kbase != 0!");
}
__device__ static constexpr index_t GetRegSizePerXdlops()
{
return MPerXdlops * NPerXdlops / mfma_type.wave_size;
}
template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{
static_assert(is_same<data_type, float>::value || is_same<data_type, half_t>::value ||
is_same<data_type, ushort>::value,
"base data_type must be float, half, ushort!");
static_assert(KPerWave % KPerXdlops == 0, "KPerWave cannot be divided by KPerXdlops");
static_for<0, KPerWave, KPerXdlops>{}([&](auto k_i) {
mfma_type.template run<MPerXdlops, NPerXdlops>(
p_a_wave[Number<k_i>{}], p_b_wave[Number<k_i>{}], p_c_thread);
});
}
template <class ADesc,
class BDesc,
class CDesc,
index_t m0,
index_t n0,
class FloatA,
class FloatB,
class FloatC>
__device__ void Run2(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{
static_assert(is_same<data_type, float>::value || is_same<data_type, half_t>::value ||
is_same<data_type, ushort>::value,
"base data_type must be float, half, ushort!");
static_assert(KPerWave % KPerXdlops == 0, "KPerWave cannot be divided by KPerXdlops");
static_for<0, KPerWave, KPerXdlops>{}([&](auto k) {
constexpr index_t a_offset = ADesc{}.CalculateOffset(make_multi_index(k, m0, 0));
constexpr index_t b_offset = BDesc{}.CalculateOffset(make_multi_index(k, n0, 0));
constexpr index_t c_offset = CDesc{}.CalculateOffset(make_multi_index(m0, n0));
mfma_type.template run<MPerXdlops, NPerXdlops>(p_a_wave[Number<a_offset>{}],
p_b_wave[Number<b_offset>{}],
p_c_thread.template AsType<float32_t>());
});
}
__device__ static MatrixIndex GetBeginOfThreadBlk(index_t i)
{
const index_t xdlops_i = i / GetNumBlksPerXdlops();
const index_t j = i % GetNumBlksPerXdlops();
const index_t m_i = xdlops_i / NRepeats;
const index_t n_i = xdlops_i % NRepeats;
const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size;
const index_t blk_id = laneId / mfma_type.num_threads_blk;
const index_t blk_td = laneId % mfma_type.num_threads_blk;
index_t col_blk = j % mfma_type.num_output_blks;
index_t row_blk = j / mfma_type.num_output_blks;
static_if<!IsABroadcast>{}([&](auto) {
col_blk = j / mfma_type.num_output_blks;
row_blk = j % mfma_type.num_output_blks;
});
index_t col = col_blk * mfma_type.n + blk_td + n_i * NPerXdlops;
index_t row = row_blk * mfma_type.m + blk_id * mfma_type.group_size + m_i * MPerXdlops;
return MatrixIndex{row, col};
}
__device__ void SetZeroXdlopsRegs() const {}
template <class FloatC>
__device__ void ReadXdlopsRegs(FloatC* const __restrict__) const
{
}
template <class data_type_ = data_type,
index_t MPerWave_ = MPerWave, index_t MPerWave_ = MPerWave,
index_t NPerWave_ = NPerWave> index_t NPerWave_ = NPerWave>
static constexpr auto GetXdlopsInfo(); static constexpr auto GetXdlopsInfo();
template <>
static constexpr auto GetXdlopsInfo<float, 128, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64, 2, 1, c_vec32_4_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 64, 128>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64, 1, 2, c_vec32_4_t>{};
}
template <> template <>
static constexpr auto GetXdlopsInfo<float, 64, 64>() static constexpr auto GetXdlopsInfo<float, 64, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64, 1, 1, c_vec32_2_t>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64, 1, 1, float64_t>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 64, 32>() static constexpr auto GetXdlopsInfo<float, 64, 32>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 32, 1, 1, c_vec32_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 32, 1, 1, float32_t>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 32, 64>() static constexpr auto GetXdlopsInfo<float, 32, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 32, 64, 1, 1, c_vec32_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 32, 64, 1, 1, float32_t>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 64, 16>() static constexpr auto GetXdlopsInfo<float, 64, 16>()
{ {
return xdlops_info<mfma_instr::mfma_f32_16x16x1xf32, 64, 16, 1, 1, c_vec16_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_16x16x1xf32, 64, 16, 1, 1, float16_t>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 16, 64>() static constexpr auto GetXdlopsInfo<float, 16, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_16x16x1xf32, 16, 64, 1, 1, c_vec16_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_16x16x1xf32, 16, 64, 1, 1, float16_t>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 8, 64>() static constexpr auto GetXdlopsInfo<float, 8, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 8, 64, 1, 1, c_vec4_2_t>{}; return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 8, 64, 1, 1, float8_t>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 4, 64>() static constexpr auto GetXdlopsInfo<float, 4, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 4, 64, 1, 1, c_vec4_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 4, 64, 1, 1, float4_t>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 32, 32>() static constexpr auto GetXdlopsInfo<float, 32, 32>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x2xf32, 32, 32, 1, 1, c_vec16_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x2xf32, 32, 32, 1, 1, float16_t>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 16, 16>() static constexpr auto GetXdlopsInfo<float, 16, 16>()
{ {
return xdlops_info<mfma_instr::mfma_f32_16x16x4xf32, 16, 16, 1, 1, c_vec4_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_16x16x4xf32, 16, 16, 1, 1, float4_t>{};
} }
#if 0
template <> template <>
static constexpr auto GetXdlopsInfo<half_t, 128, 64>() static constexpr auto GetXdlopsInfo<half_t, 128, 64>()
{ {
...@@ -861,6 +739,107 @@ struct XdlopsGemm ...@@ -861,6 +739,107 @@ struct XdlopsGemm
{ {
return xdlops_info<mfma_instr::mfma_f32_16x16x8bf16, 16, 16, 1, 1, c_vec4_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_16x16x8bf16, 16, 16, 1, 1, c_vec4_1_t>{};
} }
#endif
struct MatrixIndex
{
index_t row;
index_t col;
};
__device__ static constexpr index_t GetNumBlksPerXdlops()
{
return (MPerXdlops * NPerXdlops) / (mfma_type.m * mfma_type.n);
}
__host__ __device__ constexpr XdlopsGemm()
{
static_assert(NPerXdlops == 4 || NPerXdlops == 8 || NPerXdlops == 16 || NPerXdlops == 32 ||
NPerXdlops == 64,
"Only support GemmNPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 ||
MPerXdlops == 64,
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
static_assert(mfma_type.num_threads_blk == mfma_type.n, "n != num_threads_blk");
static_assert(mfma_type.num_regs_blk * mfma_type.num_input_blks == mfma_type.m,
"m != num_input_blks * num_regs_blk");
static_assert(mfma_type.num_output_blks == mfma_type.num_input_blks ||
mfma_type.num_output_blks == 1,
"incorrect num_output_blks");
static_assert(mfma_type.num_regs_blk * mfma_type.wave_size == mfma_type.m * mfma_type.n,
"num_regs_blk incorrect");
static_assert(mfma_type.k % mfma_type.k_base == 0, "k % kbase != 0!");
}
__device__ static constexpr index_t GetRegSizePerXdlops()
{
return MPerXdlops * NPerXdlops / mfma_type.wave_size;
}
template <class ADesc,
class BDesc,
class CDesc,
index_t m0,
index_t n0,
class FloatA,
class FloatB,
class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{
static_assert(is_same<base_type, float>::value || is_same<base_type, half_t>::value ||
is_same<base_type, ushort>::value,
"base base_type must be float, half, ushort!");
static_assert(KPerWave % KPerXdlops == 0, "KPerWave cannot be divided by KPerXdlops");
static_for<0, KPerWave, KPerXdlops>{}([&](auto k) {
constexpr index_t a_offset = ADesc{}.CalculateOffset(make_multi_index(k, m0, 0));
constexpr index_t b_offset = BDesc{}.CalculateOffset(make_multi_index(k, n0, 0));
constexpr index_t c_offset = CDesc{}.CalculateOffset(make_multi_index(m0, n0));
vector_type<base_type, GetXdlopsInfo().GetNumCRegs()> t;
using c_type = decltype(GetXdlopsInfo().GetCType());
t.template AsType<c_type>()(Number<0>{}) =
p_c_thread.template AsType<c_type>()[Number<c_offset>{}];
mfma_type.template run<MPerXdlops, NPerXdlops>(
p_a_wave[Number<a_offset>{}], p_b_wave[Number<b_offset>{}], t);
p_c_thread.template AsType<c_type>()(Number<c_offset>{}) =
t.template AsType<c_type>()[Number<0>{}];
});
}
__device__ static MatrixIndex GetBeginOfThreadBlk(index_t i)
{
const index_t xdlops_i = i / GetNumBlksPerXdlops();
const index_t j = i % GetNumBlksPerXdlops();
const index_t m_i = xdlops_i / NRepeats;
const index_t n_i = xdlops_i % NRepeats;
const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size;
const index_t blk_id = laneId / mfma_type.num_threads_blk;
const index_t blk_td = laneId % mfma_type.num_threads_blk;
index_t col_blk = j % mfma_type.num_output_blks;
index_t row_blk = j / mfma_type.num_output_blks;
static_if<!IsABroadcast>{}([&](auto) {
col_blk = j / mfma_type.num_output_blks;
row_blk = j % mfma_type.num_output_blks;
});
index_t col = col_blk * mfma_type.n + blk_td + n_i * NPerXdlops;
index_t row = row_blk * mfma_type.m + blk_id * mfma_type.group_size + m_i * MPerXdlops;
return MatrixIndex{row, col};
}
static constexpr index_t MRepeats = GetXdlopsInfo().MRepeats; static constexpr index_t MRepeats = GetXdlopsInfo().MRepeats;
static constexpr index_t NRepeats = GetXdlopsInfo().NRepeats; static constexpr index_t NRepeats = GetXdlopsInfo().NRepeats;
...@@ -896,11 +875,6 @@ struct XdlopsGemm ...@@ -896,11 +875,6 @@ struct XdlopsGemm
{ {
return GetNumBlksPerXdlops() * MRepeats * NRepeats; return GetNumBlksPerXdlops() * MRepeats * NRepeats;
} }
__device__ static constexpr auto CreateOutputVecZero()
{
return GetXdlopsInfo().OutputVecType.CreateVecZero();
}
}; };
__host__ __device__ static constexpr auto GetOutputLayout() { return OutputLayout{}; } __host__ __device__ static constexpr auto GetOutputLayout() { return OutputLayout{}; }
......
...@@ -243,10 +243,10 @@ struct intrin_mfma_f32_32x32x1f32<64, 64> ...@@ -243,10 +243,10 @@ struct intrin_mfma_f32_32x32x1f32<64, 64>
template <class FloatA, class FloatB, class FloatC> template <class FloatA, class FloatB, class FloatC>
__device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c) __device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c)
{ {
reg_c(Number<0>{}) = reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a, reg_b, reg_c[Number<0>{}], 1, 0, 0); reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
reg_c(Number<1>{}) = reg_c.template AsType<float32_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a, reg_b, reg_c[Number<1>{}], 1, 1, 0); reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
} }
}; };
...@@ -278,9 +278,11 @@ struct intrin_mfma_f32_32x32x2f32; ...@@ -278,9 +278,11 @@ struct intrin_mfma_f32_32x32x2f32;
template <> template <>
struct intrin_mfma_f32_32x32x2f32<32, 32> struct intrin_mfma_f32_32x32x2f32<32, 32>
{ {
__device__ static void Run(const float& reg_a, const float& reg_b, float16_t& reg_c) template <class FloatA, class FloatB, class FloatC>
__device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c)
{ {
reg_c = llvm_intrin_amdgcn_mfma_f32_32x32x2f32(reg_a, reg_b, reg_c, 0, 0, 0); reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
} }
}; };
......
...@@ -618,6 +618,252 @@ struct vector_type<T, 64> ...@@ -618,6 +618,252 @@ struct vector_type<T, 64>
} }
}; };
template <typename T>
struct vector_type<T, 128>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
typedef T d64_t __attribute__((ext_vector_type(64)));
typedef T d128_t __attribute__((ext_vector_type(128)));
using type = d128_t;
union
{
d128_t d128_;
StaticallyIndexedArray<d1_t, 128> d1x128_;
StaticallyIndexedArray<d2_t, 64> d2x64_;
StaticallyIndexedArray<d4_t, 32> d4x32_;
StaticallyIndexedArray<d8_t, 16> d8x16_;
StaticallyIndexedArray<d16_t, 8> d16x8_;
StaticallyIndexedArray<d32_t, 4> d32x4_;
StaticallyIndexedArray<d64_t, 2> d64x2_;
StaticallyIndexedArray<d128_t, 1> d128x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x128_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x64_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x32_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x16_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x8_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x4_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x2_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x128_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x64_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x32_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x16_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x8_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x4_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x2_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x1_;
}
}
};
template <typename T>
struct vector_type<T, 256>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
typedef T d64_t __attribute__((ext_vector_type(64)));
typedef T d128_t __attribute__((ext_vector_type(128)));
typedef T d256_t __attribute__((ext_vector_type(256)));
using type = d256_t;
union
{
d256_t d256_;
StaticallyIndexedArray<d1_t, 256> d1x256_;
StaticallyIndexedArray<d2_t, 128> d2x128_;
StaticallyIndexedArray<d4_t, 64> d4x64_;
StaticallyIndexedArray<d8_t, 32> d8x32_;
StaticallyIndexedArray<d16_t, 16> d16x16_;
StaticallyIndexedArray<d32_t, 8> d32x8_;
StaticallyIndexedArray<d64_t, 4> d64x4_;
StaticallyIndexedArray<d128_t, 2> d128x2_;
StaticallyIndexedArray<d256_t, 1> d256x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(
is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x256_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x128_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x64_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x32_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x16_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x8_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x4_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x2_;
}
else if constexpr(is_same<X, d256_t>::value)
{
return data_.d256x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(
is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x256_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x128_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x64_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x32_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x16_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x8_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x4_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x2_;
}
else if constexpr(is_same<X, d256_t>::value)
{
return data_.d256x1_;
}
}
};
// fp32 // fp32
using float2_t = typename vector_type<float, 2>::type; using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type; using float4_t = typename vector_type<float, 4>::type;
......
...@@ -102,30 +102,30 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -102,30 +102,30 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#else #else
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 64; constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 64; constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerWave = 64; constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64; constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPerWave = 1; constexpr index_t GemmKPerWave = 4;
constexpr index_t MRepeat = 1; constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1; constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>; using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<2, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>; using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 2;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>; using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 32>; using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment