Commit 20e6bc9d authored by Jing Zhang's avatar Jing Zhang
Browse files

clean code

parent 8f3c4d86
...@@ -28,7 +28,7 @@ __global__ void ...@@ -28,7 +28,7 @@ __global__ void
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AGridDesc_E0_E1_K0_K1_E2 A_E0_E1_K0_K1_E2_grid_desc, const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K_N_H0_H1_H2_W0_W1_W2 c_k_n_h0_h1_h2_w0_w1_w2_grid_desc, const CGridDesc_K_N_H0_H1_H2_W0_W1_W2 c_k_n_h0_h1_h2_w0_w1_w2_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W c_blockid_to_k_n_h_w_block_cluster_adaptor) const CBlockIdToBlockClusterAdaptor_K_N_H_W c_blockid_to_k_n_h_w_block_cluster_adaptor)
...@@ -114,6 +114,7 @@ template <index_t BlockSize, ...@@ -114,6 +114,7 @@ template <index_t BlockSize,
typename CGridDesc_K_N_Ho_Wo, typename CGridDesc_K_N_Ho_Wo,
index_t E1_, index_t E1_,
index_t E2_, index_t E2_,
index_t K2_,
index_t KPerBlock, index_t KPerBlock,
index_t HoPerBlock, index_t HoPerBlock,
index_t WoPerBlock, index_t WoPerBlock,
...@@ -152,10 +153,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -152,10 +153,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3
static constexpr auto I4 = Number<4>{}; static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
static constexpr auto E1 = Number<E1_>{}; static constexpr auto E1 = Number<E1_>{};
static constexpr auto E2 = Number<E2_>{}; static constexpr auto E2 = Number<E2_>{};
static constexpr auto K2 = Number<K2_>{};
static constexpr auto NPerBlock = I1; static constexpr auto NPerBlock = I1;
static constexpr auto K2 = 2;
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
...@@ -181,12 +183,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -181,12 +183,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2); const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2);
const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3); const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3);
const auto K0 = K / KPerBlock; const auto K0 = K / KPerBlock;
const auto N0 = N / NPerBlock; const auto N0 = N / NPerBlock;
const auto Ho0 = Ho / HoPerBlock; const auto H0 = Ho / HoPerBlock;
const auto Wo0 = Wo / WoPerBlock; const auto W0 = Wo / WoPerBlock;
const index_t grid_size = K0 * N0 * Ho0 * Wo0; const index_t grid_size = K0 * N0 * H0 * W0;
return grid_size; return grid_size;
} }
...@@ -314,13 +316,13 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -314,13 +316,13 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2); const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2);
const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3); const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3);
const auto K0 = K / KPerBlock; const auto K0 = K / KPerBlock;
const auto N0 = N / NPerBlock; const auto N0 = N / NPerBlock;
const auto Ho0 = Ho / HoPerBlock; const auto H0 = Ho / HoPerBlock;
const auto Wo0 = Wo / WoPerBlock; const auto W0 = Wo / WoPerBlock;
const auto c_blockid_to_k_n_ho_wo_block_cluster_adaptor = make_single_stage_tensor_adaptor( const auto c_blockid_to_k_n_ho_wo_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(K0, N0, Ho0, Wo0))), make_tuple(make_merge_transform(make_tuple(K0, N0, H0, W0))),
make_tuple(Sequence<0, 1, 2, 3>{}), make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
......
...@@ -43,7 +43,7 @@ __global__ void ...@@ -43,7 +43,7 @@ __global__ void
p_shared_block, p_shared_block,
a_k0_m_k1_grid_desc, a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc, b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_m0_m1_m2_n_grid_desc,
c_block_cluster_adaptor); c_block_cluster_adaptor);
} }
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
......
...@@ -124,6 +124,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( ...@@ -124,6 +124,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
constexpr index_t E1 = 2 * 9; constexpr index_t E1 = 2 * 9;
constexpr index_t E2 = 1; constexpr index_t E2 = 1;
constexpr index_t K2 = 2;
constexpr index_t E1PerBlock = 2; constexpr index_t E1PerBlock = 2;
constexpr index_t KPerThread = 8; constexpr index_t KPerThread = 8;
...@@ -151,6 +152,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( ...@@ -151,6 +152,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
TOut, TOut,
E1, E1,
E2, E2,
K2,
KPerBlock, KPerBlock,
HoPerBlock, HoPerBlock,
WoPerBlock, WoPerBlock,
......
...@@ -12,6 +12,7 @@ template <ck::index_t BlockSize, ...@@ -12,6 +12,7 @@ template <ck::index_t BlockSize,
typename FloatC, typename FloatC,
ck::index_t E1_, ck::index_t E1_,
ck::index_t E2_, ck::index_t E2_,
ck::index_t K2_,
ck::index_t KPerBlock, ck::index_t KPerBlock,
ck::index_t HoPerBlock, ck::index_t HoPerBlock,
ck::index_t WoPerBlock, ck::index_t WoPerBlock,
...@@ -96,6 +97,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -96,6 +97,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
constexpr auto E1 = Number<E1_>{}; constexpr auto E1 = Number<E1_>{};
constexpr auto E2 = Number<E2_>{}; constexpr auto E2 = Number<E2_>{};
constexpr auto K2 = Number<K2_>{};
static_assert(E2 == C1, ""); static_assert(E2 == C1, "");
...@@ -181,7 +183,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -181,7 +183,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
throw std::runtime_error("wrong! GEMM size no divisible"); throw std::runtime_error("wrong! GEMM size no divisible");
} }
// hack to control index calculation when iterating over a_k_m_global tensor // clang-format off
// hack to control index calculation when iterating over a_e0_e1_k_e2_global tensor
constexpr auto a_e0_e1_k_e2_global_step_hacks = constexpr auto a_e0_e1_k_e2_global_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
...@@ -197,579 +201,37 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -197,579 +201,37 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
constexpr auto a_e0_e1_k_e2_global_move_slice_window_step_hack = constexpr auto a_e0_e1_k_e2_global_move_slice_window_step_hack =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks = // hack to control index calculation when iterating over b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global tensor
make_tuple(make_tuple(Sequence<0, constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks =
0, make_tuple(
0, make_tuple(
0, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
0, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
0, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
0, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
0, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
0, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
0, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
0, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
1, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
0, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
0, make_tuple(
0, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
0, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
0, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
0, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
0, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
0, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
0, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
0, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
0, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
0, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})
0, );
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
1,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{}),
make_tuple(Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
2,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
2,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{}));
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack = constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack =
Sequence<0, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
0,
0, // hack to control index calculation when iterating over c_k0_k1_n_h0_h1_h2_w0_w1_w2_global tensor
0,
0,
0,
0,
0,
0,
0,
0,
1,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format
constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks = constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks =
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
...@@ -789,6 +251,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -789,6 +251,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
// clang-format on
static_assert(a_e0_e1_k_e2_grid_desc.IsKnownAtCompileTime(), ""); static_assert(a_e0_e1_k_e2_grid_desc.IsKnownAtCompileTime(), "");
static_assert(b_e0_e1_n_ho_wo_e2_grid_desc.IsKnownAtCompileTime(), ""); static_assert(b_e0_e1_n_ho_wo_e2_grid_desc.IsKnownAtCompileTime(), "");
...@@ -806,6 +269,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -806,6 +269,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
decltype(c_k_n_hop_wop_grid_desc), decltype(c_k_n_hop_wop_grid_desc),
E1, E1,
E2, E2,
K2,
KPerBlock, KPerBlock,
HoPerBlock, HoPerBlock,
WoPerBlock, WoPerBlock,
...@@ -864,30 +328,57 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -864,30 +328,57 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
float ave_time = 0; float ave_time = 0;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
const auto kernel =
kernel_gemm_dlops_v2<GridwiseGemm, if(has_main_e0_block_loop)
FloatAB, {
FloatC, const auto kernel =
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>, kernel_gemm_dlops_v2<GridwiseGemm,
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>, FloatAB,
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>, FloatC,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>, remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
has_main_e0_block_loop, remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
has_main_e1_block_loop, remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
has_double_tail_e1_block_loop>; remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat, ave_time = launch_and_time_kernel(kernel,
dim3(grid_size), nrepeat,
dim3(BlockSize), dim3(grid_size),
0, dim3(BlockSize),
p_a_grid, 0,
p_b_grid, p_a_grid,
p_c_grid, p_b_grid,
a_e0_e1_k0_k1_e2_grid_desc, p_c_grid,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, a_e0_e1_k0_k1_e2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_blockid_to_k_n_h_w_block_cluster_adaptor); c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
c_blockid_to_k_n_h_w_block_cluster_adaptor);
}
else
{
const auto kernel =
kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
c_blockid_to_k_n_h_w_block_cluster_adaptor);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_e0_e1_k0_k1_e2_grid_desc_dev_buf(sizeof(AGridDesc_E0_E1_K0_K1_E2)); DeviceMem a_e0_e1_k0_k1_e2_grid_desc_dev_buf(sizeof(AGridDesc_E0_E1_K0_K1_E2));
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include <initializer_list> #include <initializer_list>
#include <cstdlib> #include <cstdlib>
#include <stdlib.h> #include <stdlib.h>
#include <half.hpp> //#include <half.hpp>
#include "config.hpp" #include "config.hpp"
#include "print.hpp" #include "print.hpp"
#include "device.hpp" #include "device.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment