"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "457ee9a19a32f0f0a2741f2bb244e5cf10f0885e"
Commit ed305f6b authored by Umang Yadav's avatar Umang Yadav
Browse files

formatting

parent 9f4e3544
...@@ -50,9 +50,9 @@ ...@@ -50,9 +50,9 @@
#define CK_BUFFER_RESOURCE_3RD_DWORD -1 #define CK_BUFFER_RESOURCE_3RD_DWORD -1
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ #elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) // for GPU code defined(__gfx942__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx1030__) // for GPU code #elif defined(__gfx1030__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code #elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000
...@@ -86,7 +86,7 @@ ...@@ -86,7 +86,7 @@
#endif #endif
// WMMA instruction // WMMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code #ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_WMMA #define CK_USE_AMD_WMMA
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code #elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
#define CK_USE_AMD_WMMA #define CK_USE_AMD_WMMA
...@@ -107,7 +107,7 @@ ...@@ -107,7 +107,7 @@
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ #elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) // for GPU code defined(__gfx942__) // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#else // for GPU code #else // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0 #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
#endif #endif
......
...@@ -108,13 +108,13 @@ struct TensorAdaptor ...@@ -108,13 +108,13 @@ struct TensorAdaptor
__host__ __device__ static constexpr index_t GetNumOfHiddenDimension() __host__ __device__ static constexpr index_t GetNumOfHiddenDimension()
{ {
constexpr auto all_low_dim_ids = constexpr auto all_low_dim_ids = unpack(
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, [](auto&&... xs) constexpr { return merge_sequences(xs...); },
LowerDimensionHiddenIdss{}); LowerDimensionHiddenIdss{});
constexpr auto all_up_dim_ids = constexpr auto all_up_dim_ids = unpack(
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, [](auto&&... xs) constexpr { return merge_sequences(xs...); },
UpperDimensionHiddenIdss{}); UpperDimensionHiddenIdss{});
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids); constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
...@@ -338,7 +338,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a ...@@ -338,7 +338,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran]; TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran];
// sequence in, sequence out // sequence in, sequence out
constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr { constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr
{
auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1); auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1);
// shift hidden id so every dim id is unique // shift hidden id so every dim id is unique
...@@ -360,7 +361,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a ...@@ -360,7 +361,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
}); });
return low_dim_hidden_ids_1_mod_; return low_dim_hidden_ids_1_mod_;
}(); }
();
return generate_sequence_v2( return generate_sequence_v2(
[&](auto i) constexpr { return Number<low_dim_hidden_ids_1_mod[i]>{}; }, [&](auto i) constexpr { return Number<low_dim_hidden_ids_1_mod[i]>{}; },
...@@ -382,7 +384,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a ...@@ -382,7 +384,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran]; TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran];
// sequence in, constexpr tuple out // sequence in, constexpr tuple out
constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr { constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr
{
auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1); auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1);
// shift hidden id // shift hidden id
...@@ -391,7 +394,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a ...@@ -391,7 +394,8 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
}); });
return up_dim_hidden_ids_1_mod_; return up_dim_hidden_ids_1_mod_;
}(); }
();
// constexpr tuple to sequence // constexpr tuple to sequence
return generate_sequence_v2( return generate_sequence_v2(
......
...@@ -94,8 +94,10 @@ struct SpaceFillingCurve ...@@ -94,8 +94,10 @@ struct SpaceFillingCurve
// Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the // Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
// idim-th element of multidimensional index. // idim-th element of multidimensional index.
// All constexpr variables have to be captured by VALUE. // All constexpr variables have to be captured by VALUE.
constexpr auto compute_index = [idx_1d, access_strides](auto idim) constexpr { constexpr auto compute_index = [ idx_1d, access_strides ](auto idim) constexpr
constexpr auto compute_index_impl = [idx_1d, access_strides](auto jdim) constexpr { {
constexpr auto compute_index_impl = [ idx_1d, access_strides ](auto jdim) constexpr
{
auto res = idx_1d.value; auto res = idx_1d.value;
auto id = 0; auto id = 0;
......
...@@ -14,8 +14,8 @@ namespace device { ...@@ -14,8 +14,8 @@ namespace device {
struct BaseArgument struct BaseArgument
{ {
BaseArgument() = default; BaseArgument() = default;
BaseArgument(const BaseArgument&) = default; BaseArgument(const BaseArgument&) = default;
BaseArgument& operator=(const BaseArgument&) = default; BaseArgument& operator=(const BaseArgument&) = default;
virtual ~BaseArgument() {} virtual ~BaseArgument() {}
...@@ -26,8 +26,8 @@ struct BaseArgument ...@@ -26,8 +26,8 @@ struct BaseArgument
#ifndef __HIPCC_RTC__ #ifndef __HIPCC_RTC__
struct BaseInvoker struct BaseInvoker
{ {
BaseInvoker() = default; BaseInvoker() = default;
BaseInvoker(const BaseInvoker&) = default; BaseInvoker(const BaseInvoker&) = default;
BaseInvoker& operator=(const BaseInvoker&) = default; BaseInvoker& operator=(const BaseInvoker&) = default;
virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{}) virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{})
...@@ -41,13 +41,12 @@ struct BaseInvoker ...@@ -41,13 +41,12 @@ struct BaseInvoker
struct BaseOperator struct BaseOperator
{ {
BaseOperator() = default; BaseOperator() = default;
BaseOperator(const BaseOperator&) = default; BaseOperator(const BaseOperator&) = default;
BaseOperator& operator=(const BaseOperator&) = default; BaseOperator& operator=(const BaseOperator&) = default;
virtual bool IsSupportedArgument(const BaseArgument*) { return false; } virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
#ifndef __HIPCC_RTC__ #ifndef __HIPCC_RTC__
virtual std::string GetTypeString() const { return ""; } virtual std::string GetTypeString() const { return ""; }
...@@ -66,7 +65,7 @@ struct BaseOperator ...@@ -66,7 +65,7 @@ struct BaseOperator
virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const
{ {
//assert(p_arg); // assert(p_arg);
p_arg->p_workspace_ = p_workspace; p_arg->p_workspace_ = p_workspace;
} }
......
...@@ -38,25 +38,25 @@ template <typename GridwiseGemm, ...@@ -38,25 +38,25 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_contraction_multiple_d_xdl_cshuffle( kernel_contraction_multiple_d_xdl_cshuffle(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatDsPointer p_ds_grid, FloatDsPointer p_ds_grid,
FloatE* __restrict__ p_e_grid, FloatE* __restrict__ p_e_grid,
const index_t batch_count, const index_t batch_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2ETileMap block_2_etile_map) const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
......
...@@ -60,21 +60,21 @@ template <typename GridwiseGemm, ...@@ -60,21 +60,21 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_e_permute_xdl(const ABDataType* __restrict__ p_a_grid, kernel_batched_gemm_e_permute_xdl(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
const index_t batch_count, const index_t batch_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2ETileMap block_2_etile_map) const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
......
...@@ -41,25 +41,26 @@ template <typename GridwiseGemm, ...@@ -41,25 +41,26 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, kernel_gemm_gemm_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b1_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, const FloatAB* __restrict__ p_b1_grid,
const AElementwiseOperation a_element_op, FloatC* __restrict__ p_c_grid,
const BElementwiseOperation b_element_op, const AElementwiseOperation a_element_op,
const AccElementwiseOperation acc_element_op, const BElementwiseOperation b_element_op,
const B1ElementwiseOperation b1_element_op, const AccElementwiseOperation acc_element_op,
const CElementwiseOperation c_element_op, const B1ElementwiseOperation b1_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const CElementwiseOperation c_element_op,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const Block2CTileMap block_2_ctile_map, c_grid_desc_mblock_mperblock_nblock_nperblock,
const index_t batch_count, const Block2CTileMap block_2_ctile_map,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
......
...@@ -63,24 +63,24 @@ template <typename GridwiseGemm, ...@@ -63,24 +63,24 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_xdl(const ABDataType* __restrict__ p_a_grid, kernel_batched_gemm_xdl(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
const index_t batch_count, const index_t batch_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_, e_grid_desc_mblock_mperblock_nblock_nperblock_,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2ETileMap block_2_etile_map) const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
......
...@@ -52,23 +52,23 @@ template <typename GridwiseGemm, ...@@ -52,23 +52,23 @@ template <typename GridwiseGemm,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_dl_multiple_d( kernel_gemm_dl_multiple_d(
const ABDataType* __restrict__ p_a_grid, const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
const index_t batch_count, const index_t batch_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11,
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \ defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
......
...@@ -41,32 +41,32 @@ template <typename GridwiseGemm, ...@@ -41,32 +41,32 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_gemm_xdl_cshuffle_v1( kernel_batched_gemm_gemm_xdl_cshuffle_v1(
const A0B0B1DataType* __restrict__ p_a0_grid, const A0B0B1DataType* __restrict__ p_a0_grid,
const A0B0B1DataType* __restrict__ p_b0_grid, const A0B0B1DataType* __restrict__ p_b0_grid,
D0sPointer p_d0s_grid, D0sPointer p_d0s_grid,
const A0B0B1DataType* __restrict__ p_b1_grid, const A0B0B1DataType* __restrict__ p_b1_grid,
D1sPointer p_d1s_grid, D1sPointer p_d1s_grid,
E1DataType* __restrict__ p_e1_grid, E1DataType* __restrict__ p_e1_grid,
const A0ElementwiseOperation a0_element_op, const A0ElementwiseOperation a0_element_op,
const B0ElementwiseOperation b0_element_op, const B0ElementwiseOperation b0_element_op,
const CDE0ElementwiseOperation cde0_element_op, const CDE0ElementwiseOperation cde0_element_op,
const B1ElementwiseOperation b1_element_op, const B1ElementwiseOperation b1_element_op,
const CDE1ElementwiseOperation cde1_element_op, const CDE1ElementwiseOperation cde1_element_op,
const A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1, const A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1,
const B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1, const B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1,
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
d1s_grid_desc_mblock_mperblock_nblock_nperblock, d1s_grid_desc_mblock_mperblock_nblock_nperblock,
const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e1_grid_desc_mblock_mperblock_nblock_nperblock, e1_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2E1TileMap block_2_e1tile_map, const Block2E1TileMap block_2_e1tile_map,
const index_t batch_count, const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
......
...@@ -38,26 +38,26 @@ template <typename GridwiseGemm, ...@@ -38,26 +38,26 @@ template <typename GridwiseGemm,
bool HasMainK0BlockLoop> bool HasMainK0BlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_reduce_xdl_cshuffle_v1( kernel_batched_gemm_reduce_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
ReducePtrsGlobal p_reduces_grid, ReducePtrsGlobal p_reduces_grid,
const index_t batch_count, const index_t batch_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const ReduceInElementwiseOperations reduce_in_element_ops, const ReduceInElementwiseOperations reduce_in_element_ops,
const ReduceAccElementwiseOperations reduce_out_element_ops, const ReduceAccElementwiseOperations reduce_out_element_ops,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, const ComputeBasePrtOfBatch compute_base_ptr_of_batch_,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
......
...@@ -42,30 +42,30 @@ template <typename GridwiseGemm, ...@@ -42,30 +42,30 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1( kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid, const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
D0sPointer p_d0s_grid, D0sPointer p_d0s_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const C0DEElementwiseOperation c0de_element_op, const C0DEElementwiseOperation c0de_element_op,
const B1ElementwiseOperation b1_element_op, const B1ElementwiseOperation b1_element_op,
const C1DEElementwiseOperation c1de_element_op, const C1DEElementwiseOperation c1de_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock, c1_grid_desc_mblock_mperblock_nblock_nperblock,
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask) const C0MatrixMask c0_matrix_mask)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
......
...@@ -40,27 +40,27 @@ template <typename GridwiseGemm, ...@@ -40,27 +40,27 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1( kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid, const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op, const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op, const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask) const C0MatrixMask c0_matrix_mask)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...@@ -611,7 +611,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -611,7 +611,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return true; return true;
} }
static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_) static constexpr bool
IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_)
{ {
// check vector load/store // check vector load/store
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
...@@ -842,7 +843,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -842,7 +843,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
template <class ADesc, class BDesc, class B1Desc, class CDesc> template <class ADesc, class BDesc, class B1Desc, class CDesc>
struct Descriptor struct Descriptor
{ {
template<class AGridDescriptor> template <class AGridDescriptor>
static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDescriptor& a_grid_desc) static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDescriptor& a_grid_desc)
{ {
const auto a_grid_desc_m_k = DeviceOp::matrix_padder.PadADescriptor_M_K(a_grid_desc); const auto a_grid_desc_m_k = DeviceOp::matrix_padder.PadADescriptor_M_K(a_grid_desc);
...@@ -852,14 +853,15 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -852,14 +853,15 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const auto AK0 = K / AK1; const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k, return transform_tensor_descriptor(
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), a_grid_desc_m_k,
make_pass_through_transform(M)), make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_pass_through_transform(M)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
template<class BGridDescriptor> template <class BGridDescriptor>
static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const BGridDescriptor& b_grid_desc) static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const BGridDescriptor& b_grid_desc)
{ {
const auto b_grid_desc_n_k = DeviceOp::matrix_padder.PadBDescriptor_N_K(b_grid_desc); const auto b_grid_desc_n_k = DeviceOp::matrix_padder.PadBDescriptor_N_K(b_grid_desc);
...@@ -869,14 +871,15 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -869,14 +871,15 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const auto BK0 = K / BK1; const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k, return transform_tensor_descriptor(
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), b_grid_desc_n_k,
make_pass_through_transform(N)), make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_pass_through_transform(N)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
template<class B1GridDescriptor> template <class B1GridDescriptor>
static constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDescriptor& b1_grid_desc) static constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDescriptor& b1_grid_desc)
{ {
const auto b1_grid_desc_n_k = DeviceOp::matrix_padder.PadB1Descriptor_N_K(b1_grid_desc); const auto b1_grid_desc_n_k = DeviceOp::matrix_padder.PadB1Descriptor_N_K(b1_grid_desc);
...@@ -889,26 +892,24 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -889,26 +892,24 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return transform_tensor_descriptor( return transform_tensor_descriptor(
b1_grid_desc_n_k, b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)), make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
template<class CGridDescriptor> template <class CGridDescriptor>
static constexpr auto MakeCGridDescriptor_M_N(const CGridDescriptor& c_grid_desc) static constexpr auto MakeCGridDescriptor_M_N(const CGridDescriptor& c_grid_desc)
{ {
return DeviceOp::matrix_padder.PadCDescriptor_M_N(c_grid_desc); return DeviceOp::matrix_padder.PadCDescriptor_M_N(c_grid_desc);
} }
using AGridDesc_AK0_M_AK1 = using AGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1(ADesc{}))>; remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1(ADesc{}))>;
using BGridDesc_BK0_N_BK1 = using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1(BDesc{}))>; remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1(BDesc{}))>;
using B1GridDesc_BK0_N_BK1 = using B1GridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeB1GridDescriptor_BK0_N_BK1(B1Desc{}))>; remove_cvref_t<decltype(MakeB1GridDescriptor_BK0_N_BK1(B1Desc{}))>;
using CGridDesc_M_N = using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(CDesc{}))>;
remove_cvref_t<decltype(MakeCGridDescriptor_M_N(CDesc{}))>;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
...@@ -979,8 +980,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -979,8 +980,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CGridDesc_M_N c_grid_desc_m_n; CGridDesc_M_N c_grid_desc_m_n;
C0MatrixMask c0_matrix_mask; C0MatrixMask c0_matrix_mask;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_descriptor_mblock_mperblock_nblock_nperblock; typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_descriptor_mblock_mperblock_nblock_nperblock;
// element-wise op // element-wise op
AElementwiseOperation a_element_op; AElementwiseOperation a_element_op;
BElementwiseOperation b_element_op; BElementwiseOperation b_element_op;
...@@ -1002,10 +1004,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1002,10 +1004,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)}, b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)},
b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)}, b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)},
c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)}, c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)},
block_2_ctile_map{GridwiseGemm::MakeDefaultBlock2CTileMap( block_2_ctile_map{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n)},
c_grid_desc_m_n)},
c_grid_descriptor_mblock_mperblock_nblock_nperblock{ c_grid_descriptor_mblock_mperblock_nblock_nperblock{
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n)}, GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n)},
has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop( has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))}, a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))},
c0_matrix_mask{c.GetLength(I1)}, c0_matrix_mask{c.GetLength(I1)},
...@@ -1013,23 +1015,20 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1013,23 +1015,20 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
b_element_op{b_element_op_}, b_element_op{b_element_op_},
b1_element_op{b1_element_op_}, b1_element_op{b1_element_op_},
c_element_op{c_element_op_}, c_element_op{c_element_op_},
is_valid{GridwiseGemm::CheckValidity( is_valid{GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1,
a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1,
b_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1, c_grid_desc_m_n,
c_grid_desc_m_n, block_2_ctile_map) and
block_2_ctile_map) and IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1),
IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1),
b_grid_desc_bk0_n_bk1.GetLength(I1), b_grid_desc_bk0_n_bk1.GetLength(I1),
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2), a_grid_desc_ak0_m_ak1.GetLength(I0) *
a_grid_desc_ak0_m_ak1.GetLength(I2),
b1_grid_desc_bk0_n_bk1.GetLength(I1))} b1_grid_desc_bk0_n_bk1.GetLength(I1))}
{ {
} }
constexpr bool IsValid() const constexpr bool IsValid() const { return is_valid; }
{
return is_valid;
}
}; };
template <class ADesc, class BDesc, class B1Desc, class CDesc> template <class ADesc, class BDesc, class B1Desc, class CDesc>
...@@ -1038,10 +1037,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1038,10 +1037,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
BDesc b, BDesc b,
B1Desc b1, B1Desc b1,
CDesc c, CDesc c,
AElementwiseOperation a_element_op = AElementwiseOperation{}, AElementwiseOperation a_element_op = AElementwiseOperation{},
BElementwiseOperation b_element_op = BElementwiseOperation{}, BElementwiseOperation b_element_op = BElementwiseOperation{},
B1ElementwiseOperation b1_element_op = B1ElementwiseOperation{}, B1ElementwiseOperation b1_element_op = B1ElementwiseOperation{},
CElementwiseOperation c_element_op = CElementwiseOperation{}) CElementwiseOperation c_element_op = CElementwiseOperation{})
{ {
return Descriptor<ADesc, BDesc, B1Desc, CDesc>( return Descriptor<ADesc, BDesc, B1Desc, CDesc>(
a, b, b1, c, a_element_op, b_element_op, b1_element_op, c_element_op); a, b, b1, c, a_element_op, b_element_op, b1_element_op, c_element_op);
...@@ -1061,41 +1060,43 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1061,41 +1060,43 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
if(desc.has_main_k_block_loop) if(desc.has_main_k_block_loop)
{ {
Desc::GridwiseGemm::template Run<true>(p_a_grid, Desc::GridwiseGemm::template Run<true>(
p_b_grid, p_a_grid,
p_b1_grid, p_b_grid,
p_c_grid, p_b1_grid,
p_shared_block, p_c_grid,
desc.a_element_op, p_shared_block,
desc.b_element_op, desc.a_element_op,
acc_element_op, desc.b_element_op,
desc.b1_element_op, acc_element_op,
desc.c_element_op, desc.b1_element_op,
desc.a_grid_desc_ak0_m_ak1, desc.c_element_op,
desc.b_grid_desc_bk0_n_bk1, desc.a_grid_desc_ak0_m_ak1,
desc.b1_grid_desc_bk0_n_bk1, desc.b_grid_desc_bk0_n_bk1,
desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock, desc.b1_grid_desc_bk0_n_bk1,
desc.block_2_ctile_map, desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock,
desc.c0_matrix_mask); desc.block_2_ctile_map,
desc.c0_matrix_mask);
} }
else else
{ {
Desc::GridwiseGemm::template Run<false>(p_a_grid, Desc::GridwiseGemm::template Run<false>(
p_b_grid, p_a_grid,
p_b1_grid, p_b_grid,
p_c_grid, p_b1_grid,
p_shared_block, p_c_grid,
desc.a_element_op, p_shared_block,
desc.b_element_op, desc.a_element_op,
acc_element_op, desc.b_element_op,
desc.b1_element_op, acc_element_op,
desc.c_element_op, desc.b1_element_op,
desc.a_grid_desc_ak0_m_ak1, desc.c_element_op,
desc.b_grid_desc_bk0_n_bk1, desc.a_grid_desc_ak0_m_ak1,
desc.b1_grid_desc_bk0_n_bk1, desc.b_grid_desc_bk0_n_bk1,
desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock, desc.b1_grid_desc_bk0_n_bk1,
desc.block_2_ctile_map, desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock,
desc.c0_matrix_mask); desc.block_2_ctile_map,
desc.c0_matrix_mask);
} }
} }
}; };
......
...@@ -48,9 +48,9 @@ namespace device { ...@@ -48,9 +48,9 @@ namespace device {
template <typename DeviceOp, typename GridwiseGemm, bool HasMainKBlockLoop> template <typename DeviceOp, typename GridwiseGemm, bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg) kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
......
...@@ -34,23 +34,23 @@ template <typename GridwiseGemm, ...@@ -34,23 +34,23 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_contraction_multiple_d_xdl_cshuffle( kernel_contraction_multiple_d_xdl_cshuffle(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatDsPointer p_ds_grid, FloatDsPointer p_ds_grid,
FloatE* __restrict__ p_e_grid, FloatE* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map) const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
......
...@@ -404,7 +404,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -404,7 +404,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun, false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN, BBlockLdsAddExtraN,
Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder,
7, // CThreadTransferSrcDstVectorDim, 7, // CThreadTransferSrcDstVectorDim,
......
...@@ -436,7 +436,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -436,7 +436,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
BlockSize, BlockSize,
ABDataType, // TODO: distinguish A/B datatype ABDataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, // TODO: Add ShuffleType for DeviceConv2d CDataType, // TODO: Add ShuffleType for DeviceConv2d
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
......
...@@ -354,7 +354,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -354,7 +354,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
2, // BBlockTransferSrcVectorDim, 2, // BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun, false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN, BBlockLdsAddExtraN,
Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder,
7, // CThreadTransferSrcDstVectorDim, 7, // CThreadTransferSrcDstVectorDim,
......
...@@ -37,23 +37,23 @@ template <typename GridwiseGemm, ...@@ -37,23 +37,23 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdlops_v2r3_for_conv3d( kernel_gemm_xdlops_v2r3_for_conv3d(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const index_t num_batches, const index_t num_batches,
const index_t a_batch_stride, const index_t a_batch_stride,
const index_t b_batch_stride, const index_t b_batch_stride,
const index_t c_batch_stride, const index_t c_batch_stride,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
......
...@@ -1005,7 +1005,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1005,7 +1005,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun, false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN, BBlockLdsAddExtraN,
Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder,
7, // CThreadTransferSrcDstVectorDim, 7, // CThreadTransferSrcDstVectorDim,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment