"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "7c35e13337a611258e0403298a268224816881f8"
Commit d798c9b8 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed c_buffer alloc

parent 4041850f
...@@ -38,7 +38,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -38,7 +38,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
StaticBufferOfVectorTypeV2<AddressSpaceEnum_t::Vgpr, StaticBufferOfVectorTypeV2<AddressSpaceEnum_t::Vgpr,
vector_type<FloatAcc, 16>, vector_type<FloatAcc, xdlops_gemm.GetRegSizePerXdlops()>,
MRepeat * NRepeat, MRepeat * NRepeat,
true> true>
c_thread_buf_; c_thread_buf_;
...@@ -136,7 +136,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -136,7 +136,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, I1, M0, M1, M2, N)); return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
} }
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
......
...@@ -557,6 +557,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -557,6 +557,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// output: register to global memory // output: register to global memory
{ {
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
...@@ -569,10 +572,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -569,10 +572,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7); constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
make_naive_tensor_descriptor_packed(make_tuple(
Number<M0>{}, Number<N0>{}, I1, I1, Number<M2>{}, I1, Number<M4>{}, I1));
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
......
...@@ -507,7 +507,7 @@ struct MfmaSelector ...@@ -507,7 +507,7 @@ struct MfmaSelector
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{}; static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
__host__ __device__ static constexpr void mfma_check() __host__ __device__ constexpr MfmaSelector()
{ {
static_assert(selected_mfma.group_size * selected_mfma.num_groups_per_blk == static_assert(selected_mfma.group_size * selected_mfma.num_groups_per_blk ==
selected_mfma.num_regs_per_blk, selected_mfma.num_regs_per_blk,
...@@ -533,8 +533,6 @@ struct MfmaSelector ...@@ -533,8 +533,6 @@ struct MfmaSelector
"is_k_reduction wrong!"); "is_k_reduction wrong!");
} }
__host__ __device__ constexpr MfmaSelector() { mfma_check(); }
static constexpr bool IsABroadcast() static constexpr bool IsABroadcast()
{ {
static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast"); static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast");
...@@ -621,6 +619,8 @@ struct XdlopsGemm ...@@ -621,6 +619,8 @@ struct XdlopsGemm
return MPerXdlops * NPerXdlops / mfma_instr.wave_size; return MPerXdlops * NPerXdlops / mfma_instr.wave_size;
} }
__device__ static constexpr index_t GetWaveSize() { return mfma_instr.wave_size; }
template <class FloatA, class FloatB, class FloatC> template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{ {
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 0 #define USE_DYNAMIC_MODE 1
#define USE_CONV_FWD_V4R4_NCHW 0 #define USE_CONV_FWD_V4R4_NCHW 0
#define USE_CONV_FWD_V4R4R2_NHWC 0 #define USE_CONV_FWD_V4R4R2_NHWC 0
#define USE_CONV_FWD_V6R1_NCHW 0 #define USE_CONV_FWD_V6R1_NCHW 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment