Unverified Commit dc1e9c5d authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Support large tensors in grouped conv fwd (#1332)

* Support large tensors in grouped conv fwd

* Multi ABD fixes

* Fix calculate element space size
parent 37a347e3
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -247,7 +247,8 @@ struct DeviceColumnToImageImpl ...@@ -247,7 +247,8 @@ struct DeviceColumnToImageImpl
independent_filter_strides, independent_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads_with_offset, input_left_pads_with_offset,
input_right_pads); input_right_pads,
N);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -93,12 +93,9 @@ __global__ void ...@@ -93,12 +93,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
......
...@@ -54,12 +54,9 @@ __global__ void ...@@ -54,12 +54,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)]; __shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)];
......
...@@ -66,12 +66,9 @@ __global__ void ...@@ -66,12 +66,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];
......
...@@ -59,12 +59,9 @@ __global__ void ...@@ -59,12 +59,9 @@ __global__ void
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -116,12 +113,9 @@ __global__ void ...@@ -116,12 +113,9 @@ __global__ void
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
// Pass two lds pointer is the key to tell compiler that ds_read/write // Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy // operate on different lds chunk at same time without order dependecy
...@@ -1268,7 +1262,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -1268,7 +1262,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
arg.Conv_G_; arg.Conv_G_;
std::array<index_t, I1> in_out_batch_strides = { std::array<index_t, I1> in_out_batch_strides = {
arg.compute_ptr_offset_of_batch_.BatchStrideC_}; static_cast<index_t>(arg.compute_ptr_offset_of_batch_.BatchStrideC_)};
const auto kernel = kernel_batched_elementwise<GridwiseElementwise, const auto kernel = kernel_batched_elementwise<GridwiseElementwise,
ck::Tuple<CElementwiseGridDesc_M_N>, ck::Tuple<CElementwiseGridDesc_M_N>,
......
...@@ -61,12 +61,9 @@ __global__ void ...@@ -61,12 +61,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -97,12 +97,9 @@ __global__ void ...@@ -97,12 +97,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
...@@ -266,7 +263,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -266,7 +263,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads,
a_g_n_c_wis_lengths[I1]);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
...@@ -312,8 +310,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -312,8 +310,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides) const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths, conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_strides); e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]);
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -263,7 +263,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -263,7 +263,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads,
a_g_n_c_wis_lengths[I1]);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
...@@ -310,8 +311,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -310,8 +311,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides) const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>(c_g_n_k_wos_lengths, conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>(
c_g_n_k_wos_strides); c_g_n_k_wos_lengths, c_g_n_k_wos_strides, c_g_n_k_wos_lengths[I1]);
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......
...@@ -69,7 +69,8 @@ template <typename GridwiseGemm, ...@@ -69,7 +69,8 @@ template <typename GridwiseGemm,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap, typename Block2ETileMap,
typename ComputePtrOffsetOfBatch, typename ComputePtrOffsetOfG,
typename ComputePtrOffsetOfN,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool isMultiA, bool isMultiA,
bool isMultiB> bool isMultiB>
...@@ -85,7 +86,7 @@ __global__ void ...@@ -85,7 +86,7 @@ __global__ void
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
const index_t batch_count, const index_t groups_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -93,18 +94,22 @@ __global__ void ...@@ -93,18 +94,22 @@ __global__ void
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_, e_grid_desc_mblock_mperblock_nblock_nperblock_,
const Block2ETileMap block_2_ctile_map, const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfG compute_ptr_offset_of_groups,
const ComputePtrOffsetOfN compute_ptr_offset_of_n)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__)) defined(__gfx94__))
// offset base pointer for each work-group // offset base pointer for each work-group
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count);
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t& num_blocks_per_n = groups_count;
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
const long_index_t e_batch_offset = compute_ptr_offset_of_groups.GetEPtrOffset(g_idx);
const auto& ds_batch_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx);
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t e_n_offset = compute_ptr_offset_of_n.GetEPtrOffset(n_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto& ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -121,13 +126,28 @@ __global__ void ...@@ -121,13 +126,28 @@ __global__ void
AsPointer p_as_grid_grp; AsPointer p_as_grid_grp;
BsPointer p_bs_grid_grp; BsPointer p_bs_grid_grp;
const auto& as_batch_offset = compute_ptr_offset_of_batch.GetAsPtrOffset(g_idx); const auto& as_batch_offset = compute_ptr_offset_of_groups.GetAsPtrOffset(g_idx);
// compute_ptr_offset_of_n_ not need BatchStrideB so
// in case of MultiA is false but isMultiB is true
// BatchStrideA_ is not tuple.
if constexpr(isMultiA)
{
const auto& as_n_offset = compute_ptr_offset_of_n.GetAsPtrOffset(n_idx);
static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size(); static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size();
static_for<0, NumATensor, 1>{}( static_for<0, NumATensor, 1>{}([&](auto i) {
[&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i]; }); p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i] + as_n_offset[i];
});
}
else
{
const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx);
static_for<0, 1, 1>{}(
[&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i] + a_n_offset; });
}
const auto& bs_batch_offset = compute_ptr_offset_of_batch.GetBsPtrOffset(g_idx); const auto& bs_batch_offset = compute_ptr_offset_of_groups.GetBsPtrOffset(g_idx);
static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size(); static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size();
static_for<0, NumBTensor, 1>{}( static_for<0, NumBTensor, 1>{}(
...@@ -137,7 +157,7 @@ __global__ void ...@@ -137,7 +157,7 @@ __global__ void
p_as_grid_grp, p_as_grid_grp,
p_bs_grid_grp, p_bs_grid_grp,
p_ds_grid_grp, p_ds_grid_grp,
p_e_grid + e_batch_offset, p_e_grid + e_batch_offset + e_n_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -150,16 +170,16 @@ __global__ void ...@@ -150,16 +170,16 @@ __global__ void
} }
else else
{ {
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_groups.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_groups.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx);
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(
p_as_grid + a_batch_offset, p_as_grid + a_batch_offset + a_n_offset,
p_bs_grid + b_batch_offset, p_bs_grid + b_batch_offset,
p_ds_grid_grp, p_ds_grid_grp,
p_e_grid + e_batch_offset, p_e_grid + e_batch_offset + e_n_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -175,7 +195,7 @@ __global__ void ...@@ -175,7 +195,7 @@ __global__ void
ignore = p_bs_grid; ignore = p_bs_grid;
ignore = p_ds_grid; ignore = p_ds_grid;
ignore = p_e_grid; ignore = p_e_grid;
ignore = batch_count; ignore = groups_count;
ignore = a_grid_desc_k0_m_k1; ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1; ignore = b_grid_desc_k0_n_k1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
...@@ -183,7 +203,8 @@ __global__ void ...@@ -183,7 +203,8 @@ __global__ void
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = cde_element_op; ignore = cde_element_op;
ignore = compute_ptr_offset_of_batch; ignore = compute_ptr_offset_of_groups;
ignore = compute_ptr_offset_of_n;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
#endif #endif
} }
...@@ -309,7 +330,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -309,7 +330,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads) const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t Conv_N)
{ {
const auto in_gemmmraw_gemmkraw_desc = const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths, conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
...@@ -321,7 +343,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -321,7 +343,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads,
Conv_N);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
...@@ -347,11 +370,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -347,11 +370,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template <typename ELay> template <typename ELay>
static auto static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides) const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const index_t Conv_N)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths, conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_strides); e_g_n_k_wos_lengths, e_g_n_k_wos_strides, Conv_N);
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
...@@ -363,24 +387,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -363,24 +387,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// Pass e_g_n_k_wos_lengths for logical broadcast. // Pass e_g_n_k_wos_lengths for logical broadcast.
static auto MakeDsGridDescriptor_M_N( static auto MakeDsGridDescriptor_M_N(
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides) const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const index_t Conv_N)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(e_g_n_k_wos_lengths, return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
ds_g_n_k_wos_strides[i]); e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i], Conv_N);
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
// desc for problem definition // desc for problem definition
using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>( using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, 1))>;
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>; using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, 1))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>; using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}, 1))>;
// If we are using multiAB and one of the template datatype parameters is not a tuple, convert // If we are using multiAB and one of the template datatype parameters is not a tuple, convert
// it to it // it to it
...@@ -468,6 +493,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -468,6 +493,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_ds_grid_{}, p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)}, p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]}, num_group_{a_g_n_c_wis_lengths[0]},
conv_N_per_block_{
conv_to_gemm_transformer.template GetSplitedNSize<ADataType, EDataType>(
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths, a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
...@@ -477,12 +508,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -477,12 +508,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads)}, input_right_pads,
conv_N_per_block_)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(b_g_k_c_xs_lengths, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides)}, b_g_k_c_xs_strides)},
ds_grid_desc_m_n_{}, ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(
e_g_n_k_wos_strides)}, e_g_n_k_wos_lengths, e_g_n_k_wos_strides, conv_N_per_block_)},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
...@@ -490,7 +522,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -490,7 +522,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
compute_ptr_offset_of_batch_{}, compute_ptr_offset_of_groups_{},
compute_ptr_offset_of_n_{},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op},
...@@ -511,8 +544,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -511,8 +544,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
if constexpr(isMultiA || isMultiB) if constexpr(isMultiA || isMultiB)
{ {
static_for<0, NumATensor, 1>{}([&](auto i) { static_for<0, NumATensor, 1>{}([&](auto i) {
// Init compute_ptr_offset_of_batch_ for multiple AB // Init compute_ptr_offset_of_groups_ for multiple AB
compute_ptr_offset_of_batch_.BatchStrideA_(i) = a_g_n_c_wis_strides[0]; compute_ptr_offset_of_groups_.BatchStrideA_(i) = a_g_n_c_wis_strides[0];
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data // Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
// type is not tuple) // type is not tuple)
...@@ -524,16 +557,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -524,16 +557,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{ {
// p_as is tuple // p_as is tuple
p_as_grid_(i) = static_cast<const DataType*>(p_as[i.value]); p_as_grid_(i) = static_cast<const DataType*>(p_as[i.value]);
// compute_ptr_offset_of_n_ not need BatchStrideB so
// in case of MultiA is false but isMultiB is true
// BatchStrideA_ is not tuple.
compute_ptr_offset_of_n_.BatchStrideA_(i) =
a_g_n_c_wis_strides[1] * conv_N_per_block_;
} }
else else
{ {
// if MultiB and not MultiA then p_as is single pointer // if MultiB and not MultiA then p_as is single pointer
p_as_grid_(i) = static_cast<const DataType*>(p_as); p_as_grid_(i) = static_cast<const DataType*>(p_as);
compute_ptr_offset_of_n_.BatchStrideA_ =
a_g_n_c_wis_strides[1] * conv_N_per_block_;
} }
}); });
static_for<0, NumBTensor, 1>{}([&](auto i) { static_for<0, NumBTensor, 1>{}([&](auto i) {
// Init compute_ptr_offset_of_batch_ for multiple AB // Init compute_ptr_offset_of_groups_ for multiple AB
compute_ptr_offset_of_batch_.BatchStrideB_(i) = b_g_k_c_xs_strides[0]; compute_ptr_offset_of_groups_.BatchStrideB_(i) = b_g_k_c_xs_strides[0];
using DataType = remove_cvref_t<tuple_element_t<i.value, GemmBDataType>>; using DataType = remove_cvref_t<tuple_element_t<i.value, GemmBDataType>>;
// It is possible that one of the AB is a pointer and one is a tuple. // It is possible that one of the AB is a pointer and one is a tuple.
...@@ -553,8 +593,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -553,8 +593,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
} }
else else
{ {
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
// p_as and p_bs are pointers // p_as and p_bs are pointers
p_as_grid_(I0) = static_cast<const ADataType*>(p_as); p_as_grid_(I0) = static_cast<const ADataType*>(p_as);
...@@ -570,13 +611,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -570,13 +611,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]); p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
// D batch stride // D batch stride
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides[i][1] * conv_N_per_block_;
// D desc // D desc
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>( ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i]); e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i], conv_N_per_block_);
}); });
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0];
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
// populate desc for Ds/E // populate desc for Ds/E
if constexpr(isMultiA || isMultiB) if constexpr(isMultiA || isMultiB)
...@@ -638,6 +682,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -638,6 +682,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
index_t num_group_; index_t num_group_;
index_t conv_N_per_block_;
AGridDesc_M_K a_grid_desc_m_k_; AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_; BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_; DsGridDesc_M_N ds_grid_desc_m_n_;
...@@ -655,7 +701,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -655,7 +701,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor> ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>
compute_ptr_offset_of_batch_; compute_ptr_offset_of_groups_;
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor> compute_ptr_offset_of_n_;
// element-wise op // element-wise op
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
...@@ -689,8 +736,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -689,8 +736,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg.Print(); arg.Print();
} }
const index_t grid_size = const index_t num_workgroups_per_Conv_N =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_; arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
const index_t gdx = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
const index_t gdy = arg.num_group_ * num_workgroups_per_Conv_N;
const index_t gdz = 1;
const auto K = const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
...@@ -721,6 +772,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -721,6 +772,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
Block2ETileMap, Block2ETileMap,
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>, ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
has_main_loop, has_main_loop,
isMultiA, isMultiA,
isMultiB>; isMultiB>;
...@@ -728,7 +780,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -728,7 +780,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
kernel, kernel,
dim3(grid_size), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
arg.p_as_grid_, arg.p_as_grid_,
...@@ -744,7 +796,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -744,7 +796,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_, arg.block_2_etile_map_,
arg.compute_ptr_offset_of_batch_); arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_);
} }
else else
{ {
...@@ -763,6 +816,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -763,6 +816,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
Block2ETileMap, Block2ETileMap,
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>, ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
has_main_loop, has_main_loop,
isMultiA, isMultiA,
isMultiB>; isMultiB>;
...@@ -770,7 +824,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -770,7 +824,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
kernel, kernel,
dim3(grid_size), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
arg.p_as_grid_.At(I0), // Pass just A descriptor instead of tuple arg.p_as_grid_.At(I0), // Pass just A descriptor instead of tuple
...@@ -786,7 +840,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -786,7 +840,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_, arg.block_2_etile_map_,
arg.compute_ptr_offset_of_batch_); arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_);
} }
}; };
......
...@@ -60,7 +60,7 @@ template <typename GridwiseGemm, ...@@ -60,7 +60,7 @@ template <typename GridwiseGemm,
typename AGridDesc_AK0_M_K1, typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1, typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename ComputePtrOffsetOfBatch, typename ComputePtrOffset,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1, index_t MinimumOccupancy = 1,
...@@ -69,26 +69,28 @@ __global__ void ...@@ -69,26 +69,28 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif #endif
kernel_grouped_conv_fwd_xdl_cshuffle_v3( kernel_grouped_conv_fwd_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg,
typename GridwiseGemm::Argument karg,
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffset compute_ptr_offset_of_groups,
const index_t batch_count) const ComputePtrOffset compute_ptr_offset_of_n,
const index_t groups_count)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group // offset base pointer for each work-group
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / batch_count); const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count);
const index_t& num_blocks_per_n = groups_count;
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
const long_index_t a_batch_offset = compute_ptr_offset_of_groups.GetAPtrOffset(g_idx);
const long_index_t b_batch_offset = compute_ptr_offset_of_groups.GetBPtrOffset(g_idx);
const long_index_t e_batch_offset = compute_ptr_offset_of_groups.GetEPtrOffset(g_idx);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t e_n_offset = compute_ptr_offset_of_n.GetEPtrOffset(n_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -97,9 +99,9 @@ __global__ void ...@@ -97,9 +99,9 @@ __global__ void
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop, HasMainKBlockLoop,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset, TailNum>(karg.p_a_grid + a_batch_offset + a_n_offset,
karg.p_b_grid + b_batch_offset, karg.p_b_grid + b_batch_offset,
karg.p_c_grid + e_batch_offset, karg.p_c_grid + e_batch_offset + e_n_offset,
p_shared, p_shared,
karg, karg,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
...@@ -114,7 +116,7 @@ template <typename GridwiseGemm, ...@@ -114,7 +116,7 @@ template <typename GridwiseGemm,
typename AGridDesc_AK0_M_K1, typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1, typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename ComputePtrOffsetOfBatch, typename ComputePtrOffset,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1, index_t MinimumOccupancy = 1,
...@@ -129,20 +131,23 @@ __global__ void ...@@ -129,20 +131,23 @@ __global__ void
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffset compute_ptr_offset_of_groups,
const index_t batch_count) const ComputePtrOffset compute_ptr_offset_of_n,
const index_t groups_count)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group // offset base pointer for each work-group
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / batch_count); const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count);
const index_t& num_blocks_per_n = groups_count;
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_groups.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_groups.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t e_batch_offset = compute_ptr_offset_of_groups.GetEPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); const long_index_t e_n_offset = compute_ptr_offset_of_n.GetEPtrOffset(n_idx);
// Pass two lds pointer is the key to tell compiler that ds_read/write // Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy // operate on different lds chunk at same time without order dependecy
...@@ -154,9 +159,9 @@ __global__ void ...@@ -154,9 +159,9 @@ __global__ void
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop, HasMainKBlockLoop,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset, TailNum>(karg.p_a_grid + a_batch_offset + a_n_offset,
karg.p_b_grid + b_batch_offset, karg.p_b_grid + b_batch_offset,
karg.p_c_grid + e_batch_offset, karg.p_c_grid + e_batch_offset + e_n_offset,
p_shared_0, p_shared_0,
p_shared_1, p_shared_1,
karg, karg,
...@@ -294,7 +299,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -294,7 +299,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads) const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t Conv_N)
{ {
const auto in_gemmmraw_gemmkraw_desc = const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths, conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
...@@ -306,7 +313,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -306,7 +313,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads,
Conv_N);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
...@@ -350,11 +358,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -350,11 +358,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
template <typename ELay> template <typename ELay>
static auto static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides) const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const index_t Conv_N)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths, conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_strides); e_g_n_k_wos_lengths, e_g_n_k_wos_strides, Conv_N);
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
...@@ -363,7 +373,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -363,7 +373,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
} }
// desc for problem definition // desc for problem definition
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>; using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}, 1))>;
#define GridwiseGemmV3TemplateParams \ #define GridwiseGemmV3TemplateParams \
tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, \ tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, \
...@@ -396,7 +406,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -396,7 +406,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// desc for blockwise copy // desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>( using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, 1))>;
using BGridDesc_BK0_N_BK1 = using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>; remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
...@@ -429,6 +439,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -429,6 +439,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
p_b_grid_{}, p_b_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)}, p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]}, num_group_{a_g_n_c_wis_lengths[0]},
conv_N_per_block_{
conv_to_gemm_transformer.template GetSplitedNSize<ADataType, EDataType>(
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
a_grid_desc_ak0_m_ak1_{MakeAGridDescriptor_AK0_M_AK1<ALayout>(a_g_n_c_wis_lengths, a_grid_desc_ak0_m_ak1_{MakeAGridDescriptor_AK0_M_AK1<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
...@@ -438,13 +454,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -438,13 +454,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads)}, input_right_pads,
conv_N_per_block_)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
MakeBGridDescriptor_BK0_N_BK1<BLayout>(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, MakeBGridDescriptor_BK0_N_BK1<BLayout>(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(
e_g_n_k_wos_strides)}, e_g_n_k_wos_lengths, e_g_n_k_wos_strides, conv_N_per_block_)},
e_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{},
compute_ptr_offset_of_batch_{}, compute_ptr_offset_of_groups_{},
compute_ptr_offset_of_n_{},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op},
...@@ -459,15 +477,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -459,15 +477,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
input_left_pads_{input_left_pads}, input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads} input_right_pads_{input_right_pads}
{ {
// A/B/E Batch Stride // A/B/E Batch/N Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
// p_as and p_bs are pointers // p_as and p_bs are pointers
p_a_grid_ = static_cast<const ADataType*>(p_as); p_a_grid_ = static_cast<const ADataType*>(p_as);
p_b_grid_ = static_cast<const BDataType*>(p_bs); p_b_grid_ = static_cast<const BDataType*>(p_bs);
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0];
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
e_grid_desc_mblock_mperblock_nblock_nperblock_ = e_grid_desc_mblock_mperblock_nblock_nperblock_ =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
...@@ -488,6 +508,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -488,6 +508,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
index_t num_group_; index_t num_group_;
index_t conv_N_per_block_;
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
...@@ -496,7 +517,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -496,7 +517,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_groups_;
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_n_;
// element-wise op // element-wise op
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
...@@ -538,11 +560,14 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -538,11 +560,14 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
const index_t GemmK = const index_t GemmK =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
const index_t num_workgroups_per_Conv_N =
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
index_t gdx, gdy, gdz; index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = std::tie(gdx, gdy, gdz) =
GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/); GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/);
gdy *= arg.num_group_; gdy *= arg.num_group_ * num_workgroups_per_Conv_N;
index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock; index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
...@@ -579,7 +604,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -579,7 +604,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_, arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_,
arg.num_group_); arg.num_group_);
} }
else else
...@@ -594,7 +620,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -594,7 +620,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_, arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_,
arg.num_group_); arg.num_group_);
} }
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -338,7 +338,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -338,7 +338,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads,
a_g_n_c_wis_lengths[I1]);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
...@@ -367,8 +368,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -367,8 +368,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides) const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths, conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_strides); e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]);
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -163,7 +163,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -163,7 +163,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads,
a_g_n_c_wis_lengths[I1]);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
...@@ -255,8 +256,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -255,8 +256,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides) const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths, conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_strides); e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]);
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......
...@@ -68,14 +68,14 @@ template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor> ...@@ -68,14 +68,14 @@ template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
struct ComputePtrOffsetOfStridedBatch<NumATensor, struct ComputePtrOffsetOfStridedBatch<NumATensor,
NumBTensor, NumBTensor,
NumDTensor, NumDTensor,
ck::enable_if_t<(NumATensor > 1 || NumBTensor > 1)>> enable_if_t<(NumATensor > 1 || NumBTensor > 1)>>
{ {
ComputePtrOffsetOfStridedBatch() = default; ComputePtrOffsetOfStridedBatch() = default;
ComputePtrOffsetOfStridedBatch(Array<ck::index_t, NumATensor>& BatchStrideAs, ComputePtrOffsetOfStridedBatch(Array<long_index_t, NumATensor>& BatchStrideAs,
Array<ck::index_t, NumBTensor>& BatchStrideBs, Array<long_index_t, NumBTensor>& BatchStrideBs,
Array<ck::index_t, NumDTensor>& BatchStrideDs, Array<long_index_t, NumDTensor>& BatchStrideDs,
index_t BatchStrideE) long_index_t BatchStrideE)
: BatchStrideA_(BatchStrideAs), : BatchStrideA_(BatchStrideAs),
BatchStrideB_(BatchStrideBs), BatchStrideB_(BatchStrideBs),
BatchStrideDs_(BatchStrideDs), BatchStrideDs_(BatchStrideDs),
...@@ -87,7 +87,7 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor, ...@@ -87,7 +87,7 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
{ {
Array<long_index_t, NumATensor> as_offset; Array<long_index_t, NumATensor> as_offset;
static_for<0, NumATensor, 1>{}( static_for<0, NumATensor, 1>{}(
[&](auto i) { as_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideA_[i]); }); [&](auto i) { as_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideA_[i]; });
return as_offset; return as_offset;
} }
...@@ -95,7 +95,7 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor, ...@@ -95,7 +95,7 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
{ {
Array<long_index_t, NumBTensor> bs_offset; Array<long_index_t, NumBTensor> bs_offset;
static_for<0, NumBTensor, 1>{}( static_for<0, NumBTensor, 1>{}(
[&](auto i) { bs_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideB_[i]); }); [&](auto i) { bs_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideB_[i]; });
return bs_offset; return bs_offset;
} }
...@@ -103,40 +103,40 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor, ...@@ -103,40 +103,40 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
{ {
Array<long_index_t, NumDTensor> ds_offset; Array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}( static_for<0, NumDTensor, 1>{}(
[&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); }); [&](auto i) { ds_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideDs_[i]; });
return ds_offset; return ds_offset;
} }
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideE_); return static_cast<long_index_t>(g_idx) * BatchStrideE_;
} }
// alias for kernels without multiple D // alias for kernels without multiple D
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const [[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideE_); return static_cast<long_index_t>(g_idx) * BatchStrideE_;
} }
Array<ck::index_t, NumATensor> BatchStrideA_; Array<long_index_t, NumATensor> BatchStrideA_;
Array<ck::index_t, NumBTensor> BatchStrideB_; Array<long_index_t, NumBTensor> BatchStrideB_;
Array<ck::index_t, NumDTensor> BatchStrideDs_; Array<long_index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideE_; long_index_t BatchStrideE_;
index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
}; };
template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor> template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
struct ComputePtrOffsetOfStridedBatch<NumATensor, struct ComputePtrOffsetOfStridedBatch<NumATensor,
NumBTensor, NumBTensor,
NumDTensor, NumDTensor,
ck::enable_if_t<(NumATensor == 1 && NumBTensor == 1)>> enable_if_t<(NumATensor == 1 && NumBTensor == 1)>>
{ {
ComputePtrOffsetOfStridedBatch() = default; ComputePtrOffsetOfStridedBatch() = default;
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, ComputePtrOffsetOfStridedBatch(long_index_t BatchStrideA,
index_t BatchStrideB, long_index_t BatchStrideB,
Array<ck::index_t, NumDTensor> BatchStrideDs, Array<long_index_t, NumDTensor> BatchStrideDs,
index_t BatchStrideE) long_index_t BatchStrideE)
: BatchStrideA_(BatchStrideA), : BatchStrideA_(BatchStrideA),
BatchStrideB_(BatchStrideB), BatchStrideB_(BatchStrideB),
BatchStrideDs_(BatchStrideDs), BatchStrideDs_(BatchStrideDs),
...@@ -146,38 +146,38 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor, ...@@ -146,38 +146,38 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideA_); return static_cast<long_index_t>(g_idx) * BatchStrideA_;
} }
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideB_); return static_cast<long_index_t>(g_idx) * BatchStrideB_;
} }
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
{ {
Array<long_index_t, NumDTensor> ds_offset; Array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}( static_for<0, NumDTensor, 1>{}(
[&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); }); [&](auto i) { ds_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideDs_[i]; });
return ds_offset; return ds_offset;
} }
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideE_); return static_cast<long_index_t>(g_idx) * BatchStrideE_;
} }
// alias for kernels without multiple D // alias for kernels without multiple D
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const [[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideE_); return static_cast<long_index_t>(g_idx) * BatchStrideE_;
} }
ck::index_t BatchStrideA_; long_index_t BatchStrideA_;
ck::index_t BatchStrideB_; long_index_t BatchStrideB_;
Array<ck::index_t, NumDTensor> BatchStrideDs_; Array<long_index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideE_; long_index_t BatchStrideE_;
index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
}; };
template <bool isTuple, typename Tensors> template <bool isTuple, typename Tensors>
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -108,7 +108,8 @@ struct DeviceImageToColumnImpl ...@@ -108,7 +108,8 @@ struct DeviceImageToColumnImpl
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads,
N);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -60,12 +60,9 @@ __global__ void ...@@ -60,12 +60,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
...@@ -155,12 +152,9 @@ __global__ void ...@@ -155,12 +152,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -20,6 +20,71 @@ struct TransformConvFwdToGemm ...@@ -20,6 +20,71 @@ struct TransformConvFwdToGemm
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static long_index_t
calculate_element_space_size_impl(const std::array<index_t, NDimSpatial + 3>& lengths,
const std::array<index_t, NDimSpatial + 3>& strides,
index_t i)
{
long_index_t acc = 1;
for(; i < (NDimSpatial + 3); i++)
{
acc +=
static_cast<long_index_t>(lengths[i] - I1) * static_cast<long_index_t>(strides[i]);
}
return acc;
}
template <typename ADataType, typename CDataType>
static index_t GetSplitedNSize(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
{
const long_index_t a_element_space_size =
calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1);
const long_index_t c_element_space_size =
calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1);
const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType),
c_element_space_size * sizeof(CDataType));
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
const index_t N = a_g_n_c_wis_lengths[I1];
if(element_space_size > TwoGB)
{
// Minimum divisor of N to not exceed 2GB
const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB);
if(divisor <= static_cast<double>(N))
{
// Find least divisor of N larger than element_space_size / TwoGB
// Iterate up to sqrt(N). There are no divisors above this value.
for(index_t least_divisor = divisor; least_divisor * least_divisor <= N;
least_divisor++)
{
if(N % least_divisor == 0)
{
return N / least_divisor;
}
}
// Not found, process one Convolution N per block
return 1;
}
else
{
// Not possible to support even after split N.
// Too large tensor.
return N;
}
}
else
{
// Split N is not needed.
return N;
}
}
// TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as // TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
// properties // properties
template <typename ALayout, template <typename ALayout,
...@@ -38,9 +103,9 @@ struct TransformConvFwdToGemm ...@@ -38,9 +103,9 @@ struct TransformConvFwdToGemm
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads) const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t N)
{ {
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2]; const index_t C = a_g_n_c_wis_lengths[2];
const index_t Wi = a_g_n_c_wis_lengths[3]; const index_t Wi = a_g_n_c_wis_lengths[3];
...@@ -151,9 +216,10 @@ struct TransformConvFwdToGemm ...@@ -151,9 +216,10 @@ struct TransformConvFwdToGemm
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads) const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t N)
{ {
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2]; const index_t C = a_g_n_c_wis_lengths[2];
const index_t Hi = a_g_n_c_wis_lengths[3]; const index_t Hi = a_g_n_c_wis_lengths[3];
...@@ -276,13 +342,14 @@ struct TransformConvFwdToGemm ...@@ -276,13 +342,14 @@ struct TransformConvFwdToGemm
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */, const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths, const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */, const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides*/,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads) const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t N)
{ {
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2]; const index_t C = a_g_n_c_wis_lengths[2];
const index_t Di = a_g_n_c_wis_lengths[3]; const index_t Di = a_g_n_c_wis_lengths[3];
...@@ -478,9 +545,9 @@ struct TransformConvFwdToGemm ...@@ -478,9 +545,9 @@ struct TransformConvFwdToGemm
bool>::type = false> bool>::type = false>
static auto static auto
MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths, MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */) const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const index_t N)
{ {
const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2]; const index_t K = c_g_n_k_wos_lengths[2];
const index_t NHoWo = const index_t NHoWo =
...@@ -502,9 +569,9 @@ struct TransformConvFwdToGemm ...@@ -502,9 +569,9 @@ struct TransformConvFwdToGemm
is_same_v<CLayout, tensor_layout::convolution::NDHWGK>, is_same_v<CLayout, tensor_layout::convolution::NDHWGK>,
bool>::type = false> bool>::type = false>
static auto MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths, static auto MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides) const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
const index_t N)
{ {
const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2]; const index_t K = c_g_n_k_wos_lengths[2];
const auto KStride = I1; const auto KStride = I1;
...@@ -525,9 +592,9 @@ struct TransformConvFwdToGemm ...@@ -525,9 +592,9 @@ struct TransformConvFwdToGemm
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_K>, typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_K>,
bool>::type = false> bool>::type = false>
static auto MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths, static auto MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides) const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
const index_t N)
{ {
const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2]; const index_t K = c_g_n_k_wos_lengths[2];
const index_t KStride = c_g_n_k_wos_strides[2]; const index_t KStride = c_g_n_k_wos_strides[2];
......
...@@ -69,6 +69,8 @@ using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWC, GKZYXC, GNDHWK> ...@@ -69,6 +69,8 @@ using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWC, GKZYXC, GNDHWK>
std::tuple<ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK>, std::tuple<ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK>,
std::tuple<int8_t, NDHWGC, GKZYXC, NDHWGK>>; std::tuple<int8_t, NDHWGC, GKZYXC, NDHWGK>>;
using KernelTypes2dLargeCases = ::testing::Types<std::tuple<float, NHWGC, GKYXC, NHWGK>>;
template <typename Tuple> template <typename Tuple>
class TestGroupedConvndFwd1d : public TestGroupedConvndFwd<Tuple> class TestGroupedConvndFwd1d : public TestGroupedConvndFwd<Tuple>
{ {
...@@ -84,9 +86,15 @@ class TestGroupedConvndFwd3d : public TestGroupedConvndFwd<Tuple> ...@@ -84,9 +86,15 @@ class TestGroupedConvndFwd3d : public TestGroupedConvndFwd<Tuple>
{ {
}; };
template <typename Tuple>
class TestGroupedConvndFwd2dLargeCases : public TestGroupedConvndFwd<Tuple>
{
};
TYPED_TEST_SUITE(TestGroupedConvndFwd1d, KernelTypes1d); TYPED_TEST_SUITE(TestGroupedConvndFwd1d, KernelTypes1d);
TYPED_TEST_SUITE(TestGroupedConvndFwd2d, KernelTypes2d); TYPED_TEST_SUITE(TestGroupedConvndFwd2d, KernelTypes2d);
TYPED_TEST_SUITE(TestGroupedConvndFwd3d, KernelTypes3d); TYPED_TEST_SUITE(TestGroupedConvndFwd3d, KernelTypes3d);
TYPED_TEST_SUITE(TestGroupedConvndFwd2dLargeCases, KernelTypes2dLargeCases);
TYPED_TEST(TestGroupedConvndFwd1d, Test1D) TYPED_TEST(TestGroupedConvndFwd1d, Test1D)
{ {
...@@ -131,3 +139,11 @@ TYPED_TEST(TestGroupedConvndFwd3d, Test3D) ...@@ -131,3 +139,11 @@ TYPED_TEST(TestGroupedConvndFwd3d, Test3D)
{3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->template Run<3>(); this->template Run<3>();
} }
TYPED_TEST(TestGroupedConvndFwd2dLargeCases, Test2DLargeCases)
{
// Case larger than 2GB
this->conv_params.push_back(
{2, 1, 64, 4, 192, {2, 2}, {224, 224}, {224, 224}, {0, 0}, {0, 0}, {0, 0}});
this->template Run<2>();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment