Commit 1e6d6782 authored by Jing Zhang's avatar Jing Zhang
Browse files

refactor conv_add for InMem::add

parent f66a71c7
...@@ -18,7 +18,6 @@ template <typename GridwiseGemm, ...@@ -18,7 +18,6 @@ template <typename GridwiseGemm,
typename AGridDesc_E0_E1_K0_K1_E2, typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2, typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2, typename DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W, typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop> bool HasMainE0BlockLoop>
__global__ void __global__ void
...@@ -29,11 +28,9 @@ __global__ void ...@@ -29,11 +28,9 @@ __global__ void
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_d_grid, FloatC* __restrict__ p_d_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc, const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2 d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc, const DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2 d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W c_blockid_to_k_n_h_w_block_cluster_adaptor) const CBlockIdToBlockClusterAdaptor_K_N_H_W c_blockid_to_k_n_h_w_block_cluster_adaptor)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
...@@ -44,12 +41,10 @@ __global__ void ...@@ -44,12 +41,10 @@ __global__ void
GridwiseGemm::Run(p_a_grid, GridwiseGemm::Run(p_a_grid,
p_b_grid, p_b_grid,
p_d_grid, p_d_grid,
p_c_grid,
p_shared_block, p_shared_block,
a_e0_e1_k0_k1_e2_grid_desc, a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc, d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
c_blockid_to_k_n_h_w_block_cluster_adaptor, c_blockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{}); integral_constant<bool, HasMainE0BlockLoop>{});
} }
...@@ -63,7 +58,6 @@ template <typename GridwiseGemm, ...@@ -63,7 +58,6 @@ template <typename GridwiseGemm,
typename AGridDesc_E0_E1_K0_K1_E2, typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2, typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2, typename DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W, typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop> bool HasMainE0BlockLoop>
__global__ void __global__ void
...@@ -73,7 +67,6 @@ __global__ void ...@@ -73,7 +67,6 @@ __global__ void
kernel_gemm_dlops_v2_add(const FloatAB* __restrict__ p_a_grid, kernel_gemm_dlops_v2_add(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_d_grid, FloatC* __restrict__ p_d_grid,
FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_e0_e1_k0_k1_e2_grid_desc, const void CONSTANT* p_a_e0_e1_k0_k1_e2_grid_desc,
const void CONSTANT* p_b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, const void CONSTANT* p_b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const void CONSTANT* p_d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc, const void CONSTANT* p_d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
...@@ -91,9 +84,6 @@ __global__ void ...@@ -91,9 +84,6 @@ __global__ void
const auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc = const auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc =
*reinterpret_cast<const DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2*>( *reinterpret_cast<const DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2*>(
cast_pointer_to_generic_address_space(p_d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc)); cast_pointer_to_generic_address_space(p_d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc));
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc =
*reinterpret_cast<const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2*>(
cast_pointer_to_generic_address_space(p_c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc));
const auto c_blockid_to_k_n_h_w_block_cluster_adaptor = const auto c_blockid_to_k_n_h_w_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToBlockClusterAdaptor_K_N_H_W*>( *reinterpret_cast<const CBlockIdToBlockClusterAdaptor_K_N_H_W*>(
cast_pointer_to_generic_address_space(p_c_blockid_to_k_n_h_w_block_cluster_adaptor)); cast_pointer_to_generic_address_space(p_c_blockid_to_k_n_h_w_block_cluster_adaptor));
...@@ -106,12 +96,10 @@ __global__ void ...@@ -106,12 +96,10 @@ __global__ void
GridwiseGemm::Run(p_a_grid, GridwiseGemm::Run(p_a_grid,
p_b_grid, p_b_grid,
p_d_grid, p_d_grid,
p_c_grid,
p_shared_block, p_shared_block,
a_e0_e1_k0_k1_e2_grid_desc, a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc, d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
c_blockid_to_k_n_h_w_block_cluster_adaptor, c_blockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{}); integral_constant<bool, HasMainE0BlockLoop>{});
} }
...@@ -154,7 +142,6 @@ template <index_t BlockSize, ...@@ -154,7 +142,6 @@ template <index_t BlockSize,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
typename AGlobalStepHacks, typename AGlobalStepHacks,
typename BGlobalStepHacks, typename BGlobalStepHacks,
typename CGlobalStepHacks,
typename DGlobalStepHacks, typename DGlobalStepHacks,
typename AGlobalMoveSliceWindowStepHacks, typename AGlobalMoveSliceWindowStepHacks,
typename BGlobalMoveSliceWindowStepHacks, typename BGlobalMoveSliceWindowStepHacks,
...@@ -292,37 +279,6 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -292,37 +279,6 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
return b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc; return b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc;
} }
__host__ __device__ static constexpr auto
MakeCK0K1NH0H1H2W0W1W2GridDescriptor(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc)
{
const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0);
const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1);
const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2);
const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3);
const auto K1 = Number<KPerBlock>{};
const auto K0 = K / K1;
const auto H2 = Number<HoPerThread>{};
const auto H1 = Number<HoPerBlock / HoPerThread>{};
const auto H0 = Ho / (H1 * H2);
const auto W2 = Number<WoPerThread>{};
const auto W1 = Number<WoPerBlock / WoPerThread>{};
const auto W0 = Wo / (W1 * W2);
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc = transform_tensor_descriptor(
c_k_n_ho_wo_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
make_unmerge_transform(make_tuple(H0, H1, H2)),
make_unmerge_transform(make_tuple(W0, W1, W2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{}));
return c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc;
}
__host__ __device__ static constexpr auto MakeDK0K1NH0H1H2x2W0W1W2x2GridDescriptor( __host__ __device__ static constexpr auto MakeDK0K1NH0H1H2x2W0W1W2x2GridDescriptor(
const DGridDesc_K_N_Hox2_Wox2& d_k_n_hox2_wox2_grid_desc) const DGridDesc_K_N_Hox2_Wox2& d_k_n_hox2_wox2_grid_desc)
{ {
...@@ -334,26 +290,20 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -334,26 +290,20 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
const auto K1 = Number<KPerBlock>{}; const auto K1 = Number<KPerBlock>{};
const auto K0 = K / K1; const auto K0 = K / K1;
const auto HoPerBlockx2 = HoPerBlock * 2; const auto H2 = Number<HoPerThread * 2>{};
const auto WoPerBlockx2 = WoPerBlock * 2; const auto H1 = Number<HoPerBlock / HoPerThread>{};
const auto H0 = Number<Hox2 / (H1 * H2)>{};
const auto HoPerThreadx2 = HoPerThread * 2;
const auto WoPerThreadx2 = WoPerThread * 2;
const auto H2x2 = Number<HoPerThreadx2>{};
const auto H1 = Number<HoPerBlockx2 / HoPerThreadx2>{};
const auto H0 = Hox2 / (H1 * H2x2);
const auto W2x2 = Number<WoPerThreadx2>{}; const auto W2 = Number<WoPerThread * 2>{};
const auto W1 = Number<WoPerBlockx2 / WoPerThreadx2>{}; const auto W1 = Number<WoPerBlock / WoPerThread>{};
const auto W0 = Wox2 / (W1 * W2x2); const auto W0 = Number<Wox2 / (W1 * W2)>{};
const auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc = transform_tensor_descriptor( const auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc = transform_tensor_descriptor(
d_k_n_hox2_wox2_grid_desc, d_k_n_hox2_wox2_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N), make_pass_through_transform(N),
make_unmerge_transform(make_tuple(H0, H1, H2x2)), make_unmerge_transform(make_tuple(H0, H1, H2)),
make_unmerge_transform(make_tuple(W0, W1, W2x2))), make_unmerge_transform(make_tuple(W0, W1, W2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{})); make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{}));
...@@ -385,8 +335,6 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -385,8 +335,6 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
decltype(MakeAE0E1K0K1E2GridDescriptor(AGridDesc_E0_E1_K_E2{})); decltype(MakeAE0E1K0K1E2GridDescriptor(AGridDesc_E0_E1_K_E2{}));
using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 = using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
decltype(MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(BGridDesc_E0_E1_N_Ho_Wo_E2{})); decltype(MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(BGridDesc_E0_E1_N_Ho_Wo_E2{}));
using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 =
decltype(MakeCK0K1NH0H1H2W0W1W2GridDescriptor(CGridDesc_K_N_Ho_Wo{}));
using DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2 = using DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2 =
decltype(MakeDK0K1NH0H1H2x2W0W1W2x2GridDescriptor(DGridDesc_K_N_Hox2_Wox2{})); decltype(MakeDK0K1NH0H1H2x2W0W1W2x2GridDescriptor(DGridDesc_K_N_Hox2_Wox2{}));
using CBlockIdToBlockClusterAdaptor_K_N_H_W = using CBlockIdToBlockClusterAdaptor_K_N_H_W =
...@@ -397,12 +345,10 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -397,12 +345,10 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
Run(const FloatAB* __restrict__ p_a_global, Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_d_global, FloatC* __restrict__ p_d_global,
FloatC* __restrict__ p_c_global,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2& d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc, const DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2& d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& c_blockid_to_k_n_h_w_block_cluster_adaptor, const CBlockIdToBlockClusterAdaptor_K_N_H_W& c_blockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>) integral_constant<bool, HasMainE0BlockLoop>)
{ {
...@@ -410,9 +356,6 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -410,9 +356,6 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize());
(void)c_global_buf;
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_d_global, d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc.GetElementSpaceSize()); p_d_global, d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc.GetElementSpaceSize());
...@@ -806,10 +749,10 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -806,10 +749,10 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
// Resize_Add // Resize_Add
{ {
constexpr auto HoPerThreadx2 = HoPerThread * 2; constexpr auto HoPerThreadx2 = HoPerThread * 2;
constexpr auto WoPerThreadx2 = WoPerThread * 2; constexpr auto WoPerThreadx2 = WoPerThread * 2;
#if 1
constexpr auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc = constexpr auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc =
make_naive_tensor_descriptor_packed(make_tuple(I1, make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<KPerThread>{}, Number<KPerThread>{},
...@@ -827,90 +770,20 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -827,90 +770,20 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
true> true>
d_thread_buf; d_thread_buf;
// hack to control index calculation when iterating over d_k_n_ho_wo_global tensor
constexpr auto d_k_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks =
DGlobalStepHacks{};
const index_t k_thread_data_on_global = k_thread_id * KPerThread;
#if 1
ThreadwiseTensorSliceTransfer_v2<
FloatC,
FloatC,
decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc),
decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc),
Sequence<I1, KPerThread, I1, I1, I1, HoPerThreadx2, I1, I1, WoPerThreadx2>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
1,
true>(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
make_multi_index(k_block_work_id,
k_thread_data_on_global,
n_block_work_id,
ho_block_work_id,
ho_thread_id,
0,
wo_block_work_id,
wo_thread_id,
0))
.Run(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
d_global_buf,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
d_thread_buf,
d_k_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks);
#endif
static_for<0, KPerThread, 1>{}([&](auto k_i) { static_for<0, KPerThread, 1>{}([&](auto k_i) {
static_for<0, HoPerThreadx2, 1>{}([&](auto h_i) { static_for<0, HoPerThreadx2, 1>{}([&](auto h_i) {
static_for<0, WoPerThreadx2, 1>{}([&](auto w_i) { static_for<0, WoPerThreadx2, 1>{}([&](auto w_i) {
d_thread_buf( d_thread_buf(
Number<d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc.CalculateOffset( Number<d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc.CalculateOffset(
make_tuple(0, k_i, 0, 0, 0, h_i, 0, 0, w_i))>{}) += make_tuple(0, k_i, 0, 0, 0, h_i, 0, 0, w_i))>{}) =
c_thread_buf[Number<c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset( c_thread_buf[Number<c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset(
make_tuple(k_i, 0, h_i / 2, w_i / 2))>{}]; make_tuple(k_i, 0, h_i / 2, w_i / 2))>{}];
}); });
}); });
}); });
ThreadwiseTensorSliceTransfer_v1r3< #else
FloatC, constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_desc =
FloatC,
decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc),
decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc),
Sequence<I1, KPerThread, I1, I1, I1, HoPerThreadx2, I1, I1, WoPerThreadx2>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
make_multi_index(k_block_work_id,
k_thread_data_on_global,
n_block_work_id,
ho_block_work_id,
ho_thread_id,
0,
wo_block_work_id,
wo_thread_id,
0))
.Run(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
d_thread_buf,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
d_global_buf,
d_k_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks);
}
#if 1
// output: register to global memory
{
// hack to control index calculation when iterating over c_k_n_h0_h1_h2_w0_w1_w2_global
// tensor
constexpr auto c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks = CGlobalStepHacks{};
constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc =
make_naive_tensor_descriptor_packed(make_tuple(I1, make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<KPerThread>{}, Number<KPerThread>{},
I1, I1,
...@@ -921,20 +794,58 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -921,20 +794,58 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
I1, I1,
Number<WoPerThread>{})); Number<WoPerThread>{}));
constexpr auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc =
transform_tensor_descriptor(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_desc,
make_tuple(make_pass_through_transform(I1),
make_pass_through_transform(Number<KPerThread>{}),
make_pass_through_transform(I1),
make_pass_through_transform(I1),
make_pass_through_transform(I1),
make_embed_transform(make_tuple(I2, Number<HoPerThread>{}),
make_tuple(I0, I1)),
make_pass_through_transform(I1),
make_pass_through_transform(I1),
make_embed_transform(make_tuple(I2, Number<WoPerThread>{}),
make_tuple(I0, I1))),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{},
Sequence<8>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6>{},
Sequence<7>{},
Sequence<8>{},
Sequence<9, 10>{}));
#endif
// hack to control index calculation when iterating over d_k_n_ho_wo_global tensor
constexpr auto d_k_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks =
DGlobalStepHacks{};
const index_t k_thread_data_on_global = k_thread_id * KPerThread; const index_t k_thread_data_on_global = k_thread_id * KPerThread;
ThreadwiseTensorSliceTransfer_v1r3< ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC, FloatC,
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc), FloatC,
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc), decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc),
Sequence<I1, KPerThread, I1, I1, I1, HoPerThread, I1, I1, WoPerThread>, decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc),
Sequence<I1, KPerThread, I1, I1, I1, HoPerThreadx2, I1, I1, WoPerThreadx2>,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation, InMemoryDataOperationEnum_t::Add, // CGlobalMemoryDataOperation,
1, 1,
true>(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, true>(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
make_multi_index(k_block_work_id, make_multi_index(k_block_work_id,
k_thread_data_on_global, k_thread_data_on_global,
n_block_work_id, n_block_work_id,
...@@ -944,14 +855,13 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -944,14 +855,13 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
wo_block_work_id, wo_block_work_id,
wo_thread_id, wo_thread_id,
0)) 0))
.Run(c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc, .Run(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf, d_thread_buf,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
c_global_buf, d_global_buf,
c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks); d_k_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks);
} }
#endif
} }
}; };
......
...@@ -217,6 +217,22 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -217,6 +217,22 @@ struct ThreadwiseTensorSliceTransfer_v1r3
is_dst_valid, is_dst_valid,
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]); dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
} }
else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Add)
{
typename vector_type_maker<DstData, DstScalarPerVector>::type tmp;
tmp.template AsType<dst_vector_t>()(Number<0>{}) =
dst_buf.template Get<dst_vector_t>(dst_coord_.GetOffset(), is_dst_valid);
static_for<0, DstScalarPerVector, 1>{}([&](auto t) {
dst_vector.template AsType<DstData>()(t) += tmp.template AsType<DstData>()[t];
});
dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
}
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
{ {
......
...@@ -124,7 +124,8 @@ namespace ck { ...@@ -124,7 +124,8 @@ namespace ck {
enum InMemoryDataOperationEnum_t enum InMemoryDataOperationEnum_t
{ {
Set, Set,
AtomicAdd AtomicAdd,
Add
}; };
enum ActivTypeEnum_t enum ActivTypeEnum_t
......
...@@ -247,27 +247,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -247,27 +247,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack = constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
// hack to control index calculation when iterating over c_k0_k1_n_h0_h1_h2_w0_w1_w2_global tensor
constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks =
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
constexpr auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks = constexpr auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks =
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
...@@ -329,13 +308,12 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -329,13 +308,12 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
BThreadTransferSrcScalarPerVector_E2, BThreadTransferSrcScalarPerVector_E2,
false, // don't move back src coordinate after threadwise copy, which will be fused with false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation // MoveSrcSliceWindow() to save addr computation
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>, // K0, K1, N, H0, H1, H2, W0, W1, W2 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>, // K0, K1, N, H0, H1, I2, H2, W0, W1, I2, W2
1, 1,
CThreadTransferDstScalarPerVector_K, CThreadTransferDstScalarPerVector_K,
decltype(a_e0_e1_k_e2_global_step_hacks), decltype(a_e0_e1_k_e2_global_step_hacks),
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks), decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks),
decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks), decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks),
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks),
decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack), decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack),
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack), decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack),
activ_type>; activ_type>;
...@@ -346,13 +324,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -346,13 +324,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
GridwiseGemm::MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(b_e0_e1_n_ho_wo_e2_grid_desc); GridwiseGemm::MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(b_e0_e1_n_ho_wo_e2_grid_desc);
const auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc = const auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc =
GridwiseGemm::MakeDK0K1NH0H1H2x2W0W1W2x2GridDescriptor(d_k_n_hopx2_wopx2_grid_desc); GridwiseGemm::MakeDK0K1NH0H1H2x2W0W1W2x2GridDescriptor(d_k_n_hopx2_wopx2_grid_desc);
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc =
GridwiseGemm::MakeCK0K1NH0H1H2W0W1W2GridDescriptor(c_k_n_hop_wop_grid_desc);
using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc); using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc);
using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 = using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc); decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc);
using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 = decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
using DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2 = using DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2 =
decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc); decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc);
...@@ -381,7 +356,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -381,7 +356,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>, remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>, remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2>, remove_reference_t<DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2>,
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
true>; true>;
...@@ -393,11 +367,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -393,11 +367,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_d_grid, p_d_grid,
p_c_grid,
a_e0_e1_k0_k1_e2_grid_desc, a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc, d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
c_blockid_to_k_n_h_w_block_cluster_adaptor); c_blockid_to_k_n_h_w_block_cluster_adaptor);
} }
else else
...@@ -409,7 +381,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -409,7 +381,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>, remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>, remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2>, remove_reference_t<DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2>,
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
false>; false>;
...@@ -421,11 +392,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -421,11 +392,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_d_grid, p_d_grid,
p_c_grid,
a_e0_e1_k0_k1_e2_grid_desc, a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc, d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
c_blockid_to_k_n_h_w_block_cluster_adaptor); c_blockid_to_k_n_h_w_block_cluster_adaptor);
} }
...@@ -435,8 +404,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -435,8 +404,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
sizeof(BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2)); sizeof(BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2));
DeviceMem d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf( DeviceMem d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf(
sizeof(DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2)); sizeof(DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2));
DeviceMem c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf(
sizeof(CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2));
DeviceMem c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf( DeviceMem c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf(
sizeof(CBlockIdToBlockClusterAdaptor_K_N_H_W)); sizeof(CBlockIdToBlockClusterAdaptor_K_N_H_W));
...@@ -445,8 +412,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -445,8 +412,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
&b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc); &b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc);
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf.ToDevice( d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf.ToDevice(
&d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc); &d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc);
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf.ToDevice(
&c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf.ToDevice( c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf.ToDevice(
&c_blockid_to_k_n_h_w_block_cluster_adaptor); &c_blockid_to_k_n_h_w_block_cluster_adaptor);
...@@ -460,7 +425,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -460,7 +425,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>, remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>, remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2>, remove_reference_t<DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2>,
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
true>; true>;
...@@ -473,15 +437,12 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -473,15 +437,12 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_d_grid, p_d_grid,
p_c_grid,
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
a_e0_e1_k0_k1_e2_grid_desc_dev_buf.GetDeviceBuffer()), a_e0_e1_k0_k1_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf.GetDeviceBuffer()), b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf.GetDeviceBuffer()), d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
} }
...@@ -495,7 +456,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -495,7 +456,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>, remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>, remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2>, remove_reference_t<DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2>,
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
false>; false>;
...@@ -508,15 +468,12 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -508,15 +468,12 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_d_grid, p_d_grid,
p_c_grid,
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
a_e0_e1_k0_k1_e2_grid_desc_dev_buf.GetDeviceBuffer()), a_e0_e1_k0_k1_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf.GetDeviceBuffer()), b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf.GetDeviceBuffer()), d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
} }
......
...@@ -95,7 +95,7 @@ int main(int argc, char* argv[]) ...@@ -95,7 +95,7 @@ int main(int argc, char* argv[])
constexpr index_t activ_type = 0; constexpr index_t activ_type = 0;
#if 1 #if 0
constexpr auto N = Number<1>{}; constexpr auto N = Number<1>{};
constexpr auto Hi = Number<1080>{}; constexpr auto Hi = Number<1080>{};
constexpr auto Wi = Number<1920>{}; constexpr auto Wi = Number<1920>{};
...@@ -135,7 +135,7 @@ int main(int argc, char* argv[]) ...@@ -135,7 +135,7 @@ int main(int argc, char* argv[])
constexpr auto C1 = Number<8>{}; constexpr auto C1 = Number<8>{};
constexpr auto K1 = Number<8>{}; constexpr auto K1 = Number<8>{};
constexpr auto K0 = Number<8>{}; constexpr auto K0 = Number<8>{};
#elif 0 #elif 1
constexpr auto N = Number<1>{}; constexpr auto N = Number<1>{};
constexpr auto Hi = Number<32>{}; constexpr auto Hi = Number<32>{};
constexpr auto Wi = Number<32>{}; constexpr auto Wi = Number<32>{};
...@@ -235,27 +235,22 @@ int main(int argc, char* argv[]) ...@@ -235,27 +235,22 @@ int main(int argc, char* argv[])
break; break;
case 1: case 1:
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
add.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
break; break;
case 2: case 2:
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
add.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
break; break;
case 3: case 3:
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
add.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
break; break;
case 4: case 4:
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
add.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
break; break;
case 5: case 5:
in.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}, num_thread); in.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}, num_thread);
add.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}, num_thread); wei.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}, num_thread);
break; break;
default: default:
...@@ -267,6 +262,8 @@ int main(int argc, char* argv[]) ...@@ -267,6 +262,8 @@ int main(int argc, char* argv[])
wei.GenerateTensorValue(gen_wei, num_thread); wei.GenerateTensorValue(gen_wei, num_thread);
} }
add.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
auto f_make_for_device_nchwc = [&]() { auto f_make_for_device_nchwc = [&]() {
const auto in_lengths_dev = make_tuple(N, C0, Hi, Wi, C1); const auto in_lengths_dev = make_tuple(N, C0, Hi, Wi, C1);
const auto wei_lengths_dev = make_tuple(K0 * K1, C0, Y, X, C1); const auto wei_lengths_dev = make_tuple(K0 * K1, C0, Y, X, C1);
...@@ -326,15 +323,15 @@ int main(int argc, char* argv[]) ...@@ -326,15 +323,15 @@ int main(int argc, char* argv[])
make_tuple(in_right_pad_h, in_right_pad_w), make_tuple(in_right_pad_h, in_right_pad_w),
activ_type); activ_type);
check_error(out_host, out_device);
check_error(add_host, add_device); check_error(add_host, add_device);
if(do_log) if(do_log)
{ {
LogRangeAsType<float>(std::cout << "in : ", in.mData, ",") << std::endl; // LogRangeAsType<float>(std::cout << "in : ", in.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "wei: ", wei.mData, ",") << std::endl; // LogRangeAsType<float>(std::cout << "wei: ", wei.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "out_host : ", out_host.mData, ",") << std::endl; // LogRangeAsType<float>(std::cout << "out_host : ", out_host.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "out_device: ", out_device.mData, ",") << std::endl; // LogRangeAsType<float>(std::cout << "out_device: ", out_device.mData, ",") <<
// std::endl;
LogRangeAsType<float>(std::cout << "add_host: ", add_host.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "add_host: ", add_host.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "add_device: ", add_device.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "add_device: ", add_device.mData, ",") << std::endl;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment