Commit 20e6bc9d authored by Jing Zhang's avatar Jing Zhang
Browse files

clean code

parent 8f3c4d86
......@@ -28,7 +28,7 @@ __global__ void
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_E0_E1_K0_K1_E2 A_E0_E1_K0_K1_E2_grid_desc,
const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K_N_H0_H1_H2_W0_W1_W2 c_k_n_h0_h1_h2_w0_w1_w2_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W c_blockid_to_k_n_h_w_block_cluster_adaptor)
......@@ -114,6 +114,7 @@ template <index_t BlockSize,
typename CGridDesc_K_N_Ho_Wo,
index_t E1_,
index_t E2_,
index_t K2_,
index_t KPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
......@@ -152,10 +153,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto E1 = Number<E1_>{};
static constexpr auto E2 = Number<E2_>{};
static constexpr auto E1 = Number<E1_>{};
static constexpr auto E2 = Number<E2_>{};
static constexpr auto K2 = Number<K2_>{};
static constexpr auto NPerBlock = I1;
static constexpr auto K2 = 2;
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
......@@ -181,12 +183,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2);
const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3);
const auto K0 = K / KPerBlock;
const auto N0 = N / NPerBlock;
const auto Ho0 = Ho / HoPerBlock;
const auto Wo0 = Wo / WoPerBlock;
const auto K0 = K / KPerBlock;
const auto N0 = N / NPerBlock;
const auto H0 = Ho / HoPerBlock;
const auto W0 = Wo / WoPerBlock;
const index_t grid_size = K0 * N0 * Ho0 * Wo0;
const index_t grid_size = K0 * N0 * H0 * W0;
return grid_size;
}
......@@ -314,13 +316,13 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2);
const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3);
const auto K0 = K / KPerBlock;
const auto N0 = N / NPerBlock;
const auto Ho0 = Ho / HoPerBlock;
const auto Wo0 = Wo / WoPerBlock;
const auto K0 = K / KPerBlock;
const auto N0 = N / NPerBlock;
const auto H0 = Ho / HoPerBlock;
const auto W0 = Wo / WoPerBlock;
const auto c_blockid_to_k_n_ho_wo_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(K0, N0, Ho0, Wo0))),
make_tuple(make_merge_transform(make_tuple(K0, N0, H0, W0))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
......
......@@ -43,7 +43,7 @@ __global__ void
p_shared_block,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_m0_m1_m2_n_grid_desc,
c_block_cluster_adaptor);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
......
......@@ -124,6 +124,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
constexpr index_t E1 = 2 * 9;
constexpr index_t E2 = 1;
constexpr index_t K2 = 2;
constexpr index_t E1PerBlock = 2;
constexpr index_t KPerThread = 8;
......@@ -151,6 +152,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
TOut,
E1,
E2,
K2,
KPerBlock,
HoPerBlock,
WoPerBlock,
......
......@@ -3,7 +3,7 @@
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
//#include <half.hpp>
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment