Commit b93575ca authored by Jing Zhang's avatar Jing Zhang
Browse files

merge develop

parents 54df59bf c8a8385f
......@@ -114,6 +114,7 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
PipelineVersion PipelineVer = PipelineVersion::v1,
LoopScheduler LoopSched = make_default_loop_scheduler(),
// Current implementation does not support multiple D fusions.
enable_if_t<AK1 == BK1 && is_same_v<DsLayout, ck::Tuple<>> &&
......@@ -142,7 +143,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
ADataType,
BDataType,
AccDataType,
EDataType,
ALayout,
......@@ -182,7 +184,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
CDEBlockTransferScalarPerVector_NPerBlock,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
LoopSched,
PipelineVersion::v1>;
PipelineVer>;
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using Block2ETileMapKSplit =
......@@ -421,8 +423,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
for(const auto& trans_arg : arg.gemm_kernel_args_)
{
const auto& karg = trans_arg.karg_;
hip_check_error(
hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(EDataType)));
hip_check_error(hipMemsetAsync(karg.p_c_grid,
0,
karg.M * karg.N * sizeof(EDataType),
stream_config.stream_id_));
}
}
......
......@@ -3,16 +3,7 @@
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp"
namespace ck {
namespace tensor_operation {
......@@ -30,254 +21,31 @@ template <typename InDataType,
ck::index_t ReduceMThreadSliceSize,
ck::index_t ReduceKThreadSliceSize,
ck::index_t InSrcOutDstVectorSize>
struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
: public DevicePoolFwd<4, 2, InDataType, OutDataType, IndexDataType, ReduceOpId, OutputIndex>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr index_t InOutRank = 4;
static constexpr index_t WindowRank = 2;
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation =
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation =
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
static constexpr index_t InSrcOutDstVectorDim =
0; // for NHWC, the dim C is the vector Dim for both input and output in memory, which is
// not reduced.
static constexpr ck::index_t ReduceM_BlockTileSize =
ReduceMThreadClusterSize * ReduceMThreadSliceSize;
static constexpr ck::index_t ReduceK_BlockTileSize =
ReduceKThreadClusterSize * ReduceKThreadSliceSize;
static auto MakeABGridDescriptor_A_M_K_B_M(ck::index_t N,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> window_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
{
const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1];
const index_t Ho = output_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[1];
const index_t Y = window_spatial_lengths[0];
const index_t X = window_spatial_lengths[1];
const index_t ConvStrideH = window_strides[0];
const index_t ConvStrideW = window_strides[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
const index_t ReduceMRaw = N * Ho * Wo * C;
const index_t ReduceMPad =
math::integer_least_multiple(ReduceMRaw, ReduceM_BlockTileSize) - ReduceMRaw;
const index_t ReduceKRaw = Y * X;
const index_t ReduceKPad =
math::integer_least_multiple(ReduceKRaw, ReduceK_BlockTileSize) - ReduceKRaw;
// A[ReduceM, ReduceK]
const auto in_grid_desc_n_hi_wi_c =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_grid_desc_n_hip_wip_c = transform_tensor_descriptor(
in_grid_desc_n_hi_wi_c,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_grid_desc_n_y_ho_x_wo_c = transform_tensor_descriptor(
in_grid_desc_n_hip_wip_c,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(I1, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(I1, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_grid_desc_reducemraw_reducekraw =
transform_tensor_descriptor(in_grid_desc_n_y_ho_x_wo_c,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, C)),
make_merge_transform(make_tuple(Y, X))),
make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_grid_desc_reducem_reducek = transform_tensor_descriptor(
in_grid_desc_reducemraw_reducekraw,
make_tuple(make_right_pad_transform(ReduceMRaw, ReduceMPad),
make_right_pad_transform(ReduceKRaw, ReduceKPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// B[ReduceM]
const auto out_grid_desc_reducemraw =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo * C));
const auto out_grid_desc_reducem = transform_tensor_descriptor(
out_grid_desc_reducemraw,
make_tuple(make_right_pad_transform(ReduceMRaw, ReduceMPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem);
}
using ABGridDescs = decltype(MakeABGridDescriptor_A_M_K_B_M(1, 1, {}, {}, {}, {}, {}, {}));
using AGridDesc_M_K = remove_cvref_t<decltype(ABGridDescs{}[I0])>;
using BGridDesc_M = remove_cvref_t<decltype(ABGridDescs{}[I1])>;
// TODO
struct Argument : public BaseArgument
{
Argument(const InDataType* p_in_dev,
OutDataType* p_out_dev,
IndexDataType* p_out_indices_dev,
ck::index_t N,
ck::index_t C,
std::vector<ck::index_t>& input_spatial_lengths,
std::vector<ck::index_t>& window_spatial_lengths,
std::vector<ck::index_t>& output_spatial_lengths,
std::vector<ck::index_t>& window_strides,
std::vector<ck::index_t>& input_left_pads,
std::vector<ck::index_t>& input_right_pads)
: p_in_dev_{p_in_dev},
p_out_dev_{p_out_dev},
p_out_indices_dev_{p_out_indices_dev},
a_grid_desc_m_k_{},
b_grid_desc_m_{}
{
const auto descs = MakeABGridDescriptor_A_M_K_B_M(N,
C,
input_spatial_lengths,
window_spatial_lengths,
output_spatial_lengths,
window_strides,
input_left_pads,
input_right_pads);
a_grid_desc_m_k_ = descs[I0];
b_grid_desc_m_ = descs[I1];
invariant_lowest_length_ = C;
reduce_lowest_length_ = window_spatial_lengths[1];
int32_t reduceLength = window_spatial_lengths[0] * window_spatial_lengths[1];
std::tie(in_element_op_, acc_element_op_) =
reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(reduceLength);
}
const InDataType* p_in_dev_;
OutDataType* p_out_dev_;
IndexDataType* p_out_indices_dev_;
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_M b_grid_desc_m_;
InElementwiseOperation in_element_op_;
AccElementwiseOperation acc_element_op_;
// for checking vector load/store
ck::index_t invariant_lowest_length_;
ck::index_t reduce_lowest_length_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
using gridwise_reduce =
GridwiseReduction_mk_to_m_threadwise<InDataType,
struct DevicePool2dFwd_NHWC_NHWC : public DevicePool3dFwd_NDHWC_NDHWC<InDataType,
OutDataType,
ComputeDataType,
IndexDataType,
AGridDesc_M_K,
BGridDesc_M,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
InMemoryDataOperationEnum::Set,
false, // propagate_nan
ComputeDataType,
ReduceOpId,
OutputIndex,
BlockSize,
ReduceMThreadClusterSize,
ReduceKThreadClusterSize,
ReduceMThreadSliceSize,
ReduceKThreadSliceSize,
InSrcOutDstVectorDim,
InSrcOutDstVectorSize,
InSrcOutDstVectorSize>;
const auto kernel =
kernel_reduce_threadwise<gridwise_reduce,
OutputIndex,
true, // pooling need to return global index
false, // don't have index input
InDataType,
InSrcOutDstVectorSize>
{
using DevicePool3D = DevicePool3dFwd_NDHWC_NDHWC<InDataType,
OutDataType,
ComputeDataType,
IndexDataType,
AGridDesc_M_K,
BGridDesc_M,
InElementwiseOperation,
AccElementwiseOperation>;
ck::index_t ReduceM = arg.a_grid_desc_m_k_.GetLength(I0);
const index_t grid_size = (ReduceM / ReduceM_BlockTileSize);
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.a_grid_desc_m_k_,
arg.b_grid_desc_m_,
arg.in_element_op_,
arg.acc_element_op_,
float(1),
arg.p_in_dev_,
nullptr,
float(0),
arg.p_out_dev_,
arg.p_out_indices_dev_);
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if(pArg->invariant_lowest_length_ % InSrcOutDstVectorSize != 0)
{
return (false);
}
return (true);
}
ComputeDataType,
ReduceOpId,
OutputIndex,
BlockSize,
ReduceMThreadClusterSize,
ReduceKThreadClusterSize,
ReduceMThreadSliceSize,
ReduceKThreadSliceSize,
InSrcOutDstVectorSize>;
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_dev,
......@@ -286,62 +54,57 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
std::vector<ck::index_t> input_lengths,
std::vector<ck::index_t> window_lengths,
std::vector<ck::index_t> output_lengths,
std::vector<ck::index_t>, // Suppose tensor layout = NHWC
std::vector<ck::index_t>, // Suppose tensor layout = NHWC
std::vector<ck::index_t>, // Suppose tensor layout = NHWC
std::vector<ck::index_t> input_stride,
std::vector<ck::index_t> output_stride,
std::vector<ck::index_t> indices_stride,
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> pooling_dims) override
{
static constexpr index_t InOutRank = 4;
static constexpr index_t WindowRank = 2;
if(input_lengths.size() != InOutRank || window_lengths.size() != WindowRank ||
input_lengths.size() != InOutRank || window_strides.size() != WindowRank ||
input_left_pads.size() != WindowRank || input_right_pads.size() != WindowRank)
window_dilations.size() != WindowRank || input_left_pads.size() != WindowRank ||
input_right_pads.size() != WindowRank)
throw std::runtime_error("dimension is incorrect");
if(pooling_dims != std::vector<ck::index_t>{2, 3})
throw std::runtime_error("pooling_dims only support {2, 3} in pool2d so far");
index_t N = input_lengths[0];
index_t C = input_lengths[1];
index_t Hi = input_lengths[2];
index_t Wi = input_lengths[3];
index_t Ho = output_lengths[2];
index_t Wo = output_lengths[3];
std::vector<ck::index_t> input_spatial_lengths = {Hi, Wi};
std::vector<ck::index_t> output_spatial_lengths = {Ho, Wo};
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_dev),
static_cast<OutDataType*>(p_out_dev),
static_cast<IndexDataType*>(p_out_indices_dev),
N,
C,
input_spatial_lengths,
// NCHW to NCDHW
input_lengths.insert(input_lengths.begin() + 2, 1);
output_lengths.insert(output_lengths.begin() + 2, 1);
input_stride.insert(input_stride.begin() + 2, 0);
output_stride.insert(output_stride.begin() + 2, 0);
indices_stride.insert(indices_stride.begin() + 2, 0);
// YX to ZYX
window_lengths.insert(window_lengths.begin(), 1);
window_strides.insert(window_strides.begin(), 0);
window_dilations.insert(window_dilations.begin(), 0);
input_left_pads.insert(input_left_pads.begin(), 0);
input_right_pads.insert(input_right_pads.begin(), 0);
pooling_dims = {2, 3, 4};
return DevicePool3D::MakeArgumentPointer(p_in_dev,
p_out_dev,
p_out_indices_dev,
input_lengths,
window_lengths,
output_spatial_lengths,
output_lengths,
input_stride,
output_stride,
indices_stride,
window_strides,
window_dilations,
input_left_pads,
input_right_pads);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<" << BlockSize << ",";
str << "M_C" << ReduceMThreadClusterSize << "_S" << ReduceMThreadSliceSize << ",";
str << "K_C" << ReduceKThreadClusterSize << "_S" << ReduceKThreadSliceSize << ",";
str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">";
// clang-format on
return str.str();
input_right_pads,
pooling_dims);
}
};
......
......@@ -8,8 +8,10 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -30,8 +32,15 @@ template <typename InDataType,
ck::index_t MThreadSliceSize,
ck::index_t KThreadSliceSize,
ck::index_t InSrcOutDstVectorSize>
struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
: public DevicePoolFwd<5, 3, InDataType, OutDataType, IndexDataType, ReduceOpId, OutputIndex>
struct DevicePool3dFwd_NDHWC_NDHWC : public DevicePoolFwd<5,
3,
InDataType,
OutDataType,
IndexDataType,
tensor_layout::convolution::NDHWC,
tensor_layout::convolution::NDHWC,
ReduceOpId,
OutputIndex>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -51,45 +60,48 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
using AccElementwiseOperation =
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
// for NDHWC, the dim C is the vector Dim for both input and output in memory, which is not
// reduced.
static constexpr index_t InSrcOutDstVectorDim = 0;
static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeABGridDescriptor_A_M_K_B_M(ck::index_t N,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> window_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
static auto MakeABGridDescriptor_A_M_K_B_M(std::vector<ck::index_t> input_ncdhw_lengths,
std::vector<ck::index_t> output_ncdhw_lengths,
std::vector<ck::index_t> input_ncdhw_stride,
std::vector<ck::index_t> output_ncdhw_stride,
std::vector<ck::index_t> window_spatial_zyx_lengths,
std::vector<ck::index_t> window_zyx_strides,
std::vector<ck::index_t> window_zyx_dilations,
std::vector<ck::index_t> input_left_dhw_pads,
std::vector<ck::index_t> input_right_dhw_pads)
{
const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2];
const index_t N = input_ncdhw_lengths[0];
const index_t C = input_ncdhw_lengths[1];
const index_t Di = input_ncdhw_lengths[2];
const index_t Hi = input_ncdhw_lengths[3];
const index_t Wi = input_ncdhw_lengths[4];
const index_t Do = output_ncdhw_lengths[2];
const index_t Ho = output_ncdhw_lengths[3];
const index_t Wo = output_ncdhw_lengths[4];
const index_t Do = output_spatial_lengths[0];
const index_t Ho = output_spatial_lengths[1];
const index_t Wo = output_spatial_lengths[2];
const index_t Z = window_spatial_zyx_lengths[0];
const index_t Y = window_spatial_zyx_lengths[1];
const index_t X = window_spatial_zyx_lengths[2];
const index_t Z = window_spatial_lengths[0];
const index_t Y = window_spatial_lengths[1];
const index_t X = window_spatial_lengths[2];
const index_t WindowStrideD = window_zyx_strides[0];
const index_t WindowStrideH = window_zyx_strides[1];
const index_t WindowStrideW = window_zyx_strides[2];
const index_t ConvStrideD = window_strides[0];
const index_t ConvStrideH = window_strides[1];
const index_t ConvStrideW = window_strides[2];
const index_t WindowDilationD = window_zyx_dilations[0];
const index_t WindowDilationH = window_zyx_dilations[1];
const index_t WindowDilationW = window_zyx_dilations[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InLeftPadD = input_left_dhw_pads[0];
const index_t InLeftPadH = input_left_dhw_pads[1];
const index_t InLeftPadW = input_left_dhw_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
const index_t InRightPadD = input_right_dhw_pads[0];
const index_t InRightPadH = input_right_dhw_pads[1];
const index_t InRightPadW = input_right_dhw_pads[2];
const index_t MRaw = N * Do * Ho * Wo * C;
const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw;
......@@ -98,8 +110,15 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw;
// A[ReduceM, ReduceK]
const auto in_grid_desc_n_di_hi_wi_c =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const index_t Ni_stride = input_ncdhw_stride[0];
const index_t Ci_stride = input_ncdhw_stride[1];
const index_t Di_stride = input_ncdhw_stride[2];
const index_t Hi_stride = input_ncdhw_stride[3];
const index_t Wi_stride = input_ncdhw_stride[4];
const auto in_grid_desc_n_di_hi_wi_c = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(Ni_stride, Di_stride, Hi_stride, Wi_stride, Ci_stride));
const auto in_grid_desc_n_dip_hip_wip_c = transform_tensor_descriptor(
in_grid_desc_n_di_hi_wi_c,
......@@ -113,10 +132,11 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
const auto in_grid_desc_n_z_do_y_ho_x_wo_c = transform_tensor_descriptor(
in_grid_desc_n_dip_hip_wip_c,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do), make_tuple(I1, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(I1, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(I1, ConvStrideW)),
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do), make_tuple(WindowDilationD, WindowStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(WindowDilationH, WindowStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(WindowDilationW, WindowStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
......@@ -139,8 +159,21 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
make_tuple(Sequence<0>{}, Sequence<1>{}));
// B[ReduceM]
const auto out_grid_desc_reducemraw =
make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo * C));
const index_t No_stride = output_ncdhw_stride[0];
const index_t Co_stride = output_ncdhw_stride[1];
const index_t Do_stride = output_ncdhw_stride[2];
const index_t Ho_stride = output_ncdhw_stride[3];
const index_t Wo_stride = output_ncdhw_stride[4];
const auto out_grid_desc_n_do_ho_wo_c = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(No_stride, Do_stride, Ho_stride, Wo_stride, Co_stride));
const auto out_grid_desc_reducemraw = transform_tensor_descriptor(
out_grid_desc_n_do_ho_wo_c,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, C))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto out_grid_desc_reducem =
transform_tensor_descriptor(out_grid_desc_reducemraw,
......@@ -151,7 +184,9 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem);
}
using ABGridDescs = decltype(MakeABGridDescriptor_A_M_K_B_M(1, 1, {}, {}, {}, {}, {}, {}));
using ABGridDescs =
decltype(MakeABGridDescriptor_A_M_K_B_M({}, {}, {}, {}, {}, {}, {}, {}, {}));
using AGridDesc_M_K = remove_cvref_t<decltype(ABGridDescs{}[I0])>;
using BGridDesc_M = remove_cvref_t<decltype(ABGridDescs{}[I1])>;
......@@ -160,36 +195,41 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
Argument(const InDataType* p_in_dev,
OutDataType* p_out_dev,
IndexDataType* p_out_indices_dev,
ck::index_t N,
ck::index_t C,
std::vector<ck::index_t>& input_spatial_lengths,
std::vector<ck::index_t>& window_spatial_lengths,
std::vector<ck::index_t>& output_spatial_lengths,
std::vector<ck::index_t>& window_strides,
std::vector<ck::index_t>& input_left_pads,
std::vector<ck::index_t>& input_right_pads)
std::vector<ck::index_t>& input_ncdhw_lengths,
std::vector<ck::index_t>& output_ncdhw_lengths,
std::vector<ck::index_t>& input_ncdhw_stride,
std::vector<ck::index_t>& output_ncdhw_stride,
std::vector<ck::index_t>&, // indices_ncdhw_stride
std::vector<ck::index_t>& window_spatial_zyx_lengths,
std::vector<ck::index_t>& window_zyx_strides,
std::vector<ck::index_t>& window_zyx_dilations,
std::vector<ck::index_t>& input_left_dhw_pads,
std::vector<ck::index_t>& input_right_dhw_pads)
: p_in_dev_{p_in_dev},
p_out_dev_{p_out_dev},
p_out_indices_dev_{p_out_indices_dev},
a_grid_desc_m_k_{},
b_grid_desc_m_{}
b_grid_desc_m_{},
input_ncdhw_lengths_{input_ncdhw_lengths},
output_ncdhw_lengths_{output_ncdhw_lengths},
input_ncdhw_stride_{input_ncdhw_stride},
output_ncdhw_stride_{output_ncdhw_stride}
{
const auto descs = MakeABGridDescriptor_A_M_K_B_M(N,
C,
input_spatial_lengths,
window_spatial_lengths,
output_spatial_lengths,
window_strides,
input_left_pads,
input_right_pads);
const auto descs = MakeABGridDescriptor_A_M_K_B_M(input_ncdhw_lengths,
output_ncdhw_lengths,
input_ncdhw_stride,
output_ncdhw_stride,
window_spatial_zyx_lengths,
window_zyx_strides,
window_zyx_dilations,
input_left_dhw_pads,
input_right_dhw_pads);
a_grid_desc_m_k_ = descs[I0];
b_grid_desc_m_ = descs[I1];
invariant_lowest_length_ = C;
int32_t reduceLength =
window_spatial_lengths[0] * window_spatial_lengths[1] * window_spatial_lengths[2];
int32_t reduceLength = window_spatial_zyx_lengths[0] * window_spatial_zyx_lengths[1] *
window_spatial_zyx_lengths[2];
std::tie(in_element_op_, acc_element_op_) =
reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(reduceLength);
......@@ -200,17 +240,25 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
IndexDataType* p_out_indices_dev_;
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_M b_grid_desc_m_;
InElementwiseOperation in_element_op_;
AccElementwiseOperation acc_element_op_;
// for checking vector load/store
ck::index_t invariant_lowest_length_;
std::vector<ck::index_t> input_ncdhw_lengths_;
std::vector<ck::index_t> output_ncdhw_lengths_;
std::vector<ck::index_t> input_ncdhw_stride_;
std::vector<ck::index_t> output_ncdhw_stride_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
// for NDHWC, the dim C is the fastest dimension, and is not reduced.
// Hence, it is in M dimension for reduction kernel.
static constexpr index_t InSrcOutDstVectorDim = 0; // 0: M, 1: K
using gridwise_reduce =
GridwiseReduction_mk_to_m_threadwise<InDataType,
OutDataType,
......@@ -276,60 +324,66 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if(pArg->invariant_lowest_length_ % InSrcOutDstVectorSize != 0)
// C should be fastest dimension
if(pArg->input_ncdhw_stride_[1] != 1)
return false;
for(int i = 0; i < InOutRank; ++i)
{
if(pArg->input_ncdhw_stride_[i] == 1 &&
pArg->input_ncdhw_lengths_[i] % InSrcOutDstVectorSize != 0)
return false;
if(pArg->output_ncdhw_stride_[i] == 1 &&
pArg->output_ncdhw_lengths_[i] % InSrcOutDstVectorSize != 0)
return false;
}
return true;
}
std::unique_ptr<BaseArgument>
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_dev,
void* p_out_dev,
void* p_out_indices_dev,
std::vector<ck::index_t> input_lengths,
std::vector<ck::index_t> window_lengths,
std::vector<ck::index_t> output_lengths,
std::vector<ck::index_t>, // Suppose tensor layout = NDHWC
std::vector<ck::index_t>, // Suppose tensor layout = NDHWC
std::vector<ck::index_t>, // Suppose tensor layout = NDHWC
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> input_ncdhw_lengths,
std::vector<ck::index_t> window_zyx_lengths,
std::vector<ck::index_t> output_ncdhw_lengths,
std::vector<ck::index_t> input_ncdhw_stride,
std::vector<ck::index_t> output_ncdhw_stride,
std::vector<ck::index_t> indices_ncdhw_stride,
std::vector<ck::index_t> window_zyx_strides,
std::vector<ck::index_t> window_zyx_dilations,
std::vector<ck::index_t> input_left_dhw_pads,
std::vector<ck::index_t> input_right_dhw_pads,
std::vector<ck::index_t> pooling_dims) override
{
if(input_lengths.size() != InOutRank || window_lengths.size() != WindowRank ||
input_lengths.size() != InOutRank || window_strides.size() != WindowRank ||
input_left_pads.size() != WindowRank || input_right_pads.size() != WindowRank)
if(input_ncdhw_lengths.size() != InOutRank || window_zyx_lengths.size() != WindowRank ||
input_ncdhw_lengths.size() != InOutRank || window_zyx_strides.size() != WindowRank ||
window_zyx_dilations.size() != WindowRank || input_left_dhw_pads.size() != WindowRank ||
input_right_dhw_pads.size() != WindowRank)
throw std::runtime_error("dimension is incorrect");
if(pooling_dims != std::vector<ck::index_t>{2, 3, 4})
throw std::runtime_error("pooling_dims only support {2, 3, 4} in pool3d so far");
index_t N = input_lengths[0];
index_t C = input_lengths[1];
index_t Di = input_lengths[2];
index_t Hi = input_lengths[3];
index_t Wi = input_lengths[4];
index_t Do = output_lengths[2];
index_t Ho = output_lengths[3];
index_t Wo = output_lengths[4];
std::vector<ck::index_t> input_spatial_lengths = {Di, Hi, Wi};
std::vector<ck::index_t> output_spatial_lengths = {Do, Ho, Wo};
if(output_ncdhw_stride != indices_ncdhw_stride)
throw std::runtime_error(
"output_ncdhw_stride need to be equal to indices_ncdhw_stride for now");
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_dev),
static_cast<OutDataType*>(p_out_dev),
static_cast<IndexDataType*>(p_out_indices_dev),
N,
C,
input_spatial_lengths,
window_lengths,
output_spatial_lengths,
window_strides,
input_left_pads,
input_right_pads);
input_ncdhw_lengths,
output_ncdhw_lengths,
input_ncdhw_stride,
output_ncdhw_stride,
indices_ncdhw_stride,
window_zyx_lengths,
window_zyx_strides,
window_zyx_dilations,
input_left_dhw_pads,
input_right_dhw_pads);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......@@ -342,7 +396,7 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
auto str = std::stringstream();
// clang-format off
str << "DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<" << BlockSize << ",";
str << "DevicePool3dFwd_NDHWC_NDHWC<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">";
......
......@@ -617,10 +617,12 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
using AGridDesc_AKB_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AKB_AK0_M_AK1(AGridDesc_M_K{}, 1))>;
using BGridDesc_BKB_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BKB_BK0_N_BK1(BGridDesc_N_K{}, 1))>;
using AGridDesc_AKB_AK0_M_AK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AKB_AK0_M_AK1(
AGridDesc_M_K{}, 1))>;
using BGridDesc_BKB_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BKB_BK0_N_BK1(
BGridDesc_N_K{}, 1))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
......@@ -886,11 +888,12 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
typename GridwiseGemmAtomicAdd::DefaultBlock2ETileMap,
has_main_loop>;
hipGetErrorString(hipMemset(
hipGetErrorString(hipMemsetAsync(
arg.p_e_grid_,
0,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(EDataType)));
sizeof(EDataType),
stream_config.stream_id_));
return launch_and_time_kernel(stream_config,
kernel,
......
......@@ -136,8 +136,8 @@ struct GridwiseMultiblockBatchNormForward
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadReduceSrcDesc_M_1 = decltype(
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadwiseWelford1 =
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
......
......@@ -118,8 +118,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_1 = decltype(
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
......
......@@ -121,8 +121,8 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_1 = decltype(
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
......
......@@ -115,8 +115,8 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceSrcDesc_M_1 = decltype(
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
......
......@@ -101,8 +101,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
......@@ -346,14 +346,18 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using DefaultBGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using MeanVarGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(MeanVarGridDesc_M_NBlock{}))>;
using CountGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(CountGridDesc_M_NBlock{}))>;
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
using MeanVarGridDescriptor_MBlock_MPerBlock_NBlock =
remove_cvref_t<decltype(MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(
MeanVarGridDesc_M_NBlock{}))>;
using CountGridDescriptor_MBlock_MPerBlock_NBlock =
remove_cvref_t<decltype(MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(
CountGridDesc_M_NBlock{}))>;
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using DefaultBlock2ETileMap =
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
......
......@@ -102,8 +102,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
......@@ -286,8 +286,9 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
c_grid_desc_m_n);
}
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
......
......@@ -446,14 +446,17 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
e1_grid_desc_m_n);
}
using E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(E1GridDesc_M_N{}))>;
using E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
E1GridDesc_M_N{}))>;
using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype(
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(D0sGridDesc_M_N{}))>;
using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 =
remove_cvref_t<decltype(MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
D0sGridDesc_M_N{}))>;
using D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(D1sGridDesc_M_N{}))>;
using D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
D1sGridDesc_M_N{}))>;
using DefaultBlock2E1TileMap =
remove_cvref_t<decltype(MakeDefaultBlock2E1TileMap(E1GridDesc_M_N{}))>;
......
......@@ -114,8 +114,8 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
......@@ -369,11 +369,13 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
}
using D0sGridPointer = decltype(MakeD0sGridPointer());
using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype(
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(D0sGridDesc_M_N{}))>;
using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 =
remove_cvref_t<decltype(MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
D0sGridDesc_M_N{}))>;
using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeC1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(C1GridDesc_M_N{}))>;
using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeC1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
C1GridDesc_M_N{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(C1GridDesc_M_N{}))>;
......
......@@ -113,8 +113,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
......@@ -300,8 +300,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
c_grid_desc_m_n);
}
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
......
......@@ -191,8 +191,8 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
......@@ -346,14 +346,17 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_grid_desc_m_n);
}
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
CGridDesc_M_N{}))>;
using C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(C0GridDesc_M_N{}))>;
using C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
C0GridDesc_M_N{}))>;
using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(C1GridDesc_M_N{}))>;
using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
C1GridDesc_M_N{}))>;
using ReduceGridDescriptor_MBlock_MPerBlock =
remove_cvref_t<decltype(MakeReduceGridDescriptor_MBlock_MPerBlock(ReduceGridDesc_M{}))>;
......
......@@ -7,9 +7,11 @@
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
......@@ -17,6 +19,8 @@
namespace ck {
using GemmDlAlgorithm = tensor_operation::device::GemmDlAlgorithm;
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
......@@ -25,7 +29,8 @@ template <typename GridwiseGemm,
typename CGridDesc_M0_M10_M11_N0_N10_N11,
typename Block2CTileMap,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
bool HasDoubleTailKBlockLoop,
GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
......@@ -38,6 +43,13 @@ __global__ void
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map)
{
// DPP8 is currently only supported on gfx1030
#if !defined(__gfx1030__)
if(GemmDlAlg == GemmDlAlgorithm::Dpp8)
{
return;
}
#endif
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
......@@ -88,7 +100,8 @@ template <index_t BlockSize,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector>
index_t CThreadTransferDstScalarPerVector,
GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default>
struct GridwiseGemmDl_km_kn_mn_v1r3
{
static constexpr auto I0 = Number<0>{};
......@@ -244,6 +257,45 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
c_grid_desc_m_n);
}
template <typename ABlockDesc_BK0_BM_BK1, typename BBlockDesc_BK0_BN_BK1>
__host__ __device__ static constexpr auto GetBlockwiseGemm()
{
if constexpr(GemmDlAlg == GemmDlAlgorithm::Dpp8)
{
return BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0<
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
ABlockDesc_BK0_BM_BK1,
BBlockDesc_BK0_BN_BK1,
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM110Xs,
M11N11ThreadClusterN110Xs,
M1PerThreadM111,
N1PerThreadN111>{};
}
else
{
return BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
ABlockDesc_BK0_BM_BK1,
BBlockDesc_BK0_BN_BK1,
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM110Xs,
M11N11ThreadClusterN110Xs,
M1PerThreadM111,
N1PerThreadN111>{};
}
}
using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 =
......@@ -274,7 +326,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
const auto c_m0_n0_block_cluster_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR
// HACK: this forces index data into SGPR
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
......@@ -372,20 +424,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
const auto blockwise_gemm =
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM110Xs,
M11N11ThreadClusterN110Xs,
M1PerThreadM111,
N1PerThreadN111>{};
GetBlockwiseGemm<decltype(a_k0_m_k1_block_desc), decltype(b_k0_n_k1_block_desc)>();
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
......@@ -472,7 +511,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
b_block_slice_copy_step);
// LDS doubel buffer: load next data from device mem
// LDS double buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
......@@ -992,7 +1031,7 @@ struct GridwiseGemmDl_bkm_bkn_mn_v1r3
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1,
b_block_slice_copy_step);
// LDS doubel buffer: load next data from device mem
// LDS double buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
......
......@@ -92,8 +92,8 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
......@@ -300,8 +300,9 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using DefaultBGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
// Support 2 dimension in the future. Not only M
using RGridDescriptor_MBlock_MPerBlock =
......
......@@ -346,8 +346,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{
......@@ -565,10 +565,12 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
e_grid_desc_m_n);
}
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(EGridDesc_M_N{}, 1, 1))>;
using DsGridPointer = decltype(MakeDsGridPointer());
......
......@@ -94,8 +94,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
// denorm test fix, required to work around fp16 mfma issue
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
......@@ -273,10 +273,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK = a_grid_desc_m_k.GetLength(I1);
const auto BK = b_grid_desc_n_k.GetLength(I1);
// check consistency of desc
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1)))
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK))
{
return false;
}
......@@ -294,13 +295,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
// check tile size
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0))
{
return false;
}
// check gridwise gemm pipeline
const auto num_k_loop = K / KPerBlock;
const auto num_k_loop = AK / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
......
......@@ -164,8 +164,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
......@@ -318,8 +318,9 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_grid_desc_m_n);
}
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
CGridDesc_M_N{}))>;
using ReduceGridDescriptor_MBlock_MPerBlock =
remove_cvref_t<decltype(MakeReduceGridDescriptor_MBlock_MPerBlock(ReduceGridDesc_M{}))>;
......
......@@ -375,10 +375,12 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AKB_AK0_M_AK1(AGridDesc_M_K{}, 1))>;
using DefaultBGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BKB_BK0_N_BK1(BGridDesc_N_K{}, 1))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using DefaultBlock2ETileMap =
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}, 1))>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment