Commit 39002e9e authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents b26bdd61 d52ec016
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -147,7 +147,10 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -147,7 +147,10 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy == if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Atomic) StreamKReductionStrategy::Atomic)
{ {
hipGetErrorString(hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType))); hipGetErrorString(hipMemsetAsync(karg.p_c_grid,
0,
karg.M * karg.N * sizeof(CDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
grid_dims, grid_dims,
......
...@@ -378,6 +378,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -378,6 +378,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const index_t GemmM = K; const index_t GemmM = K;
const index_t GemmN = C * X; const index_t GemmN = C * X;
const auto PadGemmM = (MPerBlock - GemmM % MPerBlock) % MPerBlock;
const auto PadGemmN = (NPerBlock - GemmN % NPerBlock) % NPerBlock;
const index_t GemmKBatch = batch_k; const index_t GemmKBatch = batch_k;
const index_t GemmK0 = const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
...@@ -496,9 +499,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -496,9 +499,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const auto wei_gemmm_gemmn_grid_desc = const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); make_naive_tensor_descriptor_packed(make_tuple(K, X * C));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, // Padd
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
wei_gemmm_gemmn_grid_desc); transform_tensor_descriptor(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch),
make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmM, PadGemmM),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch),
make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmN, PadGemmN),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
wei_gemmm_gemmn_pad_grid_desc);
} }
} }
...@@ -546,6 +577,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -546,6 +577,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const index_t GemmM = K; const index_t GemmM = K;
const index_t GemmN = C * X * Y; const index_t GemmN = C * X * Y;
const auto PadGemmM = (MPerBlock - GemmM % MPerBlock) % MPerBlock;
const auto PadGemmN = (NPerBlock - GemmN % NPerBlock) % NPerBlock;
const index_t GemmKBatch = batch_k; const index_t GemmKBatch = batch_k;
const index_t GemmK0 = const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
...@@ -651,9 +685,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -651,9 +685,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, // Padd
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
wei_grid_desc); transform_tensor_descriptor(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch),
make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmM, PadGemmM),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch),
make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmN, PadGemmN),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
wei_gemmm_gemmn_pad_grid_desc);
} }
} }
...@@ -708,6 +770,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -708,6 +770,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const index_t GemmM = K; const index_t GemmM = K;
const index_t GemmN = C * Z * X * Y; const index_t GemmN = C * Z * X * Y;
const auto PadGemmM = (MPerBlock - GemmM % MPerBlock) % MPerBlock;
const auto PadGemmN = (NPerBlock - GemmN % NPerBlock) % NPerBlock;
const index_t GemmKBatch = batch_k; const index_t GemmKBatch = batch_k;
const index_t GemmK0 = const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
...@@ -822,9 +887,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -822,9 +887,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, // Padd
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
wei_grid_desc); transform_tensor_descriptor(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch),
make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmM, PadGemmM),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch),
make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmN, PadGemmN),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
wei_gemmm_gemmn_pad_grid_desc);
} }
} // function end } // function end
......
...@@ -421,8 +421,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -421,8 +421,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
for(const auto& trans_arg : arg.gemm_kernel_args_) for(const auto& trans_arg : arg.gemm_kernel_args_)
{ {
const auto& karg = trans_arg.karg_; const auto& karg = trans_arg.karg_;
hip_check_error( hip_check_error(hipMemsetAsync(karg.p_c_grid,
hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(EDataType))); 0,
karg.M * karg.N * sizeof(EDataType),
stream_config.stream_id_));
} }
} }
......
...@@ -3,16 +3,7 @@ ...@@ -3,16 +3,7 @@
#pragma once #pragma once
#include <iostream> #include "ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp"
#include <sstream>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -30,255 +21,32 @@ template <typename InDataType, ...@@ -30,255 +21,32 @@ template <typename InDataType,
ck::index_t ReduceMThreadSliceSize, ck::index_t ReduceMThreadSliceSize,
ck::index_t ReduceKThreadSliceSize, ck::index_t ReduceKThreadSliceSize,
ck::index_t InSrcOutDstVectorSize> ck::index_t InSrcOutDstVectorSize>
struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C struct DevicePool2dFwd_NHWC_NHWC : public DevicePool3dFwd_NDHWC_NDHWC<InDataType,
: public DevicePoolFwd<4, 2, InDataType, OutDataType, IndexDataType, ReduceOpId, OutputIndex> OutDataType,
IndexDataType,
ComputeDataType,
ReduceOpId,
OutputIndex,
BlockSize,
ReduceMThreadClusterSize,
ReduceKThreadClusterSize,
ReduceMThreadSliceSize,
ReduceKThreadSliceSize,
InSrcOutDstVectorSize>
{ {
static constexpr auto I0 = Number<0>{}; using DevicePool3D = DevicePool3dFwd_NDHWC_NDHWC<InDataType,
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr index_t InOutRank = 4;
static constexpr index_t WindowRank = 2;
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation =
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation =
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
static constexpr index_t InSrcOutDstVectorDim =
0; // for NHWC, the dim C is the vector Dim for both input and output in memory, which is
// not reduced.
static constexpr ck::index_t ReduceM_BlockTileSize =
ReduceMThreadClusterSize * ReduceMThreadSliceSize;
static constexpr ck::index_t ReduceK_BlockTileSize =
ReduceKThreadClusterSize * ReduceKThreadSliceSize;
static auto MakeABGridDescriptor_A_M_K_B_M(ck::index_t N,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> window_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
{
const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1];
const index_t Ho = output_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[1];
const index_t Y = window_spatial_lengths[0];
const index_t X = window_spatial_lengths[1];
const index_t ConvStrideH = window_strides[0];
const index_t ConvStrideW = window_strides[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
const index_t ReduceMRaw = N * Ho * Wo * C;
const index_t ReduceMPad =
math::integer_least_multiple(ReduceMRaw, ReduceM_BlockTileSize) - ReduceMRaw;
const index_t ReduceKRaw = Y * X;
const index_t ReduceKPad =
math::integer_least_multiple(ReduceKRaw, ReduceK_BlockTileSize) - ReduceKRaw;
// A[ReduceM, ReduceK]
const auto in_grid_desc_n_hi_wi_c =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_grid_desc_n_hip_wip_c = transform_tensor_descriptor(
in_grid_desc_n_hi_wi_c,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_grid_desc_n_y_ho_x_wo_c = transform_tensor_descriptor(
in_grid_desc_n_hip_wip_c,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(I1, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(I1, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_grid_desc_reducemraw_reducekraw =
transform_tensor_descriptor(in_grid_desc_n_y_ho_x_wo_c,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, C)),
make_merge_transform(make_tuple(Y, X))),
make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_grid_desc_reducem_reducek = transform_tensor_descriptor(
in_grid_desc_reducemraw_reducekraw,
make_tuple(make_right_pad_transform(ReduceMRaw, ReduceMPad),
make_right_pad_transform(ReduceKRaw, ReduceKPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// B[ReduceM]
const auto out_grid_desc_reducemraw =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo * C));
const auto out_grid_desc_reducem = transform_tensor_descriptor(
out_grid_desc_reducemraw,
make_tuple(make_right_pad_transform(ReduceMRaw, ReduceMPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem);
}
using ABGridDescs = decltype(MakeABGridDescriptor_A_M_K_B_M(1, 1, {}, {}, {}, {}, {}, {}));
using AGridDesc_M_K = remove_cvref_t<decltype(ABGridDescs{}[I0])>;
using BGridDesc_M = remove_cvref_t<decltype(ABGridDescs{}[I1])>;
// TODO
struct Argument : public BaseArgument
{
Argument(const InDataType* p_in_dev,
OutDataType* p_out_dev,
IndexDataType* p_out_indices_dev,
ck::index_t N,
ck::index_t C,
std::vector<ck::index_t>& input_spatial_lengths,
std::vector<ck::index_t>& window_spatial_lengths,
std::vector<ck::index_t>& output_spatial_lengths,
std::vector<ck::index_t>& window_strides,
std::vector<ck::index_t>& input_left_pads,
std::vector<ck::index_t>& input_right_pads)
: p_in_dev_{p_in_dev},
p_out_dev_{p_out_dev},
p_out_indices_dev_{p_out_indices_dev},
a_grid_desc_m_k_{},
b_grid_desc_m_{}
{
const auto descs = MakeABGridDescriptor_A_M_K_B_M(N,
C,
input_spatial_lengths,
window_spatial_lengths,
output_spatial_lengths,
window_strides,
input_left_pads,
input_right_pads);
a_grid_desc_m_k_ = descs[I0];
b_grid_desc_m_ = descs[I1];
invariant_lowest_length_ = C;
reduce_lowest_length_ = window_spatial_lengths[1];
int32_t reduceLength = window_spatial_lengths[0] * window_spatial_lengths[1];
std::tie(in_element_op_, acc_element_op_) =
reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(reduceLength);
}
const InDataType* p_in_dev_;
OutDataType* p_out_dev_;
IndexDataType* p_out_indices_dev_;
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_M b_grid_desc_m_;
InElementwiseOperation in_element_op_;
AccElementwiseOperation acc_element_op_;
// for checking vector load/store
ck::index_t invariant_lowest_length_;
ck::index_t reduce_lowest_length_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
using gridwise_reduce =
GridwiseReduction_mk_to_m_threadwise<InDataType,
OutDataType, OutDataType,
ComputeDataType,
IndexDataType, IndexDataType,
AGridDesc_M_K, ComputeDataType,
BGridDesc_M, ReduceOpId,
ReduceOperation, OutputIndex,
InElementwiseOperation,
AccElementwiseOperation,
InMemoryDataOperationEnum::Set,
false, // propagate_nan
BlockSize, BlockSize,
ReduceMThreadClusterSize,
ReduceKThreadClusterSize,
ReduceMThreadSliceSize, ReduceMThreadSliceSize,
ReduceKThreadSliceSize, ReduceKThreadSliceSize,
InSrcOutDstVectorDim,
InSrcOutDstVectorSize,
InSrcOutDstVectorSize>; InSrcOutDstVectorSize>;
const auto kernel =
kernel_reduce_threadwise<gridwise_reduce,
OutputIndex,
true, // pooling need to return global index
false, // don't have index input
InDataType,
OutDataType,
ComputeDataType,
IndexDataType,
AGridDesc_M_K,
BGridDesc_M,
InElementwiseOperation,
AccElementwiseOperation>;
ck::index_t ReduceM = arg.a_grid_desc_m_k_.GetLength(I0);
const index_t grid_size = (ReduceM / ReduceM_BlockTileSize);
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.a_grid_desc_m_k_,
arg.b_grid_desc_m_,
arg.in_element_op_,
arg.acc_element_op_,
float(1),
arg.p_in_dev_,
nullptr,
float(0),
arg.p_out_dev_,
arg.p_out_indices_dev_);
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if(pArg->invariant_lowest_length_ % InSrcOutDstVectorSize != 0)
{
return (false);
}
return (true);
}
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_dev, MakeArgumentPointer(const void* p_in_dev,
void* p_out_dev, void* p_out_dev,
...@@ -286,62 +54,57 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C ...@@ -286,62 +54,57 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
std::vector<ck::index_t> input_lengths, std::vector<ck::index_t> input_lengths,
std::vector<ck::index_t> window_lengths, std::vector<ck::index_t> window_lengths,
std::vector<ck::index_t> output_lengths, std::vector<ck::index_t> output_lengths,
std::vector<ck::index_t>, // Suppose tensor layout = NHWC std::vector<ck::index_t> input_stride,
std::vector<ck::index_t>, // Suppose tensor layout = NHWC std::vector<ck::index_t> output_stride,
std::vector<ck::index_t>, // Suppose tensor layout = NHWC std::vector<ck::index_t> indices_stride,
std::vector<ck::index_t> window_strides, std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> pooling_dims) override std::vector<ck::index_t> pooling_dims) override
{ {
static constexpr index_t InOutRank = 4;
static constexpr index_t WindowRank = 2;
if(input_lengths.size() != InOutRank || window_lengths.size() != WindowRank || if(input_lengths.size() != InOutRank || window_lengths.size() != WindowRank ||
input_lengths.size() != InOutRank || window_strides.size() != WindowRank || input_lengths.size() != InOutRank || window_strides.size() != WindowRank ||
input_left_pads.size() != WindowRank || input_right_pads.size() != WindowRank) window_dilations.size() != WindowRank || input_left_pads.size() != WindowRank ||
input_right_pads.size() != WindowRank)
throw std::runtime_error("dimension is incorrect"); throw std::runtime_error("dimension is incorrect");
if(pooling_dims != std::vector<ck::index_t>{2, 3}) if(pooling_dims != std::vector<ck::index_t>{2, 3})
throw std::runtime_error("pooling_dims only support {2, 3} in pool2d so far"); throw std::runtime_error("pooling_dims only support {2, 3} in pool2d so far");
index_t N = input_lengths[0]; // NCHW to NCDHW
index_t C = input_lengths[1]; input_lengths.insert(input_lengths.begin() + 2, 1);
index_t Hi = input_lengths[2]; output_lengths.insert(output_lengths.begin() + 2, 1);
index_t Wi = input_lengths[3]; input_stride.insert(input_stride.begin() + 2, 0);
index_t Ho = output_lengths[2]; output_stride.insert(output_stride.begin() + 2, 0);
index_t Wo = output_lengths[3]; indices_stride.insert(indices_stride.begin() + 2, 0);
std::vector<ck::index_t> input_spatial_lengths = {Hi, Wi}; // YX to ZYX
std::vector<ck::index_t> output_spatial_lengths = {Ho, Wo}; window_lengths.insert(window_lengths.begin(), 1);
window_strides.insert(window_strides.begin(), 0);
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_dev), window_dilations.insert(window_dilations.begin(), 0);
static_cast<OutDataType*>(p_out_dev), input_left_pads.insert(input_left_pads.begin(), 0);
static_cast<IndexDataType*>(p_out_indices_dev), input_right_pads.insert(input_right_pads.begin(), 0);
N,
C, pooling_dims = {2, 3, 4};
input_spatial_lengths,
window_lengths, return DevicePool3D::MakeArgumentPointer(p_in_dev,
output_spatial_lengths, p_out_dev,
window_strides, p_out_indices_dev,
input_left_pads, input_lengths,
input_right_pads); window_lengths,
} output_lengths,
input_stride,
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override output_stride,
{ indices_stride,
return std::make_unique<Invoker>(Invoker{}); window_strides,
} window_dilations,
input_left_pads,
std::string GetTypeString() const override input_right_pads,
{ pooling_dims);
auto str = std::stringstream();
// clang-format off
str << "DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<" << BlockSize << ",";
str << "M_C" << ReduceMThreadClusterSize << "_S" << ReduceMThreadSliceSize << ",";
str << "K_C" << ReduceKThreadClusterSize << "_S" << ReduceKThreadSliceSize << ",";
str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">";
// clang-format on
return str.str();
} }
}; };
......
...@@ -886,11 +886,12 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle ...@@ -886,11 +886,12 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
typename GridwiseGemmAtomicAdd::DefaultBlock2ETileMap, typename GridwiseGemmAtomicAdd::DefaultBlock2ETileMap,
has_main_loop>; has_main_loop>;
hipGetErrorString(hipMemset( hipGetErrorString(hipMemsetAsync(
arg.p_e_grid_, arg.p_e_grid_,
0, 0,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(EDataType))); sizeof(EDataType),
stream_config.stream_id_));
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
kernel, kernel,
......
...@@ -7,9 +7,11 @@ ...@@ -7,9 +7,11 @@
#include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
...@@ -17,6 +19,8 @@ ...@@ -17,6 +19,8 @@
namespace ck { namespace ck {
using GemmDlAlgorithm = tensor_operation::device::GemmDlAlgorithm;
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
...@@ -25,7 +29,8 @@ template <typename GridwiseGemm, ...@@ -25,7 +29,8 @@ template <typename GridwiseGemm,
typename CGridDesc_M0_M10_M11_N0_N10_N11, typename CGridDesc_M0_M10_M11_N0_N10_N11,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop,
GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -38,6 +43,13 @@ __global__ void ...@@ -38,6 +43,13 @@ __global__ void
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
// DPP8 is currently only supported on gfx1030
#if !defined(__gfx1030__)
if(GemmDlAlg == GemmDlAlgorithm::Dpp8)
{
return;
}
#endif
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -88,7 +100,8 @@ template <index_t BlockSize, ...@@ -88,7 +100,8 @@ template <index_t BlockSize,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector> index_t CThreadTransferDstScalarPerVector,
GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default>
struct GridwiseGemmDl_km_kn_mn_v1r3 struct GridwiseGemmDl_km_kn_mn_v1r3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -244,6 +257,45 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -244,6 +257,45 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
c_grid_desc_m_n); c_grid_desc_m_n);
} }
template <typename ABlockDesc_BK0_BM_BK1, typename BBlockDesc_BK0_BN_BK1>
__host__ __device__ static constexpr auto GetBlockwiseGemm()
{
if constexpr(GemmDlAlg == GemmDlAlgorithm::Dpp8)
{
return BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0<
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
ABlockDesc_BK0_BM_BK1,
BBlockDesc_BK0_BN_BK1,
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM110Xs,
M11N11ThreadClusterN110Xs,
M1PerThreadM111,
N1PerThreadN111>{};
}
else
{
return BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
ABlockDesc_BK0_BM_BK1,
BBlockDesc_BK0_BN_BK1,
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM110Xs,
M11N11ThreadClusterN110Xs,
M1PerThreadM111,
N1PerThreadN111>{};
}
}
using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 = using CGridDesc_M0_M10_M11_N0_N10_N11 =
...@@ -274,7 +326,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -274,7 +326,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
const auto c_m0_n0_block_cluster_idx = const auto c_m0_n0_block_cluster_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR // HACK: this forces index data into SGPR
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
...@@ -372,20 +424,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -372,20 +424,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< GetBlockwiseGemm<decltype(a_k0_m_k1_block_desc), decltype(b_k0_n_k1_block_desc)>();
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM110Xs,
M11N11ThreadClusterN110Xs,
M1PerThreadM111,
N1PerThreadN111>{};
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
...@@ -472,7 +511,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -472,7 +511,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
b_block_slice_copy_step); b_block_slice_copy_step);
// LDS doubel buffer: load next data from device mem // LDS double buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
...@@ -992,7 +1031,7 @@ struct GridwiseGemmDl_bkm_bkn_mn_v1r3 ...@@ -992,7 +1031,7 @@ struct GridwiseGemmDl_bkm_bkn_mn_v1r3
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1, b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1,
b_block_slice_copy_step); b_block_slice_copy_step);
// LDS doubel buffer: load next data from device mem // LDS double buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf); a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf); b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
......
...@@ -35,13 +35,17 @@ __global__ void ...@@ -35,13 +35,17 @@ __global__ void
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
template <typename GridwiseGemm, typename FloatAB, typename FloatC, bool HasMainKBlockLoop> template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, kernel_gemm_xdl_cshuffle_v1(const FloatA* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
typename GridwiseGemm::Problem problem) typename GridwiseGemm::Problem problem)
{ {
...@@ -61,7 +65,8 @@ __global__ void ...@@ -61,7 +65,8 @@ __global__ void
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
typename FloatAB, typename FloatA,
typename FloatB,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
typename FloatC, typename FloatC,
...@@ -102,7 +107,8 @@ template <typename ALayout, ...@@ -102,7 +107,8 @@ template <typename ALayout,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched, LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeType = FloatC>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -463,8 +469,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -463,8 +469,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// Argument // Argument
struct Argument : public tensor_operation::device::BaseArgument, public Problem struct Argument : public tensor_operation::device::BaseArgument, public Problem
{ {
__host__ Argument(const FloatAB* p_a_grid_, __host__ Argument(const FloatA* p_a_grid_,
const FloatAB* p_b_grid_, const FloatB* p_b_grid_,
FloatC* p_c_grid_, FloatC* p_c_grid_,
index_t M_, index_t M_,
index_t N_, index_t N_,
...@@ -479,8 +485,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -479,8 +485,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{ {
} }
const FloatAB* p_a_grid; const FloatA* p_a_grid;
const FloatAB* p_b_grid; const FloatB* p_b_grid;
FloatC* p_c_grid; FloatC* p_c_grid;
}; };
...@@ -541,8 +547,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -541,8 +547,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto c_block_size = constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * return math::max((a_block_space_size_aligned * sizeof(ComputeType) +
sizeof(FloatAB), b_block_space_size_aligned * sizeof(ComputeType)),
c_block_size * sizeof(FloatCShuffle)); c_block_size * sizeof(FloatCShuffle));
} }
...@@ -676,8 +682,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -676,8 +682,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>; using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
template <bool HasMainKBlockLoop> template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatA* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const Problem& problem) const Problem& problem)
...@@ -743,8 +749,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -743,8 +749,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
Sequence<AK0Number, MPerBlock, AK1Number>, Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatA,
FloatAB, ComputeType,
decltype(a_grid_desc_ak0_m_ak1), decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -774,8 +780,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -774,8 +780,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
Sequence<BK0Number, NPerBlock, BK1Number>, Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatB,
FloatAB, ComputeType,
decltype(b_grid_desc_bk0_n_bk1), decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -805,11 +811,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -805,11 +811,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// sanity check // sanity check
constexpr index_t KPack = constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number), math::max(math::lcm(AK1Number, BK1Number),
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<ComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
FloatAB, ComputeType,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -827,10 +833,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -827,10 +833,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<ComputeType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned, static_cast<ComputeType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
...@@ -78,17 +78,18 @@ struct GridwiseNormalizationSplitK1st ...@@ -78,17 +78,18 @@ struct GridwiseNormalizationSplitK1st
static constexpr auto ThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{}; static constexpr auto ThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
__device__ static int __device__ static int
GetKPerThread(int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id) GetKPerThread(int k, int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id)
{ {
bool is_rightmost_block = block_k_cluster_id == kGridSize - 1; bool is_rightmost_block = block_k_cluster_id == kGridSize - 1;
if(is_rightmost_block) if(is_rightmost_block)
{ {
int left_kPerBlock = math::integer_divide_ceil(kRaw, kGridSize); int left_kPerBlock = math::integer_divide_ceil(k, kGridSize);
int kPerBlock = kRaw % kGridSize == 0 ? left_kPerBlock : kRaw % left_kPerBlock; int kRightmostBlock = kRaw - left_kPerBlock * (kGridSize - 1);
int kPerThread = int kPerThread = kRightmostBlock < K_BlockTileSize
kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize); ? 0
int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize; : KThreadSliceSize * (kRightmostBlock / K_BlockTileSize);
int kPerBlockTail = kRightmostBlock - kPerThread * KThreadClusterSize;
if(kPerBlockTail > 0) if(kPerBlockTail > 0)
{ {
...@@ -105,7 +106,7 @@ struct GridwiseNormalizationSplitK1st ...@@ -105,7 +106,7 @@ struct GridwiseNormalizationSplitK1st
} }
else else
{ {
int kPerBlock = math::integer_divide_ceil(kRaw, kGridSize); int kPerBlock = math::integer_divide_ceil(k, kGridSize);
return KThreadSliceSize * (kPerBlock / K_BlockTileSize); return KThreadSliceSize * (kPerBlock / K_BlockTileSize);
} }
} }
...@@ -193,10 +194,13 @@ struct GridwiseNormalizationSplitK1st ...@@ -193,10 +194,13 @@ struct GridwiseNormalizationSplitK1st
auto var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize()); p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize());
auto threadwise_welford = ThreadwiseWelford(); auto threadwise_welford = ThreadwiseWelford();
int kRaw = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]; int kRaw = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
threadwise_welford.max_count_ = threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k.GetLength(I1),
GetKPerThread(kRaw, k_grid_size, block_k_cluster_id, thread_k_cluster_id); kRaw,
k_grid_size,
block_k_cluster_id,
thread_k_cluster_id);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f); mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_gemm_dpp.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/inner_product_dpp8.hpp"
#include "ck/utility/math.hpp"
namespace ck {
/**
* Threadwise contraction using dot instructions with DPP8 modifier.
*
* Assumptions:
* 1. `AThreadDesc_TK0_TM0_TM1_TK1`, `BThreadDesc_TK0_TN0_TN1_TK1`, `CThreadDesc_TM0_TM1_TN0_TN1`
* are known at compile-time;
* 2. `AOriginIdx`, `BOriginIdx`, `COriginIdx` are known at compile-time;
* 3. `TM0` is equal to 1 and `TN0` is equal to 1;
* 4. When `ShareA` is set (unset, respectively), `TM1` (`TN1`, respectively) is divisible by
* the size of the lane group (`dpp8::lane_group_size`).
*/
template <typename FloatA,
typename FloatB,
typename FloatC,
typename AThreadDesc_TK0_TM0_TM1_TK1,
typename BThreadDesc_TK0_TN0_TN1_TK1,
typename CThreadDesc_TM0_TM1_TN0_TN1,
typename TKLengths,
typename TMLengths,
typename TNLengths,
bool ShareA,
typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t TK0 = TKLengths{}[I0];
static constexpr index_t TK1 = TKLengths{}[I1];
static constexpr index_t TM0 = TMLengths{}[I0];
static constexpr index_t TM1 = TMLengths{}[I1];
static constexpr index_t TN0 = TNLengths{}[I0];
static constexpr index_t TN1 = TNLengths{}[I1];
static_assert(TM0 == 1 && TN0 == 1);
static_assert((ShareA && TM1 % dpp8::lane_group_size == 0) ||
(!ShareA && TN1 % dpp8::lane_group_size == 0));
static constexpr index_t shared_elems_per_lane =
ShareA ? TM1 / dpp8::lane_group_size : TN1 / dpp8::lane_group_size;
__device__ constexpr ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
{
static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
"wrong!");
}
template <typename ABuffer,
typename AOriginIdx,
typename BBuffer,
typename BOriginIdx,
typename CBuffer,
typename COriginIdx>
__device__ static void Run(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
{
static_assert(is_known_at_compile_time<remove_cvref_t<AOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<BOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<COriginIdx>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
static_assert(
is_same<remove_cvref_t<typename ABuffer::type>, remove_cvref_t<FloatA>>::value &&
is_same<remove_cvref_t<typename BBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type");
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, TK0, 1>{}([&](auto tk0) {
static_for<0, TM1, 1>{}([&](auto tm1) {
static_for<0, TN1, 1>{}([&](auto tn1) {
vector_type<FloatA, TK1> a_vec;
vector_type<FloatB, TK1> b_vec;
static_for<0, TK1, 1>{}([&](auto tk1) {
constexpr index_t local_tm1 = ShareA ? tm1 % shared_elems_per_lane : tm1;
constexpr index_t a_offset = AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
a_origin_idx + make_multi_index(tk0, 0, local_tm1, tk1));
constexpr index_t local_tn1 = ShareA ? tn1 : tn1 % shared_elems_per_lane;
constexpr index_t b_offset = BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
b_origin_idx + make_multi_index(tk0, 0, local_tn1, tk1));
a_vec.template AsType<FloatA>()(tk1) = a_buf[Number<a_offset>{}];
b_vec.template AsType<FloatB>()(tk1) = b_buf[Number<b_offset>{}];
});
using a_vector_t = typename vector_type<FloatA, TK1>::type;
using b_vector_t = typename vector_type<FloatB, TK1>::type;
constexpr index_t c_offset = CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
c_origin_idx + make_multi_index(0, tm1, 0, tn1));
constexpr int src_lane =
ShareA ? (tm1 / shared_elems_per_lane) % dpp8::lane_group_size
: (tn1 / shared_elems_per_lane) % dpp8::lane_group_size;
dpp8::inner_product_dpp<a_vector_t, b_vector_t, FloatC, src_lane, ShareA>(
a_vec.template AsType<a_vector_t>()[I0],
b_vec.template AsType<b_vector_t>()[I0],
c_buf(Number<c_offset>{}));
});
});
});
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/amd_gemm_dpp.hpp"
namespace ck {
namespace dpp8 {
/// Number of lanes that can share data using DPP8 modifiers.
constexpr index_t lane_group_size = 8;
__device__ index_t get_lane_group_local_idx() { return threadIdx.x / lane_group_size; }
__device__ index_t get_thread_idx_in_lane_group() { return threadIdx.x % lane_group_size; }
} // namespace dpp8
} // namespace ck
...@@ -94,8 +94,8 @@ __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const h ...@@ -94,8 +94,8 @@ __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const h
const vector_type<half_t, 2> b_vector{b}; const vector_type<half_t, 2> b_vector{b};
static_for<0, 2, 1>{}([&](auto i) { static_for<0, 2, 1>{}([&](auto i) {
c += type_convert<int32_t>(a_vector.AsType<half_t>()[i]) * c += type_convert<float>(a_vector.AsType<half_t>()[i]) *
type_convert<int32_t>(b_vector.AsType<half_t>()[i]); type_convert<float>(b_vector.AsType<half_t>()[i]);
}); });
#endif #endif
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "amd_gemm_dpp.hpp"
#include "data_type.hpp"
#include "type_convert.hpp"
namespace ck {
namespace dpp8 {
template <int SrcLaneIdx>
__device__ void inline_v_dot2c_dpp8_instr(const half2_t& a, const half2_t& b, float& c);
// clang-format off
template <>
__device__ void inline_v_dot2c_dpp8_instr<0>(const half2_t& a, const half2_t& b, float& c){
asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[0, 0, 0, 0, 0, 0, 0, 0]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
}
template <>
__device__ void inline_v_dot2c_dpp8_instr<1>(const half2_t& a, const half2_t& b, float& c){
asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[1, 1, 1, 1, 1, 1, 1, 1]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
}
template <>
__device__ void inline_v_dot2c_dpp8_instr<2>(const half2_t& a, const half2_t& b, float& c){
asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[2, 2, 2, 2, 2, 2, 2, 2]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
}
template <>
__device__ void inline_v_dot2c_dpp8_instr<3>(const half2_t& a, const half2_t& b, float& c){
asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[3, 3, 3, 3, 3, 3, 3, 3]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
}
template <>
__device__ void inline_v_dot2c_dpp8_instr<4>(const half2_t& a, const half2_t& b, float& c){
asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[4, 4, 4, 4, 4, 4, 4, 4]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
}
template <>
__device__ void inline_v_dot2c_dpp8_instr<5>(const half2_t& a, const half2_t& b, float& c){
asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[5, 5, 5, 5, 5, 5, 5, 5]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
}
template <>
__device__ void inline_v_dot2c_dpp8_instr<6>(const half2_t& a, const half2_t& b, float& c){
asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[6, 6, 6, 6, 6, 6, 6, 6]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
}
template <>
__device__ void inline_v_dot2c_dpp8_instr<7>(const half2_t& a, const half2_t& b, float& c){
asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[7, 7, 7, 7, 7, 7, 7, 7]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
}
// clang-format on
/**
* Dot product of two vectors using `v_dot` instruction with DPP8 submitted as inline assembly.
*/
template <int SrcLaneIdx, bool ShareA>
__device__ void inline_v_dot2c_dpp8(const half2_t& a, const half2_t& b, float& c)
{
static_assert(SrcLaneIdx >= 0 && SrcLaneIdx < dpp8::lane_group_size,
"DPP8 src broadcast lane out of range <0, 7>.");
if constexpr(ShareA)
{
inline_v_dot2c_dpp8_instr<SrcLaneIdx>(a, b, c);
}
else
{
inline_v_dot2c_dpp8_instr<SrcLaneIdx>(b, a, c);
}
}
/**
* DPP8 instrinsics expects to get an integer mask, hardcoding integers for specific broadcast
* patters.
*/
constexpr std::array<int, dpp8::lane_group_size> IntrinsicMaskDpp8 = {
0, // 0, 0, 0, 0, 0, 0, 0, 0
2396745, // 1, 1, 1, 1, 1, 1, 1, 1
4793490, // 2, 2, 2, 2, 2, 2, 2, 2
7190235, // 3, 3, 3, 3, 3, 3, 3, 3
9586980, // 4, 4, 4, 4, 4, 4, 4, 4
11983725, // 5, 5, 5, 5, 5, 5, 5, 5
14380470, // 6, 6, 6, 6, 6, 6, 6, 6
16777215, // 7, 7, 7, 7, 7, 7, 7, 7
};
/**
* Returns DPP8 sel modifier as an integer required for the intrinsic instruction.
*/
template <int SrcLaneIdx>
constexpr int get_dpp_sel_mask_broadcast()
{
static_assert(SrcLaneIdx >= 0 && SrcLaneIdx < dpp8::lane_group_size,
"DPP8 src broadcast lane out of range <0, 7>.");
return IntrinsicMaskDpp8[SrcLaneIdx];
}
template <int SrcLaneIdx>
__device__ void intrinsic_fdot2_impl(const half2_t& a, const half2_t& b, float& c)
{
constexpr int sel_mask = get_dpp_sel_mask_broadcast<SrcLaneIdx>();
const half2_t val_from_other_lane =
bit_cast<half2_t>(__builtin_amdgcn_mov_dpp8(bit_cast<int>(a), sel_mask));
c = __builtin_amdgcn_fdot2(val_from_other_lane, b, c, false);
}
/**
* Dot product of two vectors using `v_dot` instruction with DPP8 submitted using intrinsics.
*/
template <int SrcLaneIdx, bool ShareA>
__device__ void intrinsic_fdot2(const half2_t& a, const half2_t& b, float& c)
{
if constexpr(ShareA)
{
intrinsic_fdot2_impl<SrcLaneIdx>(a, b, c);
}
else
{
intrinsic_fdot2_impl<SrcLaneIdx>(b, a, c);
}
}
/**
* Dot product of two input vectors `a`, `b` using `v_dot` instructions with DPP modifier.
*
* DPP modifier allows us to share one of the vectors between lanes in a lane group.
* When `ShareA` is set, instruction uses vector `a` from lane `SrcLaneIdx` from the same
* lane group (8 lanes per lane group in DPP8). When `ShareA` is not set, vector `b` is shared.
* Note that all the threads in a lane group uses the same vector - broadcast pattern.
*
* `SrcLaneIdx` must be in range from 0 to 7.
*/
template <typename TA, typename TB, typename TC, int SrcLaneIdx, bool ShareA>
__device__ void inner_product_dpp(const TA& a, const TB& b, TC& c)
{
#if CK_USE_AMD_V_DOT_DPP8_INLINE_ASM
inline_v_dot2c_dpp8<SrcLaneIdx, ShareA>(a, b, c);
#else
intrinsic_fdot2<SrcLaneIdx, ShareA>(a, b, c);
#endif
}
} // namespace dpp8
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
// dinput descriptor in [N, C, Do, Ho, Wo] order
// doutput descriptor in [N, C, Di, Hi, Wi] order
// phyiscal layout is irrelavent
template <ck::index_t NDimSpatial,
typename DInDataType,
typename DOutDataType,
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct ReferenceAvgPoolBwd : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(Tensor<DInDataType>& dinput,
const Tensor<DOutDataType>& doutput,
std::vector<ck::index_t> window_spatial_lengths,
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> dinput_left_pads,
std::vector<ck::index_t> dinput_right_pads)
: dinput_{dinput},
doutput_{doutput},
window_spatial_lengths_{window_spatial_lengths},
window_strides_{window_strides},
window_dilations_{window_dilations},
in_left_pads_{dinput_left_pads},
in_right_pads_{dinput_right_pads}
{
}
Tensor<DInDataType>& dinput_;
const Tensor<DOutDataType>& doutput_;
std::vector<ck::index_t> window_spatial_lengths_;
std::vector<index_t> window_strides_;
std::vector<index_t> window_dilations_;
std::vector<index_t> in_left_pads_;
std::vector<index_t> in_right_pads_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceAvgPoolBwd::Argument;
template <ck::index_t NDimSpatial_,
typename std::enable_if<NDimSpatial_ == 1, bool>::type = false>
float RunAvgPoolBwd(const Argument& arg)
{
// Let input = x, outpu = y
// shape of x = [10], y = [6]
// window_size = 5, pad = 0, stride = 1, dilation = 1
// Forward:
// y0 = 1/5 * (x0 + x1 + x2 + x3 + x4)
// y1 = 1/5 * (x1 + x2 + x3 + x4 + x5)
// ...
// y5 = 1/5 * (x5 + x6 + x7 + x8 + x9)
// y6 = 1/5 * (x6 + x7 + x8 + x9)
// ...
// y9 = 1/5 * (x9)
// Backward:
// shape of dy = [6], dx = [10]
// dx0 = 1/5 * dy0
// dx1 = 1/5 * (dy0 + dy1)
// dx2 = 1/5 * (dy0 + dy1 + dy2)
// ...
// dx4 = 1/5 * (dy0 + dy1 + dy2 + dy3 + dy4)
// dx5 = 1/5 * (dy1 + dy2 + dy3 + dy4 + dy5)
// ...
// dx9 = 1/5 * (dy5 + dy6 + dy7 + dy8 + dy9)
auto f_ncw = [&](auto n, auto c, auto wi) {
std::size_t X = arg.window_spatial_lengths_[0];
std::size_t Wo = arg.doutput_.GetLengths()[2];
float v_acc = 0;
for(std::size_t x = 0; x < X; ++x)
{
// Out_Position = (In_Position + pad - x * dilation) / stride
auto w_tmp = static_cast<ck::long_index_t>(wi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(x * arg.window_dilations_[0]);
// Check the input pixel validity (in perspective of being affected by some
// doutput pixel)
if(w_tmp % arg.window_strides_[0] == 0)
{
auto wo = static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(arg.window_strides_[0]);
// Get the doutput pixel in valid range to accumulate the gradients for this
// input pixel
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{
v_acc += ck::type_convert<float>(arg.doutput_(n, c, wo));
}
}
}
v_acc /= ck::type_convert<float>(X);
arg.dinput_(n, c, wi) = ck::type_convert<DInDataType>(v_acc);
};
make_ParallelTensorFunctor(f_ncw,
arg.dinput_.GetLengths()[0],
arg.dinput_.GetLengths()[1],
arg.dinput_.GetLengths()[2])(
std::thread::hardware_concurrency());
return 0;
}
template <ck::index_t NDimSpatial_,
typename std::enable_if<NDimSpatial_ == 2, bool>::type = false>
float RunAvgPoolBwd(const Argument& arg)
{
auto f_nchw = [&](auto n, auto c, auto hi, auto wi) {
std::size_t Y = arg.window_spatial_lengths_[0];
std::size_t X = arg.window_spatial_lengths_[1];
std::size_t Ho = arg.doutput_.GetLengths()[2];
std::size_t Wo = arg.doutput_.GetLengths()[3];
float v_acc = 0;
for(std::size_t y = 0; y < Y; ++y)
{
// Out_Position = (In_Position + pad - x * dilation) / stride
auto h_tmp = static_cast<ck::long_index_t>(hi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(y * arg.window_dilations_[0]);
// Check the input pixel validity (in perspective of being affected by some
// doutput pixel)
if(h_tmp % arg.window_strides_[0] == 0)
{
auto ho = static_cast<ck::long_index_t>(h_tmp) /
static_cast<ck::long_index_t>(arg.window_strides_[0]);
// Get the doutput pixel in valid range to accumulate the gradients for this
// input pixel
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{
for(std::size_t x = 0; x < X; ++x)
{
auto w_tmp =
static_cast<ck::long_index_t>(wi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[1]) -
static_cast<ck::long_index_t>(x * arg.window_dilations_[1]);
if(w_tmp % arg.window_strides_[1] == 0)
{
auto wo = static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(arg.window_strides_[1]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{
v_acc +=
ck::type_convert<float>(arg.doutput_(n, c, ho, wo));
}
}
}
}
}
}
v_acc /= ck::type_convert<float>(Y * X);
arg.dinput_(n, c, hi, wi) = ck::type_convert<DInDataType>(v_acc);
};
make_ParallelTensorFunctor(f_nchw,
arg.dinput_.GetLengths()[0],
arg.dinput_.GetLengths()[1],
arg.dinput_.GetLengths()[2],
arg.dinput_.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
template <ck::index_t NDimSpatial_,
typename std::enable_if<NDimSpatial_ == 3, bool>::type = false>
float RunAvgPoolBwd(const Argument& arg)
{
auto f_ncdhw = [&](auto n, auto c, auto di, auto hi, auto wi) {
std::size_t Z = arg.window_spatial_lengths_[0];
std::size_t Y = arg.window_spatial_lengths_[1];
std::size_t X = arg.window_spatial_lengths_[2];
std::size_t Do = arg.doutput_.GetLengths()[2];
std::size_t Ho = arg.doutput_.GetLengths()[3];
std::size_t Wo = arg.doutput_.GetLengths()[4];
float v_acc = 0;
for(std::size_t z = 0; z < Z; ++z)
{
// Out_Position = (In_Position + pad - x * dilation) / stride
auto d_tmp = static_cast<ck::long_index_t>(di) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(z * arg.window_dilations_[0]);
// Check the input pixel validity (in perspective of being affected by some
// doutput pixel)
if(d_tmp % arg.window_strides_[0] == 0)
{
auto do_ = static_cast<ck::long_index_t>(d_tmp) /
static_cast<ck::long_index_t>(arg.window_strides_[0]);
// Get the doutput pixel in valid range to accumulate the gradients for this
// input pixel
if(do_ >= 0 && ck::type_convert<std::size_t>(do_) < Do)
{
for(std::size_t y = 0; y < Y; ++y)
{
auto h_tmp =
static_cast<ck::long_index_t>(hi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[1]) -
static_cast<ck::long_index_t>(y * arg.window_dilations_[1]);
if(h_tmp % arg.window_strides_[1] == 0)
{
auto ho = static_cast<ck::long_index_t>(h_tmp) /
static_cast<ck::long_index_t>(arg.window_strides_[1]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{
for(std::size_t x = 0; x < X; ++x)
{
auto w_tmp = static_cast<ck::long_index_t>(wi) +
static_cast<ck::long_index_t>(
arg.in_left_pads_[2]) -
static_cast<ck::long_index_t>(
x * arg.window_dilations_[2]);
if(w_tmp % arg.window_strides_[2] == 0)
{
auto wo = static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(
arg.window_strides_[2]);
if(wo >= 0 &&
ck::type_convert<std::size_t>(wo) < Wo)
{
v_acc += ck::type_convert<float>(
arg.doutput_(n, c, do_, ho, wo));
}
}
}
}
}
}
}
}
}
v_acc /= ck::type_convert<float>(Z * Y * X);
arg.dinput_(n, c, di, hi, wi) = ck::type_convert<DInDataType>(v_acc);
};
make_ParallelTensorFunctor(f_ncdhw,
arg.dinput_.GetLengths()[0],
arg.dinput_.GetLengths()[1],
arg.dinput_.GetLengths()[2],
arg.dinput_.GetLengths()[3],
arg.dinput_.GetLengths()[4])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const Argument& arg)
{
if(!(arg.dinput_.GetNumOfDimension() == NDimSpatial + 2 &&
arg.doutput_.GetNumOfDimension() == NDimSpatial + 2))
{
throw std::runtime_error("wrong! inconsistent dimension");
}
return RunAvgPoolBwd<NDimSpatial>(arg);
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(Tensor<DInDataType>& dinput,
const Tensor<DOutDataType>& doutput,
std::vector<ck::index_t> window_spatial_lengths,
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> dinput_left_pads,
std::vector<ck::index_t> dinput_right_pads)
{
if(window_spatial_lengths.size() != NDimSpatial || window_strides.size() != NDimSpatial ||
window_dilations.size() != NDimSpatial || dinput_left_pads.size() != NDimSpatial ||
dinput_right_pads.size() != NDimSpatial)
throw std::runtime_error("dimension is incorrect");
return Argument{dinput,
doutput,
window_spatial_lengths,
window_strides,
window_dilations,
dinput_left_pads,
dinput_right_pads};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceAvgPoolBwd"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
...@@ -125,7 +125,7 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -125,7 +125,7 @@ struct ReferenceConvBwdData : public device::BaseOperator
arg.in_element_op_(v_in, v_acc); arg.in_element_op_(v_in, v_acc);
arg.input_(g, n, c, wi) = ck::type_convert<InDataType>(v_acc); arg.input_(g, n, c, wi) = ck::type_convert<InDataType>(v_in);
}; };
make_ParallelTensorFunctor(f_ncw, make_ParallelTensorFunctor(f_ncw,
...@@ -201,7 +201,7 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -201,7 +201,7 @@ struct ReferenceConvBwdData : public device::BaseOperator
arg.in_element_op_(v_in, v_acc); arg.in_element_op_(v_in, v_acc);
arg.input_(g, n, c, hi, wi) = ck::type_convert<InDataType>(v_acc); arg.input_(g, n, c, hi, wi) = ck::type_convert<InDataType>(v_in);
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(f_nchw,
...@@ -299,7 +299,7 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -299,7 +299,7 @@ struct ReferenceConvBwdData : public device::BaseOperator
arg.in_element_op_(v_in, v_acc); arg.in_element_op_(v_in, v_acc);
arg.input_(g, n, c, di, hi, wi) = ck::type_convert<InDataType>(v_acc); arg.input_(g, n, c, di, hi, wi) = ck::type_convert<InDataType>(v_in);
}; };
make_ParallelTensorFunctor(f_ncdhw, make_ParallelTensorFunctor(f_ncdhw,
......
...@@ -39,6 +39,7 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -39,6 +39,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
Tensor<IndexDataType>& out_indices, Tensor<IndexDataType>& out_indices,
const std::vector<ck::index_t>& window_spatial_lengths, const std::vector<ck::index_t>& window_spatial_lengths,
const std::vector<ck::index_t>& window_strides, const std::vector<ck::index_t>& window_strides,
const std::vector<ck::index_t>& window_dilations,
const std::vector<ck::index_t>& in_left_pads, const std::vector<ck::index_t>& in_left_pads,
const std::vector<ck::index_t>& /*in_right_pads*/) const std::vector<ck::index_t>& /*in_right_pads*/)
: in_(in), : in_(in),
...@@ -46,6 +47,7 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -46,6 +47,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
out_indices_(out_indices), out_indices_(out_indices),
window_spatial_lengths_(window_spatial_lengths), window_spatial_lengths_(window_spatial_lengths),
window_strides_(window_strides), window_strides_(window_strides),
window_dilations_(window_dilations),
in_left_pads_(in_left_pads), in_left_pads_(in_left_pads),
reduceLength_(1) reduceLength_(1)
{ {
...@@ -58,6 +60,7 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -58,6 +60,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
Tensor<IndexDataType>& out_indices_; Tensor<IndexDataType>& out_indices_;
const std::vector<ck::index_t>& window_spatial_lengths_; const std::vector<ck::index_t>& window_spatial_lengths_;
const std::vector<ck::index_t>& window_strides_; const std::vector<ck::index_t>& window_strides_;
const std::vector<ck::index_t>& window_dilations_;
const std::vector<ck::index_t>& in_left_pads_; const std::vector<ck::index_t>& in_left_pads_;
int reduceLength_; int reduceLength_;
}; };
...@@ -85,14 +88,17 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -85,14 +88,17 @@ struct ReferencePoolingFwd : public device::BaseOperator
for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z) for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z)
{ {
ck::index_t di = do_ * arg.window_strides_[0] + z - arg.in_left_pads_[0]; ck::index_t di = do_ * arg.window_strides_[0] +
z * arg.window_dilations_[0] - arg.in_left_pads_[0];
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y) for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y)
{ {
ck::index_t hi = ho * arg.window_strides_[1] + y - arg.in_left_pads_[1]; ck::index_t hi = ho * arg.window_strides_[1] +
y * arg.window_dilations_[1] - arg.in_left_pads_[1];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x) for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x)
{ {
ck::index_t wi = ck::index_t wi = wo * arg.window_strides_[2] +
wo * arg.window_strides_[2] + x - arg.in_left_pads_[2]; x * arg.window_dilations_[2] -
arg.in_left_pads_[2];
if(di >= 0 && if(di >= 0 &&
di < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) && di < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
hi >= 0 && hi >= 0 &&
...@@ -136,14 +142,17 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -136,14 +142,17 @@ struct ReferencePoolingFwd : public device::BaseOperator
for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z) for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z)
{ {
ck::index_t di = do_ * arg.window_strides_[0] + z - arg.in_left_pads_[0]; ck::index_t di = do_ * arg.window_strides_[0] +
z * arg.window_dilations_[0] - arg.in_left_pads_[0];
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y) for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y)
{ {
ck::index_t hi = ho * arg.window_strides_[1] + y - arg.in_left_pads_[1]; ck::index_t hi = ho * arg.window_strides_[1] +
y * arg.window_dilations_[1] - arg.in_left_pads_[1];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x) for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x)
{ {
ck::index_t wi = ck::index_t wi = wo * arg.window_strides_[2] +
wo * arg.window_strides_[2] + x - arg.in_left_pads_[2]; x * arg.window_dilations_[2] -
arg.in_left_pads_[2];
if(di >= 0 && if(di >= 0 &&
di < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) && di < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
hi >= 0 && hi >= 0 &&
...@@ -202,10 +211,12 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -202,10 +211,12 @@ struct ReferencePoolingFwd : public device::BaseOperator
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y) for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y)
{ {
ck::index_t hi = ho * arg.window_strides_[0] + y - arg.in_left_pads_[0]; ck::index_t hi = ho * arg.window_strides_[0] +
y * arg.window_dilations_[0] - arg.in_left_pads_[0];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[1]; ++x) for(ck::index_t x = 0; x < arg.window_spatial_lengths_[1]; ++x)
{ {
ck::index_t wi = wo * arg.window_strides_[1] + x - arg.in_left_pads_[1]; ck::index_t wi = wo * arg.window_strides_[1] +
x * arg.window_dilations_[1] - arg.in_left_pads_[1];
if(hi >= 0 && if(hi >= 0 &&
hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) && hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
wi >= 0 && wi >= 0 &&
...@@ -308,6 +319,7 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -308,6 +319,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
Tensor<IndexDataType>& out_indices, Tensor<IndexDataType>& out_indices,
const std::vector<ck::index_t>& window_spatial_lengths, const std::vector<ck::index_t>& window_spatial_lengths,
const std::vector<ck::index_t>& window_strides, const std::vector<ck::index_t>& window_strides,
const std::vector<ck::index_t>& window_dilations,
const std::vector<ck::index_t>& in_left_pads, const std::vector<ck::index_t>& in_left_pads,
const std::vector<ck::index_t>& in_right_pads) const std::vector<ck::index_t>& in_right_pads)
{ {
...@@ -316,6 +328,7 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -316,6 +328,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
out_indices, out_indices,
window_spatial_lengths, window_spatial_lengths,
window_strides, window_strides,
window_dilations,
in_left_pads, in_left_pads,
in_right_pads}; in_right_pads};
} }
......
...@@ -23,6 +23,11 @@ void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances( ...@@ -23,6 +23,11 @@ void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances( void add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
...@@ -33,6 +38,11 @@ void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances( ...@@ -33,6 +38,11 @@ void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances( void add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
...@@ -43,6 +53,11 @@ void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances( ...@@ -43,6 +53,11 @@ void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances( void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
...@@ -53,6 +68,11 @@ void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances( ...@@ -53,6 +68,11 @@ void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances( void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
...@@ -354,6 +374,7 @@ struct DeviceOperationInstanceFactory< ...@@ -354,6 +374,7 @@ struct DeviceOperationInstanceFactory<
#ifdef DL_KERNELS #ifdef DL_KERNELS
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
add_device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
#endif #endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
} }
...@@ -364,6 +385,7 @@ struct DeviceOperationInstanceFactory< ...@@ -364,6 +385,7 @@ struct DeviceOperationInstanceFactory<
#ifdef DL_KERNELS #ifdef DL_KERNELS
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs);
add_device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
#endif #endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
...@@ -375,6 +397,7 @@ struct DeviceOperationInstanceFactory< ...@@ -375,6 +397,7 @@ struct DeviceOperationInstanceFactory<
#ifdef DL_KERNELS #ifdef DL_KERNELS
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs);
add_device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instances(op_ptrs);
#endif #endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs);
} }
...@@ -385,6 +408,7 @@ struct DeviceOperationInstanceFactory< ...@@ -385,6 +408,7 @@ struct DeviceOperationInstanceFactory<
#ifdef DL_KERNELS #ifdef DL_KERNELS
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs);
add_device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instances(op_ptrs);
#endif #endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs);
} }
......
...@@ -63,6 +63,7 @@ using device_grouped_conv_bwd_data_xdl_f16_instances = ...@@ -63,6 +63,7 @@ using device_grouped_conv_bwd_data_xdl_f16_instances =
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
// TODO: After enable, add instance for small conv.K and conv.C
#endif #endif
// clang-format on // clang-format on
>; >;
...@@ -97,6 +98,7 @@ using device_grouped_conv_bwd_data_xdl_bf16_instances = ...@@ -97,6 +98,7 @@ using device_grouped_conv_bwd_data_xdl_bf16_instances =
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
// TODO: After enable, add instance for small conv.K and conv.C
#endif #endif
// clang-format on // clang-format on
>; >;
...@@ -131,6 +133,7 @@ using device_grouped_conv_bwd_data_xdl_f32_instances = ...@@ -131,6 +133,7 @@ using device_grouped_conv_bwd_data_xdl_f32_instances =
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
// TODO: After enable, add instance for small conv.K and conv.C
#endif #endif
// clang-format on // clang-format on
>; >;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment