Commit 17cd5c7d authored by Jing Zhang's avatar Jing Zhang
Browse files

split h and w

parent 0e77b53e
...@@ -16,22 +16,22 @@ template <typename GridwiseGemm, ...@@ -16,22 +16,22 @@ template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename AGridDesc_E0_E1_K0_K1_E2, typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_Ho_Wo_E2, typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K_N_Ho_Wo, typename CGridDesc_K_N_H0_H1_H2_W0_W1_W2,
typename CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo, typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop> bool HasMainE0BlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_dlops_v2(const FloatAB* __restrict__ p_a_grid, kernel_gemm_dlops_v2(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc, const AGridDesc_E0_E1_K0_K1_E2 A_E0_E1_K0_K1_E2_grid_desc,
const BGridDesc_E0_E1_N_Ho_Wo_E2 b_e0_e1_n_ho_wo_e2_grid_desc, const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K_N_Ho_Wo c_k_n_ho_wo_grid_desc, const CGridDesc_K_N_H0_H1_H2_W0_W1_W2 c_k_n_h0_h1_h2_w0_w1_w2_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo const CBlockIdToBlockClusterAdaptor_K_N_H_W c_blockid_to_k_n_h_w_block_cluster_adaptor)
c_blockid_to_k_n_ho_wo_block_cluster_adaptor)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -43,9 +43,9 @@ __global__ void ...@@ -43,9 +43,9 @@ __global__ void
p_c_grid, p_c_grid,
p_shared_block, p_shared_block,
a_e0_e1_k0_k1_e2_grid_desc, a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_ho_wo_e2_grid_desc, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k_n_ho_wo_grid_desc, c_k_n_h0_h1_h2_w0_w1_w2_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor, c_blockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{}); integral_constant<bool, HasMainE0BlockLoop>{});
} }
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
...@@ -56,9 +56,9 @@ template <typename GridwiseGemm, ...@@ -56,9 +56,9 @@ template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename AGridDesc_E0_E1_K0_K1_E2, typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_Ho_Wo_E2, typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K_N_Ho_Wo, typename CGridDesc_K_N_H0_H1_H2_W0_W1_W2,
typename CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo, typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop> bool HasMainE0BlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
...@@ -68,22 +68,24 @@ __global__ void ...@@ -68,22 +68,24 @@ __global__ void
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_e0_e1_k0_k1_e2_grid_desc, const void CONSTANT* p_a_e0_e1_k0_k1_e2_grid_desc,
const void CONSTANT* p_b_e0_e1_n_ho_wo_e2_grid_desc, const void CONSTANT* p_b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const void CONSTANT* p_c_k_n_ho_wo_grid_desc, const void CONSTANT* p_c_k_n_h0_h1_h2_w0_w1_w2_grid_desc,
const void CONSTANT* p_c_blockid_to_k_n_ho_wo_block_cluster_adaptor) const void CONSTANT* p_c_blockid_to_k_n_h_w_block_cluster_adaptor)
{ {
// first cast void CONSTANT void* to void* // first cast void CONSTANT void* to void*
// second cast void* to Desc* // second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4) // the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_e0_e1_k0_k1_e2_grid_desc = *reinterpret_cast<const AGridDesc_E0_E1_K0_K1_E2*>( const auto a_e0_e1_k0_k1_e2_grid_desc = *reinterpret_cast<const AGridDesc_E0_E1_K0_K1_E2*>(
cast_pointer_to_generic_address_space(p_a_e0_e1_k0_k1_e2_grid_desc)); cast_pointer_to_generic_address_space(p_a_e0_e1_k0_k1_e2_grid_desc));
const auto b_e0_e1_n_ho_wo_e2_grid_desc = *reinterpret_cast<const BGridDesc_E0_E1_N_Ho_Wo_E2*>( const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc =
cast_pointer_to_generic_address_space(p_b_e0_e1_n_ho_wo_e2_grid_desc)); *reinterpret_cast<const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2*>(
const auto c_k_n_ho_wo_grid_desc = *reinterpret_cast<const CGridDesc_K_N_Ho_Wo*>( cast_pointer_to_generic_address_space(p_b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc));
cast_pointer_to_generic_address_space(p_c_k_n_ho_wo_grid_desc)); const auto c_k_n_h0_h1_h2_w0_w1_w2_grid_desc =
const auto c_blockid_to_k_n_ho_wo_block_cluster_adaptor = *reinterpret_cast<const CGridDesc_K_N_H0_H1_H2_W0_W1_W2*>(
*reinterpret_cast<const CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo*>( cast_pointer_to_generic_address_space(p_c_k_n_h0_h1_h2_w0_w1_w2_grid_desc));
cast_pointer_to_generic_address_space(p_c_blockid_to_k_n_ho_wo_block_cluster_adaptor)); const auto c_blockid_to_k_n_h_w_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToBlockClusterAdaptor_K_N_H_W*>(
cast_pointer_to_generic_address_space(p_c_blockid_to_k_n_h_w_block_cluster_adaptor));
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -95,9 +97,9 @@ __global__ void ...@@ -95,9 +97,9 @@ __global__ void
p_c_grid, p_c_grid,
p_shared_block, p_shared_block,
a_e0_e1_k0_k1_e2_grid_desc, a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_ho_wo_e2_grid_desc, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k_n_ho_wo_grid_desc, c_k_n_h0_h1_h2_w0_w1_w2_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor, c_blockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{}); integral_constant<bool, HasMainE0BlockLoop>{});
} }
#endif #endif
...@@ -232,7 +234,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -232,7 +234,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCK0K1NHoWoGridDescriptor(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc) MakeCK0K1NH0H1H2W0W1W2GridDescriptor(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc)
{ {
const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0); const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0);
const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1); const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1);
...@@ -242,19 +244,27 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -242,19 +244,27 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const auto K1 = Number<KPerBlock>{}; const auto K1 = Number<KPerBlock>{};
const auto K0 = K / K1; const auto K0 = K / K1;
const auto c_k0_k1_n_ho_wo_grid_desc = transform_tensor_descriptor( const auto H2 = Number<HoPerThread>{};
const auto H1 = Number<HoPerBlock / HoPerThread>{};
const auto H0 = Ho / (H1 * H2);
const auto W2 = Number<WoPerThread>{};
const auto W1 = Number<WoPerBlock / WoPerThread>{};
const auto W0 = Wo / (W1 * W2);
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc = transform_tensor_descriptor(
c_k_n_ho_wo_grid_desc, c_k_n_ho_wo_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N), make_pass_through_transform(N),
make_pass_through_transform(Ho), make_unmerge_transform(make_tuple(H0, H1, H2)),
make_pass_through_transform(Wo)), make_unmerge_transform(make_tuple(W0, W1, W2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{}));
return c_k0_k1_n_ho_wo_grid_desc; return c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc;
} }
__host__ __device__ static constexpr auto MakeBE0E1NH0H1W0W1E2GridDescriptor( __host__ __device__ static constexpr auto MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(
const BGridDesc_E0_E1_N_Ho_Wo_E2& b_e0_e1_n_ho_wo_e2_grid_desc) const BGridDesc_E0_E1_N_Ho_Wo_E2& b_e0_e1_n_ho_wo_e2_grid_desc)
{ {
const auto E0 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I0); const auto E0 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I0);
...@@ -264,19 +274,21 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -264,19 +274,21 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const auto Wo = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I4); const auto Wo = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I4);
// const auto E2 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I5); // const auto E2 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I5);
const auto H1 = Number<HoPerBlock>{}; const auto H2 = Number<HoPerThread>{};
const auto H0 = Ho / H1; const auto H1 = Number<HoPerBlock / HoPerThread>{};
const auto H0 = Ho / (H1 * H2);
const auto W1 = Number<WoPerBlock>{}; const auto W2 = Number<WoPerThread>{};
const auto W0 = Wo / W1; const auto W1 = Number<WoPerBlock / WoPerThread>{};
const auto W0 = Wo / (W1 * W2);
const auto b_e0_e1_n_h0_h1_w0_w1_e2_grid_desc = const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc =
transform_tensor_descriptor(b_e0_e1_n_ho_wo_e2_grid_desc, transform_tensor_descriptor(b_e0_e1_n_ho_wo_e2_grid_desc,
make_tuple(make_pass_through_transform(E0), make_tuple(make_pass_through_transform(E0),
make_pass_through_transform(E1), make_pass_through_transform(E1),
make_pass_through_transform(N), make_pass_through_transform(N),
make_unmerge_transform(make_tuple(H0, H1)), make_unmerge_transform(make_tuple(H0, H1, H2)),
make_unmerge_transform(make_tuple(W0, W1)), make_unmerge_transform(make_tuple(W0, W1, W2)),
make_pass_through_transform(E2)), make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
...@@ -287,11 +299,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -287,11 +299,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
Sequence<3, 4>{}, Sequence<3, 4, 5>{},
Sequence<5, 6>{}, Sequence<6, 7, 8>{},
Sequence<7>{})); Sequence<9>{}));
return b_e0_e1_n_h0_h1_w0_w1_e2_grid_desc; return b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc;
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -317,8 +329,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -317,8 +329,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3
using AGridDesc_E0_E1_K0_K1_E2 = using AGridDesc_E0_E1_K0_K1_E2 =
decltype(MakeAE0E1K0K1E2GridDescriptor(AGridDesc_E0_E1_K_E2{})); decltype(MakeAE0E1K0K1E2GridDescriptor(AGridDesc_E0_E1_K_E2{}));
using CGridDesc_K0_K1_N_Ho_Wo = decltype(MakeCK0K1NHoWoGridDescriptor(CGridDesc_K_N_Ho_Wo{})); using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 =
using CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo = decltype(MakeCK0K1NH0H1H2W0W1W2GridDescriptor(CGridDesc_K_N_Ho_Wo{}));
using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
decltype(MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(BGridDesc_E0_E1_N_Ho_Wo_E2{}));
using CBlockIdToBlockClusterAdaptor_K_N_H_W =
decltype(MakeCBlockIdToKNHoWoBlockClusterAdaptor(CGridDesc_K_N_Ho_Wo{})); decltype(MakeCBlockIdToKNHoWoBlockClusterAdaptor(CGridDesc_K_N_Ho_Wo{}));
template <bool HasMainE0BlockLoop> template <bool HasMainE0BlockLoop>
...@@ -328,84 +343,70 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -328,84 +343,70 @@ struct GridwiseGemmDlops_km_kn_mn_v3
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_Ho_Wo_E2& b_e0_e1_n_ho_wo_e2_grid_desc, const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K0_K1_N_Ho_Wo& c_k0_k1_n_ho_wo_grid_desc, const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo& c_blockid_to_k_n_ho_wo_block_cluster_adaptor, const CBlockIdToBlockClusterAdaptor_K_N_H_W& c_blockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>) integral_constant<bool, HasMainE0BlockLoop>)
{ {
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_global, b_e0_e1_n_ho_wo_e2_grid_desc.GetElementSpaceSize()); p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_global, c_k0_k1_n_ho_wo_grid_desc.GetElementSpaceSize()); p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize());
constexpr auto HasMainE1BlockLoop = CalculateHasMainE1BlockLoop(); constexpr auto HasMainE1BlockLoop = CalculateHasMainE1BlockLoop();
constexpr auto HasDoubleTailE1BlockLoop = CalculateHasDoubleTailE1BlockLoop(); constexpr auto HasDoubleTailE1BlockLoop = CalculateHasDoubleTailE1BlockLoop();
// const auto Ho = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I3); // const auto Ho = b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetLength(I3);
// const auto Wo = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I4); // const auto Wo = b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetLength(I4);
const auto c_k_n_ho_wo_block_cluster_idx = const auto c_k_n_h_w_block_cluster_idx =
c_blockid_to_k_n_ho_wo_block_cluster_adaptor.CalculateBottomIndex( c_blockid_to_k_n_h_w_block_cluster_adaptor.CalculateBottomIndex(
make_multi_index(get_block_1d_id())); make_multi_index(get_block_1d_id()));
const index_t k_block_work_id = const index_t k_block_work_id =
__builtin_amdgcn_readfirstlane(c_k_n_ho_wo_block_cluster_idx[I0]); __builtin_amdgcn_readfirstlane(c_k_n_h_w_block_cluster_idx[I0]);
const index_t n_block_work_id = const index_t n_block_work_id =
__builtin_amdgcn_readfirstlane(c_k_n_ho_wo_block_cluster_idx[I1]); __builtin_amdgcn_readfirstlane(c_k_n_h_w_block_cluster_idx[I1]);
const index_t ho_block_work_id = const index_t ho_block_work_id =
__builtin_amdgcn_readfirstlane(c_k_n_ho_wo_block_cluster_idx[I2]); __builtin_amdgcn_readfirstlane(c_k_n_h_w_block_cluster_idx[I2]);
const index_t wo_block_work_id = const index_t wo_block_work_id =
__builtin_amdgcn_readfirstlane(c_k_n_ho_wo_block_cluster_idx[I3]); __builtin_amdgcn_readfirstlane(c_k_n_h_w_block_cluster_idx[I3]);
constexpr auto max_lds_align = Number<ABlockTransferDstScalarPerVector_E2>{}; constexpr auto max_lds_align = Number<ABlockTransferDstScalarPerVector_E2>{};
// B matrix in thread, dst of blockwise copy constexpr auto a_e1_k1_e2_block_gemm_desc = make_naive_tensor_descriptor_aligned(
constexpr auto b_e1_n_ho_wo_e2_block_desc = make_tuple(Number<E1PerBlock>{}, Number<KPerBlock>{}, Number<E2>{}), max_lds_align);
constexpr auto b_e1_n_h_w_e2_block_gemm_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<E1PerBlock>{}, make_naive_tensor_descriptor_packed(make_tuple(Number<E1PerBlock>{},
Number<1>{}, I1,
Number<HoPerBlock>{}, Number<HoPerBlock>{},
Number<WoPerBlock>{}, Number<WoPerBlock>{},
Number<E2>{})); Number<E2>{}));
// c_thread_mtx definition: this is a mess constexpr auto c_k1_n_h2_w2_thread_gemm_desc = make_naive_tensor_descriptor_packed(
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k1_n_ho_wo_thread_gemm_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<KPerThread>{}, I1, Number<HoPerThread>{}, Number<WoPerThread>{})); make_tuple(Number<KPerThread>{}, I1, Number<HoPerThread>{}, Number<WoPerThread>{}));
constexpr auto a_e1_k1_e2_block_gemm_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<E1PerBlock>{}, Number<KPerBlock>{}, Number<E2>{}), max_lds_align);
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize, BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_e1_k1_e2_block_gemm_desc), decltype(a_e1_k1_e2_block_gemm_desc),
decltype(b_e1_n_ho_wo_e2_block_desc), decltype(b_e1_n_h_w_e2_block_gemm_desc),
decltype(c_k1_n_ho_wo_thread_gemm_desc), decltype(c_k1_n_h2_w2_thread_gemm_desc),
EPerThread, EPerThread,
K2>{}; K2>{};
auto c_thread_mtx_index = auto c_thread_mtx_index =
blockwise_gemm.GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id()); blockwise_gemm.GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id());
// const auto k_thread_id = c_thread_mtx_index[I0]; const auto k_thread_id = c_thread_mtx_index[I0];
const auto ho_thread_id = c_thread_mtx_index[I2]; const auto ho_thread_id = c_thread_mtx_index[I2];
const auto wo_thread_id = c_thread_mtx_index[I3]; const auto wo_thread_id = c_thread_mtx_index[I3];
// const index_t k_block_data_on_global = k_block_work_id * KPerBlock;
// const index_t n_block_data_on_global = n_block_work_id * HoPerBlock;
const index_t ho_block_data_on_global = ho_block_work_id * HoPerBlock;
const index_t wo_block_data_on_global = wo_block_work_id * WoPerBlock;
const index_t n_thread_data_on_global = 0;
const index_t ho_thread_data_on_global =
ho_block_data_on_global + ho_thread_id * HoPerThread;
const index_t wo_thread_data_on_global =
wo_block_data_on_global + wo_thread_id * WoPerThread;
constexpr auto a_e0_e1_k0_k1_e2_block_copy_desc = make_naive_tensor_descriptor_aligned( constexpr auto a_e0_e1_k0_k1_e2_block_copy_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<I1>{}, Number<E1>{}, I1, Number<KPerBlock>{}, Number<E2>{}), make_tuple(Number<I1>{}, Number<E1>{}, I1, Number<KPerBlock>{}, Number<E2>{}),
max_lds_align); max_lds_align);
...@@ -438,30 +439,38 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -438,30 +439,38 @@ struct GridwiseGemmDlops_km_kn_mn_v3
constexpr auto a_block_slice_copy_step = make_multi_index(I1, 0, 0, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(I1, 0, 0, 0, 0);
constexpr auto b_e0_e1_n_ho_wo_e2_thread_desc = constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc =
make_naive_tensor_descriptor_packed(make_tuple(I1, make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<E1PerBlock>{}, Number<E1PerBlock>{},
Number<1>{}, I1,
I1,
I1,
Number<HoPerThread>{}, Number<HoPerThread>{},
I1,
I1,
Number<WoPerThread>{}, Number<WoPerThread>{},
Number<E2>{})); Number<E2>{}));
auto b_threadwise_transfer = ThreadwiseTensorSliceTransfer_v2< auto b_threadwise_transfer = ThreadwiseTensorSliceTransfer_v2<
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(b_e0_e1_n_ho_wo_e2_grid_desc), decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc),
decltype(b_e0_e1_n_ho_wo_e2_thread_desc), decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc),
Sequence<I1, E1PerBlock, NPerBlock, HoPerThread, WoPerThread, E2>, Sequence<I1, E1PerBlock, I1, I1, I1, HoPerThread, I1, I1, WoPerThread, E2>,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>(b_e0_e1_n_ho_wo_e2_grid_desc, true>(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
make_multi_index(0, make_multi_index(0,
0, 0,
n_thread_data_on_global, n_block_work_id,
ho_thread_data_on_global, ho_block_work_id,
wo_thread_data_on_global, ho_thread_id,
0,
wo_block_work_id,
wo_thread_id,
0,
0)); 0));
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
...@@ -470,35 +479,36 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -470,35 +479,36 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// register allocation for output // register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatAcc, FloatAcc,
c_k1_n_ho_wo_thread_gemm_desc.GetElementSpaceSize(), c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
true> true>
c_thread_buf; c_thread_buf;
// initialize output thread tensor // initialize output thread tensor
ThreadwiseTensorSliceSet_v1<FloatAcc, ThreadwiseTensorSliceSet_v1<FloatAcc,
decltype(c_k1_n_ho_wo_thread_gemm_desc), decltype(c_k1_n_h2_w2_thread_gemm_desc),
Sequence<KPerThread, NPerBlock, HoPerThread, WoPerThread>>{} Sequence<KPerThread, NPerBlock, HoPerThread, WoPerThread>>{}
.Run(c_k1_n_ho_wo_thread_gemm_desc, .Run(c_k1_n_h2_w2_thread_gemm_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
FloatAcc{0}); FloatAcc{0});
constexpr auto b_thread_slice_copy_step = make_multi_index(0, E1PerBlock, 0, 0, 0, 0); constexpr auto b_thread_slice_copy_step =
make_multi_index(0, E1PerBlock, 0, 0, 0, 0, 0, 0, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy // hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_e0_e1_k_e2_global_step_hacks = AGlobalStepHacks{}; constexpr auto a_e0_e1_k_e2_global_step_hacks = AGlobalStepHacks{};
constexpr auto b_e0_e1_n_ho_wo_e2_global_step_hacks = BGlobalStepHacks{}; constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks = BGlobalStepHacks{};
// double regsiter buffer for b // double regsiter buffer for b
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatAB, FloatAB,
b_e0_e1_n_ho_wo_e2_thread_desc.GetElementSpaceSize(), b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc.GetElementSpaceSize(),
true> true>
b_thread_even_buf, b_thread_odd_buf; b_thread_even_buf, b_thread_odd_buf;
if constexpr(HasMainE0BlockLoop) if constexpr(HasMainE0BlockLoop)
{ {
const auto E0 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I0); const auto E0 = b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetLength(I0);
index_t e0_block_data_begin = 0; index_t e0_block_data_begin = 0;
...@@ -509,12 +519,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -509,12 +519,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_e0_e1_k0_k1_e2_grid_desc, a_global_buf, a_e0_e1_k_e2_global_step_hacks); a_e0_e1_k0_k1_e2_grid_desc, a_global_buf, a_e0_e1_k_e2_global_step_hacks);
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_grid_desc, b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf, b_global_buf,
b_e0_e1_n_ho_wo_e2_thread_desc, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_even_buf, b_thread_even_buf,
b_e0_e1_n_ho_wo_e2_global_step_hacks); b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks);
a_blockwise_copy.RunWrite(a_e0_e1_k0_k1_e2_block_copy_desc, a_block_buf); a_blockwise_copy.RunWrite(a_e0_e1_k0_k1_e2_block_copy_desc, a_block_buf);
} }
...@@ -530,32 +540,36 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -530,32 +540,36 @@ struct GridwiseGemmDlops_km_kn_mn_v3
do do
{ {
// even iteration // even iteration
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_grid_desc, b_threadwise_transfer.MoveSrcSliceWindow(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_thread_slice_copy_step, b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{}); BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_grid_desc, b_threadwise_transfer.Run(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf, b_global_buf,
b_e0_e1_n_ho_wo_e2_thread_desc, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_odd_buf, b_thread_odd_buf,
b_e0_e1_n_ho_wo_e2_global_step_hacks); b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0));
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_grid_desc, b_threadwise_transfer.MoveSrcSliceWindow(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_thread_slice_copy_step, b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{}); BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_grid_desc, b_threadwise_transfer.Run(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf, b_global_buf,
b_e0_e1_n_ho_wo_e2_thread_desc, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_even_buf, b_thread_even_buf,
b_e0_e1_n_ho_wo_e2_global_step_hacks); b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
...@@ -570,16 +584,17 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -570,16 +584,17 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// LDS double buffer: tail // LDS double buffer: tail
if constexpr(HasDoubleTailE1BlockLoop) // if has 2 iteration left if constexpr(HasDoubleTailE1BlockLoop) // if has 2 iteration left
{ {
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_grid_desc, b_threadwise_transfer.MoveSrcSliceWindow(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_thread_slice_copy_step, b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{}); BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_grid_desc, b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf, b_global_buf,
b_e0_e1_n_ho_wo_e2_thread_desc, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_odd_buf, b_thread_odd_buf,
b_e0_e1_n_ho_wo_e2_global_step_hacks); b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
...@@ -601,7 +616,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -601,7 +616,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
blockwise_gemm.MoveABlockSliceWindow(make_tuple(-(E1 - E1PerBlock), 0, 0)); blockwise_gemm.MoveABlockSliceWindow(make_tuple(-(E1 - E1PerBlock), 0, 0));
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_grid_desc, b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_thread_slice_copy_step, b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{}); BGlobalMoveSliceWindowStepHacks{});
...@@ -616,12 +631,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -616,12 +631,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_e0_e1_k0_k1_e2_grid_desc, a_global_buf, a_e0_e1_k_e2_global_step_hacks); a_e0_e1_k0_k1_e2_grid_desc, a_global_buf, a_e0_e1_k_e2_global_step_hacks);
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_grid_desc, b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf, b_global_buf,
b_e0_e1_n_ho_wo_e2_thread_desc, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_even_buf, b_thread_even_buf,
b_e0_e1_n_ho_wo_e2_global_step_hacks); b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks);
a_blockwise_copy.RunWrite(a_e0_e1_k0_k1_e2_block_copy_desc, a_block_buf); a_blockwise_copy.RunWrite(a_e0_e1_k0_k1_e2_block_copy_desc, a_block_buf);
} }
...@@ -637,32 +652,34 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -637,32 +652,34 @@ struct GridwiseGemmDlops_km_kn_mn_v3
do do
{ {
// even iteration // even iteration
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_grid_desc, b_threadwise_transfer.MoveSrcSliceWindow(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_thread_slice_copy_step, b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{}); BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_grid_desc, b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf, b_global_buf,
b_e0_e1_n_ho_wo_e2_thread_desc, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_odd_buf, b_thread_odd_buf,
b_e0_e1_n_ho_wo_e2_global_step_hacks); b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0));
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_grid_desc, b_threadwise_transfer.MoveSrcSliceWindow(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_thread_slice_copy_step, b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{}); BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_grid_desc, b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf, b_global_buf,
b_e0_e1_n_ho_wo_e2_thread_desc, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_even_buf, b_thread_even_buf,
b_e0_e1_n_ho_wo_e2_global_step_hacks); b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
...@@ -677,16 +694,16 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -677,16 +694,16 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// LDS double buffer: tail // LDS double buffer: tail
if constexpr(HasDoubleTailE1BlockLoop) // if has 2 iteration left if constexpr(HasDoubleTailE1BlockLoop) // if has 2 iteration left
{ {
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_grid_desc, b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_thread_slice_copy_step, b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{}); BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_grid_desc, b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf, b_global_buf,
b_e0_e1_n_ho_wo_e2_thread_desc, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_odd_buf, b_thread_odd_buf,
b_e0_e1_n_ho_wo_e2_global_step_hacks); b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
...@@ -705,7 +722,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -705,7 +722,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// activ // activ
{ {
static_for<0, c_k1_n_ho_wo_thread_gemm_desc.GetElementSpaceSize(), 1>{}([&](auto i) { static_for<0, c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), 1>{}([&](auto i) {
if constexpr(activ_type == 1) if constexpr(activ_type == 1)
{ {
c_thread_buf(i) = c_thread_buf[i] >= 0 ? c_thread_buf[i] : 0.0; c_thread_buf(i) = c_thread_buf[i] >= 0 ? c_thread_buf[i] : 0.0;
...@@ -726,36 +743,50 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -726,36 +743,50 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// output: register to global memory // output: register to global memory
{ {
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor // hack to control index calculation when iterating over c_k_n_h0_h1_h2_w0_w1_w2_global
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = CGlobalStepHacks{}; // tensor
constexpr auto c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks = CGlobalStepHacks{};
constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<KPerThread>{},
I1,
I1,
I1,
Number<HoPerThread>{},
I1,
I1,
Number<WoPerThread>{}));
constexpr auto c_k0_k1_n_ho_wo_thread_copy_desc = const index_t k_thread_data_on_global = k_thread_id * KPerThread;
make_naive_tensor_descriptor_packed(make_tuple(
I1, Number<KPerThread>{}, I1, Number<HoPerThread>{}, Number<WoPerThread>{}));
ThreadwiseTensorSliceTransfer_v1r3< ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc, FloatAcc,
FloatC, FloatC,
decltype(c_k0_k1_n_ho_wo_thread_copy_desc), decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc),
decltype(c_k0_k1_n_ho_wo_grid_desc), decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc),
Sequence<I1, KPerThread, I1, HoPerThread, WoPerThread>, Sequence<I1, KPerThread, I1, I1, I1, HoPerThread, I1, I1, WoPerThread>,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>(c_k0_k1_n_ho_wo_grid_desc, true>(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
make_multi_index(k_block_work_id, make_multi_index(k_block_work_id,
k_thread_data_on_global,
n_block_work_id,
ho_block_work_id,
ho_thread_id,
0, 0,
n_thread_data_on_global, wo_block_work_id,
ho_thread_data_on_global, wo_thread_id,
wo_thread_data_on_global)) 0))
.Run(c_k0_k1_n_ho_wo_thread_copy_desc, .Run(c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
c_k0_k1_n_ho_wo_grid_desc, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
c_global_buf, c_global_buf,
c_k_n_ho_wo_global_tensor_step_hacks); c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks);
} }
} }
}; };
......
...@@ -119,14 +119,14 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( ...@@ -119,14 +119,14 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
constexpr index_t KPerBlock = 16; constexpr index_t KPerBlock = 16;
constexpr index_t HoPerBlock = 8; constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 32; constexpr index_t WoPerBlock = 32;
constexpr index_t E1 = 2 * 9; constexpr index_t E1 = 2 * 9;
constexpr index_t E2 = 1; constexpr index_t E2 = 1;
constexpr index_t E1PerBlock = 2; constexpr index_t E1PerBlock = 2;
constexpr index_t KPerThread = 16; constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 2; constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 1; constexpr index_t EPerThread = 1;
......
...@@ -197,35 +197,597 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -197,35 +197,597 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
constexpr auto a_e0_e1_k_e2_global_move_slice_window_step_hack = constexpr auto a_e0_e1_k_e2_global_move_slice_window_step_hack =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
constexpr auto b_e0_e1_n_ho_wo_e2_global_step_hacks = make_tuple( constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks =
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0,
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, 0,
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, 0,
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, 0,
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, 0,
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), 0,
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, 0,
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, 0,
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, 0,
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, 0,
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, 0,
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); 1,
0,
constexpr auto b_e0_e1_n_ho_wo_e2_global_move_slice_window_step_hack = 0,
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}; 0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
1,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{}),
make_tuple(Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
2,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
2,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{},
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{}));
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack =
Sequence<0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
1,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format // hack for NKHW format
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks =
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
static_assert(a_e0_e1_k_e2_grid_desc.IsKnownAtCompileTime(), ""); static_assert(a_e0_e1_k_e2_grid_desc.IsKnownAtCompileTime(), "");
...@@ -260,29 +822,32 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -260,29 +822,32 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
ABlockTransferSrcScalarPerVector_E2, ABlockTransferSrcScalarPerVector_E2,
ABlockTransferDstScalarPerVector_E2, ABlockTransferDstScalarPerVector_E2,
false, // don't move back src coordinate after threadwise copy false, // don't move back src coordinate after threadwise copy
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, // E0, E1, N, H0, H1, H2, W0, W1, W2, E2
5, 9,
BThreadTransferSrcScalarPerVector_E2, BThreadTransferSrcScalarPerVector_E2,
false, // don't move back src coordinate after threadwise copy, which will be fused with false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation // MoveSrcSliceWindow() to save addr computation
Sequence<0, 1, 2, 3, 4>, Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>, // K0, K1, N, H0, H1, H2, W0, W1, W2
1, 1,
CThreadTransferDstScalarPerVector_K, CThreadTransferDstScalarPerVector_K,
decltype(a_e0_e1_k_e2_global_step_hacks), decltype(a_e0_e1_k_e2_global_step_hacks),
decltype(b_e0_e1_n_ho_wo_e2_global_step_hacks), decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks),
decltype(c_k_n_ho_wo_global_tensor_step_hacks), decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks),
decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack), decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack),
decltype(b_e0_e1_n_ho_wo_e2_global_move_slice_window_step_hack), decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack),
activ_type>; activ_type>;
const auto a_e0_e1_k0_k1_e2_grid_desc = const auto a_e0_e1_k0_k1_e2_grid_desc =
GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc); GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc);
const auto c_k0_k1_n_hop_wop_grid_desc = const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc =
GridwiseGemm::MakeCK0K1NHoWoGridDescriptor(c_k_n_hop_wop_grid_desc); GridwiseGemm::MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(b_e0_e1_n_ho_wo_e2_grid_desc);
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc =
GridwiseGemm::MakeCK0K1NH0H1H2W0W1W2GridDescriptor(c_k_n_hop_wop_grid_desc);
using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc); using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc);
using BGridDesc_E0_E1_N_Ho_Wo_E2 = decltype(b_e0_e1_n_ho_wo_e2_grid_desc); using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
using CGridDesc_K0_K1_N_Ho_Wo = decltype(c_k0_k1_n_hop_wop_grid_desc); decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc);
using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 = decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N; const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
...@@ -290,11 +855,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -290,11 +855,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
std::cerr << "has_main_e0_block_loop = " << has_main_e0_block_loop << std::endl; std::cerr << "has_main_e0_block_loop = " << has_main_e0_block_loop << std::endl;
const auto c_blockid_to_k_n_ho_wo_block_cluster_adaptor = const auto c_blockid_to_k_n_h_w_block_cluster_adaptor =
GridwiseGemm::MakeCBlockIdToKNHoWoBlockClusterAdaptor(c_k_n_hop_wop_grid_desc); GridwiseGemm::MakeCBlockIdToKNHoWoBlockClusterAdaptor(c_k_n_hop_wop_grid_desc);
using CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo = using CBlockIdToBlockClusterAdaptor_K_N_H_W =
decltype(c_blockid_to_k_n_ho_wo_block_cluster_adaptor); decltype(c_blockid_to_k_n_h_w_block_cluster_adaptor);
float ave_time = 0; float ave_time = 0;
...@@ -304,9 +869,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -304,9 +869,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>, remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>, remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
remove_reference_t<CGridDesc_K0_K1_N_Ho_Wo>, remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
has_main_e0_block_loop, has_main_e0_block_loop,
has_main_e1_block_loop, has_main_e1_block_loop,
has_double_tail_e1_block_loop>; has_double_tail_e1_block_loop>;
...@@ -320,22 +885,26 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -320,22 +885,26 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_e0_e1_k0_k1_e2_grid_desc, a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_ho_wo_e2_grid_desc, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k0_k1_n_hop_wop_grid_desc, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor); c_blockid_to_k_n_h_w_block_cluster_adaptor);
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_e0_e1_k0_k1_e2_grid_desc_dev_buf(sizeof(AGridDesc_E0_E1_K0_K1_E2)); DeviceMem a_e0_e1_k0_k1_e2_grid_desc_dev_buf(sizeof(AGridDesc_E0_E1_K0_K1_E2));
DeviceMem b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf(sizeof(BGridDesc_E0_E1_N_Ho_Wo_E2)); DeviceMem b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf(
DeviceMem c_k0_k1_n_hop_wop_grid_desc_dev_buf(sizeof(CGridDesc_K0_K1_N_Ho_Wo)); sizeof(BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2));
DeviceMem c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf( DeviceMem c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf(
sizeof(CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo)); sizeof(CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2));
DeviceMem c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf(
sizeof(CBlockIdToBlockClusterAdaptor_K_N_H_W));
a_e0_e1_k0_k1_e2_grid_desc_dev_buf.ToDevice(&a_e0_e1_k0_k1_e2_grid_desc); a_e0_e1_k0_k1_e2_grid_desc_dev_buf.ToDevice(&a_e0_e1_k0_k1_e2_grid_desc);
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.ToDevice(&b_e0_e1_n_ho_wo_e2_grid_desc); b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf.ToDevice(
c_k0_k1_n_hop_wop_grid_desc_dev_buf.ToDevice(&c_k0_k1_n_hop_wop_grid_desc); &b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc);
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.ToDevice( c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf.ToDevice(
&c_blockid_to_k_n_ho_wo_block_cluster_adaptor); &c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf.ToDevice(
&c_blockid_to_k_n_h_w_block_cluster_adaptor);
if(has_main_e0_block_loop) if(has_main_e0_block_loop)
{ {
...@@ -345,9 +914,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -345,9 +914,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>, remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>, remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
remove_reference_t<CGridDesc_K0_K1_N_Ho_Wo>, remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
...@@ -362,11 +931,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -362,11 +931,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
a_e0_e1_k0_k1_e2_grid_desc_dev_buf.GetDeviceBuffer()), a_e0_e1_k0_k1_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.GetDeviceBuffer()), b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
c_k0_k1_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()), c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
} }
else else
{ {
...@@ -376,9 +945,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -376,9 +945,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>, remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>, remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
remove_reference_t<CGridDesc_K0_K1_N_Ho_Wo>, remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
...@@ -393,11 +962,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -393,11 +962,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
a_e0_e1_k0_k1_e2_grid_desc_dev_buf.GetDeviceBuffer()), a_e0_e1_k0_k1_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.GetDeviceBuffer()), b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
c_k0_k1_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()), c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
} }
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment