Commit 3fb903fa authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into gemm_activation

parents a8b539da 2cbb8976
...@@ -38,7 +38,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -38,7 +38,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
StaticBufferOfVectorTypeV2<AddressSpaceEnum_t::Vgpr, StaticBufferOfVectorTypeV2<AddressSpaceEnum_t::Vgpr,
vector_type<FloatAcc, 16>, vector_type<FloatAcc, xdlops_gemm.GetRegSizePerXdlops()>,
MRepeat * NRepeat, MRepeat * NRepeat,
true> true>
c_thread_buf_; c_thread_buf_;
...@@ -136,7 +136,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -136,7 +136,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, I1, M0, M1, M2, N)); return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
} }
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
......
...@@ -552,6 +552,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -552,6 +552,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// main body // main body
index_t k0_block_data_begin = 0; index_t k0_block_data_begin = 0;
c_thread_buf.Clear();
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
do do
...@@ -591,6 +593,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -591,6 +593,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// output: register to global memory // output: register to global memory
{ {
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
...@@ -603,10 +608,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -603,10 +608,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7); constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
make_naive_tensor_descriptor_packed(make_tuple(
Number<M0>{}, Number<N0>{}, I1, I1, Number<M2>{}, I1, Number<M4>{}, I1));
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
......
...@@ -507,7 +507,7 @@ struct MfmaSelector ...@@ -507,7 +507,7 @@ struct MfmaSelector
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{}; static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
__host__ __device__ static constexpr void mfma_check() __host__ __device__ constexpr MfmaSelector()
{ {
static_assert(selected_mfma.group_size * selected_mfma.num_groups_per_blk == static_assert(selected_mfma.group_size * selected_mfma.num_groups_per_blk ==
selected_mfma.num_regs_per_blk, selected_mfma.num_regs_per_blk,
...@@ -533,8 +533,6 @@ struct MfmaSelector ...@@ -533,8 +533,6 @@ struct MfmaSelector
"is_k_reduction wrong!"); "is_k_reduction wrong!");
} }
__host__ __device__ constexpr MfmaSelector() { mfma_check(); }
static constexpr bool IsABroadcast() static constexpr bool IsABroadcast()
{ {
static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast"); static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast");
...@@ -621,6 +619,8 @@ struct XdlopsGemm ...@@ -621,6 +619,8 @@ struct XdlopsGemm
return MPerXdlops * NPerXdlops / mfma_instr.wave_size; return MPerXdlops * NPerXdlops / mfma_instr.wave_size;
} }
__device__ static constexpr index_t GetWaveSize() { return mfma_instr.wave_size; }
template <class FloatA, class FloatB, class FloatC> template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{ {
......
...@@ -22,6 +22,13 @@ struct StaticBufferOfVectorTypeV2 : public StaticallyIndexedArray<T, N> ...@@ -22,6 +22,13 @@ struct StaticBufferOfVectorTypeV2 : public StaticallyIndexedArray<T, N>
static constexpr index_t vector_size = GetVectorSize(); static constexpr index_t vector_size = GetVectorSize();
__host__ __device__ static constexpr index_t GetNumVectors() { return N; }
__host__ __device__ static constexpr index_t GetNumElements()
{
return GetVectorSize() * GetNumVectors();
}
VecBaseType invalid_element_value_ = VecBaseType{0}; VecBaseType invalid_element_value_ = VecBaseType{0};
T invalid_vec_value_ = T{0}; T invalid_vec_value_ = T{0};
...@@ -91,6 +98,12 @@ struct StaticBufferOfVectorTypeV2 : public StaticallyIndexedArray<T, N> ...@@ -91,6 +98,12 @@ struct StaticBufferOfVectorTypeV2 : public StaticallyIndexedArray<T, N>
return GetElement(i, true); return GetElement(i, true);
} }
__host__ __device__ void Clear()
{
static_for<0, GetNumElements(), 1>{}(
[&](auto i) { GetElement(i, true) = invalid_element_value_; });
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; } __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 0 #define USE_DYNAMIC_MODE 1
#define USE_CONV_FWD_V4R4_NCHW 0 #define USE_CONV_FWD_V4R4_NCHW 0
#define USE_CONV_FWD_V4R4R2_NHWC 0 #define USE_CONV_FWD_V4R4R2_NHWC 0
#define USE_CONV_FWD_V6R1_NCHW 0 #define USE_CONV_FWD_V6R1_NCHW 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment