Unverified Commit dc1e9c5d authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Support large tensors in grouped conv fwd (#1332)

* Support large tensors in grouped conv fwd

* Multi ABD fixes

* Fix calculate element space size
parent 37a347e3
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -247,7 +247,8 @@ struct DeviceColumnToImageImpl
independent_filter_strides,
conv_filter_dilations,
input_left_pads_with_offset,
input_right_pads);
input_right_pads,
N);
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -93,12 +93,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
......
......@@ -54,12 +54,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
__shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)];
......
......@@ -66,12 +66,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
__shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];
......
......@@ -59,12 +59,9 @@ __global__ void
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -116,12 +113,9 @@ __global__ void
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
......@@ -1268,7 +1262,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
arg.Conv_G_;
std::array<index_t, I1> in_out_batch_strides = {
arg.compute_ptr_offset_of_batch_.BatchStrideC_};
static_cast<index_t>(arg.compute_ptr_offset_of_batch_.BatchStrideC_)};
const auto kernel = kernel_batched_elementwise<GridwiseElementwise,
ck::Tuple<CElementwiseGridDesc_M_N>,
......
......@@ -61,12 +61,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
__shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -97,12 +97,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
......@@ -266,7 +263,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
input_right_pads,
a_g_n_c_wis_lengths[I1]);
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......@@ -312,8 +310,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides);
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]);
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -263,7 +263,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
input_right_pads,
a_g_n_c_wis_lengths[I1]);
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......@@ -310,8 +311,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>(c_g_n_k_wos_lengths,
c_g_n_k_wos_strides);
conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>(
c_g_n_k_wos_lengths, c_g_n_k_wos_strides, c_g_n_k_wos_lengths[I1]);
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......
......@@ -60,7 +60,7 @@ template <typename GridwiseGemm,
typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename ComputePtrOffsetOfBatch,
typename ComputePtrOffset,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
......@@ -69,26 +69,28 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_grouped_conv_fwd_xdl_cshuffle_v3(
typename GridwiseGemm::Argument karg,
kernel_grouped_conv_fwd_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg,
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const index_t batch_count)
const ComputePtrOffset compute_ptr_offset_of_groups,
const ComputePtrOffset compute_ptr_offset_of_n,
const index_t groups_count)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / batch_count);
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count);
const index_t& num_blocks_per_n = groups_count;
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
const long_index_t a_batch_offset = compute_ptr_offset_of_groups.GetAPtrOffset(g_idx);
const long_index_t b_batch_offset = compute_ptr_offset_of_groups.GetBPtrOffset(g_idx);
const long_index_t e_batch_offset = compute_ptr_offset_of_groups.GetEPtrOffset(g_idx);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx);
const long_index_t e_n_offset = compute_ptr_offset_of_n.GetEPtrOffset(n_idx);
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -97,9 +99,9 @@ __global__ void
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset,
TailNum>(karg.p_a_grid + a_batch_offset + a_n_offset,
karg.p_b_grid + b_batch_offset,
karg.p_c_grid + e_batch_offset,
karg.p_c_grid + e_batch_offset + e_n_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
......@@ -114,7 +116,7 @@ template <typename GridwiseGemm,
typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename ComputePtrOffsetOfBatch,
typename ComputePtrOffset,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
......@@ -129,20 +131,23 @@ __global__ void
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const index_t batch_count)
const ComputePtrOffset compute_ptr_offset_of_groups,
const ComputePtrOffset compute_ptr_offset_of_n,
const index_t groups_count)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / batch_count);
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count);
const index_t& num_blocks_per_n = groups_count;
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const long_index_t a_batch_offset = compute_ptr_offset_of_groups.GetAPtrOffset(g_idx);
const long_index_t b_batch_offset = compute_ptr_offset_of_groups.GetBPtrOffset(g_idx);
const long_index_t e_batch_offset = compute_ptr_offset_of_groups.GetEPtrOffset(g_idx);
const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx);
const long_index_t e_n_offset = compute_ptr_offset_of_n.GetEPtrOffset(n_idx);
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
......@@ -154,9 +159,9 @@ __global__ void
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset,
TailNum>(karg.p_a_grid + a_batch_offset + a_n_offset,
karg.p_b_grid + b_batch_offset,
karg.p_c_grid + e_batch_offset,
karg.p_c_grid + e_batch_offset + e_n_offset,
p_shared_0,
p_shared_1,
karg,
......@@ -294,7 +299,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t Conv_N)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
......@@ -306,7 +313,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
input_right_pads,
Conv_N);
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......@@ -350,11 +358,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
template <typename ELay>
static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const index_t Conv_N)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides);
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, Conv_N);
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......@@ -363,7 +373,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
}
// desc for problem definition
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}, 1))>;
#define GridwiseGemmV3TemplateParams \
tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, \
......@@ -396,7 +406,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, 1))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
......@@ -429,6 +439,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
p_b_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]},
conv_N_per_block_{
conv_to_gemm_transformer.template GetSplitedNSize<ADataType, EDataType>(
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
a_grid_desc_ak0_m_ak1_{MakeAGridDescriptor_AK0_M_AK1<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
......@@ -438,13 +454,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads)},
input_right_pads,
conv_N_per_block_)},
b_grid_desc_bk0_n_bk1_{
MakeBGridDescriptor_BK0_N_BK1<BLayout>(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, conv_N_per_block_)},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
compute_ptr_offset_of_batch_{},
compute_ptr_offset_of_groups_{},
compute_ptr_offset_of_n_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
......@@ -459,15 +477,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
// A/B/E Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
// A/B/E Batch/N Stride
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
// p_as and p_bs are pointers
p_a_grid_ = static_cast<const ADataType*>(p_as);
p_b_grid_ = static_cast<const BDataType*>(p_bs);
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0];
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
......@@ -488,6 +508,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// tensor descriptors for problem definiton
index_t num_group_;
index_t conv_N_per_block_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
......@@ -496,7 +517,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_batch_;
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_groups_;
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_n_;
// element-wise op
AElementwiseOperation a_element_op_;
......@@ -538,11 +560,14 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
const index_t GemmK =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
const index_t num_workgroups_per_Conv_N =
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) =
GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/);
gdy *= arg.num_group_;
gdy *= arg.num_group_ * num_workgroups_per_Conv_N;
index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
......@@ -579,7 +604,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_,
arg.num_group_);
}
else
......@@ -594,7 +620,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_,
arg.num_group_);
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -338,7 +338,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
input_right_pads,
a_g_n_c_wis_lengths[I1]);
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......@@ -367,8 +368,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides);
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]);
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -163,7 +163,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
input_right_pads,
a_g_n_c_wis_lengths[I1]);
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......@@ -255,8 +256,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides);
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]);
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......
......@@ -68,14 +68,14 @@ template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
struct ComputePtrOffsetOfStridedBatch<NumATensor,
NumBTensor,
NumDTensor,
ck::enable_if_t<(NumATensor > 1 || NumBTensor > 1)>>
enable_if_t<(NumATensor > 1 || NumBTensor > 1)>>
{
ComputePtrOffsetOfStridedBatch() = default;
ComputePtrOffsetOfStridedBatch(Array<ck::index_t, NumATensor>& BatchStrideAs,
Array<ck::index_t, NumBTensor>& BatchStrideBs,
Array<ck::index_t, NumDTensor>& BatchStrideDs,
index_t BatchStrideE)
ComputePtrOffsetOfStridedBatch(Array<long_index_t, NumATensor>& BatchStrideAs,
Array<long_index_t, NumBTensor>& BatchStrideBs,
Array<long_index_t, NumDTensor>& BatchStrideDs,
long_index_t BatchStrideE)
: BatchStrideA_(BatchStrideAs),
BatchStrideB_(BatchStrideBs),
BatchStrideDs_(BatchStrideDs),
......@@ -87,7 +87,7 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
{
Array<long_index_t, NumATensor> as_offset;
static_for<0, NumATensor, 1>{}(
[&](auto i) { as_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideA_[i]); });
[&](auto i) { as_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideA_[i]; });
return as_offset;
}
......@@ -95,7 +95,7 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
{
Array<long_index_t, NumBTensor> bs_offset;
static_for<0, NumBTensor, 1>{}(
[&](auto i) { bs_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideB_[i]); });
[&](auto i) { bs_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideB_[i]; });
return bs_offset;
}
......@@ -103,40 +103,40 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
{
Array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); });
[&](auto i) { ds_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideDs_[i]; });
return ds_offset;
}
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE_);
return static_cast<long_index_t>(g_idx) * BatchStrideE_;
}
// alias for kernels without multiple D
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE_);
return static_cast<long_index_t>(g_idx) * BatchStrideE_;
}
Array<ck::index_t, NumATensor> BatchStrideA_;
Array<ck::index_t, NumBTensor> BatchStrideB_;
Array<ck::index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideE_;
index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
Array<long_index_t, NumATensor> BatchStrideA_;
Array<long_index_t, NumBTensor> BatchStrideB_;
Array<long_index_t, NumDTensor> BatchStrideDs_;
long_index_t BatchStrideE_;
long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
};
template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
struct ComputePtrOffsetOfStridedBatch<NumATensor,
NumBTensor,
NumDTensor,
ck::enable_if_t<(NumATensor == 1 && NumBTensor == 1)>>
enable_if_t<(NumATensor == 1 && NumBTensor == 1)>>
{
ComputePtrOffsetOfStridedBatch() = default;
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB,
Array<ck::index_t, NumDTensor> BatchStrideDs,
index_t BatchStrideE)
ComputePtrOffsetOfStridedBatch(long_index_t BatchStrideA,
long_index_t BatchStrideB,
Array<long_index_t, NumDTensor> BatchStrideDs,
long_index_t BatchStrideE)
: BatchStrideA_(BatchStrideA),
BatchStrideB_(BatchStrideB),
BatchStrideDs_(BatchStrideDs),
......@@ -146,38 +146,38 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
return static_cast<long_index_t>(g_idx) * BatchStrideA_;
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
return static_cast<long_index_t>(g_idx) * BatchStrideB_;
}
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
{
Array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); });
[&](auto i) { ds_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideDs_[i]; });
return ds_offset;
}
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE_);
return static_cast<long_index_t>(g_idx) * BatchStrideE_;
}
// alias for kernels without multiple D
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE_);
return static_cast<long_index_t>(g_idx) * BatchStrideE_;
}
ck::index_t BatchStrideA_;
ck::index_t BatchStrideB_;
Array<ck::index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideE_;
index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
long_index_t BatchStrideA_;
long_index_t BatchStrideB_;
Array<long_index_t, NumDTensor> BatchStrideDs_;
long_index_t BatchStrideE_;
long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
};
template <bool isTuple, typename Tensors>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -108,7 +108,8 @@ struct DeviceImageToColumnImpl
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
input_right_pads,
N);
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -60,12 +60,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
......@@ -155,12 +152,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -20,6 +20,71 @@ struct TransformConvFwdToGemm
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static long_index_t
calculate_element_space_size_impl(const std::array<index_t, NDimSpatial + 3>& lengths,
const std::array<index_t, NDimSpatial + 3>& strides,
index_t i)
{
long_index_t acc = 1;
for(; i < (NDimSpatial + 3); i++)
{
acc +=
static_cast<long_index_t>(lengths[i] - I1) * static_cast<long_index_t>(strides[i]);
}
return acc;
}
template <typename ADataType, typename CDataType>
static index_t GetSplitedNSize(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
{
const long_index_t a_element_space_size =
calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1);
const long_index_t c_element_space_size =
calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1);
const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType),
c_element_space_size * sizeof(CDataType));
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
const index_t N = a_g_n_c_wis_lengths[I1];
if(element_space_size > TwoGB)
{
// Minimum divisor of N to not exceed 2GB
const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB);
if(divisor <= static_cast<double>(N))
{
// Find least divisor of N larger than element_space_size / TwoGB
// Iterate up to sqrt(N). There are no divisors above this value.
for(index_t least_divisor = divisor; least_divisor * least_divisor <= N;
least_divisor++)
{
if(N % least_divisor == 0)
{
return N / least_divisor;
}
}
// Not found, process one Convolution N per block
return 1;
}
else
{
// Not possible to support even after split N.
// Too large tensor.
return N;
}
}
else
{
// Split N is not needed.
return N;
}
}
// TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
// properties
template <typename ALayout,
......@@ -38,9 +103,9 @@ struct TransformConvFwdToGemm
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t N)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Wi = a_g_n_c_wis_lengths[3];
......@@ -151,9 +216,10 @@ struct TransformConvFwdToGemm
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t N)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Hi = a_g_n_c_wis_lengths[3];
......@@ -276,13 +342,14 @@ struct TransformConvFwdToGemm
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides*/,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t N)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Di = a_g_n_c_wis_lengths[3];
......@@ -478,9 +545,9 @@ struct TransformConvFwdToGemm
bool>::type = false>
static auto
MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */)
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const index_t N)
{
const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2];
const index_t NHoWo =
......@@ -502,9 +569,9 @@ struct TransformConvFwdToGemm
is_same_v<CLayout, tensor_layout::convolution::NDHWGK>,
bool>::type = false>
static auto MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
const index_t N)
{
const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2];
const auto KStride = I1;
......@@ -525,9 +592,9 @@ struct TransformConvFwdToGemm
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_K>,
bool>::type = false>
static auto MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
const index_t N)
{
const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2];
const index_t KStride = c_g_n_k_wos_strides[2];
......
......@@ -69,6 +69,8 @@ using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWC, GKZYXC, GNDHWK>
std::tuple<ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK>,
std::tuple<int8_t, NDHWGC, GKZYXC, NDHWGK>>;
using KernelTypes2dLargeCases = ::testing::Types<std::tuple<float, NHWGC, GKYXC, NHWGK>>;
template <typename Tuple>
class TestGroupedConvndFwd1d : public TestGroupedConvndFwd<Tuple>
{
......@@ -84,9 +86,15 @@ class TestGroupedConvndFwd3d : public TestGroupedConvndFwd<Tuple>
{
};
template <typename Tuple>
class TestGroupedConvndFwd2dLargeCases : public TestGroupedConvndFwd<Tuple>
{
};
TYPED_TEST_SUITE(TestGroupedConvndFwd1d, KernelTypes1d);
TYPED_TEST_SUITE(TestGroupedConvndFwd2d, KernelTypes2d);
TYPED_TEST_SUITE(TestGroupedConvndFwd3d, KernelTypes3d);
TYPED_TEST_SUITE(TestGroupedConvndFwd2dLargeCases, KernelTypes2dLargeCases);
TYPED_TEST(TestGroupedConvndFwd1d, Test1D)
{
......@@ -131,3 +139,11 @@ TYPED_TEST(TestGroupedConvndFwd3d, Test3D)
{3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->template Run<3>();
}
TYPED_TEST(TestGroupedConvndFwd2dLargeCases, Test2DLargeCases)
{
// Case larger than 2GB
this->conv_params.push_back(
{2, 1, 64, 4, 192, {2, 2}, {224, 224}, {224, 224}, {0, 0}, {0, 0}, {0, 0}});
this->template Run<2>();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment