Commit 3fb903fa authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into gemm_activation

parents a8b539da 2cbb8976
......@@ -38,7 +38,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
StaticBufferOfVectorTypeV2<AddressSpaceEnum_t::Vgpr,
vector_type<FloatAcc, 16>,
vector_type<FloatAcc, xdlops_gemm.GetRegSizePerXdlops()>,
MRepeat * NRepeat,
true>
c_thread_buf_;
......@@ -136,7 +136,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, I1, M0, M1, M2, N));
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
......
......@@ -552,6 +552,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// main body
index_t k0_block_data_begin = 0;
c_thread_buf.Clear();
if constexpr(HasMainKBlockLoop)
{
do
......@@ -591,6 +593,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// output: register to global memory
{
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
......@@ -603,10 +608,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
make_naive_tensor_descriptor_packed(make_tuple(
Number<M0>{}, Number<N0>{}, I1, I1, Number<M2>{}, I1, Number<M4>{}, I1));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
......
......@@ -507,7 +507,7 @@ struct MfmaSelector
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
__host__ __device__ static constexpr void mfma_check()
__host__ __device__ constexpr MfmaSelector()
{
static_assert(selected_mfma.group_size * selected_mfma.num_groups_per_blk ==
selected_mfma.num_regs_per_blk,
......@@ -533,8 +533,6 @@ struct MfmaSelector
"is_k_reduction wrong!");
}
__host__ __device__ constexpr MfmaSelector() { mfma_check(); }
static constexpr bool IsABroadcast()
{
static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast");
......@@ -621,6 +619,8 @@ struct XdlopsGemm
return MPerXdlops * NPerXdlops / mfma_instr.wave_size;
}
__device__ static constexpr index_t GetWaveSize() { return mfma_instr.wave_size; }
template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{
......
......@@ -22,6 +22,13 @@ struct StaticBufferOfVectorTypeV2 : public StaticallyIndexedArray<T, N>
static constexpr index_t vector_size = GetVectorSize();
__host__ __device__ static constexpr index_t GetNumVectors() { return N; }
__host__ __device__ static constexpr index_t GetNumElements()
{
return GetVectorSize() * GetNumVectors();
}
VecBaseType invalid_element_value_ = VecBaseType{0};
T invalid_vec_value_ = T{0};
......@@ -91,6 +98,12 @@ struct StaticBufferOfVectorTypeV2 : public StaticallyIndexedArray<T, N>
return GetElement(i, true);
}
__host__ __device__ void Clear()
{
static_for<0, GetNumElements(), 1>{}(
[&](auto i) { GetElement(i, true) = invalid_element_value_; });
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
......
......@@ -18,7 +18,7 @@
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 0
#define USE_DYNAMIC_MODE 1
#define USE_CONV_FWD_V4R4_NCHW 0
#define USE_CONV_FWD_V4R4R2_NHWC 0
#define USE_CONV_FWD_V6R1_NCHW 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment