Unverified Commit dc1e9c5d authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Support large tensors in grouped conv fwd (#1332)

* Support large tensors in grouped conv fwd

* Multi ABD fixes

* Fix calculate element space size
parent 37a347e3
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -247,7 +247,8 @@ struct DeviceColumnToImageImpl ...@@ -247,7 +247,8 @@ struct DeviceColumnToImageImpl
independent_filter_strides, independent_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads_with_offset, input_left_pads_with_offset,
input_right_pads); input_right_pads,
N);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -93,12 +93,9 @@ __global__ void ...@@ -93,12 +93,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
......
...@@ -54,12 +54,9 @@ __global__ void ...@@ -54,12 +54,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)]; __shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)];
......
...@@ -66,12 +66,9 @@ __global__ void ...@@ -66,12 +66,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];
......
...@@ -59,12 +59,9 @@ __global__ void ...@@ -59,12 +59,9 @@ __global__ void
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -116,12 +113,9 @@ __global__ void ...@@ -116,12 +113,9 @@ __global__ void
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
// Pass two lds pointer is the key to tell compiler that ds_read/write // Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy // operate on different lds chunk at same time without order dependecy
...@@ -1268,7 +1262,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -1268,7 +1262,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
arg.Conv_G_; arg.Conv_G_;
std::array<index_t, I1> in_out_batch_strides = { std::array<index_t, I1> in_out_batch_strides = {
arg.compute_ptr_offset_of_batch_.BatchStrideC_}; static_cast<index_t>(arg.compute_ptr_offset_of_batch_.BatchStrideC_)};
const auto kernel = kernel_batched_elementwise<GridwiseElementwise, const auto kernel = kernel_batched_elementwise<GridwiseElementwise,
ck::Tuple<CElementwiseGridDesc_M_N>, ck::Tuple<CElementwiseGridDesc_M_N>,
......
...@@ -61,12 +61,9 @@ __global__ void ...@@ -61,12 +61,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -97,12 +97,9 @@ __global__ void ...@@ -97,12 +97,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
...@@ -266,7 +263,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -266,7 +263,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads,
a_g_n_c_wis_lengths[I1]);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
...@@ -312,8 +310,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -312,8 +310,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides) const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths, conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_strides); e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]);
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -263,7 +263,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -263,7 +263,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads,
a_g_n_c_wis_lengths[I1]);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
...@@ -310,8 +311,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -310,8 +311,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides) const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>(c_g_n_k_wos_lengths, conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>(
c_g_n_k_wos_strides); c_g_n_k_wos_lengths, c_g_n_k_wos_strides, c_g_n_k_wos_lengths[I1]);
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......
...@@ -60,7 +60,7 @@ template <typename GridwiseGemm, ...@@ -60,7 +60,7 @@ template <typename GridwiseGemm,
typename AGridDesc_AK0_M_K1, typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1, typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename ComputePtrOffsetOfBatch, typename ComputePtrOffset,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1, index_t MinimumOccupancy = 1,
...@@ -69,26 +69,28 @@ __global__ void ...@@ -69,26 +69,28 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif #endif
kernel_grouped_conv_fwd_xdl_cshuffle_v3( kernel_grouped_conv_fwd_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg,
typename GridwiseGemm::Argument karg,
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffset compute_ptr_offset_of_groups,
const index_t batch_count) const ComputePtrOffset compute_ptr_offset_of_n,
const index_t groups_count)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group // offset base pointer for each work-group
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / batch_count); const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count);
const index_t& num_blocks_per_n = groups_count;
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
const long_index_t a_batch_offset = compute_ptr_offset_of_groups.GetAPtrOffset(g_idx);
const long_index_t b_batch_offset = compute_ptr_offset_of_groups.GetBPtrOffset(g_idx);
const long_index_t e_batch_offset = compute_ptr_offset_of_groups.GetEPtrOffset(g_idx);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t e_n_offset = compute_ptr_offset_of_n.GetEPtrOffset(n_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -97,9 +99,9 @@ __global__ void ...@@ -97,9 +99,9 @@ __global__ void
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop, HasMainKBlockLoop,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset, TailNum>(karg.p_a_grid + a_batch_offset + a_n_offset,
karg.p_b_grid + b_batch_offset, karg.p_b_grid + b_batch_offset,
karg.p_c_grid + e_batch_offset, karg.p_c_grid + e_batch_offset + e_n_offset,
p_shared, p_shared,
karg, karg,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
...@@ -114,7 +116,7 @@ template <typename GridwiseGemm, ...@@ -114,7 +116,7 @@ template <typename GridwiseGemm,
typename AGridDesc_AK0_M_K1, typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1, typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename ComputePtrOffsetOfBatch, typename ComputePtrOffset,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1, index_t MinimumOccupancy = 1,
...@@ -129,20 +131,23 @@ __global__ void ...@@ -129,20 +131,23 @@ __global__ void
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffset compute_ptr_offset_of_groups,
const index_t batch_count) const ComputePtrOffset compute_ptr_offset_of_n,
const index_t groups_count)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group // offset base pointer for each work-group
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / batch_count); const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count);
const index_t& num_blocks_per_n = groups_count;
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_groups.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_groups.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t e_batch_offset = compute_ptr_offset_of_groups.GetEPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); const long_index_t e_n_offset = compute_ptr_offset_of_n.GetEPtrOffset(n_idx);
// Pass two lds pointer is the key to tell compiler that ds_read/write // Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy // operate on different lds chunk at same time without order dependecy
...@@ -154,9 +159,9 @@ __global__ void ...@@ -154,9 +159,9 @@ __global__ void
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop, HasMainKBlockLoop,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset, TailNum>(karg.p_a_grid + a_batch_offset + a_n_offset,
karg.p_b_grid + b_batch_offset, karg.p_b_grid + b_batch_offset,
karg.p_c_grid + e_batch_offset, karg.p_c_grid + e_batch_offset + e_n_offset,
p_shared_0, p_shared_0,
p_shared_1, p_shared_1,
karg, karg,
...@@ -294,7 +299,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -294,7 +299,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads) const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t Conv_N)
{ {
const auto in_gemmmraw_gemmkraw_desc = const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths, conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
...@@ -306,7 +313,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -306,7 +313,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads,
Conv_N);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
...@@ -350,11 +358,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -350,11 +358,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
template <typename ELay> template <typename ELay>
static auto static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides) const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const index_t Conv_N)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths, conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_strides); e_g_n_k_wos_lengths, e_g_n_k_wos_strides, Conv_N);
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
...@@ -363,7 +373,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -363,7 +373,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
} }
// desc for problem definition // desc for problem definition
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>; using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}, 1))>;
#define GridwiseGemmV3TemplateParams \ #define GridwiseGemmV3TemplateParams \
tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, \ tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, \
...@@ -396,7 +406,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -396,7 +406,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// desc for blockwise copy // desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>( using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, 1))>;
using BGridDesc_BK0_N_BK1 = using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>; remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
...@@ -429,6 +439,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -429,6 +439,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
p_b_grid_{}, p_b_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)}, p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]}, num_group_{a_g_n_c_wis_lengths[0]},
conv_N_per_block_{
conv_to_gemm_transformer.template GetSplitedNSize<ADataType, EDataType>(
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
a_grid_desc_ak0_m_ak1_{MakeAGridDescriptor_AK0_M_AK1<ALayout>(a_g_n_c_wis_lengths, a_grid_desc_ak0_m_ak1_{MakeAGridDescriptor_AK0_M_AK1<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
...@@ -438,13 +454,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -438,13 +454,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads)}, input_right_pads,
conv_N_per_block_)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
MakeBGridDescriptor_BK0_N_BK1<BLayout>(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, MakeBGridDescriptor_BK0_N_BK1<BLayout>(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(
e_g_n_k_wos_strides)}, e_g_n_k_wos_lengths, e_g_n_k_wos_strides, conv_N_per_block_)},
e_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{},
compute_ptr_offset_of_batch_{}, compute_ptr_offset_of_groups_{},
compute_ptr_offset_of_n_{},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op},
...@@ -459,15 +477,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -459,15 +477,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
input_left_pads_{input_left_pads}, input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads} input_right_pads_{input_right_pads}
{ {
// A/B/E Batch Stride // A/B/E Batch/N Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
// p_as and p_bs are pointers // p_as and p_bs are pointers
p_a_grid_ = static_cast<const ADataType*>(p_as); p_a_grid_ = static_cast<const ADataType*>(p_as);
p_b_grid_ = static_cast<const BDataType*>(p_bs); p_b_grid_ = static_cast<const BDataType*>(p_bs);
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0];
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
e_grid_desc_mblock_mperblock_nblock_nperblock_ = e_grid_desc_mblock_mperblock_nblock_nperblock_ =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
...@@ -488,6 +508,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -488,6 +508,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
index_t num_group_; index_t num_group_;
index_t conv_N_per_block_;
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
...@@ -496,7 +517,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -496,7 +517,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_groups_;
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_n_;
// element-wise op // element-wise op
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
...@@ -538,11 +560,14 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -538,11 +560,14 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
const index_t GemmK = const index_t GemmK =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
const index_t num_workgroups_per_Conv_N =
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
index_t gdx, gdy, gdz; index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = std::tie(gdx, gdy, gdz) =
GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/); GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/);
gdy *= arg.num_group_; gdy *= arg.num_group_ * num_workgroups_per_Conv_N;
index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock; index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
...@@ -579,7 +604,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -579,7 +604,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_, arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_,
arg.num_group_); arg.num_group_);
} }
else else
...@@ -594,7 +620,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -594,7 +620,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_, arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_,
arg.num_group_); arg.num_group_);
} }
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -338,7 +338,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -338,7 +338,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads,
a_g_n_c_wis_lengths[I1]);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
...@@ -367,8 +368,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -367,8 +368,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides) const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths, conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_strides); e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]);
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -163,7 +163,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -163,7 +163,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads,
a_g_n_c_wis_lengths[I1]);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
...@@ -255,8 +256,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -255,8 +256,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides) const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths, conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_strides); e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]);
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......
...@@ -68,14 +68,14 @@ template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor> ...@@ -68,14 +68,14 @@ template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
struct ComputePtrOffsetOfStridedBatch<NumATensor, struct ComputePtrOffsetOfStridedBatch<NumATensor,
NumBTensor, NumBTensor,
NumDTensor, NumDTensor,
ck::enable_if_t<(NumATensor > 1 || NumBTensor > 1)>> enable_if_t<(NumATensor > 1 || NumBTensor > 1)>>
{ {
ComputePtrOffsetOfStridedBatch() = default; ComputePtrOffsetOfStridedBatch() = default;
ComputePtrOffsetOfStridedBatch(Array<ck::index_t, NumATensor>& BatchStrideAs, ComputePtrOffsetOfStridedBatch(Array<long_index_t, NumATensor>& BatchStrideAs,
Array<ck::index_t, NumBTensor>& BatchStrideBs, Array<long_index_t, NumBTensor>& BatchStrideBs,
Array<ck::index_t, NumDTensor>& BatchStrideDs, Array<long_index_t, NumDTensor>& BatchStrideDs,
index_t BatchStrideE) long_index_t BatchStrideE)
: BatchStrideA_(BatchStrideAs), : BatchStrideA_(BatchStrideAs),
BatchStrideB_(BatchStrideBs), BatchStrideB_(BatchStrideBs),
BatchStrideDs_(BatchStrideDs), BatchStrideDs_(BatchStrideDs),
...@@ -87,7 +87,7 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor, ...@@ -87,7 +87,7 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
{ {
Array<long_index_t, NumATensor> as_offset; Array<long_index_t, NumATensor> as_offset;
static_for<0, NumATensor, 1>{}( static_for<0, NumATensor, 1>{}(
[&](auto i) { as_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideA_[i]); }); [&](auto i) { as_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideA_[i]; });
return as_offset; return as_offset;
} }
...@@ -95,7 +95,7 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor, ...@@ -95,7 +95,7 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
{ {
Array<long_index_t, NumBTensor> bs_offset; Array<long_index_t, NumBTensor> bs_offset;
static_for<0, NumBTensor, 1>{}( static_for<0, NumBTensor, 1>{}(
[&](auto i) { bs_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideB_[i]); }); [&](auto i) { bs_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideB_[i]; });
return bs_offset; return bs_offset;
} }
...@@ -103,40 +103,40 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor, ...@@ -103,40 +103,40 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
{ {
Array<long_index_t, NumDTensor> ds_offset; Array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}( static_for<0, NumDTensor, 1>{}(
[&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); }); [&](auto i) { ds_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideDs_[i]; });
return ds_offset; return ds_offset;
} }
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideE_); return static_cast<long_index_t>(g_idx) * BatchStrideE_;
} }
// alias for kernels without multiple D // alias for kernels without multiple D
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const [[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideE_); return static_cast<long_index_t>(g_idx) * BatchStrideE_;
} }
Array<ck::index_t, NumATensor> BatchStrideA_; Array<long_index_t, NumATensor> BatchStrideA_;
Array<ck::index_t, NumBTensor> BatchStrideB_; Array<long_index_t, NumBTensor> BatchStrideB_;
Array<ck::index_t, NumDTensor> BatchStrideDs_; Array<long_index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideE_; long_index_t BatchStrideE_;
index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
}; };
template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor> template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
struct ComputePtrOffsetOfStridedBatch<NumATensor, struct ComputePtrOffsetOfStridedBatch<NumATensor,
NumBTensor, NumBTensor,
NumDTensor, NumDTensor,
ck::enable_if_t<(NumATensor == 1 && NumBTensor == 1)>> enable_if_t<(NumATensor == 1 && NumBTensor == 1)>>
{ {
ComputePtrOffsetOfStridedBatch() = default; ComputePtrOffsetOfStridedBatch() = default;
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, ComputePtrOffsetOfStridedBatch(long_index_t BatchStrideA,
index_t BatchStrideB, long_index_t BatchStrideB,
Array<ck::index_t, NumDTensor> BatchStrideDs, Array<long_index_t, NumDTensor> BatchStrideDs,
index_t BatchStrideE) long_index_t BatchStrideE)
: BatchStrideA_(BatchStrideA), : BatchStrideA_(BatchStrideA),
BatchStrideB_(BatchStrideB), BatchStrideB_(BatchStrideB),
BatchStrideDs_(BatchStrideDs), BatchStrideDs_(BatchStrideDs),
...@@ -146,38 +146,38 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor, ...@@ -146,38 +146,38 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideA_); return static_cast<long_index_t>(g_idx) * BatchStrideA_;
} }
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideB_); return static_cast<long_index_t>(g_idx) * BatchStrideB_;
} }
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
{ {
Array<long_index_t, NumDTensor> ds_offset; Array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}( static_for<0, NumDTensor, 1>{}(
[&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); }); [&](auto i) { ds_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideDs_[i]; });
return ds_offset; return ds_offset;
} }
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideE_); return static_cast<long_index_t>(g_idx) * BatchStrideE_;
} }
// alias for kernels without multiple D // alias for kernels without multiple D
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const [[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideE_); return static_cast<long_index_t>(g_idx) * BatchStrideE_;
} }
ck::index_t BatchStrideA_; long_index_t BatchStrideA_;
ck::index_t BatchStrideB_; long_index_t BatchStrideB_;
Array<ck::index_t, NumDTensor> BatchStrideDs_; Array<long_index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideE_; long_index_t BatchStrideE_;
index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
}; };
template <bool isTuple, typename Tensors> template <bool isTuple, typename Tensors>
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -108,7 +108,8 @@ struct DeviceImageToColumnImpl ...@@ -108,7 +108,8 @@ struct DeviceImageToColumnImpl
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads,
N);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -60,12 +60,9 @@ __global__ void ...@@ -60,12 +60,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
...@@ -155,12 +152,9 @@ __global__ void ...@@ -155,12 +152,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -20,6 +20,71 @@ struct TransformConvFwdToGemm ...@@ -20,6 +20,71 @@ struct TransformConvFwdToGemm
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static long_index_t
calculate_element_space_size_impl(const std::array<index_t, NDimSpatial + 3>& lengths,
const std::array<index_t, NDimSpatial + 3>& strides,
index_t i)
{
long_index_t acc = 1;
for(; i < (NDimSpatial + 3); i++)
{
acc +=
static_cast<long_index_t>(lengths[i] - I1) * static_cast<long_index_t>(strides[i]);
}
return acc;
}
template <typename ADataType, typename CDataType>
static index_t GetSplitedNSize(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
{
const long_index_t a_element_space_size =
calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1);
const long_index_t c_element_space_size =
calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1);
const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType),
c_element_space_size * sizeof(CDataType));
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
const index_t N = a_g_n_c_wis_lengths[I1];
if(element_space_size > TwoGB)
{
// Minimum divisor of N to not exceed 2GB
const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB);
if(divisor <= static_cast<double>(N))
{
// Find least divisor of N larger than element_space_size / TwoGB
// Iterate up to sqrt(N). There are no divisors above this value.
for(index_t least_divisor = divisor; least_divisor * least_divisor <= N;
least_divisor++)
{
if(N % least_divisor == 0)
{
return N / least_divisor;
}
}
// Not found, process one Convolution N per block
return 1;
}
else
{
// Not possible to support even after split N.
// Too large tensor.
return N;
}
}
else
{
// Split N is not needed.
return N;
}
}
// TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as // TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
// properties // properties
template <typename ALayout, template <typename ALayout,
...@@ -38,9 +103,9 @@ struct TransformConvFwdToGemm ...@@ -38,9 +103,9 @@ struct TransformConvFwdToGemm
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads) const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t N)
{ {
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2]; const index_t C = a_g_n_c_wis_lengths[2];
const index_t Wi = a_g_n_c_wis_lengths[3]; const index_t Wi = a_g_n_c_wis_lengths[3];
...@@ -151,9 +216,10 @@ struct TransformConvFwdToGemm ...@@ -151,9 +216,10 @@ struct TransformConvFwdToGemm
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads) const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t N)
{ {
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2]; const index_t C = a_g_n_c_wis_lengths[2];
const index_t Hi = a_g_n_c_wis_lengths[3]; const index_t Hi = a_g_n_c_wis_lengths[3];
...@@ -276,13 +342,14 @@ struct TransformConvFwdToGemm ...@@ -276,13 +342,14 @@ struct TransformConvFwdToGemm
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */, const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths, const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */, const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides*/,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads) const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t N)
{ {
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2]; const index_t C = a_g_n_c_wis_lengths[2];
const index_t Di = a_g_n_c_wis_lengths[3]; const index_t Di = a_g_n_c_wis_lengths[3];
...@@ -478,9 +545,9 @@ struct TransformConvFwdToGemm ...@@ -478,9 +545,9 @@ struct TransformConvFwdToGemm
bool>::type = false> bool>::type = false>
static auto static auto
MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths, MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */) const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const index_t N)
{ {
const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2]; const index_t K = c_g_n_k_wos_lengths[2];
const index_t NHoWo = const index_t NHoWo =
...@@ -502,9 +569,9 @@ struct TransformConvFwdToGemm ...@@ -502,9 +569,9 @@ struct TransformConvFwdToGemm
is_same_v<CLayout, tensor_layout::convolution::NDHWGK>, is_same_v<CLayout, tensor_layout::convolution::NDHWGK>,
bool>::type = false> bool>::type = false>
static auto MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths, static auto MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides) const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
const index_t N)
{ {
const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2]; const index_t K = c_g_n_k_wos_lengths[2];
const auto KStride = I1; const auto KStride = I1;
...@@ -525,9 +592,9 @@ struct TransformConvFwdToGemm ...@@ -525,9 +592,9 @@ struct TransformConvFwdToGemm
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_K>, typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_K>,
bool>::type = false> bool>::type = false>
static auto MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths, static auto MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides) const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
const index_t N)
{ {
const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2]; const index_t K = c_g_n_k_wos_lengths[2];
const index_t KStride = c_g_n_k_wos_strides[2]; const index_t KStride = c_g_n_k_wos_strides[2];
......
...@@ -69,6 +69,8 @@ using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWC, GKZYXC, GNDHWK> ...@@ -69,6 +69,8 @@ using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWC, GKZYXC, GNDHWK>
std::tuple<ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK>, std::tuple<ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK>,
std::tuple<int8_t, NDHWGC, GKZYXC, NDHWGK>>; std::tuple<int8_t, NDHWGC, GKZYXC, NDHWGK>>;
using KernelTypes2dLargeCases = ::testing::Types<std::tuple<float, NHWGC, GKYXC, NHWGK>>;
template <typename Tuple> template <typename Tuple>
class TestGroupedConvndFwd1d : public TestGroupedConvndFwd<Tuple> class TestGroupedConvndFwd1d : public TestGroupedConvndFwd<Tuple>
{ {
...@@ -84,9 +86,15 @@ class TestGroupedConvndFwd3d : public TestGroupedConvndFwd<Tuple> ...@@ -84,9 +86,15 @@ class TestGroupedConvndFwd3d : public TestGroupedConvndFwd<Tuple>
{ {
}; };
template <typename Tuple>
class TestGroupedConvndFwd2dLargeCases : public TestGroupedConvndFwd<Tuple>
{
};
TYPED_TEST_SUITE(TestGroupedConvndFwd1d, KernelTypes1d); TYPED_TEST_SUITE(TestGroupedConvndFwd1d, KernelTypes1d);
TYPED_TEST_SUITE(TestGroupedConvndFwd2d, KernelTypes2d); TYPED_TEST_SUITE(TestGroupedConvndFwd2d, KernelTypes2d);
TYPED_TEST_SUITE(TestGroupedConvndFwd3d, KernelTypes3d); TYPED_TEST_SUITE(TestGroupedConvndFwd3d, KernelTypes3d);
TYPED_TEST_SUITE(TestGroupedConvndFwd2dLargeCases, KernelTypes2dLargeCases);
TYPED_TEST(TestGroupedConvndFwd1d, Test1D) TYPED_TEST(TestGroupedConvndFwd1d, Test1D)
{ {
...@@ -131,3 +139,11 @@ TYPED_TEST(TestGroupedConvndFwd3d, Test3D) ...@@ -131,3 +139,11 @@ TYPED_TEST(TestGroupedConvndFwd3d, Test3D)
{3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->template Run<3>(); this->template Run<3>();
} }
TYPED_TEST(TestGroupedConvndFwd2dLargeCases, Test2DLargeCases)
{
// Case larger than 2GB
this->conv_params.push_back(
{2, 1, 64, 4, 192, {2, 2}, {224, 224}, {224, 224}, {0, 0}, {0, 0}, {0, 0}});
this->template Run<2>();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment