Commit ed305f6b authored by Umang Yadav's avatar Umang Yadav
Browse files

formatting

parent 9f4e3544
...@@ -34,21 +34,21 @@ template <typename GridwiseGemm, ...@@ -34,21 +34,21 @@ template <typename GridwiseGemm,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_dl_multiple_d( kernel_gemm_dl_multiple_d(
const ABDataType* __restrict__ p_a_grid, const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11,
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \ defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
......
...@@ -37,30 +37,31 @@ template <typename GridwiseGemmWelford, ...@@ -37,30 +37,31 @@ template <typename GridwiseGemmWelford,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle( kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle(
const ABDataType* __restrict__ p_a_grid, const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
EMeanVarDataType* __restrict__ p_e_grid, EMeanVarDataType* __restrict__ p_e_grid,
EMeanVarDataType* __restrict__ p_welford_mean_grid, EMeanVarDataType* __restrict__ p_welford_mean_grid,
EMeanVarDataType* __restrict__ p_welford_var_grid, EMeanVarDataType* __restrict__ p_welford_var_grid,
int32_t* __restrict__ p_welford_count_grid, int32_t* __restrict__ p_welford_count_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
const MeanVarGridDescriptor_MBlock_MPerBlock_NBlock const MeanVarGridDescriptor_MBlock_MPerBlock_NBlock
mean_var_grid_desc_mblock_mperblock_nblock, mean_var_grid_desc_mblock_mperblock_nblock,
const CountGridDescriptor_MBlock_MPerBlock_NBlock count_grid_desc_mblock_mperblock_nblock, const CountGridDescriptor_MBlock_MPerBlock_NBlock
const Block2ETileMap block_2_etile_map, count_grid_desc_mblock_mperblock_nblock,
index_t NRaw) const Block2ETileMap block_2_etile_map,
index_t NRaw)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...@@ -121,26 +122,26 @@ template <typename GridwiseWelfordLayernorm, ...@@ -121,26 +122,26 @@ template <typename GridwiseWelfordLayernorm,
typename HElementwiseOperation> typename HElementwiseOperation>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_welford_layernorm2d_second_half( kernel_welford_layernorm2d_second_half(
const EMeanVarDataType* __restrict__ p_e_grid, const EMeanVarDataType* __restrict__ p_e_grid,
const EMeanVarDataType* __restrict__ p_in_welford_mean_grid, const EMeanVarDataType* __restrict__ p_in_welford_mean_grid,
const EMeanVarDataType* __restrict__ p_in_welford_var_grid, const EMeanVarDataType* __restrict__ p_in_welford_var_grid,
const int32_t* __restrict__ p_in_welford_count_grid, const int32_t* __restrict__ p_in_welford_count_grid,
const GammaDataType* __restrict__ p_gamma_grid, const GammaDataType* __restrict__ p_gamma_grid,
const BetaDataType* __restrict__ p_beta_grid, const BetaDataType* __restrict__ p_beta_grid,
HDataType* __restrict__ p_h_grid, HDataType* __restrict__ p_h_grid,
const EHGridDesc_M_N e_grid_desc_m_n, const EHGridDesc_M_N e_grid_desc_m_n,
const EHGridDesc_M_N h_grid_desc_m_n, const EHGridDesc_M_N h_grid_desc_m_n,
const LayernormMeanVarGridDesc_M_NBlock mean_var_grid_desc_m_nblock, const LayernormMeanVarGridDesc_M_NBlock mean_var_grid_desc_m_nblock,
const LayernormCountGridDesc_M_NBlock count_grid_desc_m_nblock, const LayernormCountGridDesc_M_NBlock count_grid_desc_m_nblock,
const GammaBetaGridDesc_N gamma_grid_desc_n, const GammaBetaGridDesc_N gamma_grid_desc_n,
const GammaBetaGridDesc_N beta_grid_desc_n, const GammaBetaGridDesc_N beta_grid_desc_n,
index_t numMeanVarCountBlockTileIteration_N, index_t numMeanVarCountBlockTileIteration_N,
index_t NBlockClusterLength, index_t NBlockClusterLength,
ComputeDataType epsilon, ComputeDataType epsilon,
HElementwiseOperation h_element_op) HElementwiseOperation h_element_op)
{ {
GridwiseWelfordLayernorm::Run(p_e_grid, GridwiseWelfordLayernorm::Run(p_e_grid,
p_in_welford_mean_grid, p_in_welford_mean_grid,
......
...@@ -38,27 +38,27 @@ template <typename GridwiseGemm, ...@@ -38,27 +38,27 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_multiple_d_multiple_r_xdl_cshuffle( kernel_gemm_multiple_d_multiple_r_xdl_cshuffle(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatDsPointer p_ds_grid, FloatDsPointer p_ds_grid,
FloatE* __restrict__ p_e_grid, FloatE* __restrict__ p_e_grid,
FloatRsPointer p_rs_grid, FloatRsPointer p_rs_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
const QsElementwiseOperation qs_element_op, const QsElementwiseOperation qs_element_op,
const RsElementwiseOperation rs_element_op, const RsElementwiseOperation rs_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock,
const Block2ETileMap block_2_etile_map) const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
......
...@@ -36,22 +36,22 @@ template <typename GridwiseGemm, ...@@ -36,22 +36,22 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_multiple_d_xdl_cshuffle(const ABDataType* __restrict__ p_a_grid, kernel_gemm_multiple_d_xdl_cshuffle(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map) const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
......
...@@ -111,7 +111,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -111,7 +111,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun, false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN, BBlockLdsAddExtraN,
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
......
...@@ -216,7 +216,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout, ...@@ -216,7 +216,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
false, // AThreadTransferSrcResetCoordinateAfterRun, false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM, ABlockLdsAddExtraM,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
false, // BThreadTransferSrcResetCoordinateAfterRun, false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockBufferSize, BBlockBufferSize,
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
......
...@@ -34,19 +34,20 @@ template <typename GridwiseGemm, ...@@ -34,19 +34,20 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_WAVELET_MAX_THREAD_PER_BLOCK, CK_WAVELET_MIN_BLOCK_PER_CU) __launch_bounds__(CK_WAVELET_MAX_THREAD_PER_BLOCK, CK_WAVELET_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdl_waveletmodel_cshuffle(const ABDataType* __restrict__ p_a_grid, kernel_gemm_xdl_waveletmodel_cshuffle(
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_a_grid,
EDataType* __restrict__ p_e_grid, const ABDataType* __restrict__ p_b_grid,
const AElementwiseOperation a_element_op, EDataType* __restrict__ p_e_grid,
const BElementwiseOperation b_element_op, const AElementwiseOperation a_element_op,
const EElementwiseOperation e_element_op, const BElementwiseOperation b_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const EElementwiseOperation e_element_op,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
e_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const Block2ETileMap block_2_etile_map) e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
......
...@@ -28,14 +28,14 @@ template <typename GridwiseGemm, ...@@ -28,14 +28,14 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_contraction_multiple_d_xdl_cshuffle( kernel_grouped_contraction_multiple_d_xdl_cshuffle(
const void CK_CONSTANT_ADDRESS_SPACE* contraction_args, const void CK_CONSTANT_ADDRESS_SPACE* contraction_args,
const index_t group_count, const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op) const CDEElementwiseOperation cde_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
......
...@@ -110,25 +110,25 @@ template <typename GridwiseGemm, ...@@ -110,25 +110,25 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle( kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle(
const ABDataType* __restrict__ p_a_grid, const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
const index_t batch_count, const index_t batch_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_, e_grid_desc_mblock_mperblock_nblock_nperblock_,
const Block2ETileMap block_2_ctile_map, const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...@@ -960,7 +960,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -960,7 +960,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, // weight const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, // weight
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_c_wis_lengths, // bias ds_g_n_c_wis_lengths, // bias
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_c_wis_strides, // bias ds_g_n_c_wis_strides, // bias
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, // input image const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, // input image
...@@ -1007,7 +1007,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -1007,7 +1007,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, // weight const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, // weight
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_c_wis_lengths, // bias ds_g_n_c_wis_lengths, // bias
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_c_wis_strides, // bias ds_g_n_c_wis_strides, // bias
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, // input image const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, // input image
......
...@@ -59,18 +59,18 @@ template <typename GridwiseGemm, ...@@ -59,18 +59,18 @@ template <typename GridwiseGemm,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_dlops_bwd_weight( kernel_batched_gemm_dlops_bwd_weight(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const index_t batch_count, const index_t batch_count,
const AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1, const AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1,
const BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1, const BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -786,7 +786,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -786,7 +786,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& /*a_g_n_c_wis_strides*/, const std::array<index_t, NDimSpatial + 3>& /*a_g_n_c_wis_strides*/,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& /*b_g_k_c_xs_strides*/, const std::array<index_t, NDimSpatial + 3>& /*b_g_k_c_xs_strides*/,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& /*e_g_n_k_wos_strides*/, const std::array<index_t, NDimSpatial + 3>& /*e_g_n_k_wos_strides*/,
...@@ -1124,7 +1124,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1124,7 +1124,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
...@@ -1142,7 +1142,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1142,7 +1142,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
p_out_grid, p_out_grid,
a_g_n_c_wis_lengths, // input a_g_n_c_wis_lengths, // input
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, // weight b_g_k_c_xs_lengths, // weight
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
e_g_n_k_wos_lengths, // output e_g_n_k_wos_lengths, // output
e_g_n_k_wos_strides, e_g_n_k_wos_strides,
...@@ -1165,7 +1165,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1165,7 +1165,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
const void* p_out_grid, const void* p_out_grid,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
...@@ -1183,7 +1183,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1183,7 +1183,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static_cast<const OutDataType*>(p_out_grid), static_cast<const OutDataType*>(p_out_grid),
a_g_n_c_wis_lengths, // input a_g_n_c_wis_lengths, // input
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, // weight b_g_k_c_xs_lengths, // weight
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
e_g_n_k_wos_lengths, // output e_g_n_k_wos_lengths, // output
e_g_n_k_wos_strides, e_g_n_k_wos_strides,
......
...@@ -61,21 +61,22 @@ template <typename GridwiseGemm, ...@@ -61,21 +61,22 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_xdlops_bwd_weight(const FloatAB* __restrict__ p_a_grid, kernel_batched_gemm_xdlops_bwd_weight(
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_a_grid,
FloatC* __restrict__ p_c_grid, const FloatAB* __restrict__ p_b_grid,
const AElementwiseOperation a_element_op, FloatC* __restrict__ p_c_grid,
const BElementwiseOperation b_element_op, const AElementwiseOperation a_element_op,
const CElementwiseOperation c_element_op, const BElementwiseOperation b_element_op,
const index_t batch_count, const CElementwiseOperation c_element_op,
const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, const index_t batch_count,
const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
const Block2CTileMap block_2_ctile_map, c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...@@ -1105,7 +1106,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1105,7 +1106,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
...@@ -1408,7 +1409,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1408,7 +1409,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
...@@ -1426,7 +1427,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1426,7 +1427,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
p_out_grid, p_out_grid,
a_g_n_c_wis_lengths, // input a_g_n_c_wis_lengths, // input
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, // weight b_g_k_c_xs_lengths, // weight
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
e_g_n_k_wos_lengths, // output e_g_n_k_wos_lengths, // output
e_g_n_k_wos_strides, e_g_n_k_wos_strides,
...@@ -1451,7 +1452,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1451,7 +1452,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const void* p_out_grid, const void* p_out_grid,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
...@@ -1469,7 +1470,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1469,7 +1470,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
static_cast<const OutDataType*>(p_out_grid), static_cast<const OutDataType*>(p_out_grid),
a_g_n_c_wis_lengths, // input a_g_n_c_wis_lengths, // input
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, // weight b_g_k_c_xs_lengths, // weight
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
e_g_n_k_wos_lengths, // output e_g_n_k_wos_lengths, // output
e_g_n_k_wos_strides, e_g_n_k_wos_strides,
......
...@@ -117,23 +117,23 @@ template <typename GridwiseGemm, ...@@ -117,23 +117,23 @@ template <typename GridwiseGemm,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_conv_fwd_dl_multiple_d( kernel_grouped_conv_fwd_dl_multiple_d(
const ABDataType* __restrict__ p_a_grid, const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
const index_t batch_count, const index_t batch_count,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11,
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx940__) || defined(__gfx1100__) || \ defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx940__) || defined(__gfx1100__) || \
......
...@@ -93,18 +93,18 @@ template <typename GridwiseGemm, ...@@ -93,18 +93,18 @@ template <typename GridwiseGemm,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_conv_fwd_dl( kernel_grouped_conv_fwd_dl(
const ABDataType* __restrict__ p_a_grid, const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
CDataType* __restrict__ p_c_grid, CDataType* __restrict__ p_c_grid,
const index_t batch_count, const index_t batch_count,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \
defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)) defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__))
......
...@@ -131,29 +131,29 @@ template <typename GridwiseGemm, ...@@ -131,29 +131,29 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batch_gemm_multiple_d_xdl_cshuffle( kernel_batch_gemm_multiple_d_xdl_cshuffle(
const ABDataType* __restrict__ p_a_grid, const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
RsPointer p_rs_grid, RsPointer p_rs_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
const QsElementwiseOperation qs_element_op, const QsElementwiseOperation qs_element_op,
const RsElementwiseOperation rs_element_op, const RsElementwiseOperation rs_element_op,
const index_t batch_count, const index_t batch_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_, e_grid_desc_mblock_mperblock_nblock_nperblock_,
const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock,
const Block2ETileMap block_2_ctile_map, const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
......
...@@ -115,25 +115,25 @@ template <typename GridwiseGemm, ...@@ -115,25 +115,25 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle( kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle(
const ABDataType* __restrict__ p_a_grid, const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
const index_t batch_count, const index_t batch_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_, e_grid_desc_mblock_mperblock_nblock_nperblock_,
const Block2ETileMap block_2_ctile_map, const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
......
...@@ -31,13 +31,13 @@ template <typename GridwiseGemm, ...@@ -31,13 +31,13 @@ template <typename GridwiseGemm,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_multiple_d_dl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, kernel_grouped_gemm_multiple_d_dl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count, const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op) const CDEElementwiseOperation cde_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__) || defined(__gfx1101__) || \ defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__) || defined(__gfx1101__) || \
......
...@@ -32,16 +32,16 @@ template <typename GridwiseGemm, ...@@ -32,16 +32,16 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1( kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count, const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op, const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op, const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op) const CElementwiseOperation c_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
......
...@@ -30,13 +30,13 @@ template <typename GridwiseGemm, ...@@ -30,13 +30,13 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count, const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op) const CDEElementwiseOperation c_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
......
...@@ -29,10 +29,10 @@ template <typename GridwiseGemm, ...@@ -29,10 +29,10 @@ template <typename GridwiseGemm,
InMemoryDataOperationEnum CGlobalMemoryDataOperation> InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count) const index_t group_count)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
......
...@@ -36,25 +36,25 @@ template <typename GridwiseGemm, ...@@ -36,25 +36,25 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_contraction_multiple_d_xdl_cshuffle( kernel_contraction_multiple_d_xdl_cshuffle(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatDsPointer p_ds_grid, FloatDsPointer p_ds_grid,
FloatE* __restrict__ p_e_grid, FloatE* __restrict__ p_e_grid,
const index_t batch_count, const index_t batch_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
const AGridDesc_AKB_AK0_M_AK1 a_grid_desc_akb_ak0_m_ak1, const AGridDesc_AKB_AK0_M_AK1 a_grid_desc_akb_ak0_m_ak1,
const BGridDesc_BKB_BK0_N_BK1 b_grid_desc_bkb_bk0_n_bk1, const BGridDesc_BKB_BK0_N_BK1 b_grid_desc_bkb_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2ETileMap block_2_etile_map) const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment