Commit 83be9a70 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot Committed by Bartłomiej Kocot
Browse files

Support multi AB for grouped conv fwd xdl

parent 98fd41f5
...@@ -6,18 +6,42 @@ ...@@ -6,18 +6,42 @@
#include <array> #include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/utility/is_detected.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// Convolution Forward: template <typename T>
// input : input image A[G, N, C, Hi, Wi], using is_tuple = decltype(std::declval<T&>().IsTuple());
// input : weight B[G, K, C, Y, X],
// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ... /**
// output : output image E[G, N, K, Ho, Wo] * \brief Grouped Convolution Forward
// C = a_op(A) * b_op(B) *
// E = cde_op(C, D0, D1, ...) * \details
* input : input image A[G, N, C, Hi, Wi], A1[G, N, C, Hi, Wi]...
* input : weight B[G, K, C, Y, X], B1[G, K, C, Y, X]...
* input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
* output : output image E[G, N, K, Ho, Wo]
*
* C = a_op(A, A1...) * b_op(B, B1...)
* E = cde_op(C, D0, D1, ...)
*
* \tparam NDimSpatial Number of spatial dimensions.
* \tparam ALayout Input layout (also for a1, a2...).
* \tparam BLayout Weight layout (also for b1, b2...).
* \tparam DsLayout Ds layouts.
* \tparam ELayout Output layout.
* \tparam ADataType Input data type. Pass tuple if there is multiple A.
* \tparam BDataType Weight data type. Pass tuple if there is multiple B.
* \tparam DsDataType D data types.
* \tparam EDataType Output data type.
* \tparam AElementwiseOperation A elementwise operation.
* \tparam BElementwiseOperation B elementwise operation.
* \tparam CDEElementwiseOperation CDE elementwise operation.
* \tparam ComputeType Compute data type (default: ADataType, first if tuple passed).
*/
template <index_t NDimSpatial, template <index_t NDimSpatial,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
...@@ -30,18 +54,60 @@ template <index_t NDimSpatial, ...@@ -30,18 +54,60 @@ template <index_t NDimSpatial,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
typename ComputeType = ADataType> typename ComputeType =
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
Number<0>,
ADataType>())> // ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
struct DeviceGroupedConvFwdMultipleD : public BaseOperator struct DeviceGroupedConvFwdMultipleD : public BaseOperator
{ {
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor"); static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor");
// If DataType is tuple, user has to pass std::array with pointers.
using APointers =
std::conditional_t<isMultiA, std::array<const void*, NumATensor>&, const void*>;
using BPointers =
std::conditional_t<isMultiB, std::array<const void*, NumBTensor>&, const void*>;
/**
* \brief Make argument pointer for grouped conv fwd.
*
* \param p_a A pointer to the input (std::array<const void*, NumA> with
pointers for multiple A).
* \param p_b A pointer to the weight (std::array<const void*, NumA> with
pointers for multiple B).
* \param p_ds A pointers to the Ds.
* \param p_e A pointers to the output.
* \param a_g_n_c_wis_lengths Input lengths [G, N, C, Spatial...] (for 3d).
* \param a_g_n_c_wis_strides Input strides [G, N, C, Spatial...] (for 3d).
* \param b_g_k_c_xs_lengths Weight lengths [G, K, C, Spatial...] (for 3d).
* \param b_g_k_c_xs_strides Weight strides [G, K, C, Spatial...] (for 3d).
* \param ds_g_n_k_wos_lengths Ds lengths [G, N, K, Spatial...] (for 3d).
* \param ds_g_n_k_wos_strides Ds strides [G, N, K, Spatial...] (for 3d).
* \param e_g_n_k_wos_lengths Output lengths [G, N, K, Spatial...] (for 3d).
* \param e_g_n_k_wos_strides Output strides [G, N, K, Spatial...] (for 3d).
* \param conv_filter_strides Convolution filter strides.
* \param conv_filter_dilations Convolution filter dilations.
* \param input_left_pads Input left paddings.
* \param input_right_pads Input right paddings.
* \param a_element_op A elementwise operation object.
* \param b_element_op B elementwise operation object.
* \param cde_element_op CDE elementwise operation object.
* \return Pointer to the argument.
*/
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer( virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a, // input image APointers p_a,
const void* p_b, // weight BPointers p_b,
const std::array<const void*, NumDTensor>& p_ds, const std::array<const void*, NumDTensor>& p_ds,
void* p_e, // output image void* p_e,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
......
...@@ -517,7 +517,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ...@@ -517,7 +517,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
std::vector<typename GridwiseGemm::DefaultBlock2CTileMap> block_2_ctile_map_container_; std::vector<typename GridwiseGemm::DefaultBlock2CTileMap> block_2_ctile_map_container_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch<NumDTensor> compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
// element-wise op // element-wise op
AElementwiseOp a_element_op_; AElementwiseOp a_element_op_;
...@@ -579,7 +579,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ...@@ -579,7 +579,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
ComputePtrOffsetOfStridedBatch<NumDTensor>, ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
has_main_loop>; has_main_loop>;
return launch_and_time_kernel( return launch_and_time_kernel(
......
...@@ -677,7 +677,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -677,7 +677,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
std::vector<Block2ETileMap> block_2_etile_map_container_; std::vector<Block2ETileMap> block_2_etile_map_container_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch<NumDTensor> compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
// element-wise op // element-wise op
AElementwiseOp a_element_op_; AElementwiseOp a_element_op_;
...@@ -746,7 +746,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -746,7 +746,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
Block2ETileMap, Block2ETileMap,
ComputePtrOffsetOfStridedBatch<NumDTensor>, ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
has_main_loop>; has_main_loop>;
return launch_and_time_kernel( return launch_and_time_kernel(
......
...@@ -927,7 +927,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -927,7 +927,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
Block2CTileMap block_2_ctile_map_; Block2CTileMap block_2_ctile_map_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_;
// element-wise op // element-wise op
OutElementwiseOperation a_element_op_; OutElementwiseOperation a_element_op_;
...@@ -999,7 +999,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -999,7 +999,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
remove_reference_t<DeviceOp::BGridDesc_B_K0_N0_N1_K1>, remove_reference_t<DeviceOp::BGridDesc_B_K0_N0_N1_K1>,
remove_reference_t<DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11>, remove_reference_t<DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DeviceOp::Block2CTileMap>, remove_reference_t<DeviceOp::Block2CTileMap>,
ComputePtrOffsetOfStridedBatch<I0>, ComputePtrOffsetOfStridedBatch<>,
has_main_loop, has_main_loop,
has_double_loop>; has_double_loop>;
......
...@@ -565,7 +565,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle ...@@ -565,7 +565,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
Block2CTileMap block_2_ctile_map_; Block2CTileMap block_2_ctile_map_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_;
OutElementwiseOperation a_element_op_; OutElementwiseOperation a_element_op_;
InElementwiseOperation b_element_op_; InElementwiseOperation b_element_op_;
...@@ -647,7 +647,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle ...@@ -647,7 +647,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
ComputePtrOffsetOfStridedBatch<I0>, ComputePtrOffsetOfStridedBatch<>,
has_main_loop>; has_main_loop>;
using EmptyTuple = Tuple<>; using EmptyTuple = Tuple<>;
......
...@@ -1197,7 +1197,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1197,7 +1197,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
Block2CTileMap block_2_ctile_map_; Block2CTileMap block_2_ctile_map_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
...@@ -1276,7 +1276,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1276,7 +1276,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
remove_reference_t<DeviceOp::Block2CTileMap>, remove_reference_t<DeviceOp::Block2CTileMap>,
ComputePtrOffsetOfStridedBatch<I0>, ComputePtrOffsetOfStridedBatch<>,
has_main_loop>; has_main_loop>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
......
...@@ -537,7 +537,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -537,7 +537,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
DefaultBlock2CTileMap block_2_ctile_map_; DefaultBlock2CTileMap block_2_ctile_map_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch<NumDTensor> compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
// element-wise op // element-wise op
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
...@@ -601,7 +601,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -601,7 +601,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
DeviceOp::DsGridDesc_M0_M10_M11_N0_N10_N11, DeviceOp::DsGridDesc_M0_M10_M11_N0_N10_N11,
DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11, DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11,
DefaultBlock2CTileMap, DefaultBlock2CTileMap,
ComputePtrOffsetOfStridedBatch<NumDTensor>, ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
has_main_loop, has_main_loop,
has_double_loop>; has_double_loop>;
......
...@@ -428,7 +428,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -428,7 +428,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
typename GridwiseOp::DefaultBlock2CTileMap block_2_etile_map_; typename GridwiseOp::DefaultBlock2CTileMap block_2_etile_map_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch<NumDTensor> compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
// element-wise op // element-wise op
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
...@@ -485,7 +485,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -485,7 +485,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
ComputePtrOffsetOfStridedBatch<NumDTensor>, ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
has_main_loop>; has_main_loop>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -56,7 +57,8 @@ namespace { ...@@ -56,7 +57,8 @@ namespace {
* *
*/ */
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename ABDataType, typename AsPointer, // tuples if multi AB, pointers if no
typename BsPointer,
typename DsPointer, typename DsPointer,
typename EDataType, typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
...@@ -68,14 +70,16 @@ template <typename GridwiseGemm, ...@@ -68,14 +70,16 @@ template <typename GridwiseGemm,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap, typename Block2ETileMap,
typename ComputePtrOffsetOfBatch, typename ComputePtrOffsetOfBatch,
bool HasMainKBlockLoop> bool HasMainKBlockLoop,
bool isMultiA,
bool isMultiB>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle( kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle(
const ABDataType* __restrict__ p_a_grid, AsPointer p_as_grid,
const ABDataType* __restrict__ p_b_grid, BsPointer p_bs_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
...@@ -98,13 +102,8 @@ __global__ void ...@@ -98,13 +102,8 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -117,8 +116,26 @@ __global__ void ...@@ -117,8 +116,26 @@ __global__ void
static_for<0, NumDTensor, 1>{}( static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, if constexpr(isMultiA || isMultiB)
p_b_grid + b_batch_offset, {
AsPointer p_as_grid_grp;
BsPointer p_bs_grid_grp;
const auto as_batch_offset = compute_ptr_offset_of_batch.GetAsPtrOffset(g_idx);
static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size();
static_for<0, NumATensor, 1>{}(
[&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i]; });
const auto bs_batch_offset = compute_ptr_offset_of_batch.GetBsPtrOffset(g_idx);
static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size();
static_for<0, NumBTensor, 1>{}(
[&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_batch_offset[i]; });
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_as_grid_grp,
p_bs_grid_grp,
p_ds_grid_grp, p_ds_grid_grp,
p_e_grid + e_batch_offset, p_e_grid + e_batch_offset,
p_shared, p_shared,
...@@ -130,9 +147,32 @@ __global__ void ...@@ -130,9 +147,32 @@ __global__ void
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock_, e_grid_desc_mblock_mperblock_nblock_nperblock_,
block_2_ctile_map); block_2_ctile_map);
}
else
{
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_as_grid + a_batch_offset,
p_bs_grid + b_batch_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock_,
block_2_ctile_map);
}
#else #else
ignore = p_a_grid; ignore = p_as_grid;
ignore = p_b_grid; ignore = p_bs_grid;
ignore = p_ds_grid; ignore = p_ds_grid;
ignore = p_e_grid; ignore = p_e_grid;
ignore = batch_count; ignore = batch_count;
...@@ -150,6 +190,9 @@ __global__ void ...@@ -150,6 +190,9 @@ __global__ void
} // namespace } // namespace
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
// //
// @brief Device Convolution operation. // @brief Device Convolution operation.
// //
...@@ -211,7 +254,12 @@ template <index_t NDimSpatial, ...@@ -211,7 +254,12 @@ template <index_t NDimSpatial,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
typename ComputeDataType = ADataType, typename ComputeDataType =
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
Number<0>,
ADataType>()), // ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
: public DeviceGroupedConvFwdMultipleD<NDimSpatial, : public DeviceGroupedConvFwdMultipleD<NDimSpatial,
...@@ -230,6 +278,11 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -230,6 +278,11 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
{ {
using DeviceOp = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle; using DeviceOp = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle;
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -325,51 +378,43 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -325,51 +378,43 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>; using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
// GridwiseGemm // If we are using multiAB and one of the template datatype parameters is not a tuple, convert
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< // it to it
ADataType, // TODO: distinguish A/B datatype using GemmADataType = std::conditional_t<!isMultiA && isMultiB, Tuple<ADataType>, ADataType>;
BDataType, using GemmBDataType = std::conditional_t<!isMultiB && isMultiA, Tuple<BDataType>, BDataType>;
ComputeDataType,
AccDataType, #define GridwiseGemmTemplateParameters \
CShuffleDataType, GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
EDataType, InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
AElementwiseOperation, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \
BElementwiseOperation, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
CDEElementwiseOperation, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
InMemoryDataOperationEnum::Set, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
NumGemmKPrefetchStage, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
BlockSize, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
MPerBlock, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
NPerBlock, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
KPerBlock, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
AK1, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
BK1, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched
MPerXDL, // Use appropriate gridwise gemm
NPerXDL, using GridwiseGemm =
MXdlPerWave, std::conditional_t<isMultiA || isMultiB,
NXdlPerWave, GridwiseGemmMultipleABD_xdl_cshuffle<GridwiseGemmTemplateParameters>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>>;
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, // If ADataTypes or BDataTypes is tuple, user has to pass std::array with pointers.
ABlockTransferSrcVectorDim, using APointers =
ABlockTransferSrcScalarPerVector, std::conditional_t<isMultiA, std::array<const void*, NumATensor>&, const void*>;
ABlockTransferDstScalarPerVector_AK1, using BPointers =
false, std::conditional_t<isMultiB, std::array<const void*, NumBTensor>&, const void*>;
ABlockLdsExtraM, // Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not
BBlockTransferThreadClusterLengths_BK0_N_BK1, // in initializer list what is required for single const pointer).
BBlockTransferThreadClusterArrangeOrder, using AGridPointer = remove_cvref_t<
BBlockTransferSrcAccessOrder, decltype(GetAGridPointer < isMultiA || isMultiB, GridwiseGemm, ADataType > ())>;
BBlockTransferSrcVectorDim, using BGridPointer = remove_cvref_t<
BBlockTransferSrcScalarPerVector, decltype(GetBGridPointer < isMultiA || isMultiB, GridwiseGemm, BDataType > ())>;
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
// desc for blockwise copy // desc for blockwise copy
using AGridDesc_AK0_M_AK1 = using AGridDesc_AK0_M_AK1 =
...@@ -392,8 +437,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -392,8 +437,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const void* p_a, Argument(APointers p_as,
const void* p_b, BPointers p_bs,
const std::array<const void*, NumDTensor>& p_ds, const std::array<const void*, NumDTensor>& p_ds,
void* p_e, void* p_e,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
...@@ -413,8 +458,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -413,8 +458,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) const CDEElementwiseOperation& cde_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a)}, : p_as_grid_{},
p_b_grid_{static_cast<const BDataType*>(p_b)}, p_bs_grid_{},
p_ds_grid_{}, p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)}, p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]}, num_group_{a_g_n_c_wis_lengths[0]},
...@@ -458,9 +503,58 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -458,9 +503,58 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
input_right_pads_{input_right_pads} input_right_pads_{input_right_pads}
{ {
// A/B/E Batch Stride // A/B/E Batch Stride
if constexpr(isMultiA || isMultiB)
{
static_for<0, NumATensor, 1>{}([&](auto i) {
// Init compute_ptr_offset_of_batch_ for multiple AB
compute_ptr_offset_of_batch_.BatchStrideA_(i) = a_g_n_c_wis_strides[0];
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
// type is not tuple)
using DataType = remove_cvref_t<tuple_element_t<i.value, GemmADataType>>;
// It is possible that one of the AB is a pointer and one is a tuple.
// Then also use multiAB but we have to cast single pointer instead of tuple of
// pointer.
if constexpr(isMultiA)
{
// p_as is tuple
p_as_grid_(i) = static_cast<const DataType*>(p_as[i.value]);
}
else
{
// if MultiB and not MultiA then p_as is single pointer
p_as_grid_(i) = static_cast<const DataType*>(p_as);
}
});
static_for<0, NumBTensor, 1>{}([&](auto i) {
// Init compute_ptr_offset_of_batch_ for multiple AB
compute_ptr_offset_of_batch_.BatchStrideB_(i) = b_g_k_c_xs_strides[0];
using DataType = remove_cvref_t<tuple_element_t<i.value, GemmBDataType>>;
// It is possible that one of the AB is a pointer and one is a tuple.
// Then also use multiAB but we have to cast single pointer instead of tuple of
// pointer.
if constexpr(isMultiB)
{
// p_bs is tuple
p_bs_grid_(i) = static_cast<const DataType*>(p_bs[i.value]);
}
else
{
// if MultiA and not MultiB then p_bs is single pointer
p_bs_grid_(i) = static_cast<const DataType*>(p_bs);
}
});
}
else
{
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
// p_as and p_bs are pointers
p_as_grid_(I0) = static_cast<const ADataType*>(p_as);
p_bs_grid_(I0) = static_cast<const BDataType*>(p_bs);
}
// populate pointer, batch stride, desc for Ds // populate pointer, batch stride, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
...@@ -477,8 +571,33 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -477,8 +571,33 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>( ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]); ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]);
}); });
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
// populate desc for Ds/E // populate desc for Ds/E
if constexpr(isMultiA || isMultiB)
{
const auto as_grid_desc_ak0_m_ak1 =
generate_tuple([&](auto) { return a_grid_desc_m_k_; }, Number<NumATensor>{});
const auto bs_grid_desc_bk0_n_bk1 =
generate_tuple([&](auto) { return b_grid_desc_n_k_; }, Number<NumBTensor>{});
if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_etile_map_))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_);
}
}
else
{
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_, b_grid_desc_n_k_,
ds_grid_desc_m_n_, ds_grid_desc_m_n_,
...@@ -494,6 +613,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -494,6 +613,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
ds_grid_desc_m_n_); ds_grid_desc_m_n_);
} }
} }
}
void Print() const void Print() const
{ {
...@@ -505,9 +625,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -505,9 +625,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
} }
// private: // private:
// pointers // pointers (tuple if multi AB, pointer if no)
const ADataType* p_a_grid_; AGridPointer p_as_grid_;
const BDataType* p_b_grid_; BGridPointer p_bs_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_; typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_; EDataType* p_e_grid_;
...@@ -529,7 +649,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -529,7 +649,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
Block2ETileMap block_2_etile_map_; Block2ETileMap block_2_etile_map_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch<NumDTensor> compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>
compute_ptr_offset_of_batch_;
// element-wise op // element-wise op
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
...@@ -563,16 +684,6 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -563,16 +684,6 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
arg.Print(); arg.Print();
} }
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting");
}
const index_t grid_size = const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_; arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_;
...@@ -582,9 +693,60 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -582,9 +693,60 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
auto launch_kernel = [&](auto has_main_k_block_loop) { auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value; constexpr bool has_main_loop = has_main_k_block_loop.value;
if constexpr(isMultiA || isMultiB)
{
// Generate tuples with grid descriptors for each A and B
const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
[&](auto) { return arg.a_grid_desc_ak0_m_ak1_; }, Number<NumATensor>{});
const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
[&](auto) { return arg.b_grid_desc_bk0_n_bk1_; }, Number<NumBTensor>{});
const auto kernel = kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle<
GridwiseGemm,
AGridPointer,
BGridPointer,
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
decltype(as_grid_desc_ak0_m_ak1),
decltype(bs_grid_desc_bk0_n_bk1),
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
has_main_loop,
isMultiA,
isMultiB>;
return launch_and_time_kernel(
stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_as_grid_,
arg.p_bs_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_g_n_c_wis_lengths_[0], // Group count
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_,
arg.compute_ptr_offset_of_batch_);
}
else
{
const auto kernel = kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle< const auto kernel = kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype const ADataType*,
const BDataType*,
typename GridwiseGemm::DsGridPointer, typename GridwiseGemm::DsGridPointer,
EDataType, EDataType,
AElementwiseOperation, AElementwiseOperation,
...@@ -595,16 +757,19 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -595,16 +757,19 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
Block2ETileMap, Block2ETileMap,
ComputePtrOffsetOfStridedBatch<NumDTensor>, ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
has_main_loop>; has_main_loop,
isMultiA,
isMultiB>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(
stream_config,
kernel, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
arg.p_a_grid_, arg.p_as_grid_.At(I0), // Pass just A descriptor instead of tuple
arg.p_b_grid_, arg.p_bs_grid_.At(I0), // Pass just B descriptor instead of tuple
arg.p_ds_grid_, arg.p_ds_grid_,
arg.p_e_grid_, arg.p_e_grid_,
arg.a_element_op_, arg.a_element_op_,
...@@ -617,6 +782,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -617,6 +782,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_, arg.block_2_etile_map_,
arg.compute_ptr_offset_of_batch_); arg.compute_ptr_offset_of_batch_);
}
}; };
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
...@@ -791,12 +957,28 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -791,12 +957,28 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
} }
// check Gridwise GEMM // check Gridwise GEMM
if constexpr(isMultiA || isMultiB)
{
// Genarate tuples with the same descriptors
const auto as_grid_desc_ak0_m_ak1 =
generate_tuple([&](auto) { return arg.a_grid_desc_m_k_; }, Number<NumATensor>{});
const auto bs_grid_desc_bk0_n_bk1 =
generate_tuple([&](auto) { return arg.b_grid_desc_n_k_; }, Number<NumBTensor>{});
return GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
}
else
{
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_, arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_, arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
arg.block_2_etile_map_); arg.block_2_etile_map_);
} }
}
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
{ {
...@@ -804,9 +986,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -804,9 +986,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
} }
static auto MakeArgument( static auto MakeArgument(
const void* p_a, APointers p_as,
const void* p_b, BPointers p_bs,
const std::array<const void*, NumDTensor>& p_ds, std::array<const void*, NumDTensor>& p_ds,
void* p_e, void* p_e,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
...@@ -824,8 +1006,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -824,8 +1006,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) const CDEElementwiseOperation& cde_element_op)
{ {
return Argument{p_a, return Argument{p_as,
p_b, p_bs,
p_ds, p_ds,
p_e, p_e,
a_g_n_c_wis_lengths, a_g_n_c_wis_lengths,
...@@ -848,8 +1030,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -848,8 +1030,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument> MakeArgumentPointer( std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a, APointers p_a,
const void* p_b, BPointers p_b,
const std::array<const void*, NumDTensor>& p_ds, const std::array<const void*, NumDTensor>& p_ds,
void* p_e, void* p_e,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
......
...@@ -9,31 +9,112 @@ namespace ck { ...@@ -9,31 +9,112 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <index_t NumDTensor> template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
struct ComputePtrOffsetOfStridedBatch struct ComputePtrOffsetOfStridedBatch
{ {
static constexpr bool isMultiAB = NumATensor > 1 || NumBTensor > 1;
ComputePtrOffsetOfStridedBatch() = default; ComputePtrOffsetOfStridedBatch() = default;
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB, index_t BatchStrideB,
Array<ck::index_t, NumDTensor> BatchStrideDs, Array<ck::index_t, NumDTensor> BatchStrideDs,
index_t BatchStrideE) index_t BatchStrideE)
: BatchStrideA_(BatchStrideA), : BatchStrideA_(),
BatchStrideB_(BatchStrideB), BatchStrideB_(),
BatchStrideDs_(BatchStrideDs),
BatchStrideE_(BatchStrideE)
{
if constexpr(!isMultiAB)
{
BatchStrideA_ = BatchStrideA;
BatchStrideB_ = BatchStrideB;
}
else
{
static_assert("Invalid constructor for multiple A or B");
}
}
ComputePtrOffsetOfStridedBatch(Array<ck::index_t, NumATensor> BatchStrideAs,
Array<ck::index_t, NumBTensor> BatchStrideBs,
Array<ck::index_t, NumDTensor> BatchStrideDs,
index_t BatchStrideE)
: BatchStrideA_(),
BatchStrideB_(),
BatchStrideDs_(BatchStrideDs), BatchStrideDs_(BatchStrideDs),
BatchStrideE_(BatchStrideE) BatchStrideE_(BatchStrideE)
{ {
if constexpr(isMultiAB)
{
BatchStrideA_ = BatchStrideAs;
BatchStrideB_ = BatchStrideBs;
}
else
{
static_assert("Invalid constructor for single A and B");
}
} }
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
if constexpr(!isMultiAB)
{ {
return g_idx * static_cast<long_index_t>(BatchStrideA_); return g_idx * static_cast<long_index_t>(BatchStrideA_);
} }
else
{
static_assert("Invalid function for multiple A or B");
return 0;
}
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
if constexpr(!isMultiAB)
{ {
return g_idx * static_cast<long_index_t>(BatchStrideB_); return g_idx * static_cast<long_index_t>(BatchStrideB_);
} }
else
{
static_assert("Invalid function for multiple A or B");
return 0;
}
}
__host__ __device__ constexpr auto GetAsPtrOffset(index_t g_idx) const
{
if constexpr(isMultiAB)
{
Array<long_index_t, NumATensor> as_offset;
static_for<0, NumATensor, 1>{}([&](auto i) {
as_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideA_[i]);
});
return as_offset;
}
else
{
static_assert("Invalid function for single A and B");
return BatchStrideA_;
}
}
__host__ __device__ constexpr auto GetBsPtrOffset(index_t g_idx) const
{
if constexpr(isMultiAB)
{
Array<long_index_t, NumBTensor> bs_offset;
static_for<0, NumBTensor, 1>{}([&](auto i) {
bs_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideB_[i]);
});
return bs_offset;
}
else
{
static_assert("Invalid function for single A and B");
return BatchStrideB_;
}
}
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
{ {
...@@ -54,13 +135,73 @@ struct ComputePtrOffsetOfStridedBatch ...@@ -54,13 +135,73 @@ struct ComputePtrOffsetOfStridedBatch
return g_idx * static_cast<long_index_t>(BatchStrideE_); return g_idx * static_cast<long_index_t>(BatchStrideE_);
} }
index_t BatchStrideA_; // If multiAB use Array
index_t BatchStrideB_; using BatchStrideAType =
std::conditional_t<isMultiAB, Array<ck::index_t, NumATensor>, ck::index_t>;
using BatchStrideBType =
std::conditional_t<isMultiAB, Array<ck::index_t, NumBTensor>, ck::index_t>;
BatchStrideAType BatchStrideA_;
BatchStrideBType BatchStrideB_;
Array<ck::index_t, NumDTensor> BatchStrideDs_; Array<ck::index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideE_; index_t BatchStrideE_;
index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
}; };
template <bool isTuple, typename Tensors>
constexpr static auto GetNumABTensors()
{
if constexpr(isTuple)
{
return Number<Tensors::Size()>{};
}
else
{
return Number<1>{};
}
}
template <bool isTuple, typename GridwiseGemm, typename DataType>
constexpr static auto GetAGridPointer()
{
if constexpr(isTuple)
{
return typename GridwiseGemm::AsGridPointer{};
}
else
{
return Tuple<const DataType*>{};
}
}
template <bool isTuple, typename GridwiseGemm, typename DataType>
constexpr static auto GetBGridPointer()
{
if constexpr(isTuple)
{
return typename GridwiseGemm::BsGridPointer{};
}
else
{
return Tuple<const DataType*>{};
}
}
template <bool isTuple, typename Id, typename Type>
constexpr static auto UnpackDataType()
{
if constexpr(isTuple)
{
// unpack if tuple
return tuple_element_t<Id{}, Type>{};
}
else
{
// if no, return Type
return Type{};
}
}
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -85,10 +85,13 @@ struct Add ...@@ -85,10 +85,13 @@ struct Add
struct ScaleAdd struct ScaleAdd
{ {
__host__ __device__ ScaleAdd(float scale) : scale_(scale) {} __host__ __device__ ScaleAdd(float scale = 1.f) : scale_(scale) {}
template <typename Y, typename X0, typename X1> template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const
{
y = ck::type_convert<Y>(scale_ * ck::type_convert<float>(x0) + ck::type_convert<float>(x1));
}
template <> template <>
__host__ __device__ void __host__ __device__ void
......
...@@ -203,7 +203,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -203,7 +203,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
// A desc for source in blockwise copy // A desc for source in blockwise copy
template <typename AGridDesc_M_K> template <typename AGridDesc_M_K>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
{ {
const auto M = a_grid_desc_m_k.GetLength(I0); const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1); const auto K = a_grid_desc_m_k.GetLength(I1);
...@@ -219,7 +219,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -219,7 +219,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
template <typename AsGridDesc_M_K> template <typename AsGridDesc_M_K>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeAsGridDescriptor_AK0_M_AK1(const AsGridDesc_M_K& as_grid_desc_m_k) MakeDefaultAsGridDescriptor_AK0_M_AK1(const AsGridDesc_M_K& as_grid_desc_m_k)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { return MakeAGridDescriptor_AK0_M_AK1(as_grid_desc_m_k[i]); }, [&](auto i) { return MakeAGridDescriptor_AK0_M_AK1(as_grid_desc_m_k[i]); },
...@@ -229,7 +229,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -229,7 +229,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
// B desc for source in blockwise copy // B desc for source in blockwise copy
template <typename BGridDesc_N_K> template <typename BGridDesc_N_K>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
{ {
const auto N = b_grid_desc_n_k.GetLength(I0); const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1); const auto K = b_grid_desc_n_k.GetLength(I1);
...@@ -245,7 +245,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -245,7 +245,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
template <typename BsGridDesc_N_K> template <typename BsGridDesc_N_K>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBsGridDescriptor_BK0_N_BK1(const BsGridDesc_N_K& bs_grid_desc_n_k) MakeDefaultBsGridDescriptor_BK0_N_BK1(const BsGridDesc_N_K& bs_grid_desc_n_k)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { return MakeBGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k[i]); }, [&](auto i) { return MakeBGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k[i]); },
...@@ -288,7 +288,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -288,7 +288,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
// return block_id to E matrix tile idx (m0, n0) mapping // return block_id to E matrix tile idx (m0, n0) mapping
template <typename EGridDesc_M_N> template <typename EGridDesc_M_N>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{ {
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>( return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
e_grid_desc_m_n); e_grid_desc_m_n);
...@@ -591,6 +591,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -591,6 +591,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); }, generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); },
Number<NumATensor>{}); Number<NumATensor>{});
static_assert(ABlockTransferSrcScalarPerVector == ABlockTransferDstScalarPerVector_AK1,
"Src and Dst ScalarPerVector must be the same");
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock, ThisThreadBlock,
AsDataType, AsDataType,
...@@ -619,6 +622,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -619,6 +622,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); }, generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); },
Number<NumBTensor>{}); Number<NumBTensor>{});
static_assert(BBlockTransferSrcScalarPerVector == BBlockTransferDstScalarPerVector_BK1,
"Src and Dst ScalarPerVector must be the same");
auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock, ThisThreadBlock,
BsDataType, BsDataType,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment