Unverified Commit ce30621d authored by Rostyslav Geyyer's avatar Rostyslav Geyyer Committed by GitHub
Browse files

Merge branch 'develop' into lwpck-1815

parents e8450b71 17ed368f
...@@ -155,7 +155,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) ...@@ -155,7 +155,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#define CK_USE_AMD_V_DOT_DPP8_INLINE_ASM 1 #define CK_USE_AMD_V_DOT_DPP8_INLINE_ASM 1
// LDS direct loads using inline assembly // LDS direct loads using inline assembly
#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 1 #define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0
// set stochastic rounding as default for f8 conversions // set stochastic rounding as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 1 #define CK_USE_SR_F8_CONVERSION 1
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -247,7 +247,8 @@ struct DeviceColumnToImageImpl ...@@ -247,7 +247,8 @@ struct DeviceColumnToImageImpl
independent_filter_strides, independent_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads_with_offset, input_left_pads_with_offset,
input_right_pads); input_right_pads,
N);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -93,12 +93,9 @@ __global__ void ...@@ -93,12 +93,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
......
...@@ -54,12 +54,9 @@ __global__ void ...@@ -54,12 +54,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)]; __shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)];
......
...@@ -66,12 +66,9 @@ __global__ void ...@@ -66,12 +66,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];
......
...@@ -59,12 +59,9 @@ __global__ void ...@@ -59,12 +59,9 @@ __global__ void
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -116,12 +113,9 @@ __global__ void ...@@ -116,12 +113,9 @@ __global__ void
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
// Pass two lds pointer is the key to tell compiler that ds_read/write // Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy // operate on different lds chunk at same time without order dependecy
...@@ -1268,7 +1262,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -1268,7 +1262,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
arg.Conv_G_; arg.Conv_G_;
std::array<index_t, I1> in_out_batch_strides = { std::array<index_t, I1> in_out_batch_strides = {
arg.compute_ptr_offset_of_batch_.BatchStrideC_}; static_cast<index_t>(arg.compute_ptr_offset_of_batch_.BatchStrideC_)};
const auto kernel = kernel_batched_elementwise<GridwiseElementwise, const auto kernel = kernel_batched_elementwise<GridwiseElementwise,
ck::Tuple<CElementwiseGridDesc_M_N>, ck::Tuple<CElementwiseGridDesc_M_N>,
......
...@@ -61,12 +61,9 @@ __global__ void ...@@ -61,12 +61,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -97,12 +97,9 @@ __global__ void ...@@ -97,12 +97,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
...@@ -266,7 +263,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -266,7 +263,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads,
a_g_n_c_wis_lengths[I1]);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
...@@ -312,8 +310,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -312,8 +310,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides) const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths, conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_strides); e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]);
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -263,7 +263,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -263,7 +263,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads,
a_g_n_c_wis_lengths[I1]);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
...@@ -310,8 +311,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -310,8 +311,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides) const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>(c_g_n_k_wos_lengths, conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>(
c_g_n_k_wos_strides); c_g_n_k_wos_lengths, c_g_n_k_wos_strides, c_g_n_k_wos_lengths[I1]);
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......
...@@ -60,7 +60,7 @@ template <typename GridwiseGemm, ...@@ -60,7 +60,7 @@ template <typename GridwiseGemm,
typename AGridDesc_AK0_M_K1, typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1, typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename ComputePtrOffsetOfBatch, typename ComputePtrOffset,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1, index_t MinimumOccupancy = 1,
...@@ -69,26 +69,28 @@ __global__ void ...@@ -69,26 +69,28 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif #endif
kernel_grouped_conv_fwd_xdl_cshuffle_v3( kernel_grouped_conv_fwd_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg,
typename GridwiseGemm::Argument karg, const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_desc_mblock_mperblock_nblock_nperblock, const ComputePtrOffset compute_ptr_offset_of_groups,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffset compute_ptr_offset_of_n,
const index_t batch_count) const index_t groups_count)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group // offset base pointer for each work-group
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / batch_count); const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count);
const index_t& num_blocks_per_n = groups_count;
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_groups.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_groups.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t e_batch_offset = compute_ptr_offset_of_groups.GetEPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); const long_index_t e_n_offset = compute_ptr_offset_of_n.GetEPtrOffset(n_idx);
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -97,9 +99,9 @@ __global__ void ...@@ -97,9 +99,9 @@ __global__ void
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop, HasMainKBlockLoop,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset, TailNum>(karg.p_a_grid + a_batch_offset + a_n_offset,
karg.p_b_grid + b_batch_offset, karg.p_b_grid + b_batch_offset,
karg.p_c_grid + e_batch_offset, karg.p_c_grid + e_batch_offset + e_n_offset,
p_shared, p_shared,
karg, karg,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
...@@ -114,7 +116,7 @@ template <typename GridwiseGemm, ...@@ -114,7 +116,7 @@ template <typename GridwiseGemm,
typename AGridDesc_AK0_M_K1, typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1, typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename ComputePtrOffsetOfBatch, typename ComputePtrOffset,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1, index_t MinimumOccupancy = 1,
...@@ -129,20 +131,23 @@ __global__ void ...@@ -129,20 +131,23 @@ __global__ void
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffset compute_ptr_offset_of_groups,
const index_t batch_count) const ComputePtrOffset compute_ptr_offset_of_n,
const index_t groups_count)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group // offset base pointer for each work-group
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / batch_count); const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count);
const index_t& num_blocks_per_n = groups_count;
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
const long_index_t a_batch_offset = compute_ptr_offset_of_groups.GetAPtrOffset(g_idx);
const long_index_t b_batch_offset = compute_ptr_offset_of_groups.GetBPtrOffset(g_idx);
const long_index_t e_batch_offset = compute_ptr_offset_of_groups.GetEPtrOffset(g_idx);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t e_n_offset = compute_ptr_offset_of_n.GetEPtrOffset(n_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
// Pass two lds pointer is the key to tell compiler that ds_read/write // Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy // operate on different lds chunk at same time without order dependecy
...@@ -154,9 +159,9 @@ __global__ void ...@@ -154,9 +159,9 @@ __global__ void
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop, HasMainKBlockLoop,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset, TailNum>(karg.p_a_grid + a_batch_offset + a_n_offset,
karg.p_b_grid + b_batch_offset, karg.p_b_grid + b_batch_offset,
karg.p_c_grid + e_batch_offset, karg.p_c_grid + e_batch_offset + e_n_offset,
p_shared_0, p_shared_0,
p_shared_1, p_shared_1,
karg, karg,
...@@ -294,7 +299,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -294,7 +299,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads) const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t Conv_N)
{ {
const auto in_gemmmraw_gemmkraw_desc = const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths, conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
...@@ -306,7 +313,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -306,7 +313,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads,
Conv_N);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
...@@ -350,11 +358,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -350,11 +358,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
template <typename ELay> template <typename ELay>
static auto static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides) const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const index_t Conv_N)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths, conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_strides); e_g_n_k_wos_lengths, e_g_n_k_wos_strides, Conv_N);
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
...@@ -363,7 +373,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -363,7 +373,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
} }
// desc for problem definition // desc for problem definition
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>; using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}, 1))>;
#define GridwiseGemmV3TemplateParams \ #define GridwiseGemmV3TemplateParams \
tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, \ tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, \
...@@ -396,7 +406,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -396,7 +406,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// desc for blockwise copy // desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>( using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, 1))>;
using BGridDesc_BK0_N_BK1 = using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>; remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
...@@ -429,6 +439,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -429,6 +439,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
p_b_grid_{}, p_b_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)}, p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]}, num_group_{a_g_n_c_wis_lengths[0]},
conv_N_per_block_{
conv_to_gemm_transformer.template GetSplitedNSize<ADataType, EDataType>(
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
a_grid_desc_ak0_m_ak1_{MakeAGridDescriptor_AK0_M_AK1<ALayout>(a_g_n_c_wis_lengths, a_grid_desc_ak0_m_ak1_{MakeAGridDescriptor_AK0_M_AK1<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
...@@ -438,13 +454,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -438,13 +454,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads)}, input_right_pads,
conv_N_per_block_)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
MakeBGridDescriptor_BK0_N_BK1<BLayout>(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, MakeBGridDescriptor_BK0_N_BK1<BLayout>(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(
e_g_n_k_wos_strides)}, e_g_n_k_wos_lengths, e_g_n_k_wos_strides, conv_N_per_block_)},
e_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{},
compute_ptr_offset_of_batch_{}, compute_ptr_offset_of_groups_{},
compute_ptr_offset_of_n_{},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op},
...@@ -459,15 +477,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -459,15 +477,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
input_left_pads_{input_left_pads}, input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads} input_right_pads_{input_right_pads}
{ {
// A/B/E Batch Stride // A/B/E Batch/N Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
// p_as and p_bs are pointers // p_as and p_bs are pointers
p_a_grid_ = static_cast<const ADataType*>(p_as); p_a_grid_ = static_cast<const ADataType*>(p_as);
p_b_grid_ = static_cast<const BDataType*>(p_bs); p_b_grid_ = static_cast<const BDataType*>(p_bs);
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0];
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
e_grid_desc_mblock_mperblock_nblock_nperblock_ = e_grid_desc_mblock_mperblock_nblock_nperblock_ =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
...@@ -488,6 +508,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -488,6 +508,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
index_t num_group_; index_t num_group_;
index_t conv_N_per_block_;
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
...@@ -496,7 +517,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -496,7 +517,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_groups_;
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_n_;
// element-wise op // element-wise op
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
...@@ -538,11 +560,14 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -538,11 +560,14 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
const index_t GemmK = const index_t GemmK =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
const index_t num_workgroups_per_Conv_N =
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
index_t gdx, gdy, gdz; index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = std::tie(gdx, gdy, gdz) =
GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/); GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/);
gdy *= arg.num_group_; gdy *= arg.num_group_ * num_workgroups_per_Conv_N;
index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock; index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
...@@ -579,7 +604,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -579,7 +604,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_, arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_,
arg.num_group_); arg.num_group_);
} }
else else
...@@ -594,7 +620,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -594,7 +620,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_, arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_,
arg.num_group_); arg.num_group_);
} }
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -338,7 +338,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -338,7 +338,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads,
a_g_n_c_wis_lengths[I1]);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
...@@ -367,8 +368,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -367,8 +368,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides) const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths, conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_strides); e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]);
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -163,7 +163,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -163,7 +163,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads,
a_g_n_c_wis_lengths[I1]);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
...@@ -255,8 +256,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -255,8 +256,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides) const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths, conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_strides); e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]);
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......
...@@ -68,14 +68,14 @@ template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor> ...@@ -68,14 +68,14 @@ template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
struct ComputePtrOffsetOfStridedBatch<NumATensor, struct ComputePtrOffsetOfStridedBatch<NumATensor,
NumBTensor, NumBTensor,
NumDTensor, NumDTensor,
ck::enable_if_t<(NumATensor > 1 || NumBTensor > 1)>> enable_if_t<(NumATensor > 1 || NumBTensor > 1)>>
{ {
ComputePtrOffsetOfStridedBatch() = default; ComputePtrOffsetOfStridedBatch() = default;
ComputePtrOffsetOfStridedBatch(Array<ck::index_t, NumATensor>& BatchStrideAs, ComputePtrOffsetOfStridedBatch(Array<long_index_t, NumATensor>& BatchStrideAs,
Array<ck::index_t, NumBTensor>& BatchStrideBs, Array<long_index_t, NumBTensor>& BatchStrideBs,
Array<ck::index_t, NumDTensor>& BatchStrideDs, Array<long_index_t, NumDTensor>& BatchStrideDs,
index_t BatchStrideE) long_index_t BatchStrideE)
: BatchStrideA_(BatchStrideAs), : BatchStrideA_(BatchStrideAs),
BatchStrideB_(BatchStrideBs), BatchStrideB_(BatchStrideBs),
BatchStrideDs_(BatchStrideDs), BatchStrideDs_(BatchStrideDs),
...@@ -87,7 +87,7 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor, ...@@ -87,7 +87,7 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
{ {
Array<long_index_t, NumATensor> as_offset; Array<long_index_t, NumATensor> as_offset;
static_for<0, NumATensor, 1>{}( static_for<0, NumATensor, 1>{}(
[&](auto i) { as_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideA_[i]); }); [&](auto i) { as_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideA_[i]; });
return as_offset; return as_offset;
} }
...@@ -95,7 +95,7 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor, ...@@ -95,7 +95,7 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
{ {
Array<long_index_t, NumBTensor> bs_offset; Array<long_index_t, NumBTensor> bs_offset;
static_for<0, NumBTensor, 1>{}( static_for<0, NumBTensor, 1>{}(
[&](auto i) { bs_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideB_[i]); }); [&](auto i) { bs_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideB_[i]; });
return bs_offset; return bs_offset;
} }
...@@ -103,40 +103,40 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor, ...@@ -103,40 +103,40 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
{ {
Array<long_index_t, NumDTensor> ds_offset; Array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}( static_for<0, NumDTensor, 1>{}(
[&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); }); [&](auto i) { ds_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideDs_[i]; });
return ds_offset; return ds_offset;
} }
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideE_); return static_cast<long_index_t>(g_idx) * BatchStrideE_;
} }
// alias for kernels without multiple D // alias for kernels without multiple D
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const [[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideE_); return static_cast<long_index_t>(g_idx) * BatchStrideE_;
} }
Array<ck::index_t, NumATensor> BatchStrideA_; Array<long_index_t, NumATensor> BatchStrideA_;
Array<ck::index_t, NumBTensor> BatchStrideB_; Array<long_index_t, NumBTensor> BatchStrideB_;
Array<ck::index_t, NumDTensor> BatchStrideDs_; Array<long_index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideE_; long_index_t BatchStrideE_;
index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
}; };
template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor> template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
struct ComputePtrOffsetOfStridedBatch<NumATensor, struct ComputePtrOffsetOfStridedBatch<NumATensor,
NumBTensor, NumBTensor,
NumDTensor, NumDTensor,
ck::enable_if_t<(NumATensor == 1 && NumBTensor == 1)>> enable_if_t<(NumATensor == 1 && NumBTensor == 1)>>
{ {
ComputePtrOffsetOfStridedBatch() = default; ComputePtrOffsetOfStridedBatch() = default;
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, ComputePtrOffsetOfStridedBatch(long_index_t BatchStrideA,
index_t BatchStrideB, long_index_t BatchStrideB,
Array<ck::index_t, NumDTensor> BatchStrideDs, Array<long_index_t, NumDTensor> BatchStrideDs,
index_t BatchStrideE) long_index_t BatchStrideE)
: BatchStrideA_(BatchStrideA), : BatchStrideA_(BatchStrideA),
BatchStrideB_(BatchStrideB), BatchStrideB_(BatchStrideB),
BatchStrideDs_(BatchStrideDs), BatchStrideDs_(BatchStrideDs),
...@@ -146,38 +146,38 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor, ...@@ -146,38 +146,38 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideA_); return static_cast<long_index_t>(g_idx) * BatchStrideA_;
} }
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideB_); return static_cast<long_index_t>(g_idx) * BatchStrideB_;
} }
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
{ {
Array<long_index_t, NumDTensor> ds_offset; Array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}( static_for<0, NumDTensor, 1>{}(
[&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); }); [&](auto i) { ds_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideDs_[i]; });
return ds_offset; return ds_offset;
} }
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideE_); return static_cast<long_index_t>(g_idx) * BatchStrideE_;
} }
// alias for kernels without multiple D // alias for kernels without multiple D
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const [[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideE_); return static_cast<long_index_t>(g_idx) * BatchStrideE_;
} }
ck::index_t BatchStrideA_; long_index_t BatchStrideA_;
ck::index_t BatchStrideB_; long_index_t BatchStrideB_;
Array<ck::index_t, NumDTensor> BatchStrideDs_; Array<long_index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideE_; long_index_t BatchStrideE_;
index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
}; };
template <bool isTuple, typename Tensors> template <bool isTuple, typename Tensors>
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -108,7 +108,8 @@ struct DeviceImageToColumnImpl ...@@ -108,7 +108,8 @@ struct DeviceImageToColumnImpl
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads,
N);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -60,12 +60,9 @@ __global__ void ...@@ -60,12 +60,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
...@@ -155,12 +152,9 @@ __global__ void ...@@ -155,12 +152,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx);
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -20,6 +20,71 @@ struct TransformConvFwdToGemm ...@@ -20,6 +20,71 @@ struct TransformConvFwdToGemm
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static long_index_t
calculate_element_space_size_impl(const std::array<index_t, NDimSpatial + 3>& lengths,
const std::array<index_t, NDimSpatial + 3>& strides,
index_t i)
{
long_index_t acc = 1;
for(; i < (NDimSpatial + 3); i++)
{
acc +=
static_cast<long_index_t>(lengths[i] - I1) * static_cast<long_index_t>(strides[i]);
}
return acc;
}
template <typename ADataType, typename CDataType>
static index_t GetSplitedNSize(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
{
const long_index_t a_element_space_size =
calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1);
const long_index_t c_element_space_size =
calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1);
const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType),
c_element_space_size * sizeof(CDataType));
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
const index_t N = a_g_n_c_wis_lengths[I1];
if(element_space_size > TwoGB)
{
// Minimum divisor of N to not exceed 2GB
const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB);
if(divisor <= static_cast<double>(N))
{
// Find least divisor of N larger than element_space_size / TwoGB
// Iterate up to sqrt(N). There are no divisors above this value.
for(index_t least_divisor = divisor; least_divisor * least_divisor <= N;
least_divisor++)
{
if(N % least_divisor == 0)
{
return N / least_divisor;
}
}
// Not found, process one Convolution N per block
return 1;
}
else
{
// Not possible to support even after split N.
// Too large tensor.
return N;
}
}
else
{
// Split N is not needed.
return N;
}
}
// TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as // TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
// properties // properties
template <typename ALayout, template <typename ALayout,
...@@ -38,9 +103,9 @@ struct TransformConvFwdToGemm ...@@ -38,9 +103,9 @@ struct TransformConvFwdToGemm
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads) const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t N)
{ {
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2]; const index_t C = a_g_n_c_wis_lengths[2];
const index_t Wi = a_g_n_c_wis_lengths[3]; const index_t Wi = a_g_n_c_wis_lengths[3];
...@@ -151,9 +216,10 @@ struct TransformConvFwdToGemm ...@@ -151,9 +216,10 @@ struct TransformConvFwdToGemm
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads) const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t N)
{ {
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2]; const index_t C = a_g_n_c_wis_lengths[2];
const index_t Hi = a_g_n_c_wis_lengths[3]; const index_t Hi = a_g_n_c_wis_lengths[3];
...@@ -276,13 +342,14 @@ struct TransformConvFwdToGemm ...@@ -276,13 +342,14 @@ struct TransformConvFwdToGemm
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */, const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths, const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */, const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides*/,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads) const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t N)
{ {
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2]; const index_t C = a_g_n_c_wis_lengths[2];
const index_t Di = a_g_n_c_wis_lengths[3]; const index_t Di = a_g_n_c_wis_lengths[3];
...@@ -478,9 +545,9 @@ struct TransformConvFwdToGemm ...@@ -478,9 +545,9 @@ struct TransformConvFwdToGemm
bool>::type = false> bool>::type = false>
static auto static auto
MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths, MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */) const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const index_t N)
{ {
const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2]; const index_t K = c_g_n_k_wos_lengths[2];
const index_t NHoWo = const index_t NHoWo =
...@@ -502,9 +569,9 @@ struct TransformConvFwdToGemm ...@@ -502,9 +569,9 @@ struct TransformConvFwdToGemm
is_same_v<CLayout, tensor_layout::convolution::NDHWGK>, is_same_v<CLayout, tensor_layout::convolution::NDHWGK>,
bool>::type = false> bool>::type = false>
static auto MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths, static auto MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides) const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
const index_t N)
{ {
const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2]; const index_t K = c_g_n_k_wos_lengths[2];
const auto KStride = I1; const auto KStride = I1;
...@@ -525,9 +592,9 @@ struct TransformConvFwdToGemm ...@@ -525,9 +592,9 @@ struct TransformConvFwdToGemm
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_K>, typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_K>,
bool>::type = false> bool>::type = false>
static auto MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths, static auto MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides) const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
const index_t N)
{ {
const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2]; const index_t K = c_g_n_k_wos_lengths[2];
const index_t KStride = c_g_n_k_wos_strides[2]; const index_t KStride = c_g_n_k_wos_strides[2];
......
...@@ -61,10 +61,13 @@ CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; } ...@@ -61,10 +61,13 @@ CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
CK_TILE_DEVICE void block_sync_lds() CK_TILE_DEVICE void block_sync_lds()
{ {
#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM #if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
asm volatile("\ // asm volatile("\
s_waitcnt lgkmcnt(0) \n \ // s_waitcnt lgkmcnt(0) \n \
s_barrier \ // s_barrier \
" ::); // " ::);
__builtin_amdgcn_s_waitcnt(0xc07f);
__builtin_amdgcn_s_barrier();
#else #else
__syncthreads(); __syncthreads();
#endif #endif
......
...@@ -167,6 +167,10 @@ ...@@ -167,6 +167,10 @@
#define CK_TILE_USE_SUBDWORD_TILE_CAST 0 #define CK_TILE_USE_SUBDWORD_TILE_CAST 0
#endif #endif
#ifndef CK_TILE_USE_PK_FP16_TILE_CAST
#define CK_TILE_USE_PK_FP16_TILE_CAST 0
#endif
// TODO: better solve this inside compiler // TODO: better solve this inside compiler
#ifndef CK_TILE_FMHA_FWD_FAST_EXP2 #ifndef CK_TILE_FMHA_FWD_FAST_EXP2
#define CK_TILE_FMHA_FWD_FAST_EXP2 0 #define CK_TILE_FMHA_FWD_FAST_EXP2 0
......
...@@ -110,7 +110,7 @@ CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor) ...@@ -110,7 +110,7 @@ CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor)
namespace impl { namespace impl {
// TODO: this is ugly // TODO: this is ugly
template <typename OutDataType, typename InTensor> template <typename OutDataType, typename InTensor>
CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors) CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor& in_dstr_tensors)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
// This API is designed to use the _pk_ serious of function // This API is designed to use the _pk_ serious of function
...@@ -156,6 +156,37 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors) ...@@ -156,6 +156,37 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors)
#endif #endif
} }
template <typename OutDataType, typename InTensor>
CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor& in_dstr_tensors)
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)
// This API is designed to use the _pk_ serious of function
constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
static_assert(thread_buffer_size % 2 == 0);
constexpr index_t thread_buffer_size_pk = thread_buffer_size / 2;
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
// TODO: this is rtz cvt, need be very careful
for(index_t i = 0; i < thread_buffer_size_pk; i++)
{
auto o = __builtin_amdgcn_cvt_pkrtz(in_dstr_tensors.get_thread_buffer()[2 * i + 0],
in_dstr_tensors.get_thread_buffer()[2 * i + 1]);
out_dstr_tensor.get_thread_buffer().at(2 * i + 0) = o.x;
out_dstr_tensor.get_thread_buffer().at(2 * i + 1) = o.y;
}
return out_dstr_tensor;
#else
// fallback
return tile_elementwise_in(type_convert<OutDataType, typename InTensor::DataType>,
in_dstr_tensors);
#endif
}
#if CK_TILE_USE_SUBDWORD_TILE_CAST #if CK_TILE_USE_SUBDWORD_TILE_CAST
// this function assume either src or dst (or both) date type is under 1 dword // this function assume either src or dst (or both) date type is under 1 dword
// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy) // we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
...@@ -229,8 +260,16 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor) ...@@ -229,8 +260,16 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
float> && float> &&
(SrcTensor::get_thread_buffer_size() % 4 == 0)) (SrcTensor::get_thread_buffer_size() % 4 == 0))
{ {
return impl::cast_tile_pk_fp8x4<DstType, SrcTensor>(src_tensor); return impl::cast_tile_pk_fp8_fp32<DstType, SrcTensor>(src_tensor);
} }
#if CK_TILE_USE_PK_FP16_TILE_CAST
else if constexpr(std::is_same_v<DstType, fp16_t> &&
std::is_same_v<typename SrcTensor::DataType, float> &&
(SrcTensor::get_thread_buffer_size() % 2 == 0))
{
return impl::cast_tile_pk_fp16_fp32<DstType, SrcTensor>(src_tensor);
}
#endif
#if CK_TILE_USE_SUBDWORD_TILE_CAST #if CK_TILE_USE_SUBDWORD_TILE_CAST
else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4) else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment