Commit 6805df0e authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into gelu

parents 1fdbe3fe e4584d91
......@@ -557,11 +557,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
float ave_time = 0;
using Add =
ck::tensor_operation::binary_element_wise::Add<CDataType, CDataType, CDataType>;
using Substract = ck::tensor_operation::binary_element_wise::
Substract<CDataType, CDataType, CDataType>;
using GridwiseBinAdd = GridwiseBinaryElementwise_1D<CDataType,
using Add = ck::tensor_operation::element_wise::Add;
using Subtract = ck::tensor_operation::element_wise::Subtract;
using GridwiseBinAdd = GridwiseBinaryElementwise_1D<CDataType,
CDataType,
CDataType,
CDataType,
......@@ -573,19 +571,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
AScalarPerVector,
BScalarPerVector,
CScalarPerVector>;
using GridwiseBinSubstract = GridwiseBinaryElementwise_1D<CDataType,
CDataType,
CDataType,
CDataType,
CGridDesc_M,
CGridDesc_M,
CGridDesc_M,
Substract,
MPerThread,
AScalarPerVector,
BScalarPerVector,
CScalarPerVector>;
const auto add_kernel = kernel_binary_elementwise_1d<GridwiseBinAdd,
using GridwiseBinSubtract = GridwiseBinaryElementwise_1D<CDataType,
CDataType,
CDataType,
CDataType,
CGridDesc_M,
CGridDesc_M,
CGridDesc_M,
Subtract,
MPerThread,
AScalarPerVector,
BScalarPerVector,
CScalarPerVector>;
const auto add_kernel = kernel_binary_elementwise_1d<GridwiseBinAdd,
CDataType,
CDataType,
CDataType,
......@@ -593,14 +591,14 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CGridDesc_M,
CGridDesc_M,
Add>;
const auto substract_kernel = kernel_binary_elementwise_1d<GridwiseBinSubstract,
CDataType,
CDataType,
CDataType,
CGridDesc_M,
CGridDesc_M,
CGridDesc_M,
Substract>;
const auto subtract_kernel = kernel_binary_elementwise_1d<GridwiseBinSubtract,
CDataType,
CDataType,
CDataType,
CGridDesc_M,
CGridDesc_M,
CGridDesc_M,
Subtract>;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
......@@ -653,7 +651,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
// c_real = aux - aux_2
ave_time += launch_and_time_kernel(stream_config,
substract_kernel,
subtract_kernel,
dim3(grid_size),
dim3(BlockSize),
0,
......@@ -663,7 +661,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.c_grid_desc_m_,
arg.c_grid_desc_m_,
arg.c_grid_desc_m_,
Substract{});
Subtract{});
ave_time +=
launch_and_time_kernel(stream_config,
......@@ -764,7 +762,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
// c_real = aux - aux_2
ave_time += launch_and_time_kernel(stream_config,
substract_kernel,
subtract_kernel,
dim3(grid_size),
dim3(BlockSize),
0,
......@@ -774,7 +772,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.c_grid_desc_m_,
arg.c_grid_desc_m_,
arg.c_grid_desc_m_,
Substract{});
Subtract{});
ave_time +=
launch_and_time_kernel(stream_config,
......
......@@ -11,6 +11,7 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_bwd_weight.hpp"
#include "gridwise_unary_elementwise_1d.hpp"
namespace ck {
namespace tensor_operation {
......@@ -628,6 +629,54 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
1);
}
// type convert descs
template <typename Desc_M0>
static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize)
{
const auto m0 = desc_m0.GetLength(I0);
const index_t loop_step = gridSize * blockSize * 4;
const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
const auto desc_m0_pad =
transform_tensor_descriptor(desc_m0,
make_tuple(make_right_pad_transform(m0, pad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return desc_m0_pad;
}
template <index_t Dim>
static auto MakeDescriptor_M0(const std::vector<index_t>& shape,
const std::vector<index_t>& stride,
index_t gridSize,
index_t blockSize)
{
auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number<Dim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<Dim>{});
// nd desc - [s0, s1, s2, ...]
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
// merge nd to 1d desc - [s0 * s1 * ...]
if constexpr(Dim > 1)
{
const auto desc_m0 = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(tupleOfShape)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<Dim>{})),
make_tuple(Sequence<0>{}));
return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize);
}
else
return PadDescriptor_M0_1d(desc, gridSize, blockSize);
}
using TypeConvertFunctor =
ck::tensor_operation::element_wise::UnaryTypeConvert<ck::bhalf_t, float>;
using GridDesc_M0 = decltype(MakeDescriptor_M0<1>({1}, {1}, 1, 1));
using GridwiseUEltwise =
GridwiseUnaryElementwise_1D<AccDataType, InDataType, GridDesc_M0, TypeConvertFunctor, 4>;
using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>());
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
......@@ -733,6 +782,55 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
true,
true>;
using GridwiseGemmAtomicAddFloatBf16Splitk = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
AccDataType,
AccDataType,
InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
MPerBlock,
NPerBlock,
K0PerBlock,
MPerXdl,
NPerXdl,
K1,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM,
ABlockLdsM1PerBlock,
ABlockLdsM0PerBlock,
ABlockLdsM1Padding,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
BBlockLdsN1PerBlock,
BBlockLdsN0PerBlock,
BBlockLdsN1Padding,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXdl,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
true,
true>;
// Argument
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
......@@ -910,41 +1008,159 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg.block_2_ctile_map_);
};
// run kernel for bf16 with splitk
const auto run_bf16_splitk = [&](const auto& kernel) {
hipGetErrorString(hipMemset(
arg.p_workspace_,
0,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(AccDataType)));
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
static_cast<AccDataType*>(arg.p_workspace_),
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
};
// kernel for type conversion
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(arg.Conv_K_),
static_cast<std::size_t>(arg.Conv_C_)};
filter_dims.insert(std::end(filter_dims),
std::begin(arg.filter_spatial_lengths_),
std::end(arg.filter_spatial_lengths_));
int tensor_size =
std::accumulate(filter_dims.begin(), filter_dims.end(), 1, std::multiplies<int>{});
const index_t type_convert_grid_size = GridwiseUEltwise::CalculateGridSize(tensor_size);
GridDesc_M0 a_grid_desc_m0_ =
MakeDescriptor_M0<1>({tensor_size}, {1}, type_convert_grid_size, 256);
GridDesc_M0 b_grid_desc_m0_ =
MakeDescriptor_M0<1>({tensor_size}, {1}, type_convert_grid_size, 256);
if(!GridwiseUEltwise::CheckValidity(a_grid_desc_m0_, b_grid_desc_m0_))
{
throw std::runtime_error("wrong! GridwiseUnaryElementwise_1D has invalid setting");
}
// run kernel for type conversion
void* p_c_grid_tmp_ = static_cast<void*>(arg.p_c_grid_);
InDataType* p_c_grid_tmp_bf16_ = static_cast<InDataType*>(p_c_grid_tmp_);
const auto Run_type_convert = [&](const auto& kernel) {
float elapsed_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(type_convert_grid_size),
dim3(256),
0,
static_cast<AccDataType*>(arg.p_workspace_),
p_c_grid_tmp_bf16_,
a_grid_desc_m0_,
b_grid_desc_m0_,
TypeConvertFunctor{});
return elapsed_time;
};
if constexpr(std::is_same<InDataType, ck::bhalf_t>::value)
{
if(has_main_k0_block_loop)
{
const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
OutElementwiseOperation,
InElementwiseOperation,
WeiElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
true>;
Run(kernel);
if(kbatch == 1)
{
const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
OutElementwiseOperation,
InElementwiseOperation,
WeiElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
true>;
Run(kernel);
}
else
{
const auto kernel_type_convert =
kernel_unary_elementwise_1d<GridwiseUEltwise,
AccDataType,
InDataType,
GridDesc_M0,
TypeConvertFunctor>;
const auto kernel_conv = kernel_gemm_xdlops_bwd_weight<
GridwiseGemmAtomicAddFloatBf16Splitk,
ADataType, // TODO: distiguish A/B datatype
AccDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
OutElementwiseOperation,
InElementwiseOperation,
WeiElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
true>;
run_bf16_splitk(kernel_conv);
ave_time += Run_type_convert(kernel_type_convert);
}
}
else
{
const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
OutElementwiseOperation,
InElementwiseOperation,
WeiElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
false>;
Run(kernel);
if(kbatch == 1)
{
const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
OutElementwiseOperation,
InElementwiseOperation,
WeiElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
false>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemmAtomicAddFloatBf16Splitk,
ADataType, // TODO: distiguish A/B datatype
AccDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
OutElementwiseOperation,
InElementwiseOperation,
WeiElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
false>;
run_bf16_splitk(kernel);
}
}
}
else
......
......@@ -6,19 +6,18 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <typename DPtrsGlobal,
typename AElementwiseOperation,
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation>
typename DxsReduceAccElementwiseOperation>
struct DeviceGemmReduce : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
DPtrsGlobal p_dxs,
void* p_dxs,
ck::index_t M,
ck::index_t N,
ck::index_t K,
......@@ -29,24 +28,69 @@ struct DeviceGemmReduce : public BaseOperator
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename DPtrsGlobal,
typename AElementwiseOperation,
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation>
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<DPtrsGlobal,
AElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsAccElementwiseOperation>>;
DxsReduceAccElementwiseOperation>>;
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
struct DeviceGemmBiasAddReduce : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
const void* p_c0,
const void* p_c1,
void* p_dxs,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
ck::index_t StrideC1,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
C1ElementwiseOperation c1_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
using DeviceGemmBiasAddReducePtr =
std::unique_ptr<DeviceGemmBiasAddReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
C1ElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
......
......@@ -32,7 +32,7 @@ template <typename ALayout,
typename CElementwiseOperation,
typename DxsReduceOperation,
typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation,
typename DxsReduceAccElementwiseOperation,
typename DGlobalMemoryDataOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
......@@ -68,12 +68,11 @@ template <typename ALayout,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
AElementwiseOperation,
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsAccElementwiseOperation>
DxsReduceAccElementwiseOperation>
{
using DeviceOp = DeviceGemmReduce_Xdl_CShuffle;
......@@ -389,7 +388,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
CElementwiseOperation,
DxsReduceOperation,
DxsInElementwiseOperation,
DxsAccElementwiseOperation,
DxsReduceAccElementwiseOperation,
InMemoryDataOperationEnum::Set,
DGlobalMemoryDataOperation,
AGridDesc_AK0_M_AK1,
......@@ -449,7 +448,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op)
DxsReduceAccElementwiseOperation dxs_out_element_op)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
......@@ -498,7 +497,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
DxsInElementwiseOperation dxs_in_element_op_;
DxsAccElementwiseOperation dxs_out_element_op_;
DxsReduceAccElementwiseOperation dxs_out_element_op_;
};
// Invoker
......@@ -554,7 +553,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsAccElementwiseOperation,
DxsReduceAccElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -594,7 +593,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsAccElementwiseOperation,
DxsReduceAccElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -669,7 +668,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op)
DxsReduceAccElementwiseOperation dxs_out_element_op)
{
return Argument{p_a,
p_b,
......@@ -691,27 +690,29 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
DPtrsGlobal p_dxs,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op,
index_t /* KBatch */ = 1) override
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
void* p_dxs,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
index_t /* KBatch */ = 1) override
{
DPtrsGlobal dxs_tuple = *(static_cast<DPtrsGlobal*>(p_dxs));
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
p_dxs,
dxs_tuple,
MRaw,
NRaw,
KRaw,
......
......@@ -362,7 +362,7 @@ struct DeviceGroupedGemmXdl
{
grid_size_ = 0;
gemm_descs_args_workspace_ = nullptr;
p_workspace_ = nullptr;
group_count_ = ck::type_convert<ck::index_t>(gemm_shapes.size());
......@@ -437,8 +437,6 @@ struct DeviceGroupedGemmXdl
std::vector<GemmDescKernelArg> gemm_desc_kernel_arg_;
void* gemm_descs_args_workspace_;
index_t grid_size_;
};
......@@ -488,7 +486,7 @@ struct DeviceGroupedGemmXdl
}
hipGetErrorString(
hipMemcpy(arg.gemm_descs_args_workspace_,
hipMemcpy(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg),
hipMemcpyHostToDevice));
......@@ -507,17 +505,17 @@ struct DeviceGroupedGemmXdl
CElementwiseOperation,
true>;
ave_time = launch_and_time_kernel(
stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_),
arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
}
else
{
......@@ -531,17 +529,17 @@ struct DeviceGroupedGemmXdl
CElementwiseOperation,
false>;
ave_time = launch_and_time_kernel(
stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_),
arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
}
return ave_time;
......@@ -635,11 +633,6 @@ struct DeviceGroupedGemmXdl
{
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GemmDescKernelArg);
}
void SetWorkSpacePointer(BaseArgument* p_arg, void* workspace_ptr) const override
{
dynamic_cast<Argument*>(p_arg)->gemm_descs_args_workspace_ = workspace_ptr;
}
};
} // namespace device
......
......@@ -35,14 +35,13 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
using IndexDataType = int32_t;
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
AccElementwiseOperation;
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
static constexpr index_t InSrcOutDstVectorDim =
0; // for NHWC, the dim C is the vector Dim for both input and output in memory, which is
......@@ -178,13 +177,10 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
invariant_lowest_length_ = C;
reduce_lowest_length_ = window_spatial_lengths[1];
// TODO: is this correct?
if constexpr(ReduceOpId == ck::ReduceTensorOp::AVG)
{
ck::index_t divider = window_spatial_lengths[0] * window_spatial_lengths[1];
in_element_op_ = InElementwiseOperation{divider};
acc_element_op_ = AccElementwiseOperation{divider};
}
int32_t reduceLength = window_spatial_lengths[0] * window_spatial_lengths[1];
std::tie(in_element_op_, acc_element_op_) =
reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(reduceLength);
}
const InDataType* p_in_dev_;
......
......@@ -61,12 +61,9 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
static constexpr bool use_multiblock =
(OutMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd);
static constexpr bool out_type_compatible_with_atomic_op =
std::is_same<OutDataType, float>::value || std::is_same<OutDataType, double>::value;
static_assert(
!use_multiblock || (use_multiblock && out_type_compatible_with_atomic_op),
"The OutDataType must support the atomic operation for using MultiBlock reduction");
static_assert(ck::reduce::InMemoryDataOperatonSupportedOnDataType<OutMemoryDataOperation,
OutDataType>::value,
"The OutDataType must support the specified OutMemoryDataOperation!");
static_assert(!use_multiblock || (use_multiblock && !OutputIndex),
"MultiBlock reduction can only be used when outputing index is not required");
......@@ -349,7 +346,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
if constexpr(use_multiblock)
{
const auto identityVal =
ck::reduce::GetIdentityValueueForInMemoryDataOperation<OutDataType>(
ck::reduce::GetIdentityValueForInMemoryDataOperation<OutDataType>(
OutMemoryDataOperation);
const auto kernel_pre =
......@@ -492,7 +489,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
auto str = std::stringstream();
// clang-format off
str << "DeviceReduceMultiBlockAtomicAdd<" << BlockSize << ",";
str << (OutMemoryDataOperation == InMemoryDataOperationEnum::Set? "DeviceReduceBlockWise<" : "DeviceReduceMultiBlock<") << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
......
#pragma once
#include <iostream>
#include <vector>
#include "device.hpp"
#include "device_base.hpp"
#include "gridwise_unary_elementwise_1d.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ADataType,
typename BDataType,
typename ElementwiseFunctor,
index_t Dim,
index_t ScalarPerVector>
struct DeviceUnaryElementwise : public BaseOperator
{
static constexpr auto I0 = Number<0>{};
template <typename Desc_M0>
static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize)
{
const auto m0 = desc_m0.GetLength(I0);
const index_t loop_step = gridSize * blockSize * ScalarPerVector;
const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
const auto desc_m0_pad =
transform_tensor_descriptor(desc_m0,
make_tuple(make_right_pad_transform(m0, pad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return desc_m0_pad;
}
static auto MakeDescriptor_M0(const std::vector<index_t>& shape,
const std::vector<index_t>& stride,
index_t gridSize,
index_t blockSize)
{
auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number<Dim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<Dim>{});
// nd desc - [s0, s1, s2, ...]
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
// merge nd to 1d desc - [s0 * s1 * ...]
if constexpr(Dim > 1)
{
const auto desc_m0 = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(tupleOfShape)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<Dim>{})),
make_tuple(Sequence<0>{}));
return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize);
}
else
return PadDescriptor_M0_1d(desc, gridSize, blockSize);
}
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
using GridwiseUEltwise = GridwiseUnaryElementwise_1D<ADataType,
BDataType,
GridDesc_M0,
ElementwiseFunctor,
ScalarPerVector>;
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a,
BDataType* p_b,
const std::vector<index_t>& shape,
const std::vector<index_t>& stride_a,
const std::vector<index_t>& stride_b,
ElementwiseFunctor functor)
: p_a_(p_a),
p_b_(p_b),
shape_(shape),
functor_(functor),
blockSize_(256) // FIXME - Calculate the grid size by number of CU in the future
{
index_t tensor_size =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>{});
gridSize_ = GridwiseUEltwise::CalculateGridSize(tensor_size);
a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, blockSize_);
b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, blockSize_);
}
const ADataType* p_a_;
BDataType* p_b_;
std::vector<int> shape_;
GridDesc_M0 a_grid_desc_m0_;
GridDesc_M0 b_grid_desc_m0_;
ElementwiseFunctor functor_;
index_t blockSize_;
index_t gridSize_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto kernel = kernel_unary_elementwise_1d<GridwiseUEltwise,
ADataType,
BDataType,
GridDesc_M0,
ElementwiseFunctor>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize_),
dim3(arg.blockSize_),
0,
arg.p_a_,
arg.p_b_,
arg.a_grid_desc_m0_,
arg.b_grid_desc_m0_,
arg.functor_);
return elapsed_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if(pArg == nullptr)
return false;
if(pArg->shape_.back() % ScalarPerVector != 0)
return false;
return true;
};
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
void* p_b,
std::vector<index_t> shape,
std::vector<index_t> stride_a,
std::vector<index_t> stride_b,
ElementwiseFunctor functor)
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<BDataType*>(p_b),
shape,
stride_a,
stride_b,
functor);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); }
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceBinaryElementwise"
<< "<"
<< "ScalarPerVector = " << ScalarPerVector
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -29,6 +29,7 @@
#include "reduction_operator.hpp"
#include "reduction_enums.hpp"
#include "element_wise_operation.hpp"
#include <tuple>
namespace ck {
......@@ -37,77 +38,69 @@ namespace ck {
// The boolean member "indexable" are also provided in reduce_binary_operactor for
// easier checking by the upper-layer codes in the kernels.
template <typename T, ReduceTensorOp Op>
template <ReduceTensorOp Op>
struct reduce_binary_operator;
template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp::ADD>
template <>
struct reduce_binary_operator<ReduceTensorOp::ADD>
{
using opType = reduce::Add<T>;
using dataType = T;
using opType = reduce::Add;
static constexpr bool indexable = false;
};
template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp::MUL>
template <>
struct reduce_binary_operator<ReduceTensorOp::MUL>
{
using opType = reduce::Mul<T>;
using dataType = T;
using opType = reduce::Mul;
static constexpr bool indexable = false;
};
template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp::MIN>
template <>
struct reduce_binary_operator<ReduceTensorOp::MIN>
{
using opType = reduce::Min<T>;
using dataType = T;
using opType = reduce::Min;
static constexpr bool indexable = true;
};
template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp::MAX>
template <>
struct reduce_binary_operator<ReduceTensorOp::MAX>
{
using opType = reduce::Max<T>;
using dataType = T;
using opType = reduce::Max;
static constexpr bool indexable = true;
};
template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp::AMAX>
template <>
struct reduce_binary_operator<ReduceTensorOp::AMAX>
{
using opType = reduce::AMax<T>;
using dataType = T;
using opType = reduce::AMax;
static constexpr bool indexable = true;
};
template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp::AVG>
template <>
struct reduce_binary_operator<ReduceTensorOp::AVG>
{
using opType = reduce::Add<T>;
using dataType = T;
using opType = reduce::Add;
static constexpr bool indexable = false;
};
template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp::NORM1>
template <>
struct reduce_binary_operator<ReduceTensorOp::NORM1>
{
using opType = reduce::Add<T>;
using dataType = T;
using opType = reduce::Add;
static constexpr bool indexable = false;
};
template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp::NORM2>
template <>
struct reduce_binary_operator<ReduceTensorOp::NORM2>
{
using opType = reduce::Add<T>;
using dataType = T;
using opType = reduce::Add;
static constexpr bool indexable = false;
};
......@@ -115,53 +108,101 @@ struct reduce_binary_operator<T, ReduceTensorOp::NORM2>
// The templated struct reduce_unary_operator maps the enum Ids of Reduce operators to two unary
// functor classes.
// The two unary functors are called before and afer the Reduction is executed respectively
template <typename T, ReduceTensorOp Op, bool IsFirstReduce, bool IsLastReduce>
template <ReduceTensorOp Op, bool IsFirstReduce, bool IsLastReduce>
struct reduce_unary_operator
{
using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
using InElementwiseOperation = tensor_operation::element_wise::PassThrough;
using AccElementwiseOperation = tensor_operation::element_wise::PassThrough;
static std::tuple<InElementwiseOperation, AccElementwiseOperation>
GetElementwiseOperator(int32_t reduceLength)
{
(void)reduceLength;
return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{});
};
};
template <typename T, bool IsFirstReduce>
struct reduce_unary_operator<T, ReduceTensorOp::AVG, IsFirstReduce, true>
template <bool IsFirstReduce>
struct reduce_unary_operator<ReduceTensorOp::AVG, IsFirstReduce, true>
{
using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T, true>;
using InElementwiseOperation = tensor_operation::element_wise::PassThrough;
using AccElementwiseOperation = tensor_operation::element_wise::UnaryDivide;
static std::tuple<InElementwiseOperation, AccElementwiseOperation>
GetElementwiseOperator(int32_t reduceLength)
{
return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{reduceLength});
};
};
template <typename T, bool IsLastReduce>
struct reduce_unary_operator<T, ReduceTensorOp::NORM1, true, IsLastReduce>
template <bool IsLastReduce>
struct reduce_unary_operator<ReduceTensorOp::NORM1, true, IsLastReduce>
{
using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs;
using AccElementwiseOperation = tensor_operation::element_wise::PassThrough;
static std::tuple<InElementwiseOperation, AccElementwiseOperation>
GetElementwiseOperator(int32_t reduceLength)
{
(void)reduceLength;
return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{});
};
};
template <typename T, bool IsLastReduce>
struct reduce_unary_operator<T, ReduceTensorOp::AMAX, true, IsLastReduce>
template <bool IsLastReduce>
struct reduce_unary_operator<ReduceTensorOp::AMAX, true, IsLastReduce>
{
using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs;
using AccElementwiseOperation = tensor_operation::element_wise::PassThrough;
static std::tuple<InElementwiseOperation, AccElementwiseOperation>
GetElementwiseOperator(int32_t reduceLength)
{
(void)reduceLength;
return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{});
};
};
template <typename T>
struct reduce_unary_operator<T, ReduceTensorOp::NORM2, true, false>
template <>
struct reduce_unary_operator<ReduceTensorOp::NORM2, true, false>
{
using InElementwiseOperation = tensor_operation::element_wise::UnarySquare<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
using InElementwiseOperation = tensor_operation::element_wise::UnarySquare;
using AccElementwiseOperation = tensor_operation::element_wise::PassThrough;
static std::tuple<InElementwiseOperation, AccElementwiseOperation>
GetElementwiseOperator(int32_t reduceLength)
{
(void)reduceLength;
return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{});
};
};
template <typename T>
struct reduce_unary_operator<T, ReduceTensorOp::NORM2, true, true>
template <>
struct reduce_unary_operator<ReduceTensorOp::NORM2, true, true>
{
using InElementwiseOperation = tensor_operation::element_wise::UnarySquare<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt<T, T>;
using InElementwiseOperation = tensor_operation::element_wise::UnarySquare;
using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt;
static std::tuple<InElementwiseOperation, AccElementwiseOperation>
GetElementwiseOperator(int32_t reduceLength)
{
(void)reduceLength;
return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{});
};
};
template <typename T>
struct reduce_unary_operator<T, ReduceTensorOp::NORM2, false, true>
template <>
struct reduce_unary_operator<ReduceTensorOp::NORM2, false, true>
{
using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt<T, T>;
using InElementwiseOperation = tensor_operation::element_wise::PassThrough;
using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt;
static std::tuple<InElementwiseOperation, AccElementwiseOperation>
GetElementwiseOperator(int32_t reduceLength)
{
(void)reduceLength;
return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{});
};
};
} // end of namespace ck
......
......@@ -28,100 +28,189 @@
namespace ck {
namespace tensor_operation {
namespace binary_element_wise {
template <typename Y, typename X1, typename X2>
struct Add;
namespace element_wise {
template <>
struct Add<double, double, double>
struct Add
{
template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
template <>
__host__ __device__ constexpr void
operator()(double& dst, const double& src1, const double& src2) const
operator()<float>(float& y, const float& x0, const float& x1) const
{
dst = src1 + src2;
}
};
y = x0 + x1;
};
template <>
struct Add<float, float, float>
{
template <>
__host__ __device__ constexpr void
operator()(float& dst, const float& src1, const float& src2) const
operator()<double>(double& y, const double& x0, const double& x1) const
{
dst = src1 + src2;
}
};
y = x0 + x1;
};
template <>
struct Add<half_t, half_t, half_t>
{
// Question: should half_t be supported ?
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
{
y = x0 + x1;
};
// Question: should bhalf_t be supported ?
template <>
__host__ __device__ constexpr void
operator()(half_t& dst, const half_t& src1, const half_t& src2) const
operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
{
dst = src1 + src2;
const float x1_tmp = ck::type_convert<float>(x0);
const float x2_tmp = ck::type_convert<float>(x1);
const float y_tmp = x1_tmp + x2_tmp;
y = ck::type_convert<bhalf_t>(y_tmp);
}
};
template <>
struct Add<bhalf_t, bhalf_t, bhalf_t>
struct Subtract
{
template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
template <>
__host__ __device__ constexpr void
operator()(bhalf_t& dst, const bhalf_t& src1, const bhalf_t& src2) const
operator()<float>(float& y, const float& x0, const float& x1) const
{
const float x1 = ck::type_convert<float>(src1);
const float x2 = ck::type_convert<float>(src2);
const float y = x1 + x2;
dst = ck::type_convert<bhalf_t>(y);
}
};
y = x0 - x1;
};
template <typename Y, typename X1, typename X2>
struct Substract;
template <>
__host__ __device__ constexpr void
operator()<double>(double& y, const double& x0, const double& x1) const
{
y = x0 - x1;
};
template <>
struct Substract<double, double, double>
{
// Question: should half_t be supported ?
template <>
__host__ __device__ constexpr void
operator()(double& dst, const double& src1, const double& src2) const
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
{
dst = src1 - src2;
y = x0 - x1;
};
// Question: should bhalf_t be supported ?
template <>
__host__ __device__ constexpr void
operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
{
const float x1_tmp = ck::type_convert<float>(x0);
const float x2_tmp = ck::type_convert<float>(x1);
const float y_tmp = x1_tmp - x2_tmp;
y = ck::type_convert<bhalf_t>(y_tmp);
}
};
template <>
struct Substract<float, float, float>
struct AlphaBetaAdd
{
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){};
template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
template <>
__host__ __device__ constexpr void
operator()(float& dst, const float& src1, const float& src2) const
operator()<float>(float& y, const float& x0, const float& x1) const
{
dst = src1 - src2;
}
y = alpha_ * x0 + beta_ * x1;
};
template <>
__host__ __device__ constexpr void
operator()<double>(double& y, const double& x0, const double& x1) const
{
y = static_cast<double>(alpha_) * x0 + static_cast<double>(beta_) * x1;
};
// Question: should half_t be supported ?
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
{
y = static_cast<half_t>(alpha_ * static_cast<float>(x0) + beta_ * static_cast<float>(x1));
};
float alpha_;
float beta_;
};
template <>
struct Substract<half_t, half_t, half_t>
struct AddRelu
{
template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
template <>
__host__ __device__ constexpr void
operator()(half_t& dst, const half_t& src1, const half_t& src2) const
operator()<float>(float& y, const float& x0, const float& x1) const
{
dst = src1 - src2;
}
const float a = x0 + x1;
y = a > 0.0f ? a : 0.0f;
};
template <>
__host__ __device__ constexpr void
operator()<double>(double& y, const double& x0, const double& x1) const
{
const double a = x0 + x1;
y = a > 0.0 ? a : 0.0;
};
// Question: should half_t be supported ?
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
{
const half_t a = x0 + x1;
y = a > static_cast<half_t>(0.0f) ? a : static_cast<half_t>(0.0f);
};
};
template <>
struct Substract<bhalf_t, bhalf_t, bhalf_t>
struct AddHardswish
{
template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
template <>
__host__ __device__ constexpr void
operator()(bhalf_t& dst, const bhalf_t& src1, const bhalf_t& src2) const
operator()<float>(float& y, const float& x0, const float& x1) const
{
const float x1 = ck::type_convert<float>(src1);
const float x2 = ck::type_convert<float>(src2);
const float y = x1 - x2;
dst = ck::type_convert<bhalf_t>(y);
}
float a = x0 + x1;
float b = a + float{3};
float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f;
y = c;
};
template <>
__host__ __device__ constexpr void
operator()<double>(double& y, const double& x0, const double& x1) const
{
double a = x0 + x1;
double b = a + 3.0;
double c = (b > 0) * (b > 6.0 ? 6.0 : b) * a * 0.166667;
y = c;
};
// Question: should half_t be supported ?
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
{
float a = x0 + x1;
float b = a + 3.0f;
float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f;
y = c;
};
};
} // namespace binary_element_wise
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
#pragma once
#include "data_type.hpp"
#include "math_v2.hpp"
#include "unary_element_wise_operation.hpp"
#include "binary_element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace element_wise {
struct PassThrough
{
__host__ __device__ void operator()(float& y, const float& x) const { y = x; }
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = x; }
__host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { y = x; }
__host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x; }
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = x; }
__host__ __device__ void operator()(double& y, const double& x) const { y = x; }
};
struct FastGelu
{
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ void operator()(float& y, const float& x) const
{
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885));
const float emu = exp(-u);
const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1));
y = x * cdf;
}
};
struct Add
{
__host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const
{
y = x0 + x1;
}
__host__ __device__ constexpr void
operator()(half_t& y, const half_t& x0, const half_t& x1) const
{
// FIXME - Use float (acc type) bias in the future.
y = x0 + x1;
}
};
struct AlphaBetaAdd
{
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {}
__host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const
{
y = alpha_ * x0 + beta_ * x1;
}
__host__ __device__ constexpr void
operator()(half_t& y, const half_t& x0, const half_t& x1) const
{
// FIXME - Let x0 be acc type
y = static_cast<half_t>(alpha_ * static_cast<float>(x0) + beta_ * static_cast<float>(x1));
}
float alpha_;
float beta_;
};
struct AddRelu
{
__host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const
{
const float a = x0 + x1;
y = a > 0 ? a : 0;
}
__host__ __device__ constexpr void
operator()(half_t& y, const half_t& x0, const half_t& x1) const
{
const half_t a = x0 + x1;
y = a > 0 ? a : 0;
}
};
struct AddHardswish
{
__host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const
{
float a = x0 + x1;
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
y = c;
}
__host__ __device__ constexpr void
operator()(half_t& y, const half_t& x0, const half_t& x1) const
{
float a = x0 + x1;
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
y = c;
}
};
struct AddReluAdd
{
__host__ __device__ constexpr void
......@@ -162,8 +65,14 @@ struct AddHardswishAdd
// E = FastGelu(C + D0 + D1)
struct AddAddFastGelu
{
__host__ __device__ void
operator()(ck::half_t& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ void operator()<half_t, float, half_t, half_t>(half_t& e,
const float& c,
const half_t& d0,
const half_t& d1) const
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
......@@ -177,209 +86,67 @@ struct AddAddFastGelu
const float y = fast_gelu(c + float(d0) + float(d1));
e = ck::type_convert<ck::half_t>(y);
e = type_convert<half_t>(y);
}
};
struct Normalize
{
Normalize(float epsilon = 1e-4) : epsilon_(epsilon) {}
__host__ __device__ constexpr void operator()(float& y,
const float& x,
const float& mean,
const float& mean_square,
const float& gamma,
const float& beta) const
{
float variance = mean_square - (mean * mean);
y = ((x - mean) / sqrtf(variance + epsilon_)) * gamma + beta;
}
float epsilon_;
};
// Unary operators are usually called element-wisely before/after the reduction is executed on the
// elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2
template <typename Y, typename X, bool HasDividing = false>
struct UnaryIdentic;
template <>
struct UnaryIdentic<float, float, false>
{
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(float& y, const float& x) const { y = x; };
};
Normalize(double epsilon = 1e-4) : epsilon_(epsilon) {}
template <>
struct UnaryIdentic<float, float, true>
{
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
template <typename T>
__host__ __device__ constexpr void operator()(
T& y, const T& x, const T& mean, const T& mean_square, const T& gamma, const T& beta) const;
__host__ __device__ void operator()(float& y, const float& x) const
template <>
__host__ __device__ constexpr void operator()<float>(float& y,
const float& x,
const float& mean,
const float& mean_square,
const float& gamma,
const float& beta) const
{
y = x / type_convert<float>(divider_);
};
int32_t divider_ = 1;
};
using ck::math::sqrt;
template <>
struct UnaryIdentic<half_t, half_t, false>
{
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = x; };
};
template <>
struct UnaryIdentic<double, double, false>
{
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(double& y, const double& x) const { y = x; };
};
template <>
struct UnaryIdentic<double, double, true>
{
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
__host__ __device__ void operator()(double& y, const double& x) const
{
y = x / type_convert<double>(divider_);
float variance = mean_square - (mean * mean);
y = ((x - mean) / sqrt(variance + static_cast<float>(epsilon_))) * gamma + beta;
};
int32_t divider_ = 1;
};
template <>
struct UnaryIdentic<int32_t, int32_t, false>
{
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x; };
};
template <>
struct UnaryIdentic<int32_t, int32_t, true>
{
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
__host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x / divider_; };
int32_t divider_ = 1;
};
template <>
struct UnaryIdentic<int8_t, int8_t, false>
{
__host__ __device__ UnaryIdentic(const int8_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = x; };
};
template <typename Y, typename X, bool HasDividing = false>
struct UnarySquare;
template <>
struct UnarySquare<float, float, false>
{
__host__ __device__ UnarySquare(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(float& y, const float& x) const { y = x * x; };
};
template <>
struct UnarySquare<float, float, true>
{
__host__ __device__ UnarySquare(const int32_t divider = 1) { divider_ = divider; };
__host__ __device__ void operator()(float& y, const float& x) const
template <>
__host__ __device__ constexpr void operator()<double>(double& y,
const double& x,
const double& mean,
const double& mean_square,
const double& gamma,
const double& beta) const
{
y = x * x / type_convert<float>(divider_);
};
int32_t divider_ = 1;
};
template <>
struct UnarySquare<double, double, false>
{
__host__ __device__ UnarySquare(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(double& y, const double& x) const { y = x * x; };
};
using ck::math::sqrt;
template <>
struct UnarySquare<double, double, true>
{
__host__ __device__ UnarySquare(const int32_t divider = 1) { divider_ = divider; };
__host__ __device__ void operator()(double& y, const double& x) const
{
y = x * x / type_convert<double>(divider_);
double variance = mean_square - (mean * mean);
y = ((x - mean) / sqrt(variance + epsilon_)) * gamma + beta;
};
int32_t divider_ = 1;
};
template <typename Y, typename X>
struct UnaryAbs;
template <>
struct UnaryAbs<float, float>
{
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(float& y, const float& x) const { y = ck::math::abs(x); };
};
template <>
struct UnaryAbs<half_t, half_t>
{
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = ck::math::abs(x); };
};
template <>
struct UnaryAbs<double, double>
{
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(double& y, const double& x) const { y = ck::math::abs(x); };
};
template <>
struct UnaryAbs<int8_t, int8_t>
{
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = ck::math::abs(x); };
double epsilon_;
};
template <typename Y, typename X>
struct UnarySqrt;
struct UnaryTypeConvert;
template <>
struct UnarySqrt<float, float>
struct UnaryTypeConvert<float, ck::bhalf_t>
{
__host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(float& y, const float& x) const { y = ck::math::sqrt(x); };
__host__ __device__ void operator()(float& y, ck::bhalf_t& x) const
{
y = ck::type_convert<float, ck::bhalf_t>(x);
};
};
template <>
struct UnarySqrt<double, double>
struct UnaryTypeConvert<ck::bhalf_t, float>
{
__host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(double& y, const double& x) const
__host__ __device__ void operator()(ck::bhalf_t& y, float& x) const
{
y = ck::math::sqrt(x);
y = ck::type_convert<ck::bhalf_t, float>(x);
};
};
......
#pragma once
#include "data_type.hpp"
#include "math_v2.hpp"
namespace ck {
namespace tensor_operation {
namespace element_wise {
struct PassThrough
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, bhalf_t>::value ||
is_same<T, int32_t>::value || is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
y = x;
};
};
struct UnaryDivide
{
__host__ __device__ UnaryDivide(const int32_t divider = 1) : divider_(divider){};
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = x / type_convert<T>(divider_);
};
int32_t divider_ = 1;
};
struct UnarySquare
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value,
"Data type is not supported by this operation!");
y = x * x;
};
};
struct UnaryAbs
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
y = ck::math::abs(x);
};
};
struct UnarySqrt
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value,
"Data type is not supported by this operation!");
y = ck::math::sqrt(x);
};
};
struct Relu
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
y = x > 0 ? x : 0;
}
template <>
__host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const
{
float x_f32 = ck::type_convert<float>(x);
float y_f32 = x_f32 > 0 ? x_f32 : 0;
y = ck::type_convert<bhalf_t>(y_f32);
}
};
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
struct FastGelu
{
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885));
const float emu = exp(-u);
const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1));
y = x * cdf;
}
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
......@@ -171,15 +171,15 @@ struct GridwiseReduction_mk_to_m_multiblock
AccDataType beta,
OutDataType* const __restrict__ p_out_value_global)
{
const auto identityVal = ReduceOperation::GetIdentityValue();
const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
// LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(identityVal));
const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
ReduceOperation::template GetIdentityValue<InDataType>());
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
......@@ -358,12 +358,12 @@ struct GridwiseReduction_mk_to_m_multiblock
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto identityVal = ReduceOperation::GetIdentityValue();
const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(identityVal));
const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
ReduceOperation::template GetIdentityValue<InDataType>());
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......
......@@ -135,12 +135,12 @@ struct GridwiseReduction_mk_to_m_threadwise
ReduceOperation,
PropagateNan>;
const auto identityVal = ReduceOperation::GetIdentityValue();
const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(identityVal));
const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
ReduceOperation::template GetIdentityValue<InDataType>());
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
......@@ -276,12 +276,12 @@ struct GridwiseReduction_mk_to_m_threadwise
(void)acc_elementwise_op;
const auto identityVal = ReduceOperation::GetIdentityValue();
const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(identityVal));
const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
ReduceOperation::template GetIdentityValue<InDataType>());
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
......
......@@ -21,7 +21,7 @@ template <typename GridwiseGemm,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation,
typename DxsReduceAccElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -41,7 +41,7 @@ __global__ void
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const DxsInElementwiseOperation dxs_in_element_op,
const DxsAccElementwiseOperation dxs_out_element_op,
const DxsReduceAccElementwiseOperation dxs_out_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
......@@ -96,7 +96,7 @@ template <typename FloatAB,
typename CElementwiseOperation,
typename DxsReduceOperation,
typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation,
typename DxsReduceAccElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename DGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
......@@ -329,7 +329,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const DxsInElementwiseOperation& dxs_in_element_op,
const DxsAccElementwiseOperation& dxs_out_element_op,
const DxsReduceAccElementwiseOperation& dxs_out_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
......@@ -816,7 +816,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
false>;
// Global write Gemm shuffle + reduction
const auto d_identityVal = DReduceOperation::GetIdentityValue();
const auto d_identityVal =
DReduceOperation::template GetIdentityValue<FloatReduceAcc>();
static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d_thread_buf(I) = d_identityVal; });
......
......@@ -791,8 +791,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock =
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
void* p_shared = static_cast<void*>(p_shared_block);
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatC*>(p_shared_block),
static_cast<FloatC*>(p_shared),
c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
static_assert(M1 == MWave, "");
......
......@@ -37,7 +37,7 @@ __global__ void kernel_buffer_set_value(const Grid1dBufferDescType grid_1d_buffe
{
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<DataType, DataType>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
constexpr auto I0 = Number<0>{};
......
#pragma once
#include "cluster_descriptor.hpp"
#include "data_type.hpp"
#include "element_wise_operation.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace ck {
template <typename GridwiseUEltwise,
typename ADataType,
typename BDataType,
typename GridDesc_M0,
typename ElementwiseFunctor>
__global__ void kernel_unary_elementwise_1d(const ADataType* __restrict__ p_a_global,
BDataType* __restrict__ p_b_global,
const GridDesc_M0 a_grid_desc_m0,
const GridDesc_M0 b_grid_desc_m0,
const ElementwiseFunctor functor)
{
GridwiseUEltwise::Run(p_a_global, p_b_global, a_grid_desc_m0, b_grid_desc_m0, functor);
}
template <typename ADataType,
typename BDataType,
typename GridDesc_M0,
typename ElementwiseFunctor,
index_t ScalarPerVector>
struct GridwiseUnaryElementwise_1D
{
static constexpr auto I0 = Number<0>{};
static constexpr auto thread_desc_m0 =
make_naive_tensor_descriptor_packed(make_tuple(Number<ScalarPerVector>{}));
using PassThrough = tensor_operation::element_wise::PassThrough;
static __device__ auto CalculateElementwiseIndex()
{
const index_t global_thread_id = get_thread_global_1d_id();
return make_multi_index(global_thread_id * ScalarPerVector);
}
__host__ __device__ static constexpr bool CheckValidity(const GridDesc_M0 a_grid_desc_m0,
const GridDesc_M0 b_grid_desc_m0)
{
return a_grid_desc_m0.GetLength(I0) == b_grid_desc_m0.GetLength(I0);
}
__host__ __device__ static constexpr index_t CalculateGridSize(const index_t tensor_size)
{
const index_t grid_size = math::integer_divide_ceil(tensor_size, 256 * ScalarPerVector);
return grid_size;
}
__device__ static void Run(const ADataType* __restrict__ p_a_global,
BDataType* __restrict__ p_b_global,
const GridDesc_M0 a_grid_desc_m0,
const GridDesc_M0 b_grid_desc_m0,
const ElementwiseFunctor functor)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_grid_desc_m0.GetElementSpaceSize());
auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_grid_desc_m0.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, ADataType, ScalarPerVector, true> a_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, BDataType, ScalarPerVector, true> b_thread_buf;
const auto thread_store_global_offset = CalculateElementwiseIndex();
auto a_global_load =
ThreadwiseTensorSliceTransfer_v2<ADataType,
ADataType,
GridDesc_M0,
decltype(thread_desc_m0),
Sequence<ScalarPerVector>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
ScalarPerVector,
1, // SrcScalarStrideInVector
false>{a_grid_desc_m0, thread_store_global_offset};
auto b_global_write =
ThreadwiseTensorSliceTransfer_v1r3<BDataType,
BDataType,
decltype(thread_desc_m0),
GridDesc_M0,
PassThrough,
Sequence<ScalarPerVector>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // DstVectorDim
ScalarPerVector,
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
false>{
b_grid_desc_m0, thread_store_global_offset, PassThrough{}};
const index_t blockSize = get_block_size();
const index_t blockPerGrid = get_grid_size();
const auto m0 = b_grid_desc_m0.GetLength(I0);
const index_t loop_step = blockPerGrid * blockSize * ScalarPerVector;
const auto loop_step_index = make_multi_index(loop_step);
index_t num_iter = m0 / (loop_step);
do
{
// read and process ScalarPerVector elements
a_global_load.Run(
a_grid_desc_m0, a_global_buf, thread_desc_m0, make_tuple(I0), a_thread_buf);
static_for<0, ScalarPerVector, 1>{}([&](auto m) {
constexpr auto offset = thread_desc_m0.CalculateOffset(make_tuple(m));
functor(b_thread_buf(Number<offset>{}), a_thread_buf(Number<offset>{}));
});
b_global_write.Run(thread_desc_m0,
make_tuple(I0), // SrcSliceOriginIdx
b_thread_buf,
b_grid_desc_m0,
b_global_buf);
a_global_load.MoveSrcSliceWindow(a_grid_desc_m0, loop_step_index);
b_global_write.MoveDstSliceWindow(b_grid_desc_m0, loop_step_index);
} while(--num_iter);
}
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment