Commit 4fec5ad3 authored by aska-0096's avatar aska-0096
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/composable_kernel into wmma_op

parents 24faa1fc 87fd1152
......@@ -15,13 +15,9 @@
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp"
#include "batchnorm_forward_impl.hpp"
template <typename InOutDataType, typename AccDataType>
using ReferenceBatchNormFwdInstance =
ck::tensor_operation::host::ReferenceBatchNormFwd_Input_N_H_W_C_Output_C<InOutDataType,
AccDataType>;
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'},
{"verify", required_argument, nullptr, 'v'},
......@@ -44,6 +40,7 @@ class BatchNormFwdArg
int data_type = 0;
int init_method = 2;
bool time_kernel = false;
bool use_multiblock_welford = false;
public:
void show_usage(const char* cmd)
......@@ -68,6 +65,7 @@ class BatchNormFwdArg
"value, 3=decimal value)"
<< std::endl;
std::cout << "Arg5: time kernel (0=no, 1=yes)" << std::endl;
std::cout << "Arg6: use multi-block welford (0=n0, 1=yes)" << std::endl;
};
int processArgs(int argc, char* argv[])
......@@ -110,14 +108,15 @@ class BatchNormFwdArg
};
};
if(optind + 5 > argc)
if(optind + 6 > argc)
throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!");
data_type = std::atoi(argv[optind++]);
updateMovingAverage = std::atoi(argv[optind++]);
saveMeanAndInvVariance = std::atoi(argv[optind++]);
init_method = std::atoi(argv[optind++]);
time_kernel = static_cast<bool>(std::atoi(argv[optind]));
time_kernel = static_cast<bool>(std::atoi(argv[optind++]));
use_multiblock_welford = static_cast<bool>(std::atoi(argv[optind]));
if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6)
return (-1);
......@@ -128,7 +127,7 @@ class BatchNormFwdArg
using namespace ck;
template <typename InOutDataType, typename AccDataType>
template <typename InOutDataType, typename AccDataType, bool UseMultiblockInK>
bool bnorm_fwd_nhwc_test(bool do_verification,
int init_method,
bool time_kernel,
......@@ -273,73 +272,140 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
scaleBiasMeanVarStrides.end(),
i_scaleBiasMeanVarStrides.begin());
int result = 0;
// used for saving meansquare
DeviceMem workspace(sizeof(AccDataType) * 2 * resultSaveMean_ref.mDesc.GetElementSpaceSize() +
128);
using PassThroughOp = ck::tensor_operation::element_wise::PassThrough;
using DeviceBatchNormFwdInstance =
ck::tensor_operation::device::DeviceBatchNormFwdImpl<InOutDataType,
InOutDataType,
AccDataType,
AccDataType, // ScaleDataType
AccDataType, // BiasDataType
AccDataType, // MeanVarDataType
PassThroughOp, // YElementwiseOp
Rank,
NumReduceDim,
UseMultiblockInK,
256,
16,
16,
1,
2,
0,
1,
1,
1,
1,
1>;
void* p_tmp_mean = workspace.GetDeviceBuffer();
void* p_tmp_meansquare =
static_cast<char*>(p_tmp_mean) +
(sizeof(AccDataType) * resultSaveMean_ref.mDesc.GetElementSpaceSize() + 63) / 64 * 64;
auto batchnorm_fwd = DeviceBatchNormFwdInstance{};
result = bnorm_fwd<InOutDataType, AccDataType, Rank, NumReduceDim, false>(
time_kernel,
updateMovingAverage,
saveMeanAndInvVariance,
{0, 1, 2},
auto argument_ptr = batchnorm_fwd.MakeArgumentPointer(
i_inOutLengths,
i_inOutStrides,
i_inOutStrides,
{0, 1, 2},
i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
x_dev.GetDeviceBuffer(),
bnScale_dev.GetDeviceBuffer(),
bnBias_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(),
averageFactor,
updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr,
updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr,
epsilon,
PassThroughOp{},
y_dev.GetDeviceBuffer(),
saveMeanAndInvVariance ? resultSaveMean_dev.GetDeviceBuffer() : nullptr,
saveMeanAndInvVariance ? resultSaveInvVariance_dev.GetDeviceBuffer() : nullptr,
p_tmp_mean,
p_tmp_meansquare);
averageFactor,
updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr,
updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr);
if(result < 0)
if(!batchnorm_fwd.IsSupportedArgument(argument_ptr.get()))
{
std::cout << "The runtime parameters seems not supported by the BatchNorm device instance, "
"exiting!"
<< std::endl;
return (false);
};
size_t workspace_sz = batchnorm_fwd.GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
batchnorm_fwd.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
auto invoker_ptr = batchnorm_fwd.MakeInvokerPointer();
if(time_kernel)
{
float avg_time = 0.0f;
size_t num_bytes = 0;
size_t total_length = inOutLengths[0] * inOutLengths[1] * inOutLengths[2] * inOutLengths[3];
size_t invariant_length = inOutLengths[3];
avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
// inputing of x, scale, bias, outputing of y
num_bytes +=
total_length * sizeof(InOutDataType) * 2 + invariant_length * sizeof(AccDataType) * 2;
// outputing of mean, inv-variance
num_bytes += saveMeanAndInvVariance ? invariant_length * sizeof(AccDataType) * 2 : 0;
// updating of moving mean, variance
num_bytes += updateMovingAverage ? invariant_length * sizeof(AccDataType) * 4 : 0;
float gb_per_sec = num_bytes / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
}
else
(void)invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
bool pass = true;
if(do_verification)
{
auto batchNormFwd_ref = ReferenceBatchNormFwdInstance<InOutDataType, AccDataType>{};
using ReferenceBatchNormFwdInstance =
ck::tensor_operation::host::ReferenceBatchNormFwd_Input_N_H_W_C_Output_C<InOutDataType,
InOutDataType,
AccDataType,
AccDataType,
AccDataType,
AccDataType,
PassThroughOp>;
auto batchNormFwd_ref = ReferenceBatchNormFwdInstance{};
auto argument_ptr_ref = batchNormFwd_ref.MakeArgumentPointer(
i_inOutLengths,
i_inOutStrides,
i_inOutStrides,
{0, 1, 2},
i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
x.mData.data(),
bnScale.mData.data(),
bnBias.mData.data(),
y_ref.mData.data(),
0.1, // exponentialAverageFactor
updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr, // resultRunningMean
updateMovingAverage ? resultRunningVariance_ref.mData.data()
: nullptr, // resultRunningVariance
epsilon,
PassThroughOp{},
y_ref.mData.data(),
saveMeanAndInvVariance ? resultSaveMean_ref.mData.data() : nullptr,
saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr);
saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr,
averageFactor,
updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr,
updateMovingAverage ? resultRunningVariance_ref.mData.data() : nullptr);
if(!batchNormFwd_ref.IsSupportedArgument(argument_ptr_ref.get()))
{
std::cout
<< "The runtime parameters seems not supported by the BatchNorm instance, exiting!"
std::cout << "The runtime parameters seems not supported by the BatchNorm reference "
"instance, exiting!"
<< std::endl;
return (-2);
return (false);
};
auto invoker_ptr_ref = batchNormFwd_ref.MakeInvokerPointer();
......@@ -365,6 +431,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
if(saveMeanAndInvVariance)
{
using ck::host_common::dumpBufferToFile;
Tensor<AccDataType> resultSaveMean(scaleBiasMeanVarLengths);
Tensor<AccDataType> resultSaveInvVariance(scaleBiasMeanVarLengths);
......@@ -396,7 +464,17 @@ int main(int argc, char* argv[])
if(arg.data_type == 0)
{
pass = bnorm_fwd_nhwc_test<ck::half_t, float>(arg.do_verification,
if(arg.use_multiblock_welford)
pass = bnorm_fwd_nhwc_test<ck::half_t, float, true>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<ck::half_t, float, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
......@@ -407,7 +485,17 @@ int main(int argc, char* argv[])
}
else if(arg.data_type == 1)
{
pass = bnorm_fwd_nhwc_test<float, float>(arg.do_verification,
if(arg.use_multiblock_welford)
pass = bnorm_fwd_nhwc_test<float, float, true>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<float, float, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
......@@ -418,7 +506,17 @@ int main(int argc, char* argv[])
}
else if(arg.data_type == 3)
{
pass = bnorm_fwd_nhwc_test<int8_t, float>(arg.do_verification,
if(arg.use_multiblock_welford)
pass = bnorm_fwd_nhwc_test<int8_t, float, true>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<int8_t, float, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
......@@ -429,7 +527,17 @@ int main(int argc, char* argv[])
}
else if(arg.data_type == 5)
{
pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float>(arg.do_verification,
if(arg.use_multiblock_welford)
pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float, true>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
......@@ -440,7 +548,17 @@ int main(int argc, char* argv[])
}
else if(arg.data_type == 6)
{
pass = bnorm_fwd_nhwc_test<double, double>(arg.do_verification,
if(arg.use_multiblock_welford)
pass = bnorm_fwd_nhwc_test<double, double, true>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<double, double, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
......@@ -452,12 +570,21 @@ int main(int argc, char* argv[])
}
else
{
pass = bnorm_fwd_nhwc_test<ck::half_t, float>(true,
pass = bnorm_fwd_nhwc_test<ck::half_t, float, true>(true,
2,
false, // don't time kernel
{128, 16, 16, 1024},
{128, 16, 6, 512},
true,
true,
averageFactor,
epsilon);
pass = pass && bnorm_fwd_nhwc_test<ck::half_t, float, false>(true,
2,
false, // don't time kernel
{128, 16, 3, 1024},
true,
true,
false,
averageFactor,
epsilon);
};
......
......@@ -14,8 +14,12 @@
#include "batchnorm_common.hpp"
template <typename InOutDataType,
template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
ck::index_t Rank,
ck::index_t NumBatchNormReduceDim,
bool fastest_dim_is_reduced = false>
......@@ -26,7 +30,9 @@ int bnorm_infer(
const std::array<ck::index_t, Rank> xStrides,
const std::array<ck::index_t, Rank> yStrides,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
const void* p_x,
const void* p_scale,
const void* p_bias,
......@@ -41,11 +47,11 @@ int bnorm_infer(
"Invalid number of reduced dimensions for batchnorm!");
using DeviceNormalizeInstance = ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<InOutDataType, AccDataType, AccDataType, AccDataType, AccDataType>, // x, mean,
ck::Tuple<XDataType, AccDataType, AccDataType, AccDataType, AccDataType>, // x, mean,
// variance,
// scale,
// bias,
ck::Tuple<InOutDataType>, // y
ck::Tuple<YDataType>, // y
NormalizeInInfer,
Rank,
2, // MPerthread
......@@ -53,14 +59,18 @@ int bnorm_infer(
ck::Sequence<1>>; // scalarPerVector: y
auto invariantDims = get_invariant_dims<Rank, NumBatchNormReduceDim>(reduceDims);
std::array<ck::index_t, Rank> aligned_scaleBiasMeanVarStrides{0};
std::array<ck::index_t, Rank> aligned_bnScaleStrides{0};
std::array<ck::index_t, Rank> aligned_bnBiasStrides{0};
std::array<ck::index_t, Rank> aligned_bnMeanVarStrides{0};
int i = 0;
for(auto dim : invariantDims)
{
assert(xyLengths[dim] == bnScaleBiasMeanVarLengths[i]);
aligned_scaleBiasMeanVarStrides[dim] = bnScaleBiasMeanVarStrides[i];
aligned_bnScaleStrides[dim] = bnScaleStrides[i];
aligned_bnBiasStrides[dim] = bnBiasStrides[i];
aligned_bnMeanVarStrides[dim] = bnMeanVarStrides[i];
i++;
};
......@@ -84,10 +94,10 @@ int bnorm_infer(
auto argument_ptr1 = dev_normalize.MakeArgumentPointer(
xyLengths,
{xStrides,
aligned_scaleBiasMeanVarStrides,
aligned_scaleBiasMeanVarStrides,
aligned_scaleBiasMeanVarStrides,
aligned_scaleBiasMeanVarStrides},
aligned_bnMeanVarStrides,
aligned_bnMeanVarStrides,
aligned_bnScaleStrides,
aligned_bnBiasStrides},
{yStrides},
{p_x, p_estimatedMean, p_estimatedVariance, p_scale, p_bias},
{p_y},
......@@ -105,8 +115,10 @@ int bnorm_infer(
avg_time += invoker_ptr1->Run(argument_ptr1.get(), StreamConfig{nullptr, time_kernel});
num_bytes += (total_length * (1 * sizeof(InOutDataType) + 4 * sizeof(AccDataType)) +
total_length * sizeof(InOutDataType));
num_bytes += total_length * sizeof(XDataType) +
invariantLength *
(sizeof(ScaleDataType) + sizeof(BiasDataType) + 2 * sizeof(MeanVarDataType)) +
total_length * sizeof(YDataType);
if(time_kernel)
{
......
......@@ -18,11 +18,6 @@
#include "batchnorm_infer_impl.hpp"
template <typename InOutDataType, typename AccDataType>
using ReferenceBatchNormInferInstance =
ck::tensor_operation::host::ReferenceBatchNormInfer_Input_N_H_W_C_Output_C<InOutDataType,
AccDataType>;
static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'},
{"verify", required_argument, nullptr, 'v'},
{"help", no_argument, nullptr, '?'},
......@@ -236,14 +231,23 @@ bool bnorm_infer_nhwc_test(bool do_verification,
int result = 0;
result = bnorm_infer<InOutDataType, AccDataType, Rank, NumReduceDim, false>(
time_kernel,
result = bnorm_infer<InOutDataType,
InOutDataType,
AccDataType,
AccDataType,
AccDataType,
AccDataType,
Rank,
NumReduceDim,
false>(time_kernel,
{0, 1, 2},
i_inOutLengths,
i_inOutStrides,
i_inOutStrides,
i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
x_dev.GetDeviceBuffer(),
bnScale_dev.GetDeviceBuffer(),
bnBias_dev.GetDeviceBuffer(),
......@@ -259,7 +263,15 @@ bool bnorm_infer_nhwc_test(bool do_verification,
if(do_verification)
{
auto batchNormInfer_ref = ReferenceBatchNormInferInstance<InOutDataType, AccDataType>{};
using ReferenceBatchNormInferInstance =
ck::tensor_operation::host::ReferenceBatchNormInfer_Input_N_H_W_C_Output_C<
InOutDataType,
InOutDataType,
AccDataType,
AccDataType,
AccDataType,
AccDataType>;
auto batchNormInfer_ref = ReferenceBatchNormInferInstance{};
auto argument_ptr_ref =
batchNormInfer_ref.MakeArgumentPointer(i_inOutLengths,
......@@ -267,6 +279,8 @@ bool bnorm_infer_nhwc_test(bool do_verification,
i_inOutStrides,
i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
x.mData.data(),
bnScale.mData.data(),
bnBias.mData.data(),
......
......@@ -168,6 +168,11 @@
// tuning parameter
#define CK_WORKAROUND_SWDEV_325164 0
// workaround: disable broken fused attention kernel instance that does not pass validation
// issue found on mi100/#10738 combo when irregular KPerBlock attention kernel has acc0 scaling
// enabled
#define CK_WORKAROUND_DISABLE_BROKEN_ATTN_KERNEL_INSTANCE 1
namespace ck {
enum struct InMemoryDataOperationEnum
......
......@@ -14,7 +14,8 @@ namespace ck {
template <typename TensorLengths,
typename DimAccessOrder,
typename ScalarsPerAccess> // # of scalars per access in each dimension
typename ScalarsPerAccess,
bool SnakeCurved = true> // # of scalars per access in each dimension
struct SpaceFillingCurve
{
static constexpr index_t nDim = TensorLengths::Size();
......@@ -136,9 +137,10 @@ struct SpaceFillingCurve
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto idim) {
ordered_idx(idim) = forward_sweep[idim] ? ordered_access_idx[idim]
: ordered_access_lengths[idim] - 1 -
ordered_access_idx[idim];
ordered_idx(idim) =
!SnakeCurved || forward_sweep[idim]
? ordered_access_idx[idim]
: ordered_access_lengths[idim] - 1 - ordered_access_idx[idim];
});
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
......
......@@ -151,6 +151,27 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
return make_tuple(c_thread_m, c_thread_n);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
return make_tuple(Number<m0>{},
Number<n0>{},
waveId_m,
waveId_n,
blk_idx[I0],
blk_idx[I1],
blk_idx[I2],
blk_idx[I3]);
}
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
{
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
......@@ -724,6 +745,21 @@ struct BlockwiseGemmXdlops_v2
return make_tuple(c_thread_m, c_thread_n);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
return make_tuple(
m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
}
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
......
......@@ -24,7 +24,8 @@ template <typename ALayout,
typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation>
typename CElementwiseOperation,
bool MaskOutUpperTriangle> // TODO: enum for mask type
struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
......
......@@ -7,44 +7,55 @@
#include <vector>
#include "device_base.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename B0Layout,
typename B1Layout,
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<>
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation>
typename CElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b0,
const void* p_b1,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t O,
ck::index_t Batch,
std::vector<index_t> c_gs_ms_os_lengths,
std::vector<index_t> c_gs_ms_os_strides,
ck::index_t StrideA,
ck::index_t StrideB0,
ck::index_t StrideB1,
ck::index_t BatchStrideA,
ck::index_t BatchStrideB0,
ck::index_t BatchStrideB1,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op,
......
......@@ -13,31 +13,36 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <index_t Rank, index_t NumBatchNormReduceDim>
template <index_t Rank, index_t NumBatchNormReduceDim, typename YElementwiseOp>
struct DeviceBatchNormFwd : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
const void* p_x,
const void* bnScale,
const void* bnBias,
double epsilon,
const YElementwiseOp y_elementwise_op,
void* p_y,
void* resultSaveMean,
void* resultSaveInvVariance,
double exponentialAverageFactor,
void* resultRunningMean,
void* resultRunningVariance,
double epsilon,
void* resultSaveMean,
void* resultSaveInvVariance) = 0;
void* resultRunningVariance) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <index_t Rank, index_t NumBatchNormReduceDim>
using DeviceBatchNormFwdPtr = std::unique_ptr<DeviceBatchNormFwd<Rank, NumBatchNormReduceDim>>;
template <index_t Rank, index_t NumBatchNormReduceDim, typename YElementwiseOp>
using DeviceBatchNormFwdPtr =
std::unique_ptr<DeviceBatchNormFwd<Rank, NumBatchNormReduceDim, YElementwiseOp>>;
} // namespace device
} // namespace tensor_operation
......
......@@ -21,7 +21,9 @@ struct DeviceBatchNormInfer : public BaseOperator
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
const void* p_x,
const void* bnScale,
const void* bnBias,
......
......@@ -7,46 +7,50 @@
#include <vector>
#include "device_base.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename B0Layout,
typename B1Layout,
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<>
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation>
typename CElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator
{
struct ProblemDesc
{
// Overall problem shape
index_t M;
index_t N;
index_t K;
index_t O;
index_t Batch;
std::vector<index_t> a_gs_ms_ks_lengths;
std::vector<index_t> a_gs_ms_ks_strides;
// Stride for A/B0/B1; layout determined by template args
index_t StrideA;
index_t StrideB0;
index_t StrideB1;
index_t BatchStrideA;
index_t BatchStrideB0;
index_t BatchStrideB1;
std::vector<index_t> b0_gs_ns_ks_lengths;
std::vector<index_t> b0_gs_ns_ks_strides;
std::vector<index_t> b1_gs_os_ns_lengths;
std::vector<index_t> b1_gs_os_ns_strides;
// Lengths and strides for output C
std::vector<index_t> c_gs_ms_os_lengths;
std::vector<index_t> c_gs_ms_os_strides;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_lengths;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_strides;
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_lengths;
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_strides;
};
virtual std::unique_ptr<BaseArgument>
......@@ -54,6 +58,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator
std::vector<const void*> p_b0_vec,
std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
......
......@@ -3,27 +3,30 @@
#pragma once
#include <vector>
#include <array>
#include <memory>
#include <iostream>
#include "ck/utility/common_header.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InElementwiseOperation, typename AccElementwiseOperation>
template <index_t Rank,
index_t NumReduceDim,
typename InElementwiseOperation,
typename AccElementwiseOperation>
struct DeviceReduce : public BaseOperator
{
static constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim;
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<index_t> outLengths,
const std::vector<index_t> outStrides,
const std::vector<int> reduceDims,
MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
const std::array<index_t, Rank> inStrides,
const std::array<index_t, NumOutDim> outLengths,
const std::array<index_t, NumOutDim> outStrides,
const std::array<int, NumReduceDim> reduceDims,
float alpha,
float beta,
const void* in_dev,
......@@ -36,9 +39,12 @@ struct DeviceReduce : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename InElementwiseOperation, typename AccElementwiseOperation>
using DeviceReducePtr =
std::unique_ptr<DeviceReduce<InElementwiseOperation, AccElementwiseOperation>>;
template <index_t Rank,
index_t NumReduceDim,
typename InElementwiseOperation,
typename AccElementwiseOperation>
using DeviceReducePtr = std::unique_ptr<
DeviceReduce<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
......
......@@ -130,8 +130,11 @@ namespace device {
// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// FIXME: TensorSpecialization::Packed specialization does not cover all packed tensor cases, it
// merely degenerates into TensorSpecialization::Default with NumDimG/M/N/K = 1
// NOTE: TensorSpecialization::Packed specialized tensor is "packed" in a sense that each inner
// dimension in a dimension group (eg [G0, G1] in Gs, [M0, M1, M2] in Ms, etc.) are contiguous and
// ordered. Not in a sense that the tensor [G0, G1, ..., M0, M1, ..., N0, N1...] can be permuted
// while still being a contiguous, unpadded tensor. In other words, it merely degenerates into
// TensorSpecialization::Default with NumDimG/M/N/K = 1
//
// Detail- Packed tensor satisfies
// stride_0 = 1
......@@ -147,7 +150,7 @@ namespace device {
// essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1.
//
// Might need to expose dimension order to the interface to fully support
// TensorSpecialization::Packed.
// TensorSpecialization::Packed in a traditional sense of "packed" tensor
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
......
......@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
......@@ -196,7 +197,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation>
CElementwiseOperation,
MaskOutUpperTriangle>
{
using DeviceOp = DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle;
......@@ -315,29 +317,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
}
// to track the points which need to be set to -inf on C0
// Note: no need to reset M padding value, because they will not be stored out.
struct C0MatrixMask
{
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {}
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
{
return n >= NRaw_;
}
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
{
return IsUpperTriangle(m, n) || IsNOutOfBound(n);
}
private:
// index_t MRaw_;
index_t NRaw_;
};
struct ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
......@@ -383,6 +362,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using C0MatrixMask = conditional_t<MaskOutUpperTriangle,
C0MatrixMask_impl<MaskOutUpperTrianglePredicate>,
C0MatrixMask_impl<MaskDisabledPredicate>>;
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype
......
......@@ -5,9 +5,8 @@
#include <iostream>
#include <sstream>
#include <array>
#include "ck/utility/common_header.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
......@@ -41,7 +40,8 @@ template <typename InDataType,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
struct DeviceReduceMultiBlock
: public DeviceReduce<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
......@@ -58,8 +58,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank;
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr index_t NumSrcDim = Rank;
static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
// So far, only AtomicAdd is considered, other Atomic Operation like AtomicMax can be added
......@@ -81,13 +81,15 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
const std::vector<index_t>& inStrides,
static auto MakeSrc2dDescriptor(const std::array<index_t, Rank>& inLengths,
const std::array<index_t, Rank>& inStrides,
int blkGroupSize,
int numBlockTileIteration)
{
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
const auto tupleSrcLengths =
generate_tuple([&](auto I) { return inLengths[I]; }, Number<Rank>{});
const auto tupleSrcStrides =
generate_tuple([&](auto I) { return inStrides[I]; }, Number<Rank>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
......@@ -97,7 +99,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
make_tuple(typename arithmetic_sequence_gen<0, NumSrcDim, 1>::type{}),
make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc,
......@@ -111,10 +113,10 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto reduceDimLengths = generate_tuple(
[&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
return transform_tensor_descriptor(
inDesc,
......@@ -143,18 +145,20 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
return (in_grid_desc_m_k_padded);
};
static auto MakeDst1dDescriptor(const std::vector<index_t>& outLengths,
const std::vector<index_t>& outStrides)
static auto MakeDst1dDescriptor(const std::array<index_t, NumDstDim>& outLengths,
const std::array<index_t, NumDstDim>& outStrides)
{
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
const auto tupleDstLengths =
generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumDstDim>{});
const auto tupleDstStrides =
generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor(
outDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
make_tuple(typename arithmetic_sequence_gen<0, NumDstDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
......@@ -170,18 +174,20 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
return (out_grid_desc_m_padded);
};
static auto MakeDst1dDescriptorForBufferSet(const std::vector<index_t>& outLengths,
const std::vector<index_t>& outStrides)
static auto MakeDst1dDescriptorForBufferSet(const std::array<index_t, NumDstDim>& outLengths,
const std::array<index_t, NumDstDim>& outStrides)
{
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
const auto tupleDstLengths =
generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumDstDim>{});
const auto tupleDstStrides =
generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor(
outDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
make_tuple(typename arithmetic_sequence_gen<0, NumDstDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto length = out_grid_desc_m.GetLength(Number<0>{});
......@@ -198,11 +204,11 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
struct Argument : public BaseArgument
{
Argument(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<index_t> outLengths,
const std::vector<index_t> outStrides,
const std::vector<int> reduceDims,
Argument(const std::array<index_t, Rank> inLengths,
const std::array<index_t, Rank> inStrides,
const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims,
float alpha,
float beta,
const InDataType* in_dev,
......@@ -272,10 +278,10 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
math::integer_least_multiple(invariant_total_length, BlockSize) / BlockSize;
}
std::vector<index_t> inLengths_;
std::vector<index_t> inStrides_;
std::vector<index_t> outLengths_;
std::vector<index_t> outStrides_;
std::array<index_t, Rank> inLengths_;
std::array<index_t, Rank> inStrides_;
std::array<index_t, NumDstDim> outLengths_;
std::array<index_t, NumDstDim> outStrides_;
AccDataType alpha_;
AccDataType beta_;
......@@ -459,11 +465,11 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<index_t> outLengths,
const std::vector<index_t> outStrides,
const std::vector<int> reduceDims,
MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
const std::array<index_t, Rank> inStrides,
const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims,
float alpha,
float beta,
const void* in_dev,
......
......@@ -5,6 +5,7 @@
#include <iostream>
#include <sstream>
#include <array>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -34,7 +35,8 @@ template <typename InDataType,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
struct DeviceReduceThreadWise
: public DeviceReduce<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
......@@ -49,18 +51,20 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank;
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr index_t NumSrcDim = Rank;
static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize;
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
const std::vector<index_t>& inStrides)
static auto MakeSrc2dDescriptor(const std::array<index_t, Rank>& inLengths,
const std::array<index_t, Rank>& inStrides)
{
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
const auto tupleSrcLengths =
generate_tuple([&](auto I) { return inLengths[I]; }, Number<Rank>{});
const auto tupleSrcStrides =
generate_tuple([&](auto I) { return inStrides[I]; }, Number<Rank>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
......@@ -70,7 +74,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
make_tuple(typename arithmetic_sequence_gen<0, NumSrcDim, 1>::type{}),
make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc,
......@@ -84,10 +88,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto reduceDimLengths = generate_tuple(
[&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
return transform_tensor_descriptor(
inDesc,
......@@ -116,18 +120,20 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
return (in_grid_desc_m_k_padded);
};
static auto MakeDst1dDescriptor(const std::vector<index_t>& outLengths,
const std::vector<index_t>& outStrides)
static auto MakeDst1dDescriptor(const std::array<index_t, NumDstDim>& outLengths,
const std::array<index_t, NumDstDim>& outStrides)
{
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
const auto tupleDstLengths =
generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumDstDim>{});
const auto tupleDstStrides =
generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor(
outDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
make_tuple(typename arithmetic_sequence_gen<0, NumDstDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
......@@ -145,11 +151,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
struct Argument : public BaseArgument
{
Argument(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<index_t> outLengths,
const std::vector<index_t> outStrides,
const std::vector<int> reduceDims,
Argument(const std::array<index_t, Rank> inLengths,
const std::array<index_t, Rank> inStrides,
const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims,
float alpha,
float beta,
const InDataType* in_dev,
......@@ -187,10 +193,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
M_BlockTileSize;
}
std::vector<index_t> inLengths_;
std::vector<index_t> inStrides_;
std::vector<index_t> outLengths_;
std::vector<index_t> outStrides_;
std::array<index_t, Rank> inLengths_;
std::array<index_t, Rank> inStrides_;
std::array<index_t, NumDstDim> outLengths_;
std::array<index_t, NumDstDim> outStrides_;
AccDataType alpha_;
AccDataType beta_;
......@@ -321,11 +327,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<index_t> outLengths,
const std::vector<index_t> outStrides,
const std::vector<int> reduceDims,
MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
const std::array<index_t, Rank> inStrides,
const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims,
float alpha,
float beta,
const void* in_dev,
......
......@@ -8,12 +8,9 @@
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -50,29 +47,80 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
virtual index_t GetNumReduceDim() const override { return kNumReduceDim; }
// Used for freeloading of some handy functions from DeviceReduceMultiBlock
using Reduction = DeviceReduceMultiBlock<InDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
reduce::Add,
InElementwiseOp,
AccElementwiseOp,
InMemoryDataOperationEnum::Set,
false, // PropagateNan
false, // OutputIndex
false, // HaveIndexInputIfOutputIndex
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
1>; // OutDstVectorSize
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t NumSrcDim = Rank;
static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
const std::vector<index_t>& inStrides,
int blkGroupSize,
int numBlockTileIteration)
{
const auto tupleSrcLengths =
generate_tuple([&](auto I) { return inLengths[I]; }, Number<Rank>{});
const auto tupleSrcStrides =
generate_tuple([&](auto I) { return inStrides[I]; }, Number<Rank>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto in_grid_desc_m_k = [&]() {
if constexpr(reduceAllDim)
{
const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, NumSrcDim, 1>::type{}),
make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc,
make_tuple(make_unmerge_transform(make_tuple(
1, one_dim_inDesc.GetLength(Number<0>{})))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
}
else
{
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1));
const auto reduceDimLengths = generate_tuple(
[&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
const auto invariantDimLengths =
generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
return transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}();
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration;
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded);
};
using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1));
using GridwiseSoftmaxGeneric = GridwiseSoftmax_mk_to_mk<InDataType,
OutDataType,
......@@ -102,7 +150,7 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
OutDstVectorSize,
true>;
struct Argument : public Reduction::Argument
struct Argument : public BaseArgument
{
Argument(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
......@@ -113,42 +161,60 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
OutDataType* out_dev,
InElementwiseOp in_elementwise_op,
AccElementwiseOp acc_elementwise_op)
: Reduction::Argument(inLengths,
inStrides,
{},
{},
reduceDims,
0.0f, // alpha
0.0f, // beta
in_dev,
nullptr,
out_dev,
nullptr,
in_elementwise_op,
acc_elementwise_op),
// FIXME: The base class DeviceReduceMultiBlock::Argument only supports alpha/beta of
// float32 precision. Make it support any data type so the fields can be removed.
alpha_(alpha),
beta_(beta)
: alpha_{alpha},
beta_{beta},
in_dev_{in_dev},
out_dev_{out_dev},
in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op}
{
// std::cout << "blkGroupSize= " << this->blkGroupSize
// << ", numBlockTileIteration= " << this->numBlockTileIteration
// << ", gridSize=" << this->gridSize
// << ", invariant_total_length=" << this->invariant_total_length <<
// std::endl;
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
long_index_t invariant_total_length;
long_index_t reduce_total_length;
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
if constexpr(NumInvariantDim == 0)
invariant_lowest_length_ = 1;
else
invariant_lowest_length_ = inLengths_[NumInvariantDim - 1];
blkGroupSize = 1;
numBlockTileIteration = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize;
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize * blkGroupSize;
}
std::vector<index_t> inLengths_;
std::vector<index_t> inStrides_;
AccDataType alpha_;
AccDataType beta_;
const InDataType* in_dev_;
OutDataType* out_dev_;
InElementwiseOp in_elementwise_op_;
AccElementwiseOp acc_elementwise_op_;
index_t invariant_lowest_length_;
int blkGroupSize;
int numBlockTileIteration;
size_t gridSize;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
const auto in_grid_desc_m_k = DeviceSoftmaxImpl::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto out_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
const auto out_grid_desc_m_k = DeviceSoftmaxImpl::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
bool sweep_once =
......@@ -195,15 +261,32 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
{
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
if(!Reduction::IsSupportedArgument(p_arg_))
if constexpr(InSrcVectorDim == 0)
{
if constexpr(NumInvariantDim == 0)
{
return false;
}
if(p_arg_->inLengths_[Rank - 1] % OutDstVectorSize != 0)
else
{
if(p_arg_->inStrides_[NumInvariantDim - 1] != 1)
return false;
if(p_arg_->invariant_lowest_length_ % InSrcVectorSize != 0)
return false;
};
}
else
{
if(p_arg_->inStrides_[Rank - 1] != 1)
return false;
if(p_arg_->inLengths_[Rank - 1] % InSrcVectorSize != 0)
return false;
};
if(p_arg_->invariant_lowest_length_ % OutDstVectorSize != 0)
return false;
return true;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment