Commit 795bea35 authored by Umang Yadav's avatar Umang Yadav
Browse files

remove unnecessary changes

parent 8216854a
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -232,7 +229,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -232,7 +229,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC}; return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
} }
#ifndef __HIPCC_RTC__ #ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif #endif
// polymorphic // polymorphic
...@@ -308,5 +305,3 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -308,5 +305,3 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -705,7 +702,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -705,7 +702,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
c_element_op}; c_element_op};
} }
#ifndef __HIPCC_RTC__ #ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif #endif
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
...@@ -783,5 +780,3 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -783,5 +780,3 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -219,7 +216,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout, ...@@ -219,7 +216,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
false, // AThreadTransferSrcResetCoordinateAfterRun, false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM, ABlockLdsAddExtraM,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
false, // BThreadTransferSrcResetCoordinateAfterRun, false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockBufferSize, BBlockBufferSize,
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
...@@ -466,7 +463,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout, ...@@ -466,7 +463,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
c_element_op}; c_element_op};
} }
#ifndef __HIPCC_RTC__ #ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif #endif
// polymorphic // polymorphic
...@@ -532,5 +529,3 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout, ...@@ -532,5 +529,3 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -284,7 +281,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -284,7 +281,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
KBatch}; KBatch};
} }
#ifndef __HIPCC_RTC__ #ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif #endif
// polymorphic // polymorphic
...@@ -331,5 +328,3 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -331,5 +328,3 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -37,20 +34,19 @@ template <typename GridwiseGemm, ...@@ -37,20 +34,19 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_WAVELET_MAX_THREAD_PER_BLOCK, CK_WAVELET_MIN_BLOCK_PER_CU) __launch_bounds__(CK_WAVELET_MAX_THREAD_PER_BLOCK, CK_WAVELET_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdl_waveletmodel_cshuffle( kernel_gemm_xdl_waveletmodel_cshuffle(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_a_grid, const ABDataType* __restrict__ p_b_grid,
const ABDataType* __restrict__ p_b_grid, EDataType* __restrict__ p_e_grid,
EDataType* __restrict__ p_e_grid, const AElementwiseOperation a_element_op,
const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op,
const BElementwiseOperation b_element_op, const EElementwiseOperation e_element_op,
const EElementwiseOperation e_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map)
const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...@@ -533,5 +529,3 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout, ...@@ -533,5 +529,3 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -31,14 +28,14 @@ template <typename GridwiseGemm, ...@@ -31,14 +28,14 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_contraction_multiple_d_xdl_cshuffle( kernel_grouped_contraction_multiple_d_xdl_cshuffle(
const void CK_CONSTANT_ADDRESS_SPACE* contraction_args, const void CK_CONSTANT_ADDRESS_SPACE* contraction_args,
const index_t group_count, const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op) const CDEElementwiseOperation cde_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...@@ -850,7 +847,7 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle ...@@ -850,7 +847,7 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
cde_element_op}; cde_element_op};
} }
#ifndef __HIPCC_RTC__ #ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif #endif
// polymorphic // polymorphic
...@@ -915,5 +912,3 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle ...@@ -915,5 +912,3 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -113,25 +110,25 @@ template <typename GridwiseGemm, ...@@ -113,25 +110,25 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle( kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle(
const ABDataType* __restrict__ p_a_grid, const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
const index_t batch_count, const index_t batch_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_, e_grid_desc_mblock_mperblock_nblock_nperblock_,
const Block2ETileMap block_2_ctile_map, const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...@@ -963,7 +960,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -963,7 +960,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, // weight const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, // weight
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_c_wis_lengths, // bias ds_g_n_c_wis_lengths, // bias
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_c_wis_strides, // bias ds_g_n_c_wis_strides, // bias
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, // input image const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, // input image
...@@ -997,7 +994,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -997,7 +994,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
cde_element_op}; cde_element_op};
} }
#ifndef __HIPCC_RTC__ #ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif #endif
std::unique_ptr<BaseArgument> MakeArgumentPointer( std::unique_ptr<BaseArgument> MakeArgumentPointer(
...@@ -1010,7 +1007,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -1010,7 +1007,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, // weight const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, // weight
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_c_wis_lengths, // bias ds_g_n_c_wis_lengths, // bias
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_c_wis_strides, // bias ds_g_n_c_wis_strides, // bias
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, // input image const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, // input image
...@@ -1080,5 +1077,3 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -1080,5 +1077,3 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -62,18 +59,18 @@ template <typename GridwiseGemm, ...@@ -62,18 +59,18 @@ template <typename GridwiseGemm,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_dlops_bwd_weight( kernel_batched_gemm_dlops_bwd_weight(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const index_t batch_count, const index_t batch_count,
const AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1, const AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1,
const BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1, const BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -789,7 +786,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -789,7 +786,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& /*a_g_n_c_wis_strides*/, const std::array<index_t, NDimSpatial + 3>& /*a_g_n_c_wis_strides*/,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& /*b_g_k_c_xs_strides*/, const std::array<index_t, NDimSpatial + 3>& /*b_g_k_c_xs_strides*/,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& /*e_g_n_k_wos_strides*/, const std::array<index_t, NDimSpatial + 3>& /*e_g_n_k_wos_strides*/,
...@@ -1127,7 +1124,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1127,7 +1124,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
...@@ -1145,7 +1142,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1145,7 +1142,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
p_out_grid, p_out_grid,
a_g_n_c_wis_lengths, // input a_g_n_c_wis_lengths, // input
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, // weight b_g_k_c_xs_lengths, // weight
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
e_g_n_k_wos_lengths, // output e_g_n_k_wos_lengths, // output
e_g_n_k_wos_strides, e_g_n_k_wos_strides,
...@@ -1159,7 +1156,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1159,7 +1156,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
split_k}; split_k};
} }
#ifndef __HIPCC_RTC__ #ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif #endif
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
...@@ -1168,7 +1165,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1168,7 +1165,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
const void* p_out_grid, const void* p_out_grid,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
...@@ -1186,7 +1183,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1186,7 +1183,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static_cast<const OutDataType*>(p_out_grid), static_cast<const OutDataType*>(p_out_grid),
a_g_n_c_wis_lengths, // input a_g_n_c_wis_lengths, // input
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, // weight b_g_k_c_xs_lengths, // weight
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
e_g_n_k_wos_lengths, // output e_g_n_k_wos_lengths, // output
e_g_n_k_wos_strides, e_g_n_k_wos_strides,
...@@ -1228,5 +1225,3 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1228,5 +1225,3 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -64,22 +61,21 @@ template <typename GridwiseGemm, ...@@ -64,22 +61,21 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_xdlops_bwd_weight( kernel_batched_gemm_xdlops_bwd_weight(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid,
FloatC* __restrict__ p_c_grid, const AElementwiseOperation a_element_op,
const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op,
const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op,
const CElementwiseOperation c_element_op, const index_t batch_count,
const index_t batch_count, const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_desc_mblock_mperblock_nblock_nperblock, const Block2CTileMap block_2_ctile_map,
const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...@@ -1109,7 +1105,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1109,7 +1105,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
...@@ -1412,7 +1408,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1412,7 +1408,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
...@@ -1430,7 +1426,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1430,7 +1426,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
p_out_grid, p_out_grid,
a_g_n_c_wis_lengths, // input a_g_n_c_wis_lengths, // input
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, // weight b_g_k_c_xs_lengths, // weight
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
e_g_n_k_wos_lengths, // output e_g_n_k_wos_lengths, // output
e_g_n_k_wos_strides, e_g_n_k_wos_strides,
...@@ -1446,7 +1442,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1446,7 +1442,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
split_k}; split_k};
} }
#ifndef __HIPCC_RTC__ #ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif #endif
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
...@@ -1455,7 +1451,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1455,7 +1451,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const void* p_out_grid, const void* p_out_grid,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
...@@ -1473,7 +1469,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1473,7 +1469,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
static_cast<const OutDataType*>(p_out_grid), static_cast<const OutDataType*>(p_out_grid),
a_g_n_c_wis_lengths, // input a_g_n_c_wis_lengths, // input
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, // weight b_g_k_c_xs_lengths, // weight
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
e_g_n_k_wos_lengths, // output e_g_n_k_wos_lengths, // output
e_g_n_k_wos_strides, e_g_n_k_wos_strides,
...@@ -1526,5 +1522,3 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1526,5 +1522,3 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -16,7 +13,6 @@ ...@@ -16,7 +13,6 @@
#include "ck/host_utility/io.hpp" #include "ck/host_utility/io.hpp"
#endif #endif
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
...@@ -121,23 +117,23 @@ template <typename GridwiseGemm, ...@@ -121,23 +117,23 @@ template <typename GridwiseGemm,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_conv_fwd_dl_multiple_d( kernel_grouped_conv_fwd_dl_multiple_d(
const ABDataType* __restrict__ p_a_grid, const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
const index_t batch_count, const index_t batch_count,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11,
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx940__) || defined(__gfx1100__) || \ defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx940__) || defined(__gfx1100__) || \
...@@ -898,7 +894,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -898,7 +894,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
cde_element_op}; cde_element_op};
} }
#ifndef __HIPCC_RTC__ #ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif #endif
std::unique_ptr<BaseArgument> MakeArgumentPointer( std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a, const void* p_a,
...@@ -970,5 +966,3 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -970,5 +966,3 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -96,18 +93,18 @@ template <typename GridwiseGemm, ...@@ -96,18 +93,18 @@ template <typename GridwiseGemm,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_conv_fwd_dl( kernel_grouped_conv_fwd_dl(
const ABDataType* __restrict__ p_a_grid, const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
CDataType* __restrict__ p_c_grid, CDataType* __restrict__ p_c_grid,
const index_t batch_count, const index_t batch_count,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \
defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)) defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__))
...@@ -777,10 +774,9 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -777,10 +774,9 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
} }
#ifndef __HIPCC_RTC__ #ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif #endif
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
...@@ -856,5 +852,3 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -856,5 +852,3 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -78,5 +75,3 @@ struct DeviceGroupedConvFwdMultipleDMultipleR : public BaseOperator ...@@ -78,5 +75,3 @@ struct DeviceGroupedConvFwdMultipleDMultipleR : public BaseOperator
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -134,29 +131,29 @@ template <typename GridwiseGemm, ...@@ -134,29 +131,29 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batch_gemm_multiple_d_xdl_cshuffle( kernel_batch_gemm_multiple_d_xdl_cshuffle(
const ABDataType* __restrict__ p_a_grid, const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
RsPointer p_rs_grid, RsPointer p_rs_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
const QsElementwiseOperation qs_element_op, const QsElementwiseOperation qs_element_op,
const RsElementwiseOperation rs_element_op, const RsElementwiseOperation rs_element_op,
const index_t batch_count, const index_t batch_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_, e_grid_desc_mblock_mperblock_nblock_nperblock_,
const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock,
const Block2ETileMap block_2_ctile_map, const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...@@ -1029,7 +1026,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -1029,7 +1026,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
rs_element_op}; rs_element_op};
} }
#ifndef __HIPCC_RTC__ #ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif #endif
std::unique_ptr<BaseArgument> MakeArgumentPointer( std::unique_ptr<BaseArgument> MakeArgumentPointer(
...@@ -1119,5 +1116,3 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -1119,5 +1116,3 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -779,7 +776,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -779,7 +776,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
cde_element_op}; cde_element_op};
} }
#ifndef __HIPCC_RTC__ #ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif #endif
std::unique_ptr<BaseArgument> MakeArgumentPointer( std::unique_ptr<BaseArgument> MakeArgumentPointer(
...@@ -856,5 +853,3 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -856,5 +853,3 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -118,25 +115,25 @@ template <typename GridwiseGemm, ...@@ -118,25 +115,25 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle( kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle(
const ABDataType* __restrict__ p_a_grid, const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
const index_t batch_count, const index_t batch_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_, e_grid_desc_mblock_mperblock_nblock_nperblock_,
const Block2ETileMap block_2_ctile_map, const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...@@ -888,7 +885,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -888,7 +885,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
cde_element_op}; cde_element_op};
} }
#ifndef __HIPCC_RTC__ #ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif #endif
std::unique_ptr<BaseArgument> MakeArgumentPointer( std::unique_ptr<BaseArgument> MakeArgumentPointer(
...@@ -968,5 +965,3 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -968,5 +965,3 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
#pragma once #pragma once
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -34,13 +31,13 @@ template <typename GridwiseGemm, ...@@ -34,13 +31,13 @@ template <typename GridwiseGemm,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_multiple_d_dl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, kernel_grouped_gemm_multiple_d_dl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count, const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op) const CDEElementwiseOperation cde_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__) || defined(__gfx1101__) || \ defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__) || defined(__gfx1101__) || \
...@@ -712,7 +709,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout, ...@@ -712,7 +709,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, cde_element_op}; p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, cde_element_op};
} }
#ifndef __HIPCC_RTC__ #ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif #endif
// polymorphic // polymorphic
...@@ -768,5 +765,3 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout, ...@@ -768,5 +765,3 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -35,16 +32,16 @@ template <typename GridwiseGemm, ...@@ -35,16 +32,16 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1( kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count, const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op, const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op, const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op) const CElementwiseOperation c_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...@@ -812,7 +809,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -812,7 +809,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
} }
#ifndef __HIPCC_RTC__ #ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif #endif
// polymorphic // polymorphic
...@@ -897,5 +894,3 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -897,5 +894,3 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
#pragma once #pragma once
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -33,13 +30,13 @@ template <typename GridwiseGemm, ...@@ -33,13 +30,13 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count, const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op) const CDEElementwiseOperation c_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...@@ -661,7 +658,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -661,7 +658,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op}; p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op};
} }
#ifndef __HIPCC_RTC__ #ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif #endif
// polymorphic // polymorphic
...@@ -723,5 +720,3 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -723,5 +720,3 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -32,10 +29,10 @@ template <typename GridwiseGemm, ...@@ -32,10 +29,10 @@ template <typename GridwiseGemm,
InMemoryDataOperationEnum CGlobalMemoryDataOperation> InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count) const index_t group_count)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...@@ -561,7 +558,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -561,7 +558,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
return Argument{p_As, p_Bs, p_Es, gemm_descs}; return Argument{p_As, p_Bs, p_Es, gemm_descs};
} }
#ifndef __HIPCC_RTC__ #ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif #endif
// polymorphic // polymorphic
...@@ -634,5 +631,3 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -634,5 +631,3 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -596,5 +593,3 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank, ...@@ -596,5 +593,3 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment