"configs/vscode:/vscode.git/clone" did not exist on "b3f5d9e421ba5021f607c67a48a599221148e6a7"
Commit 35a57947 authored by Jing Zhang's avatar Jing Zhang
Browse files

add conv_out

parent 3e298e42
...@@ -17,6 +17,7 @@ template <typename GridwiseGemm, ...@@ -17,6 +17,7 @@ template <typename GridwiseGemm,
typename FloatC, typename FloatC,
typename AGridDesc_E0_E1_K0_K1_E2, typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2, typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2, typename DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W, typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop> bool HasMainE0BlockLoop>
...@@ -28,9 +29,11 @@ __global__ void ...@@ -28,9 +29,11 @@ __global__ void
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
const FloatC* __restrict__ p_bias_grid, const FloatC* __restrict__ p_bias_grid,
FloatC* __restrict__ p_c_grid,
FloatC* __restrict__ p_d_grid, FloatC* __restrict__ p_d_grid,
const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc, const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2 d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc, const DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2 d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W c_blockid_to_k_n_h_w_block_cluster_adaptor) const CBlockIdToBlockClusterAdaptor_K_N_H_W c_blockid_to_k_n_h_w_block_cluster_adaptor)
{ {
...@@ -42,10 +45,12 @@ __global__ void ...@@ -42,10 +45,12 @@ __global__ void
GridwiseGemm::Run(p_a_grid, GridwiseGemm::Run(p_a_grid,
p_b_grid, p_b_grid,
p_bias_grid, p_bias_grid,
p_c_grid,
p_d_grid, p_d_grid,
p_shared_block, p_shared_block,
a_e0_e1_k0_k1_e2_grid_desc, a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc, d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
c_blockid_to_k_n_h_w_block_cluster_adaptor, c_blockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{}); integral_constant<bool, HasMainE0BlockLoop>{});
...@@ -59,6 +64,7 @@ template <typename GridwiseGemm, ...@@ -59,6 +64,7 @@ template <typename GridwiseGemm,
typename FloatC, typename FloatC,
typename AGridDesc_E0_E1_K0_K1_E2, typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2, typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2, typename DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W, typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop> bool HasMainE0BlockLoop>
...@@ -69,11 +75,12 @@ __global__ void ...@@ -69,11 +75,12 @@ __global__ void
kernel_gemm_dlops_v2_add(const FloatAB* __restrict__ p_a_grid, kernel_gemm_dlops_v2_add(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
const FloatC* __restrict__ p_bias_grid, const FloatC* __restrict__ p_bias_grid,
FloatC* __restrict__ p_c_grid,
FloatC* __restrict__ p_d_grid, FloatC* __restrict__ p_d_grid,
const void CONSTANT* p_a_e0_e1_k0_k1_e2_grid_desc, const void CONSTANT* p_a_e0_e1_k0_k1_e2_grid_desc,
const void CONSTANT* p_b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, const void CONSTANT* p_b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const void CONSTANT* p_d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
const void CONSTANT* p_c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, const void CONSTANT* p_c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const void CONSTANT* p_d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
const void CONSTANT* p_c_blockid_to_k_n_h_w_block_cluster_adaptor) const void CONSTANT* p_c_blockid_to_k_n_h_w_block_cluster_adaptor)
{ {
// first cast void CONSTANT void* to void* // first cast void CONSTANT void* to void*
...@@ -84,6 +91,9 @@ __global__ void ...@@ -84,6 +91,9 @@ __global__ void
const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc = const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc =
*reinterpret_cast<const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2*>( *reinterpret_cast<const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2*>(
cast_pointer_to_generic_address_space(p_b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc)); cast_pointer_to_generic_address_space(p_b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc));
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc =
*reinterpret_cast<const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2*>(
cast_pointer_to_generic_address_space(p_c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc));
const auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc = const auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc =
*reinterpret_cast<const DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2*>( *reinterpret_cast<const DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2*>(
cast_pointer_to_generic_address_space(p_d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc)); cast_pointer_to_generic_address_space(p_d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc));
...@@ -99,10 +109,12 @@ __global__ void ...@@ -99,10 +109,12 @@ __global__ void
GridwiseGemm::Run(p_a_grid, GridwiseGemm::Run(p_a_grid,
p_b_grid, p_b_grid,
p_bias_grid, p_bias_grid,
p_c_grid,
p_d_grid, p_d_grid,
p_shared_block, p_shared_block,
a_e0_e1_k0_k1_e2_grid_desc, a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc, d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
c_blockid_to_k_n_h_w_block_cluster_adaptor, c_blockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{}); integral_constant<bool, HasMainE0BlockLoop>{});
...@@ -116,8 +128,8 @@ template <index_t BlockSize, ...@@ -116,8 +128,8 @@ template <index_t BlockSize,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AGridDesc_E0_E1_K_E2, typename AGridDesc_E0_E1_K_E2,
typename BGridDesc_E0_E1_N_Ho_Wo_E2, typename BGridDesc_E0_E1_N_Ho_Wo_E2,
typename DGridDesc_K_N_Hox2_Wox2,
typename CGridDesc_K_N_Ho_Wo, typename CGridDesc_K_N_Ho_Wo,
typename DGridDesc_K_N_Hox2_Wox2,
index_t E1_, index_t E1_,
index_t E2_, index_t E2_,
index_t K2_, index_t K2_,
...@@ -146,6 +158,7 @@ template <index_t BlockSize, ...@@ -146,6 +158,7 @@ template <index_t BlockSize,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
typename AGlobalStepHacks, typename AGlobalStepHacks,
typename BGlobalStepHacks, typename BGlobalStepHacks,
typename CGlobalStepHacks,
typename DGlobalStepHacks, typename DGlobalStepHacks,
typename AGlobalMoveSliceWindowStepHacks, typename AGlobalMoveSliceWindowStepHacks,
typename BGlobalMoveSliceWindowStepHacks, typename BGlobalMoveSliceWindowStepHacks,
...@@ -283,6 +296,37 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -283,6 +296,37 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
return b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc; return b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc;
} }
__host__ __device__ static constexpr auto
MakeCK0K1NH0H1H2W0W1W2GridDescriptor(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc)
{
const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0);
const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1);
const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2);
const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3);
const auto K1 = Number<KPerBlock>{};
const auto K0 = K / K1;
const auto H2 = Number<HoPerThread>{};
const auto H1 = Number<HoPerBlock / HoPerThread>{};
const auto H0 = Ho / (H1 * H2);
const auto W2 = Number<WoPerThread>{};
const auto W1 = Number<WoPerBlock / WoPerThread>{};
const auto W0 = Wo / (W1 * W2);
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc = transform_tensor_descriptor(
c_k_n_ho_wo_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
make_unmerge_transform(make_tuple(H0, H1, H2)),
make_unmerge_transform(make_tuple(W0, W1, W2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{}));
return c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc;
}
__host__ __device__ static constexpr auto MakeDK0K1NH0H1H2x2W0W1W2x2GridDescriptor( __host__ __device__ static constexpr auto MakeDK0K1NH0H1H2x2W0W1W2x2GridDescriptor(
const DGridDesc_K_N_Hox2_Wox2& d_k_n_hox2_wox2_grid_desc) const DGridDesc_K_N_Hox2_Wox2& d_k_n_hox2_wox2_grid_desc)
{ {
...@@ -339,8 +383,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -339,8 +383,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
decltype(MakeAE0E1K0K1E2GridDescriptor(AGridDesc_E0_E1_K_E2{})); decltype(MakeAE0E1K0K1E2GridDescriptor(AGridDesc_E0_E1_K_E2{}));
using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 = using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
decltype(MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(BGridDesc_E0_E1_N_Ho_Wo_E2{})); decltype(MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(BGridDesc_E0_E1_N_Ho_Wo_E2{}));
using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 =
decltype(MakeCK0K1NH0H1H2W0W1W2GridDescriptor(CGridDesc_K_N_Ho_Wo{}));
using DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2 = using DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2 =
decltype(MakeDK0K1NH0H1H2x2W0W1W2x2GridDescriptor(DGridDesc_K_N_Hox2_Wox2{})); decltype(MakeDK0K1NH0H1H2x2W0W1W2x2GridDescriptor(DGridDesc_K_N_Hox2_Wox2{}));
using CBlockIdToBlockClusterAdaptor_K_N_H_W = using CBlockIdToBlockClusterAdaptor_K_N_H_W =
decltype(MakeCBlockIdToKNHoWoBlockClusterAdaptor(CGridDesc_K_N_Ho_Wo{})); decltype(MakeCBlockIdToKNHoWoBlockClusterAdaptor(CGridDesc_K_N_Ho_Wo{}));
...@@ -358,10 +405,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -358,10 +405,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
Run(const FloatAB* __restrict__ p_a_global, Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const FloatC* __restrict__ p_bias_global, const FloatC* __restrict__ p_bias_global,
FloatC* __restrict__ p_c_global,
FloatC* __restrict__ p_d_global, FloatC* __restrict__ p_d_global,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2& d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc, const DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2& d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& c_blockid_to_k_n_h_w_block_cluster_adaptor, const CBlockIdToBlockClusterAdaptor_K_N_H_W& c_blockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>) integral_constant<bool, HasMainE0BlockLoop>)
...@@ -382,6 +431,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -382,6 +431,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize());
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_d_global, d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc.GetElementSpaceSize()); p_d_global, d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
...@@ -826,6 +877,56 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -826,6 +877,56 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
#endif #endif
} }
#if 1
// Output
{
// hack to control index calculation when iterating over c_k_n_h0_h1_h2_w0_w1_w2_global
// tensor
constexpr auto c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks = CGlobalStepHacks{};
constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<KPerThread>{},
I1,
I1,
I1,
Number<HoPerThread>{},
I1,
I1,
Number<WoPerThread>{}));
const index_t k_thread_data_on_global = k_thread_id * KPerThread;
ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc),
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc),
Sequence<I1, KPerThread, I1, I1, I1, HoPerThread, I1, I1, WoPerThread>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
make_multi_index(k_block_work_id,
k_thread_data_on_global,
n_block_work_id,
ho_block_work_id,
ho_thread_id,
0,
wo_block_work_id,
wo_thread_id,
0))
.Run(c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
c_global_buf,
c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks);
}
#endif
// Resize_Add // Resize_Add
{ {
constexpr auto HoPerThreadx2 = HoPerThread * 2; constexpr auto HoPerThreadx2 = HoPerThread * 2;
......
...@@ -27,6 +27,7 @@ void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0 ...@@ -27,6 +27,7 @@ void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0
const Tensor<TInWei>& in_n_c0_hi_wi_c1, const Tensor<TInWei>& in_n_c0_hi_wi_c1,
const Tensor<TInWei>& wei_k_c0_y_x_c1, const Tensor<TInWei>& wei_k_c0_y_x_c1,
const Tensor<TOut>& bias_k0_k1, const Tensor<TOut>& bias_k0_k1,
Tensor<TOut>& out_n_k0_ho_wo_k1,
const Tensor<TOut>& add_n_k0_hox2_wox2_k1, const Tensor<TOut>& add_n_k0_hox2_wox2_k1,
Tensor<TOut>& add_n_k0_hox2_wox2_k1_out, Tensor<TOut>& add_n_k0_hox2_wox2_k1_out,
ck::index_t nrepeat) ck::index_t nrepeat)
...@@ -63,6 +64,8 @@ void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0 ...@@ -63,6 +64,8 @@ void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0
in_n_c0_hi_wi_c1.mDesc.GetElementSpace()); in_n_c0_hi_wi_c1.mDesc.GetElementSpace());
DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace()); DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace());
DeviceMem bias_k0_k1_device_buf(sizeof(TOut) * bias_k0_k1.mDesc.GetElementSpace()); DeviceMem bias_k0_k1_device_buf(sizeof(TOut) * bias_k0_k1.mDesc.GetElementSpace());
DeviceMem out_n_k0_ho_wo_k1_device_buf(sizeof(TOut) *
out_n_k0_ho_wo_k1.mDesc.GetElementSpace());
DeviceMem add_n_k0_hox2_wox2_k1_device_buf(sizeof(TOut) * DeviceMem add_n_k0_hox2_wox2_k1_device_buf(sizeof(TOut) *
add_n_k0_hox2_wox2_k1.mDesc.GetElementSpace()); add_n_k0_hox2_wox2_k1.mDesc.GetElementSpace());
...@@ -177,8 +180,8 @@ void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0 ...@@ -177,8 +180,8 @@ void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0
const auto ave_time = const auto ave_time =
conv_driver.Run(wei_k_c0_y_x_c1_desc, conv_driver.Run(wei_k_c0_y_x_c1_desc,
in_n_c0_hi_wi_c1_desc, in_n_c0_hi_wi_c1_desc,
add_n_k0_hox2_wox2_k1_desc,
out_n_k0_ho_wo_k1_desc, out_n_k0_ho_wo_k1_desc,
add_n_k0_hox2_wox2_k1_desc,
conv_strides, conv_strides,
conv_dilations, conv_dilations,
in_left_pads, in_left_pads,
...@@ -188,6 +191,7 @@ void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0 ...@@ -188,6 +191,7 @@ void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()), in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(bias_k0_k1_device_buf.GetDeviceBuffer()), static_cast<TOut*>(bias_k0_k1_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k0_ho_wo_k1_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(add_n_k0_hox2_wox2_k1_device_buf.GetDeviceBuffer()), static_cast<TOut*>(add_n_k0_hox2_wox2_k1_device_buf.GetDeviceBuffer()),
nrepeat); nrepeat);
...@@ -204,8 +208,8 @@ void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0 ...@@ -204,8 +208,8 @@ void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0
conv_driver.Run(wei_k_c0_y_x_c1_desc, conv_driver.Run(wei_k_c0_y_x_c1_desc,
in_n_c0_hi_wi_c1_desc, in_n_c0_hi_wi_c1_desc,
add_n_k0_hox2_wox2_k1_desc,
out_n_k0_ho_wo_k1_desc, out_n_k0_ho_wo_k1_desc,
add_n_k0_hox2_wox2_k1_desc,
conv_strides, conv_strides,
conv_dilations, conv_dilations,
in_left_pads, in_left_pads,
...@@ -215,8 +219,10 @@ void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0 ...@@ -215,8 +219,10 @@ void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()), in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(bias_k0_k1_device_buf.GetDeviceBuffer()), static_cast<TOut*>(bias_k0_k1_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k0_ho_wo_k1_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(add_n_k0_hox2_wox2_k1_device_buf.GetDeviceBuffer()), static_cast<TOut*>(add_n_k0_hox2_wox2_k1_device_buf.GetDeviceBuffer()),
0); 0);
out_n_k0_ho_wo_k1_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data());
add_n_k0_hox2_wox2_k1_device_buf.FromDevice(add_n_k0_hox2_wox2_k1_out.mData.data()); add_n_k0_hox2_wox2_k1_device_buf.FromDevice(add_n_k0_hox2_wox2_k1_out.mData.data());
} }
...@@ -40,8 +40,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -40,8 +40,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
typename InRightPads> typename InRightPads>
__host__ float Run(const ck::TensorDescriptor<Wei...>& wei_k_c0_y_x_c1_global_desc, __host__ float Run(const ck::TensorDescriptor<Wei...>& wei_k_c0_y_x_c1_global_desc,
const ck::TensorDescriptor<In...>& in_n_c0_hi_wi_c1_global_desc, const ck::TensorDescriptor<In...>& in_n_c0_hi_wi_c1_global_desc,
const ck::TensorDescriptor<Add...>& add_n_k0_hox2_wox2_k1_global_desc,
const ck::TensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc, const ck::TensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
const ck::TensorDescriptor<Add...>& add_n_k0_hox2_wox2_k1_global_desc,
const ConvStrides& conv_strides, const ConvStrides& conv_strides,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
...@@ -49,6 +49,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -49,6 +49,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
const FloatC* __restrict__ p_bias_grid, const FloatC* __restrict__ p_bias_grid,
FloatC* __restrict__ p_c_grid,
FloatC* __restrict__ p_d_grid, FloatC* __restrict__ p_d_grid,
const int nrepeat) const const int nrepeat) const
{ {
...@@ -247,6 +248,26 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -247,6 +248,26 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack = constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks =
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
constexpr auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks = constexpr auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks =
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
...@@ -282,8 +303,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -282,8 +303,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
InMemoryDataOperationEnum_t::Add, InMemoryDataOperationEnum_t::Add,
decltype(a_e0_e1_k_e2_grid_desc), decltype(a_e0_e1_k_e2_grid_desc),
decltype(b_e0_e1_n_ho_wo_e2_grid_desc), decltype(b_e0_e1_n_ho_wo_e2_grid_desc),
decltype(d_k_n_hopx2_wopx2_grid_desc),
decltype(c_k_n_hop_wop_grid_desc), decltype(c_k_n_hop_wop_grid_desc),
decltype(d_k_n_hopx2_wopx2_grid_desc),
E1, E1,
E2, E2,
K2, K2,
...@@ -313,6 +334,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -313,6 +334,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
CThreadTransferDstScalarPerVector_K, CThreadTransferDstScalarPerVector_K,
decltype(a_e0_e1_k_e2_global_step_hacks), decltype(a_e0_e1_k_e2_global_step_hacks),
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks), decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks),
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks),
decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks), decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks),
decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack), decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack),
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack), decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack),
...@@ -322,12 +344,15 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -322,12 +344,15 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc); GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc);
const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc = const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc =
GridwiseGemm::MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(b_e0_e1_n_ho_wo_e2_grid_desc); GridwiseGemm::MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(b_e0_e1_n_ho_wo_e2_grid_desc);
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc =
GridwiseGemm::MakeCK0K1NH0H1H2W0W1W2GridDescriptor(c_k_n_hop_wop_grid_desc);
const auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc = const auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc =
GridwiseGemm::MakeDK0K1NH0H1H2x2W0W1W2x2GridDescriptor(d_k_n_hopx2_wopx2_grid_desc); GridwiseGemm::MakeDK0K1NH0H1H2x2W0W1W2x2GridDescriptor(d_k_n_hopx2_wopx2_grid_desc);
using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc); using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc);
using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 = using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc); decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc);
using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 = decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
using DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2 = using DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2 =
decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc); decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc);
...@@ -355,6 +380,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -355,6 +380,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
FloatC, FloatC,
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>, remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>, remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2>, remove_reference_t<DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
true>; true>;
...@@ -367,9 +393,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -367,9 +393,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_bias_grid, p_bias_grid,
p_c_grid,
p_d_grid, p_d_grid,
a_e0_e1_k0_k1_e2_grid_desc, a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc, d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
c_blockid_to_k_n_h_w_block_cluster_adaptor); c_blockid_to_k_n_h_w_block_cluster_adaptor);
} }
...@@ -381,6 +409,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -381,6 +409,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
FloatC, FloatC,
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>, remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>, remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2>, remove_reference_t<DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
false>; false>;
...@@ -393,9 +422,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -393,9 +422,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_bias_grid, p_bias_grid,
p_c_grid,
p_d_grid, p_d_grid,
a_e0_e1_k0_k1_e2_grid_desc, a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc, d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
c_blockid_to_k_n_h_w_block_cluster_adaptor); c_blockid_to_k_n_h_w_block_cluster_adaptor);
} }
...@@ -404,6 +435,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -404,6 +435,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
DeviceMem a_e0_e1_k0_k1_e2_grid_desc_dev_buf(sizeof(AGridDesc_E0_E1_K0_K1_E2)); DeviceMem a_e0_e1_k0_k1_e2_grid_desc_dev_buf(sizeof(AGridDesc_E0_E1_K0_K1_E2));
DeviceMem b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf( DeviceMem b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf(
sizeof(BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2)); sizeof(BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2));
DeviceMem c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf(
sizeof(CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2));
DeviceMem d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf( DeviceMem d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf(
sizeof(DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2)); sizeof(DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2));
DeviceMem c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf( DeviceMem c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf(
...@@ -412,6 +445,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -412,6 +445,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
a_e0_e1_k0_k1_e2_grid_desc_dev_buf.ToDevice(&a_e0_e1_k0_k1_e2_grid_desc); a_e0_e1_k0_k1_e2_grid_desc_dev_buf.ToDevice(&a_e0_e1_k0_k1_e2_grid_desc);
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf.ToDevice( b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf.ToDevice(
&b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc); &b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc);
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf.ToDevice(
&c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf.ToDevice( d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf.ToDevice(
&d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc); &d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc);
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf.ToDevice( c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf.ToDevice(
...@@ -426,6 +461,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -426,6 +461,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
FloatC, FloatC,
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>, remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>, remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2>, remove_reference_t<DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
true>; true>;
...@@ -439,11 +475,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -439,11 +475,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_bias_grid, p_bias_grid,
p_c_grid,
p_d_grid, p_d_grid,
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
a_e0_e1_k0_k1_e2_grid_desc_dev_buf.GetDeviceBuffer()), a_e0_e1_k0_k1_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf.GetDeviceBuffer()), b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf.GetDeviceBuffer()), d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
...@@ -458,6 +497,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -458,6 +497,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
FloatC, FloatC,
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>, remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>, remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2>, remove_reference_t<DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
false>; false>;
...@@ -471,11 +511,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -471,11 +511,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_bias_grid, p_bias_grid,
p_c_grid,
p_d_grid, p_d_grid,
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
a_e0_e1_k0_k1_e2_grid_desc_dev_buf.GetDeviceBuffer()), a_e0_e1_k0_k1_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf.GetDeviceBuffer()), b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf.GetDeviceBuffer()), d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
......
...@@ -308,6 +308,7 @@ int main(int argc, char* argv[]) ...@@ -308,6 +308,7 @@ int main(int argc, char* argv[])
in, in,
wei, wei,
bias, bias,
out_device,
add, add,
add_device, add_device,
nrepeat); nrepeat);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment