Commit 1d48b521 authored by Jing Zhang's avatar Jing Zhang
Browse files

clean code

parent c0ffe379
...@@ -24,6 +24,8 @@ template <index_t BlockSize, ...@@ -24,6 +24,8 @@ template <index_t BlockSize,
index_t MPerWave, index_t MPerWave,
index_t NPerWave, index_t NPerWave,
index_t KPerWave, index_t KPerWave,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K_M, typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M, typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
...@@ -99,6 +101,8 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global, ...@@ -99,6 +101,8 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
MPerWave, MPerWave,
NPerWave, NPerWave,
KPerWave, KPerWave,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K_M, ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M, ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
......
...@@ -125,8 +125,203 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -125,8 +125,203 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
"wrong! K dimension not consistent"); "wrong! K dimension not consistent");
// static_assert(MPerWave * MWaves == MPerBlock, "GemmMWaves * MPerWave != M"); static_assert(BlockSize == MWaves * NWaves * WaveSize,
// static_assert(NPerWave * NWaves == NPerBlock, "GemmNWaves * NPerWave != N"); "BlockSize != MWaves * NWaves * WaveSize\n");
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf =
make_static_buffer<AddressSpace::Vgpr, FloatA>(a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf =
make_static_buffer<AddressSpace::Vgpr, FloatB>(b_thread_desc_.GetElementSpaceSize());
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
static_for<0, KPerBlock, KPerWave>{}([&](auto k) {
// read A
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0),
a_thread_buf);
// read B
b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0),
b_thread_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
xdlops_gemm.template Run2<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
m0,
n0>(a_thread_buf, b_thread_buf, c_thread_buf);
});
});
});
}
private:
// A[K, M]
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerWave>{}, Number<MRepeat>{}, I1));
// B[K, N]
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerWave>{}, Number<NRepeat>{}, I1));
static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
FloatA,
ABlockDesc,
decltype(a_thread_desc_),
Sequence<KPerWave, MRepeat, 1>,
Sequence<0, 1, 2>,
2,
1,
1>;
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
FloatB,
BBlockDesc,
decltype(b_thread_desc_),
Sequence<KPerWave, NRepeat, 1>,
Sequence<0, 1, 2>,
2,
1,
1>;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
template <index_t BlockSize,
typename FloatA,
typename FloatB,
class ABlockDesc,
class BBlockDesc,
index_t MPerWave,
index_t NPerWave,
index_t KPerWave>
struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
{
using CIndex = MultiIndex<2>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto xdlops_gemm = XdlopsGemm<float, MPerWave, NPerWave, KPerWave>{};
static constexpr index_t WaveSize = 64;
static constexpr index_t M0 = ABlockDesc{}.GetLength(I1);
static constexpr index_t M1 = ABlockDesc{}.GetLength(I2);
static constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
static constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
static constexpr index_t MWaves = M1 / MPerWave;
static constexpr index_t NWaves = N1 / NPerWave;
static constexpr index_t MRepeat = M0;
static constexpr index_t NRepeat = N0;
__device__ constexpr auto GetOutputLayout() const { return xdlops_gemm.GetOutputLayout(); }
__device__ constexpr auto GetNumBlks() const
{
return xdlops_gemm.GetOutputLayout().GetNumBlks();
}
__device__ constexpr auto GetBlkSize() const
{
return xdlops_gemm.GetOutputLayout().GetBlkSize();
}
__device__ static auto CalculateAThreadOriginDataIndex()
{
const index_t thread_id = get_thread_local_1d_id();
const index_t waveId = thread_id / WaveSize;
const index_t laneId = thread_id % WaveSize;
const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves;
if constexpr(xdlops_gemm.IsKReduction)
{
const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId);
const index_t k_offset = xdlops_gemm.GetBlkId(laneId) * xdlops_gemm.mfma_type.k_base;
return make_tuple(k_offset, 0, m_offset);
}
else
{
const index_t m_offset = waveId_m * MPerWave + laneId;
const index_t k_offset = 0;
return make_tuple(k_offset, 0, m_offset);
}
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
const index_t thread_id = get_thread_local_1d_id();
const index_t waveId = thread_id / WaveSize;
const index_t laneId = thread_id % WaveSize;
const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves;
if constexpr(xdlops_gemm.IsKReduction)
{
const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId);
const index_t k_offset = xdlops_gemm.GetBlkId(laneId) * xdlops_gemm.mfma_type.k_base;
return make_tuple(k_offset, 0, n_offset);
}
else
{
const index_t n_offset = waveId_n * NPerWave + laneId;
const index_t k_offset = 0;
return make_tuple(k_offset, 0, n_offset);
}
}
__device__ static CIndex
CalculateCThreadOriginDataIndex(const index_t m0, const index_t n0, const index_t blk_i)
{
const index_t waveId = get_thread_local_1d_id() / WaveSize;
const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(blk_i);
const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves;
const index_t row = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk.row;
const index_t col = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk.col;
return CIndex{row, col};
}
__device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline()
: a_thread_copy_{CalculateAThreadOriginDataIndex()},
b_thread_copy_{CalculateBThreadOriginDataIndex()}
{
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
"wrong! K dimension not consistent");
static_assert(BlockSize == MWaves * NWaves * WaveSize, static_assert(BlockSize == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n"); "BlockSize != MWaves * NWaves * WaveSize\n");
......
...@@ -111,6 +111,8 @@ template <index_t BlockSize, ...@@ -111,6 +111,8 @@ template <index_t BlockSize,
index_t MPerWave, index_t MPerWave,
index_t NPerWave, index_t NPerWave,
index_t KPerWave, index_t KPerWave,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K_M, typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M, typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
...@@ -278,8 +280,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -278,8 +280,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
static_assert(MPerBlock % (MPerWave * MRepeat) == 0 && static_assert(MPerBlock % (MPerWave * MRepeat) == 0 &&
NPerBlock % (NPerWave * NRepeat) == 0, NPerBlock % (NPerWave * NRepeat) == 0,
......
...@@ -620,10 +620,9 @@ struct XdlopsGemm ...@@ -620,10 +620,9 @@ struct XdlopsGemm
constexpr index_t b_offset = BDesc{}.CalculateOffset(make_multi_index(k, n0, 0)); constexpr index_t b_offset = BDesc{}.CalculateOffset(make_multi_index(k, n0, 0));
constexpr index_t c_offset = CDesc{}.CalculateOffset(make_multi_index(m0, n0)); constexpr index_t c_offset = CDesc{}.CalculateOffset(make_multi_index(m0, n0));
mfma_type.template run<MPerXdlops, NPerXdlops>( mfma_type.template run<MPerXdlops, NPerXdlops>(p_a_wave[Number<a_offset>{}],
p_a_wave[Number<a_offset>{}], p_b_wave[Number<b_offset>{}],
p_b_wave[Number<b_offset>{}], p_c_thread.template AsType<float32_t>());
p_c_thread.template AsType<float16_t>()(Number<c_offset>{}));
}); });
} }
......
...@@ -240,13 +240,13 @@ struct intrin_mfma_f32_32x32x1f32; ...@@ -240,13 +240,13 @@ struct intrin_mfma_f32_32x32x1f32;
template <> template <>
struct intrin_mfma_f32_32x32x1f32<64, 64> struct intrin_mfma_f32_32x32x1f32<64, 64>
{ {
__device__ static void template <class FloatA, class FloatB, class FloatC>
Run(const float& reg_a, const float& reg_b, vector_type<float, 64>& reg_c) __device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32( reg_c(Number<0>{}) =
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0); llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a, reg_b, reg_c[Number<0>{}], 1, 0, 0);
reg_c.template AsType<float32_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32( reg_c(Number<1>{}) =
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0); llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a, reg_b, reg_c[Number<1>{}], 1, 1, 0);
} }
}; };
......
...@@ -108,9 +108,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -108,9 +108,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr index_t GemmNPerBlock = 64; constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerWave = 32; constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 32; constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPerWave = 2; constexpr index_t GemmKPerWave = 1;
constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>; using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>; using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>;
...@@ -159,6 +162,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -159,6 +162,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
GemmMPerWave, GemmMPerWave,
GemmNPerWave, GemmNPerWave,
GemmKPerWave, GemmKPerWave,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM, GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM, GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>, Sequence<1, 0>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment