Commit 83be9a70 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot Committed by Bartłomiej Kocot
Browse files

Support multi AB for grouped conv fwd xdl

parent 98fd41f5
...@@ -6,18 +6,42 @@ ...@@ -6,18 +6,42 @@
#include <array> #include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/utility/is_detected.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// Convolution Forward: template <typename T>
// input : input image A[G, N, C, Hi, Wi], using is_tuple = decltype(std::declval<T&>().IsTuple());
// input : weight B[G, K, C, Y, X],
// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ... /**
// output : output image E[G, N, K, Ho, Wo] * \brief Grouped Convolution Forward
// C = a_op(A) * b_op(B) *
// E = cde_op(C, D0, D1, ...) * \details
* input : input image A[G, N, C, Hi, Wi], A1[G, N, C, Hi, Wi]...
* input : weight B[G, K, C, Y, X], B1[G, K, C, Y, X]...
* input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
* output : output image E[G, N, K, Ho, Wo]
*
* C = a_op(A, A1...) * b_op(B, B1...)
* E = cde_op(C, D0, D1, ...)
*
* \tparam NDimSpatial Number of spatial dimensions.
* \tparam ALayout Input layout (also for a1, a2...).
* \tparam BLayout Weight layout (also for b1, b2...).
* \tparam DsLayout Ds layouts.
* \tparam ELayout Output layout.
* \tparam ADataType Input data type. Pass tuple if there is multiple A.
* \tparam BDataType Weight data type. Pass tuple if there is multiple B.
* \tparam DsDataType D data types.
* \tparam EDataType Output data type.
* \tparam AElementwiseOperation A elementwise operation.
* \tparam BElementwiseOperation B elementwise operation.
* \tparam CDEElementwiseOperation CDE elementwise operation.
* \tparam ComputeType Compute data type (default: ADataType, first if tuple passed).
*/
template <index_t NDimSpatial, template <index_t NDimSpatial,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
...@@ -30,18 +54,60 @@ template <index_t NDimSpatial, ...@@ -30,18 +54,60 @@ template <index_t NDimSpatial,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
typename ComputeType = ADataType> typename ComputeType =
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
Number<0>,
ADataType>())> // ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
struct DeviceGroupedConvFwdMultipleD : public BaseOperator struct DeviceGroupedConvFwdMultipleD : public BaseOperator
{ {
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor"); static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor");
// If DataType is tuple, user has to pass std::array with pointers.
using APointers =
std::conditional_t<isMultiA, std::array<const void*, NumATensor>&, const void*>;
using BPointers =
std::conditional_t<isMultiB, std::array<const void*, NumBTensor>&, const void*>;
/**
* \brief Make argument pointer for grouped conv fwd.
*
* \param p_a A pointer to the input (std::array<const void*, NumA> with
pointers for multiple A).
* \param p_b A pointer to the weight (std::array<const void*, NumA> with
pointers for multiple B).
* \param p_ds A pointers to the Ds.
* \param p_e A pointers to the output.
* \param a_g_n_c_wis_lengths Input lengths [G, N, C, Spatial...] (for 3d).
* \param a_g_n_c_wis_strides Input strides [G, N, C, Spatial...] (for 3d).
* \param b_g_k_c_xs_lengths Weight lengths [G, K, C, Spatial...] (for 3d).
* \param b_g_k_c_xs_strides Weight strides [G, K, C, Spatial...] (for 3d).
* \param ds_g_n_k_wos_lengths Ds lengths [G, N, K, Spatial...] (for 3d).
* \param ds_g_n_k_wos_strides Ds strides [G, N, K, Spatial...] (for 3d).
* \param e_g_n_k_wos_lengths Output lengths [G, N, K, Spatial...] (for 3d).
* \param e_g_n_k_wos_strides Output strides [G, N, K, Spatial...] (for 3d).
* \param conv_filter_strides Convolution filter strides.
* \param conv_filter_dilations Convolution filter dilations.
* \param input_left_pads Input left paddings.
* \param input_right_pads Input right paddings.
* \param a_element_op A elementwise operation object.
* \param b_element_op B elementwise operation object.
* \param cde_element_op CDE elementwise operation object.
* \return Pointer to the argument.
*/
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer( virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a, // input image APointers p_a,
const void* p_b, // weight BPointers p_b,
const std::array<const void*, NumDTensor>& p_ds, const std::array<const void*, NumDTensor>& p_ds,
void* p_e, // output image void* p_e,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
......
...@@ -517,7 +517,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ...@@ -517,7 +517,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
std::vector<typename GridwiseGemm::DefaultBlock2CTileMap> block_2_ctile_map_container_; std::vector<typename GridwiseGemm::DefaultBlock2CTileMap> block_2_ctile_map_container_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch<NumDTensor> compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
// element-wise op // element-wise op
AElementwiseOp a_element_op_; AElementwiseOp a_element_op_;
...@@ -579,7 +579,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ...@@ -579,7 +579,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
ComputePtrOffsetOfStridedBatch<NumDTensor>, ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
has_main_loop>; has_main_loop>;
return launch_and_time_kernel( return launch_and_time_kernel(
......
...@@ -677,7 +677,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -677,7 +677,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
std::vector<Block2ETileMap> block_2_etile_map_container_; std::vector<Block2ETileMap> block_2_etile_map_container_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch<NumDTensor> compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
// element-wise op // element-wise op
AElementwiseOp a_element_op_; AElementwiseOp a_element_op_;
...@@ -746,7 +746,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -746,7 +746,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
Block2ETileMap, Block2ETileMap,
ComputePtrOffsetOfStridedBatch<NumDTensor>, ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
has_main_loop>; has_main_loop>;
return launch_and_time_kernel( return launch_and_time_kernel(
......
...@@ -927,7 +927,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -927,7 +927,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
Block2CTileMap block_2_ctile_map_; Block2CTileMap block_2_ctile_map_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_;
// element-wise op // element-wise op
OutElementwiseOperation a_element_op_; OutElementwiseOperation a_element_op_;
...@@ -999,7 +999,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -999,7 +999,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
remove_reference_t<DeviceOp::BGridDesc_B_K0_N0_N1_K1>, remove_reference_t<DeviceOp::BGridDesc_B_K0_N0_N1_K1>,
remove_reference_t<DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11>, remove_reference_t<DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DeviceOp::Block2CTileMap>, remove_reference_t<DeviceOp::Block2CTileMap>,
ComputePtrOffsetOfStridedBatch<I0>, ComputePtrOffsetOfStridedBatch<>,
has_main_loop, has_main_loop,
has_double_loop>; has_double_loop>;
......
...@@ -565,7 +565,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle ...@@ -565,7 +565,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
Block2CTileMap block_2_ctile_map_; Block2CTileMap block_2_ctile_map_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_;
OutElementwiseOperation a_element_op_; OutElementwiseOperation a_element_op_;
InElementwiseOperation b_element_op_; InElementwiseOperation b_element_op_;
...@@ -647,7 +647,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle ...@@ -647,7 +647,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
ComputePtrOffsetOfStridedBatch<I0>, ComputePtrOffsetOfStridedBatch<>,
has_main_loop>; has_main_loop>;
using EmptyTuple = Tuple<>; using EmptyTuple = Tuple<>;
......
...@@ -1197,7 +1197,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1197,7 +1197,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
Block2CTileMap block_2_ctile_map_; Block2CTileMap block_2_ctile_map_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
...@@ -1276,7 +1276,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1276,7 +1276,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
remove_reference_t<DeviceOp::Block2CTileMap>, remove_reference_t<DeviceOp::Block2CTileMap>,
ComputePtrOffsetOfStridedBatch<I0>, ComputePtrOffsetOfStridedBatch<>,
has_main_loop>; has_main_loop>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
......
...@@ -537,7 +537,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -537,7 +537,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
DefaultBlock2CTileMap block_2_ctile_map_; DefaultBlock2CTileMap block_2_ctile_map_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch<NumDTensor> compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
// element-wise op // element-wise op
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
...@@ -601,7 +601,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -601,7 +601,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
DeviceOp::DsGridDesc_M0_M10_M11_N0_N10_N11, DeviceOp::DsGridDesc_M0_M10_M11_N0_N10_N11,
DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11, DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11,
DefaultBlock2CTileMap, DefaultBlock2CTileMap,
ComputePtrOffsetOfStridedBatch<NumDTensor>, ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
has_main_loop, has_main_loop,
has_double_loop>; has_double_loop>;
......
...@@ -428,7 +428,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -428,7 +428,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
typename GridwiseOp::DefaultBlock2CTileMap block_2_etile_map_; typename GridwiseOp::DefaultBlock2CTileMap block_2_etile_map_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch<NumDTensor> compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
// element-wise op // element-wise op
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
...@@ -485,7 +485,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -485,7 +485,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
ComputePtrOffsetOfStridedBatch<NumDTensor>, ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
has_main_loop>; has_main_loop>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
......
...@@ -9,30 +9,111 @@ namespace ck { ...@@ -9,30 +9,111 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <index_t NumDTensor> template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
struct ComputePtrOffsetOfStridedBatch struct ComputePtrOffsetOfStridedBatch
{ {
static constexpr bool isMultiAB = NumATensor > 1 || NumBTensor > 1;
ComputePtrOffsetOfStridedBatch() = default; ComputePtrOffsetOfStridedBatch() = default;
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB, index_t BatchStrideB,
Array<ck::index_t, NumDTensor> BatchStrideDs, Array<ck::index_t, NumDTensor> BatchStrideDs,
index_t BatchStrideE) index_t BatchStrideE)
: BatchStrideA_(BatchStrideA), : BatchStrideA_(),
BatchStrideB_(BatchStrideB), BatchStrideB_(),
BatchStrideDs_(BatchStrideDs),
BatchStrideE_(BatchStrideE)
{
if constexpr(!isMultiAB)
{
BatchStrideA_ = BatchStrideA;
BatchStrideB_ = BatchStrideB;
}
else
{
static_assert("Invalid constructor for multiple A or B");
}
}
ComputePtrOffsetOfStridedBatch(Array<ck::index_t, NumATensor> BatchStrideAs,
Array<ck::index_t, NumBTensor> BatchStrideBs,
Array<ck::index_t, NumDTensor> BatchStrideDs,
index_t BatchStrideE)
: BatchStrideA_(),
BatchStrideB_(),
BatchStrideDs_(BatchStrideDs), BatchStrideDs_(BatchStrideDs),
BatchStrideE_(BatchStrideE) BatchStrideE_(BatchStrideE)
{ {
if constexpr(isMultiAB)
{
BatchStrideA_ = BatchStrideAs;
BatchStrideB_ = BatchStrideBs;
}
else
{
static_assert("Invalid constructor for single A and B");
}
} }
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideA_); if constexpr(!isMultiAB)
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
else
{
static_assert("Invalid function for multiple A or B");
return 0;
}
} }
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideB_); if constexpr(!isMultiAB)
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
else
{
static_assert("Invalid function for multiple A or B");
return 0;
}
}
__host__ __device__ constexpr auto GetAsPtrOffset(index_t g_idx) const
{
if constexpr(isMultiAB)
{
Array<long_index_t, NumATensor> as_offset;
static_for<0, NumATensor, 1>{}([&](auto i) {
as_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideA_[i]);
});
return as_offset;
}
else
{
static_assert("Invalid function for single A and B");
return BatchStrideA_;
}
}
__host__ __device__ constexpr auto GetBsPtrOffset(index_t g_idx) const
{
if constexpr(isMultiAB)
{
Array<long_index_t, NumBTensor> bs_offset;
static_for<0, NumBTensor, 1>{}([&](auto i) {
bs_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideB_[i]);
});
return bs_offset;
}
else
{
static_assert("Invalid function for single A and B");
return BatchStrideB_;
}
} }
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
...@@ -54,13 +135,73 @@ struct ComputePtrOffsetOfStridedBatch ...@@ -54,13 +135,73 @@ struct ComputePtrOffsetOfStridedBatch
return g_idx * static_cast<long_index_t>(BatchStrideE_); return g_idx * static_cast<long_index_t>(BatchStrideE_);
} }
index_t BatchStrideA_; // If multiAB use Array
index_t BatchStrideB_; using BatchStrideAType =
std::conditional_t<isMultiAB, Array<ck::index_t, NumATensor>, ck::index_t>;
using BatchStrideBType =
std::conditional_t<isMultiAB, Array<ck::index_t, NumBTensor>, ck::index_t>;
BatchStrideAType BatchStrideA_;
BatchStrideBType BatchStrideB_;
Array<ck::index_t, NumDTensor> BatchStrideDs_; Array<ck::index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideE_; index_t BatchStrideE_;
index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
}; };
template <bool isTuple, typename Tensors>
constexpr static auto GetNumABTensors()
{
if constexpr(isTuple)
{
return Number<Tensors::Size()>{};
}
else
{
return Number<1>{};
}
}
template <bool isTuple, typename GridwiseGemm, typename DataType>
constexpr static auto GetAGridPointer()
{
if constexpr(isTuple)
{
return typename GridwiseGemm::AsGridPointer{};
}
else
{
return Tuple<const DataType*>{};
}
}
template <bool isTuple, typename GridwiseGemm, typename DataType>
constexpr static auto GetBGridPointer()
{
if constexpr(isTuple)
{
return typename GridwiseGemm::BsGridPointer{};
}
else
{
return Tuple<const DataType*>{};
}
}
template <bool isTuple, typename Id, typename Type>
constexpr static auto UnpackDataType()
{
if constexpr(isTuple)
{
// unpack if tuple
return tuple_element_t<Id{}, Type>{};
}
else
{
// if no, return Type
return Type{};
}
}
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -85,10 +85,13 @@ struct Add ...@@ -85,10 +85,13 @@ struct Add
struct ScaleAdd struct ScaleAdd
{ {
__host__ __device__ ScaleAdd(float scale) : scale_(scale) {} __host__ __device__ ScaleAdd(float scale = 1.f) : scale_(scale) {}
template <typename Y, typename X0, typename X1> template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const
{
y = ck::type_convert<Y>(scale_ * ck::type_convert<float>(x0) + ck::type_convert<float>(x1));
}
template <> template <>
__host__ __device__ void __host__ __device__ void
......
...@@ -203,7 +203,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -203,7 +203,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
// A desc for source in blockwise copy // A desc for source in blockwise copy
template <typename AGridDesc_M_K> template <typename AGridDesc_M_K>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
{ {
const auto M = a_grid_desc_m_k.GetLength(I0); const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1); const auto K = a_grid_desc_m_k.GetLength(I1);
...@@ -219,7 +219,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -219,7 +219,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
template <typename AsGridDesc_M_K> template <typename AsGridDesc_M_K>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeAsGridDescriptor_AK0_M_AK1(const AsGridDesc_M_K& as_grid_desc_m_k) MakeDefaultAsGridDescriptor_AK0_M_AK1(const AsGridDesc_M_K& as_grid_desc_m_k)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { return MakeAGridDescriptor_AK0_M_AK1(as_grid_desc_m_k[i]); }, [&](auto i) { return MakeAGridDescriptor_AK0_M_AK1(as_grid_desc_m_k[i]); },
...@@ -229,7 +229,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -229,7 +229,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
// B desc for source in blockwise copy // B desc for source in blockwise copy
template <typename BGridDesc_N_K> template <typename BGridDesc_N_K>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
{ {
const auto N = b_grid_desc_n_k.GetLength(I0); const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1); const auto K = b_grid_desc_n_k.GetLength(I1);
...@@ -245,7 +245,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -245,7 +245,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
template <typename BsGridDesc_N_K> template <typename BsGridDesc_N_K>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBsGridDescriptor_BK0_N_BK1(const BsGridDesc_N_K& bs_grid_desc_n_k) MakeDefaultBsGridDescriptor_BK0_N_BK1(const BsGridDesc_N_K& bs_grid_desc_n_k)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { return MakeBGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k[i]); }, [&](auto i) { return MakeBGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k[i]); },
...@@ -288,7 +288,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -288,7 +288,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
// return block_id to E matrix tile idx (m0, n0) mapping // return block_id to E matrix tile idx (m0, n0) mapping
template <typename EGridDesc_M_N> template <typename EGridDesc_M_N>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{ {
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>( return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
e_grid_desc_m_n); e_grid_desc_m_n);
...@@ -591,6 +591,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -591,6 +591,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); }, generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); },
Number<NumATensor>{}); Number<NumATensor>{});
static_assert(ABlockTransferSrcScalarPerVector == ABlockTransferDstScalarPerVector_AK1,
"Src and Dst ScalarPerVector must be the same");
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock, ThisThreadBlock,
AsDataType, AsDataType,
...@@ -619,6 +622,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -619,6 +622,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); }, generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); },
Number<NumBTensor>{}); Number<NumBTensor>{});
static_assert(BBlockTransferSrcScalarPerVector == BBlockTransferDstScalarPerVector_BK1,
"Src and Dst ScalarPerVector must be the same");
auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock, ThisThreadBlock,
BsDataType, BsDataType,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment