"...composable_kernel_rocm.git" did not exist on "6014185ac65e75f2a84cb67ef6ba83b48ae0fcb3"
Commit 20e6bc9d authored by Jing Zhang's avatar Jing Zhang
Browse files

clean code

parent 8f3c4d86
...@@ -28,7 +28,7 @@ __global__ void ...@@ -28,7 +28,7 @@ __global__ void
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AGridDesc_E0_E1_K0_K1_E2 A_E0_E1_K0_K1_E2_grid_desc, const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K_N_H0_H1_H2_W0_W1_W2 c_k_n_h0_h1_h2_w0_w1_w2_grid_desc, const CGridDesc_K_N_H0_H1_H2_W0_W1_W2 c_k_n_h0_h1_h2_w0_w1_w2_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W c_blockid_to_k_n_h_w_block_cluster_adaptor) const CBlockIdToBlockClusterAdaptor_K_N_H_W c_blockid_to_k_n_h_w_block_cluster_adaptor)
...@@ -114,6 +114,7 @@ template <index_t BlockSize, ...@@ -114,6 +114,7 @@ template <index_t BlockSize,
typename CGridDesc_K_N_Ho_Wo, typename CGridDesc_K_N_Ho_Wo,
index_t E1_, index_t E1_,
index_t E2_, index_t E2_,
index_t K2_,
index_t KPerBlock, index_t KPerBlock,
index_t HoPerBlock, index_t HoPerBlock,
index_t WoPerBlock, index_t WoPerBlock,
...@@ -152,10 +153,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -152,10 +153,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3
static constexpr auto I4 = Number<4>{}; static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
static constexpr auto E1 = Number<E1_>{}; static constexpr auto E1 = Number<E1_>{};
static constexpr auto E2 = Number<E2_>{}; static constexpr auto E2 = Number<E2_>{};
static constexpr auto K2 = Number<K2_>{};
static constexpr auto NPerBlock = I1; static constexpr auto NPerBlock = I1;
static constexpr auto K2 = 2;
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
...@@ -181,12 +183,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -181,12 +183,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2); const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2);
const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3); const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3);
const auto K0 = K / KPerBlock; const auto K0 = K / KPerBlock;
const auto N0 = N / NPerBlock; const auto N0 = N / NPerBlock;
const auto Ho0 = Ho / HoPerBlock; const auto H0 = Ho / HoPerBlock;
const auto Wo0 = Wo / WoPerBlock; const auto W0 = Wo / WoPerBlock;
const index_t grid_size = K0 * N0 * Ho0 * Wo0; const index_t grid_size = K0 * N0 * H0 * W0;
return grid_size; return grid_size;
} }
...@@ -314,13 +316,13 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -314,13 +316,13 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2); const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2);
const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3); const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3);
const auto K0 = K / KPerBlock; const auto K0 = K / KPerBlock;
const auto N0 = N / NPerBlock; const auto N0 = N / NPerBlock;
const auto Ho0 = Ho / HoPerBlock; const auto H0 = Ho / HoPerBlock;
const auto Wo0 = Wo / WoPerBlock; const auto W0 = Wo / WoPerBlock;
const auto c_blockid_to_k_n_ho_wo_block_cluster_adaptor = make_single_stage_tensor_adaptor( const auto c_blockid_to_k_n_ho_wo_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(K0, N0, Ho0, Wo0))), make_tuple(make_merge_transform(make_tuple(K0, N0, H0, W0))),
make_tuple(Sequence<0, 1, 2, 3>{}), make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
......
...@@ -43,7 +43,7 @@ __global__ void ...@@ -43,7 +43,7 @@ __global__ void
p_shared_block, p_shared_block,
a_k0_m_k1_grid_desc, a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc, b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_m0_m1_m2_n_grid_desc,
c_block_cluster_adaptor); c_block_cluster_adaptor);
} }
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
......
...@@ -124,6 +124,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( ...@@ -124,6 +124,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
constexpr index_t E1 = 2 * 9; constexpr index_t E1 = 2 * 9;
constexpr index_t E2 = 1; constexpr index_t E2 = 1;
constexpr index_t K2 = 2;
constexpr index_t E1PerBlock = 2; constexpr index_t E1PerBlock = 2;
constexpr index_t KPerThread = 8; constexpr index_t KPerThread = 8;
...@@ -151,6 +152,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( ...@@ -151,6 +152,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
TOut, TOut,
E1, E1,
E2, E2,
K2,
KPerBlock, KPerBlock,
HoPerBlock, HoPerBlock,
WoPerBlock, WoPerBlock,
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include <initializer_list> #include <initializer_list>
#include <cstdlib> #include <cstdlib>
#include <stdlib.h> #include <stdlib.h>
#include <half.hpp> //#include <half.hpp>
#include "config.hpp" #include "config.hpp"
#include "print.hpp" #include "print.hpp"
#include "device.hpp" #include "device.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment