Commit 90276e6b authored by Jing Zhang's avatar Jing Zhang
Browse files

add e1 with bugs

parent 10fdada7
...@@ -109,6 +109,7 @@ template <index_t BlockSize, ...@@ -109,6 +109,7 @@ template <index_t BlockSize,
typename AGlobalDesc_E0_E1_K, typename AGlobalDesc_E0_E1_K,
typename BGlobalDesc_E0_E1_N_Ho_Wo, typename BGlobalDesc_E0_E1_N_Ho_Wo,
typename CGlobalDesc_K_N_Ho_Wo, typename CGlobalDesc_K_N_Ho_Wo,
index_t E1,
index_t KPerBlock, index_t KPerBlock,
index_t HoPerBlock, index_t HoPerBlock,
index_t WoPerBlock, index_t WoPerBlock,
...@@ -139,7 +140,11 @@ template <index_t BlockSize, ...@@ -139,7 +140,11 @@ template <index_t BlockSize,
typename BGlobalMoveSliceWindowStepHacks> typename BGlobalMoveSliceWindowStepHacks>
struct GridwiseGemmDlops_km_kn_mn_v3 struct GridwiseGemmDlops_km_kn_mn_v3
{ {
static constexpr auto E = EPerBlock; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
...@@ -148,12 +153,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -148,12 +153,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_e1_k_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto a_e0_e1_k_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align); make_tuple(I1, Number<E1>{}, Number<KPerBlock>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_e1_k_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_e0_e1_k_block_desc.GetElementSpaceSize(), max_lds_align);
return a_block_space_size * sizeof(FloatAB); return a_block_space_size * sizeof(FloatAB);
} }
...@@ -169,11 +174,6 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -169,11 +174,6 @@ struct GridwiseGemmDlops_km_kn_mn_v3
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) integral_constant<bool, HasDoubleTailKBlockLoop>)
{ {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_global, a_e0_e1_k_global_desc.GetElementSpaceSize()); p_a_global, a_e0_e1_k_global_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
...@@ -220,15 +220,15 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -220,15 +220,15 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_e1_k_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto a_e0_e1_k_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align); make_tuple(Number<I1>{}, Number<E1>{}, Number<KPerBlock>{}), max_lds_align);
constexpr auto a_e2_k_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto a_e1_k_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<EPerBlock>{}, Number<KPerBlock>{}), max_lds_align); make_tuple(Number<EPerBlock>{}, Number<KPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_e2_n_ho_wo_block_desc = make_naive_tensor_descriptor_packed(make_tuple( constexpr auto b_e1_n_ho_wo_block_desc = make_naive_tensor_descriptor_packed(make_tuple(
Number<EPerBlock>{}, Number<1>{}, Number<HoPerBlock>{}, Number<WoPerBlock>{})); Number<EPerBlock>{}, Number<1>{}, Number<HoPerBlock>{}, Number<WoPerBlock>{}));
// c_thread_mtx definition: this is a mess // c_thread_mtx definition: this is a mess
...@@ -240,8 +240,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -240,8 +240,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize, BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_e2_k_block_desc), decltype(a_e1_k_block_desc),
decltype(b_e2_n_ho_wo_block_desc), decltype(b_e1_n_ho_wo_block_desc),
decltype(c_k_n_ho_wo_thread_desc), decltype(c_k_n_ho_wo_thread_desc),
EPerThread, EPerThread,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
...@@ -266,47 +266,47 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -266,47 +266,47 @@ struct GridwiseGemmDlops_km_kn_mn_v3
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<E, KPerBlock>, Sequence<I1, E1, KPerBlock>,
ABlockTransferThreadSliceLengths_E_K, ABlockTransferThreadSliceLengths_E_K,
ABlockTransferThreadClusterLengths_E_K, ABlockTransferThreadClusterLengths_E_K,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_e0_e1_k_global_desc), decltype(a_e0_e1_k_global_desc),
decltype(a_e1_k_block_desc), decltype(a_e0_e1_k_block_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<0, 1>, Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
1, 2,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K, ABlockTransferDstScalarPerVector_K,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>(a_e0_e1_k_global_desc, true>(a_e0_e1_k_global_desc,
make_multi_index(0, k_block_data_on_global), make_multi_index(0, 0, k_block_data_on_global),
a_e1_k_block_desc, a_e0_e1_k_block_desc,
make_multi_index(0, 0)); make_multi_index(0, 0, 0));
constexpr auto b_e2_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( constexpr auto b_e0_e1_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{})); I1, Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
auto b_threadwise_transfer = auto b_threadwise_transfer =
ThreadwiseTensorSliceTransfer_v2<FloatAB, ThreadwiseTensorSliceTransfer_v2<FloatAB,
FloatAB, FloatAB,
decltype(b_e0_e1_n_ho_wo_global_desc), decltype(b_e0_e1_n_ho_wo_global_desc),
decltype(b_e2_n_ho_wo_thread_desc), decltype(b_e0_e1_n_ho_wo_thread_desc),
Sequence<EPerBlock, 1, HoPerThread, WoPerThread>, Sequence<I1, EPerBlock, 1, HoPerThread, WoPerThread>,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
1, 1,
true>( true>(
b_e0_e1_n_ho_wo_global_desc, b_e0_e1_n_ho_wo_global_desc,
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global)); make_multi_index(0, 0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_shared_block, a_e1_k_block_desc.GetElementSpaceSize()); p_shared_block, a_e0_e1_k_block_desc.GetElementSpaceSize());
// register allocation for output // register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum_t::Vgpr,
...@@ -321,7 +321,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -321,7 +321,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
Sequence<KPerThread, 1, HoPerThread, WoPerThread>>{} Sequence<KPerThread, 1, HoPerThread, WoPerThread>>{}
.Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); .Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0); constexpr auto b_thread_slice_copy_step = make_multi_index(0, EPerBlock, 0, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy // hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_e0_e1_k_global_step_hacks = AGlobalStepHacks{}; constexpr auto a_e0_e1_k_global_step_hacks = AGlobalStepHacks{};
...@@ -330,7 +330,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -330,7 +330,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// double regsiter buffer for b // double regsiter buffer for b
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatAB, FloatAB,
b_e2_n_ho_wo_thread_desc.GetElementSpaceSize(), b_e0_e1_n_ho_wo_thread_desc.GetElementSpaceSize(),
true> true>
b_thread_even_buf, b_thread_odd_buf; b_thread_even_buf, b_thread_odd_buf;
...@@ -341,12 +341,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -341,12 +341,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_global_desc,
b_global_buf, b_global_buf,
b_e2_n_ho_wo_thread_desc, b_e0_e1_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
b_thread_even_buf, b_thread_even_buf,
b_e0_e1_n_ho_wo_global_step_hacks); b_e0_e1_n_ho_wo_global_step_hacks);
a_blockwise_copy.RunWrite(a_e1_k_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_e0_e1_k_block_desc, a_block_buf);
} }
__syncthreads(); __syncthreads();
...@@ -365,8 +365,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -365,8 +365,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_global_desc,
b_global_buf, b_global_buf,
b_e2_n_ho_wo_thread_desc, b_e0_e1_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
b_thread_odd_buf, b_thread_odd_buf,
b_e0_e1_n_ho_wo_global_step_hacks); b_e0_e1_n_ho_wo_global_step_hacks);
...@@ -381,8 +381,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -381,8 +381,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_global_desc,
b_global_buf, b_global_buf,
b_e2_n_ho_wo_thread_desc, b_e0_e1_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
b_thread_even_buf, b_thread_even_buf,
b_e0_e1_n_ho_wo_global_step_hacks); b_e0_e1_n_ho_wo_global_step_hacks);
...@@ -393,7 +393,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -393,7 +393,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
e_block_data_begin += 2 * EPerBlock; e_block_data_begin += 2 * EPerBlock;
} while(e_block_data_begin < E - 2 * EPerBlock); } while(e_block_data_begin < E1 - 2 * EPerBlock);
} }
// LDS double buffer: tail // LDS double buffer: tail
...@@ -404,8 +404,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -404,8 +404,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_global_desc, b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_global_desc,
b_global_buf, b_global_buf,
b_e2_n_ho_wo_thread_desc, b_e0_e1_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
b_thread_odd_buf, b_thread_odd_buf,
b_e0_e1_n_ho_wo_global_step_hacks); b_e0_e1_n_ho_wo_global_step_hacks);
......
...@@ -104,6 +104,8 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( ...@@ -104,6 +104,8 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
constexpr index_t KPerBlock = 16; constexpr index_t KPerBlock = 16;
constexpr index_t HoPerBlock = 8; constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 8; constexpr index_t WoPerBlock = 8;
constexpr index_t E1 = 16;
constexpr index_t EPerBlock = 16; constexpr index_t EPerBlock = 16;
constexpr index_t KPerThread = KPerBlock; constexpr index_t KPerThread = KPerBlock;
...@@ -111,15 +113,15 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( ...@@ -111,15 +113,15 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
constexpr index_t WoPerThread = 1; constexpr index_t WoPerThread = 1;
constexpr index_t EPerThread = EPerBlock; constexpr index_t EPerThread = EPerBlock;
using ABlockTransferThreadSliceLengths_E_K = Sequence<4, 1>; using ABlockTransferThreadSliceLengths_E_K = Sequence<1, 4, 1>;
using ABlockTransferThreadClusterLengths_E_K = Sequence<4, 16>; using ABlockTransferThreadClusterLengths_E_K = Sequence<1, 4, 16>;
constexpr index_t ABlockTransferSrcScalarPerVector_E = 4; constexpr index_t ABlockTransferSrcScalarPerVector_E = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K = 1; constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
constexpr index_t BThreadTransferSrcScalarPerVector_E = 4; constexpr index_t BThreadTransferSrcScalarPerVector_E = 1;
constexpr index_t CThreadTransferDstScalarPerVector_K = 4; constexpr index_t CThreadTransferDstScalarPerVector_K = 1;
#endif #endif
constexpr auto conv_driver = constexpr auto conv_driver =
...@@ -128,6 +130,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( ...@@ -128,6 +130,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
E1,
KPerBlock, KPerBlock,
HoPerBlock, HoPerBlock,
WoPerBlock, WoPerBlock,
......
...@@ -10,6 +10,7 @@ template <ck::index_t BlockSize, ...@@ -10,6 +10,7 @@ template <ck::index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
ck::index_t E1,
ck::index_t KPerBlock, ck::index_t KPerBlock,
ck::index_t HoPerBlock, ck::index_t HoPerBlock,
ck::index_t WoPerBlock, ck::index_t WoPerBlock,
...@@ -92,6 +93,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -92,6 +93,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
std::cerr << "InRightPadH = " << InRightPadH << " InRightPadW = " << InRightPadW std::cerr << "InRightPadH = " << InRightPadH << " InRightPadW = " << InRightPadW
<< std::endl; << std::endl;
const auto E = C0 * Y * X * C1;
const auto E0 = E / E1;
// weight tensor // weight tensor
const auto a_e_k_grid_desc = transform_tensor_descriptor( const auto a_e_k_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, C0 * Y * X * C1)), make_naive_tensor_descriptor_packed(make_tuple(K, C0 * Y * X * C1)),
...@@ -100,6 +104,12 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -100,6 +104,12 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto a_e0_e1_k_grid_desc = transform_tensor_descriptor(
a_e_k_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(E0, E1)), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}));
// input tensor // input tensor
const auto in_n_c0_hip_wip_c1_global_desc = transform_tensor_descriptor( const auto in_n_c0_hip_wip_c1_global_desc = transform_tensor_descriptor(
in_n_c0_hi_wi_c1_global_desc, in_n_c0_hi_wi_c1_global_desc,
...@@ -132,6 +142,15 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -132,6 +142,15 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
make_tuple(Sequence<1, 2, 4, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}), make_tuple(Sequence<1, 2, 4, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto b_e0_e1_n_ho_wo_grid_desc = transform_tensor_descriptor(
b_e_n_ho_wo_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(E0, E1)),
make_pass_through_transform(N),
make_pass_through_transform(Hop),
make_pass_through_transform(Wop)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
// output tensor // output tensor
const auto c_k_n_hop_wop_grid_desc = transform_tensor_descriptor( const auto c_k_n_hop_wop_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)), make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)),
...@@ -142,8 +161,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -142,8 +161,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto E = C0 * Y * X * C1;
std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl; std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl;
if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 && if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 &&
...@@ -153,24 +170,28 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -153,24 +170,28 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
} }
// hack to control index calculation when iterating over a_k_m_global tensor // hack to control index calculation when iterating over a_k_m_global tensor
constexpr auto a_e_k_global_step_hacks = constexpr auto a_e0_e1_k_global_step_hacks = make_tuple(
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), make_tuple(
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
constexpr auto a_e_k_global_move_slice_window_step_hack = Sequence<0, 0, 0>{}; constexpr auto a_e0_e1_k_global_move_slice_window_step_hack = Sequence<0, 0, 0, 0, 0>{};
constexpr auto b_e_n_ho_wo_global_step_hacks = constexpr auto b_e0_e1_n_ho_wo_global_step_hacks = make_tuple(
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack = constexpr auto b_e0_e1_n_ho_wo_global_move_slice_window_step_hack =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format // hack for NKHW format
...@@ -191,9 +212,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -191,9 +212,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
FloatAcc, FloatAcc,
FloatC, FloatC,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
decltype(a_e_k_grid_desc), decltype(a_e0_e1_k_grid_desc),
decltype(b_e_n_ho_wo_grid_desc), decltype(b_e0_e1_n_ho_wo_grid_desc),
decltype(c_k_n_hop_wop_grid_desc), decltype(c_k_n_hop_wop_grid_desc),
E1,
KPerBlock, KPerBlock,
HoPerBlock, HoPerBlock,
WoPerBlock, WoPerBlock,
...@@ -204,42 +226,42 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -204,42 +226,42 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
EPerThread, EPerThread,
ABlockTransferThreadSliceLengths_E_K, ABlockTransferThreadSliceLengths_E_K,
ABlockTransferThreadClusterLengths_E_K, ABlockTransferThreadClusterLengths_E_K,
Sequence<1, 0>, Sequence<2, 0, 1>,
Sequence<1, 0>, Sequence<2, 0, 1>,
0, 1,
ABlockTransferSrcScalarPerVector_E, ABlockTransferSrcScalarPerVector_E,
ABlockTransferDstScalarPerVector_K, ABlockTransferDstScalarPerVector_K,
false, // don't move back src coordinate after threadwise copy false, // don't move back src coordinate after threadwise copy
Sequence<0, 2, 3, 1>, Sequence<0, 2, 3, 4, 1>,
0, 1,
BThreadTransferSrcScalarPerVector_E, BThreadTransferSrcScalarPerVector_E,
false, // don't move back src coordinate after threadwise copy, which will be fused with false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation // MoveSrcSliceWindow() to save addr computation
Sequence<0, 2, 3, 1>, Sequence<0, 2, 3, 1>,
0, 0,
CThreadTransferDstScalarPerVector_K, CThreadTransferDstScalarPerVector_K,
decltype(a_e_k_global_step_hacks), decltype(a_e0_e1_k_global_step_hacks),
decltype(b_e_n_ho_wo_global_step_hacks), decltype(b_e0_e1_n_ho_wo_global_step_hacks),
decltype(c_k_n_ho_wo_global_tensor_step_hacks), decltype(c_k_n_ho_wo_global_tensor_step_hacks),
decltype(a_e_k_global_move_slice_window_step_hack), decltype(a_e0_e1_k_global_move_slice_window_step_hack),
decltype(b_e_n_ho_wo_global_move_slice_window_step_hack)>; decltype(b_e0_e1_n_ho_wo_global_move_slice_window_step_hack)>;
using AEKGridDesc = decltype(a_e_k_grid_desc); using AGridDesc_E0_E1_K = decltype(a_e0_e1_k_grid_desc);
using BENHoWoGridDesc = decltype(b_e_n_ho_wo_grid_desc); using BGridDesc_E0_E1_N_Ho_Wo = decltype(b_e0_e1_n_ho_wo_grid_desc);
using CKNHopWopGridDesc = decltype(c_k_n_hop_wop_grid_desc); using CGridDesc_K_N_Ho_Wo = decltype(c_k_n_hop_wop_grid_desc);
const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N; const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1; const bool has_main_k_block_loop = (E1 + EPerBlock) / (2 * EPerBlock) > 1;
const bool has_double_tail_k_block_loop = (E / EPerBlock) % 2 == 0; const bool has_double_tail_k_block_loop = (E1 / EPerBlock) % 2 == 0;
const auto c_blockid_to_k_n_ho_wo_block_cluster_adaptor = const auto c_blockid_to_k_n_ho_wo_block_cluster_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(I0, I0))), make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(I0, I0))),
make_tuple(Sequence<0, 1>{}), make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
using CBlockIdToKNHopWopBlockClusterAdaptor = using CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo =
decltype(c_blockid_to_k_n_ho_wo_block_cluster_adaptor); decltype(c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
...@@ -251,10 +273,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -251,10 +273,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AEKGridDesc>, remove_reference_t<AGridDesc_E0_E1_K>,
remove_reference_t<BENHoWoGridDesc>, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo>,
remove_reference_t<CKNHopWopGridDesc>, remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToKNHopWopBlockClusterAdaptor>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
true, true,
true>; true>;
...@@ -266,8 +288,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -266,8 +288,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_e_k_grid_desc, a_e0_e1_k_grid_desc,
b_e_n_ho_wo_grid_desc, b_e0_e1_n_ho_wo_grid_desc,
c_k_n_hop_wop_grid_desc, c_k_n_hop_wop_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor); c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
} }
...@@ -277,10 +299,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -277,10 +299,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AEKGridDesc>, remove_reference_t<AGridDesc_E0_E1_K>,
remove_reference_t<BENHoWoGridDesc>, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo>,
remove_reference_t<CKNHopWopGridDesc>, remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToKNHopWopBlockClusterAdaptor>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
true, true,
false>; false>;
...@@ -292,8 +314,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -292,8 +314,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_e_k_grid_desc, a_e0_e1_k_grid_desc,
b_e_n_ho_wo_grid_desc, b_e0_e1_n_ho_wo_grid_desc,
c_k_n_hop_wop_grid_desc, c_k_n_hop_wop_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor); c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
} }
...@@ -303,10 +325,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -303,10 +325,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AEKGridDesc>, remove_reference_t<AGridDesc_E0_E1_K>,
remove_reference_t<BENHoWoGridDesc>, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo>,
remove_reference_t<CKNHopWopGridDesc>, remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToKNHopWopBlockClusterAdaptor>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
false, false,
true>; true>;
...@@ -318,8 +340,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -318,8 +340,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_e_k_grid_desc, a_e0_e1_k_grid_desc,
b_e_n_ho_wo_grid_desc, b_e0_e1_n_ho_wo_grid_desc,
c_k_n_hop_wop_grid_desc, c_k_n_hop_wop_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor); c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
} }
...@@ -329,10 +351,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -329,10 +351,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AEKGridDesc>, remove_reference_t<AGridDesc_E0_E1_K>,
remove_reference_t<BENHoWoGridDesc>, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo>,
remove_reference_t<CKNHopWopGridDesc>, remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToKNHopWopBlockClusterAdaptor>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
false, false,
false>; false>;
...@@ -344,22 +366,22 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -344,22 +366,22 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_e_k_grid_desc, a_e0_e1_k_grid_desc,
b_e_n_ho_wo_grid_desc, b_e0_e1_n_ho_wo_grid_desc,
c_k_n_hop_wop_grid_desc, c_k_n_hop_wop_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor); c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
} }
return ave_time; return ave_time;
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_e_k_grid_desc_dev_buf(sizeof(AEKGridDesc)); DeviceMem a_e0_e1_k_grid_desc_dev_buf(sizeof(AGridDesc_E0_E1_K));
DeviceMem b_e_n_ho_wo_grid_desc_dev_buf(sizeof(BENHoWoGridDesc)); DeviceMem b_e0_e1_n_ho_wo_grid_desc_dev_buf(sizeof(BGridDesc_E0_E1_N_Ho_Wo));
DeviceMem c_k_n_hop_wop_grid_desc_dev_buf(sizeof(CKNHopWopGridDesc)); DeviceMem c_k_n_hop_wop_grid_desc_dev_buf(sizeof(CGridDesc_K_N_Ho_Wo));
DeviceMem c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf( DeviceMem c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf(
sizeof(CBlockIdToKNHopWopBlockClusterAdaptor)); sizeof(CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo));
a_e_k_grid_desc_dev_buf.ToDevice(&a_e_k_grid_desc); a_e0_e1_k_grid_desc_dev_buf.ToDevice(&a_e0_e1_k_grid_desc);
b_e_n_ho_wo_grid_desc_dev_buf.ToDevice(&b_e_n_ho_wo_grid_desc); b_e0_e1_n_ho_wo_grid_desc_dev_buf.ToDevice(&b_e0_e1_n_ho_wo_grid_desc);
c_k_n_hop_wop_grid_desc_dev_buf.ToDevice(&c_k_n_hop_wop_grid_desc); c_k_n_hop_wop_grid_desc_dev_buf.ToDevice(&c_k_n_hop_wop_grid_desc);
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.ToDevice( c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.ToDevice(
&c_blockid_to_k_n_ho_wo_block_cluster_adaptor); &c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
...@@ -372,10 +394,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -372,10 +394,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AEKGridDesc>, remove_reference_t<AGridDesc_E0_E1_K>,
remove_reference_t<BENHoWoGridDesc>, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo>,
remove_reference_t<CKNHopWopGridDesc>, remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToKNHopWopBlockClusterAdaptor>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
true, true,
true>; true>;
...@@ -388,9 +410,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -388,9 +410,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
cast_pointer_to_constant_address_space(a_e_k_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
b_e_n_ho_wo_grid_desc_dev_buf.GetDeviceBuffer()), a_e0_e1_k_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
b_e0_e1_n_ho_wo_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()), c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
...@@ -402,10 +425,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -402,10 +425,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AEKGridDesc>, remove_reference_t<AGridDesc_E0_E1_K>,
remove_reference_t<BENHoWoGridDesc>, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo>,
remove_reference_t<CKNHopWopGridDesc>, remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToKNHopWopBlockClusterAdaptor>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
true, true,
false>; false>;
...@@ -418,9 +441,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -418,9 +441,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
cast_pointer_to_constant_address_space(a_e_k_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
b_e_n_ho_wo_grid_desc_dev_buf.GetDeviceBuffer()), a_e0_e1_k_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
b_e0_e1_n_ho_wo_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()), c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
...@@ -432,10 +456,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -432,10 +456,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AEKGridDesc>, remove_reference_t<AGridDesc_E0_E1_K>,
remove_reference_t<BENHoWoGridDesc>, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo>,
remove_reference_t<CKNHopWopGridDesc>, remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToKNHopWopBlockClusterAdaptor>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
false, false,
true>; true>;
...@@ -448,9 +472,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -448,9 +472,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
cast_pointer_to_constant_address_space(a_e_k_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
b_e_n_ho_wo_grid_desc_dev_buf.GetDeviceBuffer()), a_e0_e1_k_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
b_e0_e1_n_ho_wo_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()), c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
...@@ -462,10 +487,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -462,10 +487,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AEKGridDesc>, remove_reference_t<AGridDesc_E0_E1_K>,
remove_reference_t<BENHoWoGridDesc>, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo>,
remove_reference_t<CKNHopWopGridDesc>, remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToKNHopWopBlockClusterAdaptor>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
false, false,
false>; false>;
...@@ -478,9 +503,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -478,9 +503,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
cast_pointer_to_constant_address_space(a_e_k_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
b_e_n_ho_wo_grid_desc_dev_buf.GetDeviceBuffer()), a_e0_e1_k_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
b_e0_e1_n_ho_wo_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()), c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment