Unverified Commit dc70e3e1 authored by arai713's avatar arai713 Committed by GitHub
Browse files

Merge branch 'develop' into gridwise_2d

parents 10947a54 8ee36118
......@@ -15,13 +15,9 @@
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp"
#include "batchnorm_forward_impl.hpp"
template <typename InOutDataType, typename AccDataType>
using ReferenceBatchNormFwdInstance =
ck::tensor_operation::host::ReferenceBatchNormFwd_Input_N_H_W_C_Output_C<InOutDataType,
AccDataType>;
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'},
{"verify", required_argument, nullptr, 'v'},
......@@ -41,9 +37,10 @@ class BatchNormFwdArg
bool updateMovingAverage;
bool saveMeanAndInvVariance;
int data_type = 0;
int init_method = 2;
bool time_kernel = false;
int data_type = 0;
int init_method = 2;
bool time_kernel = false;
bool use_multiblock_welford = false;
public:
void show_usage(const char* cmd)
......@@ -68,6 +65,7 @@ class BatchNormFwdArg
"value, 3=decimal value)"
<< std::endl;
std::cout << "Arg5: time kernel (0=no, 1=yes)" << std::endl;
std::cout << "Arg6: use multi-block welford (0=n0, 1=yes)" << std::endl;
};
int processArgs(int argc, char* argv[])
......@@ -110,14 +108,15 @@ class BatchNormFwdArg
};
};
if(optind + 5 > argc)
if(optind + 6 > argc)
throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!");
data_type = std::atoi(argv[optind++]);
updateMovingAverage = std::atoi(argv[optind++]);
saveMeanAndInvVariance = std::atoi(argv[optind++]);
init_method = std::atoi(argv[optind++]);
time_kernel = static_cast<bool>(std::atoi(argv[optind]));
time_kernel = static_cast<bool>(std::atoi(argv[optind++]));
use_multiblock_welford = static_cast<bool>(std::atoi(argv[optind]));
if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6)
return (-1);
......@@ -128,7 +127,7 @@ class BatchNormFwdArg
using namespace ck;
template <typename InOutDataType, typename AccDataType>
template <typename InOutDataType, typename AccDataType, bool UseMultiblockInK>
bool bnorm_fwd_nhwc_test(bool do_verification,
int init_method,
bool time_kernel,
......@@ -273,73 +272,140 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
scaleBiasMeanVarStrides.end(),
i_scaleBiasMeanVarStrides.begin());
int result = 0;
// used for saving meansquare
DeviceMem workspace(sizeof(AccDataType) * 2 * resultSaveMean_ref.mDesc.GetElementSpaceSize() +
128);
void* p_tmp_mean = workspace.GetDeviceBuffer();
void* p_tmp_meansquare =
static_cast<char*>(p_tmp_mean) +
(sizeof(AccDataType) * resultSaveMean_ref.mDesc.GetElementSpaceSize() + 63) / 64 * 64;
result = bnorm_fwd<InOutDataType, AccDataType, Rank, NumReduceDim, false>(
time_kernel,
updateMovingAverage,
saveMeanAndInvVariance,
{0, 1, 2},
using PassThroughOp = ck::tensor_operation::element_wise::PassThrough;
using DeviceBatchNormFwdInstance =
ck::tensor_operation::device::DeviceBatchNormFwdImpl<InOutDataType,
InOutDataType,
AccDataType,
AccDataType, // ScaleDataType
AccDataType, // BiasDataType
AccDataType, // MeanVarDataType
PassThroughOp, // YElementwiseOp
Rank,
NumReduceDim,
UseMultiblockInK,
256,
16,
16,
1,
2,
0,
1,
1,
1,
1,
1>;
auto batchnorm_fwd = DeviceBatchNormFwdInstance{};
auto argument_ptr = batchnorm_fwd.MakeArgumentPointer(
i_inOutLengths,
i_inOutStrides,
i_inOutStrides,
{0, 1, 2},
i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
x_dev.GetDeviceBuffer(),
bnScale_dev.GetDeviceBuffer(),
bnBias_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(),
averageFactor,
updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr,
updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr,
epsilon,
PassThroughOp{},
y_dev.GetDeviceBuffer(),
saveMeanAndInvVariance ? resultSaveMean_dev.GetDeviceBuffer() : nullptr,
saveMeanAndInvVariance ? resultSaveInvVariance_dev.GetDeviceBuffer() : nullptr,
p_tmp_mean,
p_tmp_meansquare);
averageFactor,
updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr,
updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr);
if(result < 0)
if(!batchnorm_fwd.IsSupportedArgument(argument_ptr.get()))
{
std::cout << "The runtime parameters seems not supported by the BatchNorm device instance, "
"exiting!"
<< std::endl;
return (false);
};
size_t workspace_sz = batchnorm_fwd.GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
batchnorm_fwd.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
auto invoker_ptr = batchnorm_fwd.MakeInvokerPointer();
if(time_kernel)
{
float avg_time = 0.0f;
size_t num_bytes = 0;
size_t total_length = inOutLengths[0] * inOutLengths[1] * inOutLengths[2] * inOutLengths[3];
size_t invariant_length = inOutLengths[3];
avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
// inputing of x, scale, bias, outputing of y
num_bytes +=
total_length * sizeof(InOutDataType) * 2 + invariant_length * sizeof(AccDataType) * 2;
// outputing of mean, inv-variance
num_bytes += saveMeanAndInvVariance ? invariant_length * sizeof(AccDataType) * 2 : 0;
// updating of moving mean, variance
num_bytes += updateMovingAverage ? invariant_length * sizeof(AccDataType) * 4 : 0;
float gb_per_sec = num_bytes / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
}
else
(void)invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
bool pass = true;
if(do_verification)
{
auto batchNormFwd_ref = ReferenceBatchNormFwdInstance<InOutDataType, AccDataType>{};
using ReferenceBatchNormFwdInstance =
ck::tensor_operation::host::ReferenceBatchNormFwd_Input_N_H_W_C_Output_C<InOutDataType,
InOutDataType,
AccDataType,
AccDataType,
AccDataType,
AccDataType,
PassThroughOp>;
auto batchNormFwd_ref = ReferenceBatchNormFwdInstance{};
auto argument_ptr_ref = batchNormFwd_ref.MakeArgumentPointer(
i_inOutLengths,
i_inOutStrides,
i_inOutStrides,
{0, 1, 2},
i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
x.mData.data(),
bnScale.mData.data(),
bnBias.mData.data(),
y_ref.mData.data(),
0.1, // exponentialAverageFactor
updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr, // resultRunningMean
updateMovingAverage ? resultRunningVariance_ref.mData.data()
: nullptr, // resultRunningVariance
epsilon,
PassThroughOp{},
y_ref.mData.data(),
saveMeanAndInvVariance ? resultSaveMean_ref.mData.data() : nullptr,
saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr);
saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr,
averageFactor,
updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr,
updateMovingAverage ? resultRunningVariance_ref.mData.data() : nullptr);
if(!batchNormFwd_ref.IsSupportedArgument(argument_ptr_ref.get()))
{
std::cout
<< "The runtime parameters seems not supported by the BatchNorm instance, exiting!"
<< std::endl;
return (-2);
std::cout << "The runtime parameters seems not supported by the BatchNorm reference "
"instance, exiting!"
<< std::endl;
return (false);
};
auto invoker_ptr_ref = batchNormFwd_ref.MakeInvokerPointer();
......@@ -365,6 +431,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
if(saveMeanAndInvVariance)
{
using ck::host_common::dumpBufferToFile;
Tensor<AccDataType> resultSaveMean(scaleBiasMeanVarLengths);
Tensor<AccDataType> resultSaveInvVariance(scaleBiasMeanVarLengths);
......@@ -396,70 +464,129 @@ int main(int argc, char* argv[])
if(arg.data_type == 0)
{
pass = bnorm_fwd_nhwc_test<ck::half_t, float>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
if(arg.use_multiblock_welford)
pass = bnorm_fwd_nhwc_test<ck::half_t, float, true>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<ck::half_t, float, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
}
else if(arg.data_type == 1)
{
pass = bnorm_fwd_nhwc_test<float, float>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
if(arg.use_multiblock_welford)
pass = bnorm_fwd_nhwc_test<float, float, true>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<float, float, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
}
else if(arg.data_type == 3)
{
pass = bnorm_fwd_nhwc_test<int8_t, float>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
if(arg.use_multiblock_welford)
pass = bnorm_fwd_nhwc_test<int8_t, float, true>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<int8_t, float, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
}
else if(arg.data_type == 5)
{
pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
if(arg.use_multiblock_welford)
pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float, true>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
}
else if(arg.data_type == 6)
{
pass = bnorm_fwd_nhwc_test<double, double>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
if(arg.use_multiblock_welford)
pass = bnorm_fwd_nhwc_test<double, double, true>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<double, double, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
}
}
else
{
pass = bnorm_fwd_nhwc_test<ck::half_t, float>(true,
2,
false, // don't time kernel
{128, 16, 16, 1024},
true,
false,
averageFactor,
epsilon);
pass = bnorm_fwd_nhwc_test<ck::half_t, float, true>(true,
2,
false, // don't time kernel
{128, 16, 6, 512},
true,
true,
averageFactor,
epsilon);
pass = pass && bnorm_fwd_nhwc_test<ck::half_t, float, false>(true,
2,
false, // don't time kernel
{128, 16, 3, 1024},
true,
true,
averageFactor,
epsilon);
};
return (pass ? 0 : 1);
......
......@@ -14,8 +14,12 @@
#include "batchnorm_common.hpp"
template <typename InOutDataType,
template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
ck::index_t Rank,
ck::index_t NumBatchNormReduceDim,
bool fastest_dim_is_reduced = false>
......@@ -26,7 +30,9 @@ int bnorm_infer(
const std::array<ck::index_t, Rank> xStrides,
const std::array<ck::index_t, Rank> yStrides,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
const void* p_x,
const void* p_scale,
const void* p_bias,
......@@ -41,11 +47,11 @@ int bnorm_infer(
"Invalid number of reduced dimensions for batchnorm!");
using DeviceNormalizeInstance = ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<InOutDataType, AccDataType, AccDataType, AccDataType, AccDataType>, // x, mean,
// variance,
// scale,
// bias,
ck::Tuple<InOutDataType>, // y
ck::Tuple<XDataType, AccDataType, AccDataType, AccDataType, AccDataType>, // x, mean,
// variance,
// scale,
// bias,
ck::Tuple<YDataType>, // y
NormalizeInInfer,
Rank,
2, // MPerthread
......@@ -53,14 +59,18 @@ int bnorm_infer(
ck::Sequence<1>>; // scalarPerVector: y
auto invariantDims = get_invariant_dims<Rank, NumBatchNormReduceDim>(reduceDims);
std::array<ck::index_t, Rank> aligned_scaleBiasMeanVarStrides{0};
std::array<ck::index_t, Rank> aligned_bnScaleStrides{0};
std::array<ck::index_t, Rank> aligned_bnBiasStrides{0};
std::array<ck::index_t, Rank> aligned_bnMeanVarStrides{0};
int i = 0;
for(auto dim : invariantDims)
{
assert(xyLengths[dim] == bnScaleBiasMeanVarLengths[i]);
aligned_scaleBiasMeanVarStrides[dim] = bnScaleBiasMeanVarStrides[i];
aligned_bnScaleStrides[dim] = bnScaleStrides[i];
aligned_bnBiasStrides[dim] = bnBiasStrides[i];
aligned_bnMeanVarStrides[dim] = bnMeanVarStrides[i];
i++;
};
......@@ -84,10 +94,10 @@ int bnorm_infer(
auto argument_ptr1 = dev_normalize.MakeArgumentPointer(
xyLengths,
{xStrides,
aligned_scaleBiasMeanVarStrides,
aligned_scaleBiasMeanVarStrides,
aligned_scaleBiasMeanVarStrides,
aligned_scaleBiasMeanVarStrides},
aligned_bnMeanVarStrides,
aligned_bnMeanVarStrides,
aligned_bnScaleStrides,
aligned_bnBiasStrides},
{yStrides},
{p_x, p_estimatedMean, p_estimatedVariance, p_scale, p_bias},
{p_y},
......@@ -105,8 +115,10 @@ int bnorm_infer(
avg_time += invoker_ptr1->Run(argument_ptr1.get(), StreamConfig{nullptr, time_kernel});
num_bytes += (total_length * (1 * sizeof(InOutDataType) + 4 * sizeof(AccDataType)) +
total_length * sizeof(InOutDataType));
num_bytes += total_length * sizeof(XDataType) +
invariantLength *
(sizeof(ScaleDataType) + sizeof(BiasDataType) + 2 * sizeof(MeanVarDataType)) +
total_length * sizeof(YDataType);
if(time_kernel)
{
......
......@@ -18,11 +18,6 @@
#include "batchnorm_infer_impl.hpp"
template <typename InOutDataType, typename AccDataType>
using ReferenceBatchNormInferInstance =
ck::tensor_operation::host::ReferenceBatchNormInfer_Input_N_H_W_C_Output_C<InOutDataType,
AccDataType>;
static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'},
{"verify", required_argument, nullptr, 'v'},
{"help", no_argument, nullptr, '?'},
......@@ -236,21 +231,30 @@ bool bnorm_infer_nhwc_test(bool do_verification,
int result = 0;
result = bnorm_infer<InOutDataType, AccDataType, Rank, NumReduceDim, false>(
time_kernel,
{0, 1, 2},
i_inOutLengths,
i_inOutStrides,
i_inOutStrides,
i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides,
x_dev.GetDeviceBuffer(),
bnScale_dev.GetDeviceBuffer(),
bnBias_dev.GetDeviceBuffer(),
epsilon,
estimatedMean_dev.GetDeviceBuffer(),
estimatedVariance_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer());
result = bnorm_infer<InOutDataType,
InOutDataType,
AccDataType,
AccDataType,
AccDataType,
AccDataType,
Rank,
NumReduceDim,
false>(time_kernel,
{0, 1, 2},
i_inOutLengths,
i_inOutStrides,
i_inOutStrides,
i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
x_dev.GetDeviceBuffer(),
bnScale_dev.GetDeviceBuffer(),
bnBias_dev.GetDeviceBuffer(),
epsilon,
estimatedMean_dev.GetDeviceBuffer(),
estimatedVariance_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer());
if(result < 0)
return (false);
......@@ -259,7 +263,15 @@ bool bnorm_infer_nhwc_test(bool do_verification,
if(do_verification)
{
auto batchNormInfer_ref = ReferenceBatchNormInferInstance<InOutDataType, AccDataType>{};
using ReferenceBatchNormInferInstance =
ck::tensor_operation::host::ReferenceBatchNormInfer_Input_N_H_W_C_Output_C<
InOutDataType,
InOutDataType,
AccDataType,
AccDataType,
AccDataType,
AccDataType>;
auto batchNormInfer_ref = ReferenceBatchNormInferInstance{};
auto argument_ptr_ref =
batchNormInfer_ref.MakeArgumentPointer(i_inOutLengths,
......@@ -267,6 +279,8 @@ bool bnorm_infer_nhwc_test(bool do_verification,
i_inOutStrides,
i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
x.mData.data(),
bnScale.mData.data(),
bnBias.mData.data(),
......
......@@ -159,6 +159,11 @@
// tuning parameter
#define CK_WORKAROUND_SWDEV_325164 0
// workaround: disable broken fused attention kernel instance that does not pass validation
// issue found on mi100/#10738 combo when irregular KPerBlock attention kernel has acc0 scaling
// enabled
#define CK_WORKAROUND_DISABLE_BROKEN_ATTN_KERNEL_INSTANCE 1
namespace ck {
enum struct InMemoryDataOperationEnum
......
......@@ -14,7 +14,8 @@ namespace ck {
template <typename TensorLengths,
typename DimAccessOrder,
typename ScalarsPerAccess> // # of scalars per access in each dimension
typename ScalarsPerAccess,
bool SnakeCurved = true> // # of scalars per access in each dimension
struct SpaceFillingCurve
{
static constexpr index_t nDim = TensorLengths::Size();
......@@ -136,9 +137,10 @@ struct SpaceFillingCurve
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto idim) {
ordered_idx(idim) = forward_sweep[idim] ? ordered_access_idx[idim]
: ordered_access_lengths[idim] - 1 -
ordered_access_idx[idim];
ordered_idx(idim) =
!SnakeCurved || forward_sweep[idim]
? ordered_access_idx[idim]
: ordered_access_lengths[idim] - 1 - ordered_access_idx[idim];
});
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
......
......@@ -151,6 +151,27 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
return make_tuple(c_thread_m, c_thread_n);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
return make_tuple(Number<m0>{},
Number<n0>{},
waveId_m,
waveId_n,
blk_idx[I0],
blk_idx[I1],
blk_idx[I2],
blk_idx[I3]);
}
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
{
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
......@@ -724,6 +745,21 @@ struct BlockwiseGemmXdlops_v2
return make_tuple(c_thread_m, c_thread_n);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
return make_tuple(
m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
}
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
......
......@@ -24,7 +24,8 @@ template <typename ALayout,
typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation>
typename CElementwiseOperation,
bool MaskOutUpperTriangle> // TODO: enum for mask type
struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
......
......@@ -7,49 +7,60 @@
#include <vector>
#include "device_base.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename B0Layout,
typename B1Layout,
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<>
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation>
typename CElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b0,
const void* p_b1,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t O,
ck::index_t Batch,
std::vector<index_t> c_gs_ms_os_lengths,
std::vector<index_t> c_gs_ms_os_strides,
ck::index_t StrideA,
ck::index_t StrideB0,
ck::index_t StrideB1,
ck::index_t BatchStrideA,
ck::index_t BatchStrideB0,
ck::index_t BatchStrideB1,
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) = 0;
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b0,
const void* p_b1,
void* p_c,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
......
......@@ -13,31 +13,36 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <index_t Rank, index_t NumBatchNormReduceDim>
template <index_t Rank, index_t NumBatchNormReduceDim, typename YElementwiseOp>
struct DeviceBatchNormFwd : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
const void* p_x,
const void* bnScale,
const void* bnBias,
double epsilon,
const YElementwiseOp y_elementwise_op,
void* p_y,
void* resultSaveMean,
void* resultSaveInvVariance,
double exponentialAverageFactor,
void* resultRunningMean,
void* resultRunningVariance,
double epsilon,
void* resultSaveMean,
void* resultSaveInvVariance) = 0;
void* resultRunningVariance) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <index_t Rank, index_t NumBatchNormReduceDim>
using DeviceBatchNormFwdPtr = std::unique_ptr<DeviceBatchNormFwd<Rank, NumBatchNormReduceDim>>;
template <index_t Rank, index_t NumBatchNormReduceDim, typename YElementwiseOp>
using DeviceBatchNormFwdPtr =
std::unique_ptr<DeviceBatchNormFwd<Rank, NumBatchNormReduceDim, YElementwiseOp>>;
} // namespace device
} // namespace tensor_operation
......
......@@ -21,7 +21,9 @@ struct DeviceBatchNormInfer : public BaseOperator
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
const void* p_x,
const void* bnScale,
const void* bnBias,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// Convolution Forward:
// input : input image A[G, N, C, Hi, Wi],
// input : weight B[G, K, C, Y, X],
// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
// output : output image E[G, N, K, Ho, Wo]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGroupedConvFwd : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, // input image
const void* p_b, // weight
void* p_c, // output image
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <functional>
#include <iostream>
#include <iterator>
#include <numeric>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace {
struct ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC)
: BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC)
{
}
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideC_);
}
index_t BatchStrideA_;
index_t BatchStrideB_;
index_t BatchStrideC_;
};
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
* limitations.
*
* \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
* returns the 2D index of the tile that it computes. \see
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
*
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
* device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
* pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
* \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
*
*/
template <typename GridwiseGemm,
typename ABDataType,
typename CDataType,
typename AGridDesc_K0_M0_M1_K1,
typename BGridDesc_K0_N0_N1_K1,
typename CGridDesc_M0_M10_M11_N0_N10_N11,
typename Block2CTileMap,
typename ComputePtrOffsetOfBatch,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_conv_fwd_dl(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
CDataType* __restrict__ p_c_grid,
const index_t batch_count,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType);
__shared__ ABDataType p_shared[shared_block_size];
GridwiseGemm::Run(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_grid_desc_k0_m0_m1_k1,
b_grid_desc_k0_n0_n1_k1,
c_grid_desc_m0_m10_m11_n0_n10_n11,
block_2_ctile_map,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = batch_count;
ignore = a_grid_desc_k0_m0_m1_k1;
ignore = b_grid_desc_k0_n0_n1_k1;
ignore = c_grid_desc_m0_m10_m11_n0_n10_n11;
ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map;
compute_ptr_offset_of_batch.GetAPtrOffset(0);
compute_ptr_offset_of_batch.GetBPtrOffset(0);
compute_ptr_offset_of_batch.GetCPtrOffset(0);
#endif
}
} // namespace
//
// @brief Device Convolution operation.
//
// Supports:
// @li Forward convolution with up to 3 spatial dimentions
// @li Input tensor in GNWC data format
// @li Weight tensor in GKXC data format
// @li Output tensor in GNWK data format
//
// 1D:
// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
// 2D:
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
// 3D:
// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
//
template <
index_t NDimSpatial,
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
typename M1N1ThreadClusterM1Xs,
typename M1N1ThreadClusterN1Xs,
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false>
struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimSpatial,
ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
using DeviceOp = DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock};
template <typename ALay>
static auto
MakeAGridDescriptor_AK0_M_AK1(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
c_g_n_k_wos_lengths,
c_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
const auto M = in_gemmm_gemmk_desc.GetLength(I0);
const auto K = in_gemmm_gemmk_desc.GetLength(I1);
const auto AK0 = K / K1;
return transform_tensor_descriptor(
in_gemmm_gemmk_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0, K1)), make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <typename BLay>
static auto
MakeBGridDescriptor_BK0_N_BK1(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides);
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
const auto N = wei_gemmn_gemmk_desc.GetLength(I0);
const auto K = wei_gemmn_gemmk_desc.GetLength(I1);
const auto BK0 = K / K1;
return transform_tensor_descriptor(
wei_gemmn_gemmk_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, K1)), make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <typename CLay>
static auto
MakeCGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>(c_g_n_k_wos_lengths,
c_g_n_k_wos_strides);
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
return out_gemmm_gemmn_desc;
}
// desc for problem definition
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
MakeAGridDescriptor_AK0_M_AK1<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>;
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N<CLayout>({}, {}))>;
// GridwiseGemm
using GridwiseGemm =
GridwiseGemmDl_km_kn_mn_v1r3<BlockSize,
ADataType,
AccDataType,
CDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
using AGridDesc_K0_M0_M1_K1 =
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_AK0_M_AK1{}));
using BGridDesc_K0_N0_N1_K1 =
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_BK0_N_BK1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
using DefaultBlock2CTileMap =
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
// Argument
struct Argument : public BaseArgument
{
Argument(const void* p_a,
const void* p_b,
void* p_c,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a)},
p_b_grid_{static_cast<const BDataType*>(p_b)},
p_c_grid_{static_cast<CDataType*>(p_c)},
num_group_{a_g_n_c_wis_lengths[0]},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
c_g_n_k_wos_lengths,
c_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1<BLayout>(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N<CLayout>(c_g_n_k_wos_lengths,
c_g_n_k_wos_strides)},
a_grid_desc_k0_m0_m1_k1_{},
b_grid_desc_k0_n0_n1_k1_{},
c_grid_desc_m0_m10_m11_n0_n10_n11_{},
block_2_ctile_map_{},
compute_ptr_offset_of_batch_{
a_g_n_c_wis_strides[0], b_g_k_c_xs_strides[0], c_g_n_k_wos_strides[0]},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
c_g_n_k_wos_lengths_{c_g_n_k_wos_lengths},
c_g_n_k_wos_strides_{c_g_n_k_wos_strides},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
// A/B/E Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_batch_.BatchStrideC_ = c_g_n_k_wos_strides[0];
// populate desc for Ds/E
if(GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_))
{
a_grid_desc_k0_m0_m1_k1_ =
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_ak0_m_ak1_);
b_grid_desc_k0_n0_n1_k1_ =
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_bk0_n_bk1_);
c_grid_desc_m0_m10_m11_n0_n10_n11_ =
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_);
}
}
void Print() const
{
std::cout << "A[K0, M, K1]: " << a_grid_desc_ak0_m_ak1_ << std::endl;
std::cout << "B[K0, N, K1]: " << b_grid_desc_bk0_n_bk1_ << std::endl;
std::cout << "C[M, N]: " << c_grid_desc_m_n_ << std::endl;
std::cout << "num_group: " << num_group_ << std::endl;
std::cout << "A[k0, m0, m1, k1]: " << a_grid_desc_k0_m0_m1_k1_ << std::endl;
std::cout << "B[k0, n0, n1, k1]: " << b_grid_desc_k0_n0_n1_k1_ << std::endl;
std::cout << "A[m0, m10, m11, n0, n10, n11]: " << c_grid_desc_m0_m10_m11_n0_n10_n11_
<< std::endl;
}
// private:
// pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
// tensor descriptors for problem definiton
index_t num_group_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_;
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_;
CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_;
// block-to-e-tile map
DefaultBlock2CTileMap block_2_ctile_map_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
// for checking IsSupportedArgument()
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
std::array<index_t, NDimSpatial + 3> c_g_n_k_wos_lengths_;
std::array<index_t, NDimSpatial + 3> c_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial> conv_filter_strides_;
std::array<index_t, NDimSpatial> conv_filter_dilations_;
std::array<index_t, NDimSpatial> input_left_pads_;
std::array<index_t, NDimSpatial> input_right_pads_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
// if(stream_config.log_level_ > 0)
{
arg.Print();
}
if(!GridwiseGemm::CheckValidity(
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_))
{
throw std::runtime_error(
"wrong! DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK has invalid setting");
}
const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_.GetLength(I0),
arg.c_grid_desc_m_n_.GetLength(I1)) *
arg.num_group_;
auto launch_kernel = [&](auto has_main_k_block_loop,
auto has_double_tail_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
constexpr bool has_double_loop = has_double_tail_k_block_loop;
const auto kernel =
kernel_grouped_conv_fwd_dl<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
DeviceOp::AGridDesc_K0_M0_M1_K1,
DeviceOp::BGridDesc_K0_N0_N1_K1,
DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11,
DefaultBlock2CTileMap,
ComputePtrOffsetOfStridedBatch,
has_main_loop,
has_double_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_g_n_c_wis_lengths_[0], // Group count
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_,
arg.compute_ptr_offset_of_batch_);
};
const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0);
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
const bool has_double_tail_k_block_loop =
GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
return launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
return launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
return launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
namespace ctc = tensor_layout::convolution;
// check device
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030"))
{
return false;
}
// check ConvolutionForwardSpecialization
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t ConvStride = arg.conv_filter_strides_[i];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
{
std::cout << "Filter1x1Stride1Pad0 check: i = " << i << " X = " << X
<< " ConvStride = " << ConvStride << " LeftPad = " << LeftPad
<< " RightPad = " << RightPad << std::endl;
return false;
}
}
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// check if it's 1x1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
if(!(X == 1 && LeftPad == 0 && RightPad == 0))
{
std::cout << "Filter1x1Stride1Pad0 check: i = " << i << " X = " << X
<< " LeftPad = " << LeftPad << " RightPad = " << RightPad
<< std::endl;
return false;
}
}
}
// check vector access of A
// FIXME: layout
if constexpr(is_same_v<ALayout, ctc::G_NW_C> || is_same_v<ALayout, ctc::G_NHW_C> ||
is_same_v<ALayout, ctc::G_NDHW_C> || is_same_v<ALayout, ctc::GNWC> ||
is_same_v<ALayout, ctc::GNHWC> || is_same_v<ALayout, ctc::GNDHWC> ||
is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
is_same_v<ALayout, ctc::NDHWGC>)
{
auto srcVectorLengths = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1{};
if(srcVectorLengths[I1] != 1 || srcVectorLengths[I2] != 1)
{
return false;
}
if(K1 % srcVectorLengths[I3] != 0 || K0PerBlock % srcVectorLengths[I0] != 0)
{
return false;
}
const index_t C = arg.a_g_n_c_wis_lengths_[2];
if(C % (srcVectorLengths[I0] * srcVectorLengths[I3]) != 0)
{
return false;
}
}
else
{
return false;
}
// check vector access of B
// FIXME: layout
if constexpr(is_same_v<BLayout, ctc::G_K_X_C> || is_same_v<BLayout, ctc::G_K_YX_C> ||
is_same_v<BLayout, ctc::G_K_ZYX_C> || is_same_v<BLayout, ctc::GKXC> ||
is_same_v<BLayout, ctc::GKYXC> || is_same_v<BLayout, ctc::GKZYXC> ||
is_same_v<BLayout, ctc::KXGC> || is_same_v<BLayout, ctc::KYXGC> ||
is_same_v<BLayout, ctc::KZYXGC>)
{
auto srcVectorLengths = BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1{};
if(srcVectorLengths[I1] != 1 || srcVectorLengths[I2] != 1)
{
return false;
}
if(K1 % srcVectorLengths[I3] != 0 || K0PerBlock % srcVectorLengths[I0] != 0)
{
return false;
}
const index_t C = arg.b_g_k_c_xs_lengths_[2];
if(C % (srcVectorLengths[I0] * srcVectorLengths[I3]) != 0)
{
return false;
}
}
else
{
return false;
}
// check vector access of C
if constexpr(is_same_v<CLayout, ctc::G_NW_K> || is_same_v<CLayout, ctc::G_NHW_K> ||
is_same_v<CLayout, ctc::G_NDHW_K> || is_same_v<CLayout, ctc::GNWK> ||
is_same_v<CLayout, ctc::GNHWK> || is_same_v<CLayout, ctc::GNDHWK> ||
is_same_v<CLayout, ctc::NWGK> || is_same_v<CLayout, ctc::NHWGK> ||
is_same_v<CLayout, ctc::NDHWGK>)
{
const index_t K = arg.c_g_n_k_wos_lengths_[2];
if(!(K % CThreadTransferDstScalarPerVector == 0 && CThreadTransferSrcDstVectorDim == 5))
{
return false;
}
}
else
{
return false;
}
// check Gridwise GEMM
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_a,
const void* p_b,
void* p_c,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op)
{
return Argument{p_a,
p_b,
p_c,
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
c_g_n_k_wos_lengths,
c_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op) override
{
return std::make_unique<Argument>(p_a,
p_b,
p_c,
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
c_g_n_k_wos_lengths,
c_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_element_op,
b_element_op,
c_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization)
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -7,46 +7,50 @@
#include <vector>
#include "device_base.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename B0Layout,
typename B1Layout,
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<>
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation>
typename CElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator
{
struct ProblemDesc
{
// Overall problem shape
index_t M;
index_t N;
index_t K;
index_t O;
index_t Batch;
std::vector<index_t> a_gs_ms_ks_lengths;
std::vector<index_t> a_gs_ms_ks_strides;
// Stride for A/B0/B1; layout determined by template args
index_t StrideA;
index_t StrideB0;
index_t StrideB1;
index_t BatchStrideA;
index_t BatchStrideB0;
index_t BatchStrideB1;
std::vector<index_t> b0_gs_ns_ks_lengths;
std::vector<index_t> b0_gs_ns_ks_strides;
std::vector<index_t> b1_gs_os_ns_lengths;
std::vector<index_t> b1_gs_os_ns_strides;
// Lengths and strides for output C
std::vector<index_t> c_gs_ms_os_lengths;
std::vector<index_t> c_gs_ms_os_strides;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_lengths;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_strides;
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_lengths;
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_strides;
};
virtual std::unique_ptr<BaseArgument>
......@@ -54,6 +58,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator
std::vector<const void*> p_b0_vec,
std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
......
......@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -54,9 +55,8 @@ __global__ void
index_t right = group_count;
index_t group_id = index_t((left + right) / 2);
while((!(block_id >= arg_ptr[group_id].block_start_ &&
block_id < arg_ptr[group_id].block_end_)) &&
left <= right)
while(
(!(block_id >= arg_ptr[group_id].block_start_ && block_id < arg_ptr[group_id].block_end_)))
{
if(block_id < arg_ptr[group_id].block_start_)
{
......@@ -114,14 +114,17 @@ __global__ void
// Computes C = A * B0 * B1
// ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1)
template <typename ALayout,
typename BLayout, // B0Layout
typename B1Layout,
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<NumDimG, NumDimM, NumDimGemm1N>
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO, // NumDimGemm1N
typename ADataType,
typename BDataType,
typename B1DataType,
typename CDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
......@@ -130,6 +133,10 @@ template <typename ALayout,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
TensorSpecialization ASpec,
TensorSpecialization BSpec,
TensorSpecialization B1Spec,
TensorSpecialization CSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
......@@ -170,297 +177,152 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
bool MaskOutUpperTriangle,
MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
: public DeviceGroupedGemmSoftmaxGemmPermute<ALayout,
BLayout,
B1Layout,
CPermuteNumDims_G_M_Gemm1N,
: public DeviceGroupedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
BDataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation>
CElementwiseOperation,
MaskingSpec>
{
using DeviceOp = DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle;
using ProblemDesc =
typename DeviceGroupedGemmSoftmaxGemmPermute<ALayout,
BLayout,
B1Layout,
CPermuteNumDims_G_M_Gemm1N,
ADataType,
BDataType,
B1DataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation>::ProblemDesc;
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0");
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
// TODO ANT: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
#if 0
// TODO ANT: use alias
static constexpr index_t NumDimGemm0M = NumDimM;
static constexpr index_t NumDimGemm0N = NumDimN;
static constexpr index_t NumDimGemm0K = NumDimK;
static constexpr index_t NumDimGemm1M = NumDimM;
static constexpr index_t NumDimGemm1N = NumDimO;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using DeviceOp = DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle;
using ProblemDesc = typename DeviceGroupedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
BDataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
MaskingSpec>::ProblemDesc;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto matrix_padder =
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
}
}();
const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1
static auto MakeB1GridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
GemmSpec,
ASpec,
BSpec,
B1Spec,
CSpec>;
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
{
const auto b1_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, B1Layout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, B1Layout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw);
const auto N = b1_grid_desc_n_k.GetLength(I0);
const auto K = b1_grid_desc_n_k.GetLength(I1);
const auto B1K0 = K / B1K1;
return transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
Number<AK1>{});
}
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
static auto MakeCGridDescriptor_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec,
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
const std::vector<index_t>& b_gs_ns_ks_strides_vec)
{
constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0);
constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1);
constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N
assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
const auto to_tuple = [&](auto& vec, auto start, auto end) {
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
};
const auto c_ms_ns_lengths = to_tuple(
c_gs_ms_ns_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
const auto c_ms_ns_strides = to_tuple(
c_gs_ms_ns_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
// dimension Ids for M0, M1, ...
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
// dimension Ids for N0, N1, ...
constexpr auto nDimIds =
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimN, 1>::type{};
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(c_ms_ns_lengths, mDimIds);
// lengths for K0, K1, ...
const auto nLengths = get_container_subset(c_ms_ns_lengths, nDimIds);
// naive tensor C[M0, M1, M2, ..., N0, N1, N2...]
const auto c_grid_desc_ms_ns =
make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides);
// transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
const auto c_grid_desc_mraw_nraw = transform_tensor_descriptor(
c_grid_desc_ms_ns,
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
make_tuple(mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec),
Number<BK1>{});
}
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
static auto MakeCGridDescriptor_G_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec,
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
static auto
MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec)
{
constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0);
constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1);
constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N
assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
const auto to_tuple = [&](auto& vec, auto start, auto end) {
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
};
const auto c_gs_ms_ns_lengths =
to_tuple(c_gs_ms_ns_lengths_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
const auto c_gs_ms_ns_strides =
to_tuple(c_gs_ms_ns_strides_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
// dimension Ids for G0, G1, ...
constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{};
// dimension Ids for M0, M1, ...
constexpr auto mDimIds =
typename arithmetic_sequence_gen<NumDimG, NumDimG + NumDimM, 1>::type{};
// dimension Ids for N0, N1, ...
constexpr auto nDimIds = typename arithmetic_sequence_gen<NumDimG + NumDimM,
NumDimG + NumDimM + NumDimN,
1>::type{};
// lengths for G0, G1, ...
const auto gLengths = get_container_subset(c_gs_ms_ns_lengths, gDimIds);
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(c_gs_ms_ns_lengths, mDimIds);
// lengths for K0, K1, ...
const auto nLengths = get_container_subset(c_gs_ms_ns_lengths, nDimIds);
// naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
const auto c_grid_desc_gs_ms_ns =
make_naive_tensor_descriptor(c_gs_ms_ns_lengths, c_gs_ms_ns_strides);
// transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
const auto c_grid_desc_g_mraw_nraw =
transform_tensor_descriptor(c_grid_desc_gs_ms_ns,
make_tuple(make_merge_transform(gLengths),
make_merge_transform(mLengths),
make_merge_transform(nLengths)),
make_tuple(gDimIds, mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// this desc is only for calculating batch offset so no padding needed
return c_grid_desc_g_mraw_nraw;
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec,
b1_gs_gemm1ns_gemm1ks_strides_vec),
Number<B1K1>{});
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N({}, {}));
using CGridDesc_G_M_N = decltype(MakeCGridDescriptor_G_M_N({}, {}));
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
// to track the points which need to be set to -inf on C0
// Note: no need to reset M padding value, because they will not be stored out.
struct C0MatrixMask
constexpr static auto make_MaskOutPredicate()
{
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {}
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled)
{
return n >= NRaw_;
return MaskDisabledPredicate{};
}
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
{
return IsUpperTriangle(m, n) || IsNOutOfBound(n);
return MaskOutUpperTrianglePredicate{};
}
private:
// index_t MRaw_;
index_t NRaw_;
};
}
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
struct ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideB1,
CGridDesc_G_M_N c_grid_desc_g_m_n)
: BatchStrideA_(BatchStrideA),
BatchStrideB_(BatchStrideB),
BatchStrideB1_(BatchStrideB1),
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n)
{
}
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
......@@ -469,9 +331,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
}
private:
index_t BatchStrideA_;
index_t BatchStrideB_;
index_t BatchStrideB1_;
AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
};
......@@ -535,8 +397,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
matrix_padder.PadN,
MaskOutUpperTriangle>;
Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
......@@ -570,16 +432,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
struct GroupDeviceArg
{
// problem definiton
index_t M;
index_t N;
index_t K;
index_t O;
// lengths for the last dimensions of overall problem for sanity check of vector load/store
std::vector<index_t> raw_lengths_mz_nz_kz_gemm1nz_;
// Strides for the last dimensions of C for sanity check of vector load/store
index_t c_extent_lowest_;
index_t c_stride_lowest_;
// strides for the last dimensions of each tensor for sanity check of vector load/store
std::vector<index_t> a_mz_kz_strides_;
std::vector<index_t> b_nz_kz_strides_;
std::vector<index_t> b1_nz_kz_strides_;
std::vector<index_t> c_mz_gemm1nz_strides_;
// for gridwise gemm check
CGridDesc_M_N c_grid_desc_m_n_;
};
......@@ -591,6 +453,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
std::vector<const void*> p_b_vec,
std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -603,6 +467,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
b1_element_op_{b1_element_op},
c_element_op_{c_element_op}
{
// TODO ANT: implement bias addition
group_count_ = problem_desc_vec.size();
if(!(group_count_ == p_a_vec.size() && group_count_ == p_b_vec.size() &&
......@@ -611,6 +476,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size");
}
if(!(p_acc0_biases_vec.size() == p_acc1_biases_vec.size()))
{
throw std::runtime_error("wrong! acc0_bias_vec.size != acc1_bias_vec.size");
}
grid_size_ = 0;
for(std::size_t i = 0; i < group_count_; i++)
......@@ -620,14 +490,25 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]);
const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]);
const auto a_grid_desc_ak0_m_ak1 = DeviceOp::MakeAGridDescriptor_AK0_M_AK1(
problem_desc_vec[i].M, problem_desc_vec[i].K, problem_desc_vec[i].StrideA);
const auto b_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1(
problem_desc_vec[i].K, problem_desc_vec[i].N, problem_desc_vec[i].StrideB0);
const auto b1_grid_desc_bk0_n_bk1 = DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(
problem_desc_vec[i].N, problem_desc_vec[i].O, problem_desc_vec[i].StrideB1);
const auto c_grid_desc_m_n = DeviceOp::MakeCGridDescriptor_M_N(
problem_desc_vec[i].c_gs_ms_os_lengths, problem_desc_vec[i].c_gs_ms_os_strides);
const auto& problem_desc = problem_desc_vec[i];
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides);
const auto b1_grid_desc_bk0_n_bk1 = MakeB1GridDescriptor_BK0_N_BK1(
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
const auto c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(
problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
const auto a_grid_desc_g_m_k = Transform::MakeAGridDescriptor_G_M_K(
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_g_n_k = Transform::MakeB0GridDescriptor_G_N_K(
problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides);
const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
......@@ -635,25 +516,32 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
const index_t BlockStart = grid_size_;
const auto block_2_ctile_map = Block2CTileMap(c_grid_desc_m_n, BlockStart);
const index_t grid_size_grp = block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) *
problem_desc_vec[i].Batch;
const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0);
const index_t grid_size_grp =
block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) * batch_count;
const index_t BlockEnd = grid_size_ + grid_size_grp;
// batch stride
// TODO ANT: only keep batch stride in tensor desc to reduce scalar cache pressure
const auto c_grid_desc_g_m_n = DeviceOp::MakeCGridDescriptor_G_M_N(
problem_desc_vec[i].c_gs_ms_os_lengths, problem_desc_vec[i].c_gs_ms_os_strides);
const auto compute_base_ptr_of_batch =
ComputeBasePtrOfStridedBatch(problem_desc_vec[i].BatchStrideA,
problem_desc_vec[i].BatchStrideB0,
problem_desc_vec[i].BatchStrideB1,
c_grid_desc_g_m_n);
const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch(
a_grid_desc_g_m_k, b_grid_desc_g_n_k, b1_grid_desc_g_n_k, c_grid_desc_g_m_n);
// C0 mask
const auto c0_matrix_mask = C0MatrixMask(problem_desc_vec[i].N);
const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1));
grid_size_ += grid_size_grp;
// for each group, make sure acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias and
// so on
if(!(problem_desc.acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias &&
problem_desc.acc0_biases_gs_ms_ns_strides.size() == NumAcc0Bias &&
problem_desc.acc1_biases_gs_ms_os_lengths.size() == NumAcc1Bias &&
problem_desc.acc1_biases_gs_ms_os_strides.size() == NumAcc1Bias))
{
throw std::runtime_error(
"wrong! number of biases in function argument does not "
"match that in template argument");
}
group_kernel_args_.push_back({p_a_grid,
p_b_grid,
p_b1_grid,
......@@ -669,13 +557,20 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
BlockStart,
BlockEnd});
group_device_args_.push_back({problem_desc_vec[i].M,
problem_desc_vec[i].N,
problem_desc_vec[i].K,
problem_desc_vec[i].O,
problem_desc_vec[i].c_gs_ms_os_lengths.back(),
problem_desc_vec[i].c_gs_ms_os_strides.back(),
c_grid_desc_m_n});
group_device_args_.push_back(
{{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1],
problem_desc.b1_gs_os_ns_lengths[NumDimG + NumDimO - 1]},
{problem_desc.a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
problem_desc.a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]},
{problem_desc.b0_gs_ns_ks_strides[NumDimG + NumDimN - 1],
problem_desc.b0_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]},
{problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO - 1],
problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO + NumDimN - 1]},
{problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1],
problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]},
c_grid_desc_m_n});
}
}
......@@ -788,6 +683,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
return false;
}
// TODO ANT: Check if tensor specialization & strides mismatch
bool all_has_main_k_block_loop = true;
bool some_has_main_k_block_loop = false;
......@@ -815,19 +712,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
// Note: we need raw lengths since threadwise copy can not handle vector load when
// part of vector is out of bounds
const auto MRaw = device_arg.M;
const auto NRaw = device_arg.N;
const auto KRaw = device_arg.K;
const auto Gemm1NRaw = device_arg.O;
const auto MzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[0];
const auto NzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[1];
const auto KzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[2];
const auto Gemm1NzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[3];
// Check scalar per vector requirement
const auto a_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
const auto b_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
const auto b1_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
const auto c_extent_lowest = device_arg.c_extent_lowest_;
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw;
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? NzRaw : Gemm1NzRaw;
const auto c_extent_lowest = Gemm1NzRaw;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
......@@ -837,8 +731,22 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
return false;
}
// Check vector store requirement; assumes last dimension in N to be contiguous
if(device_arg.c_stride_lowest_ != 1)
// Check vector load/store requirement
const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
? device_arg.a_mz_kz_strides_[1]
: device_arg.a_mz_kz_strides_[0];
const auto b_stride_lowest = BBlockTransferSrcVectorDim == 2
? device_arg.b_nz_kz_strides_[1]
: device_arg.b_nz_kz_strides_[0];
const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2
? device_arg.b1_nz_kz_strides_[1]
: device_arg.b1_nz_kz_strides_[0];
const auto c_stride_lowest =
device_arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be
// contiguous
if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 ||
c_stride_lowest == 1))
{
return false;
}
......@@ -873,6 +781,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
std::vector<const void*> p_b_vec,
std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -884,6 +794,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
p_b_vec,
p_b1_vec,
p_c_vec,
p_acc0_biases_vec,
p_acc1_biases_vec,
problem_desc_vec,
a_element_op,
b_element_op,
......@@ -895,21 +807,26 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*> p_a_vec,
std::vector<const void*> p_b_vec,
std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec,
std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) override
std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<const void*> p_a_vec,
std::vector<const void*> p_b_vec,
std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(p_a_vec,
p_b_vec,
p_b1_vec,
p_c_vec,
p_acc0_biases_vec,
p_acc1_biases_vec,
problem_desc_vec,
a_element_op,
b_element_op,
......@@ -942,7 +859,12 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
<< Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", "
<< B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ">";
<< getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
<< "B0Spec" << getTensorSpecializationString(BSpec) << ", "
<< "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
<< "CSpec" << getTensorSpecializationString(CSpec) << ", "
<< getMaskingSpecializationString(MaskingSpec) << ">";
// clang-format on
return str.str();
......
......@@ -130,8 +130,11 @@ namespace device {
// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// FIXME: TensorSpecialization::Packed specialization does not cover all packed tensor cases, it
// merely degenerates into TensorSpecialization::Default with NumDimG/M/N/K = 1
// NOTE: TensorSpecialization::Packed specialized tensor is "packed" in a sense that each inner
// dimension in a dimension group (eg [G0, G1] in Gs, [M0, M1, M2] in Ms, etc.) are contiguous and
// ordered. Not in a sense that the tensor [G0, G1, ..., M0, M1, ..., N0, N1...] can be permuted
// while still being a contiguous, unpadded tensor. In other words, it merely degenerates into
// TensorSpecialization::Default with NumDimG/M/N/K = 1
//
// Detail- Packed tensor satisfies
// stride_0 = 1
......@@ -147,7 +150,7 @@ namespace device {
// essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1.
//
// Might need to expose dimension order to the interface to fully support
// TensorSpecialization::Packed.
// TensorSpecialization::Packed in a traditional sense of "packed" tensor
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
......
......@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -116,14 +117,17 @@ __global__ void
// Computes C = A * B0 * B1
// ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1)
template <typename ALayout,
typename BLayout, // B0Layout
typename B1Layout,
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<NumDimG, NumDimM, NumDimGemm1N>
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO, // NumDimGemm1N
typename ADataType,
typename BDataType,
typename B1DataType,
typename CDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
......@@ -132,6 +136,10 @@ template <typename ALayout,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
TensorSpecialization ASpec,
TensorSpecialization BSpec,
TensorSpecialization B1Spec,
TensorSpecialization CSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
......@@ -172,283 +180,135 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
bool MaskOutUpperTriangle,
MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
: public DeviceBatchedGemmSoftmaxGemmPermute<ALayout,
BLayout,
B1Layout,
CPermuteNumDims_G_M_Gemm1N,
: public DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
BDataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation>
CElementwiseOperation,
MaskingSpec>
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0");
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
// TODO ANT: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
#if 0
// TODO ANT: use alias
static constexpr index_t NumDimGemm0M = NumDimM;
static constexpr index_t NumDimGemm0N = NumDimN;
static constexpr index_t NumDimGemm0K = NumDimK;
static constexpr index_t NumDimGemm1M = NumDimM;
static constexpr index_t NumDimGemm1N = NumDimO;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto matrix_padder =
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
GemmSpec,
ASpec,
BSpec,
B1Spec,
CSpec>;
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
}
}();
const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1
static auto MakeB1GridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b1_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, B1Layout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, B1Layout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw);
const auto N = b1_grid_desc_n_k.GetLength(I0);
const auto K = b1_grid_desc_n_k.GetLength(I1);
const auto B1K0 = K / B1K1;
return transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
Number<AK1>{});
}
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
static auto MakeCGridDescriptor_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec,
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
const std::vector<index_t>& b_gs_ns_ks_strides_vec)
{
constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0);
constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1);
constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N
assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
const auto to_tuple = [&](auto& vec, auto start, auto end) {
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
};
const auto c_ms_ns_lengths = to_tuple(
c_gs_ms_ns_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
const auto c_ms_ns_strides = to_tuple(
c_gs_ms_ns_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
// dimension Ids for M0, M1, ...
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
// dimension Ids for N0, N1, ...
constexpr auto nDimIds =
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimN, 1>::type{};
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(c_ms_ns_lengths, mDimIds);
// lengths for K0, K1, ...
const auto nLengths = get_container_subset(c_ms_ns_lengths, nDimIds);
// naive tensor C[M0, M1, M2, ..., N0, N1, N2...]
const auto c_grid_desc_ms_ns =
make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides);
// transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
const auto c_grid_desc_mraw_nraw = transform_tensor_descriptor(
c_grid_desc_ms_ns,
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
make_tuple(mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec),
Number<BK1>{});
}
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
static auto MakeCGridDescriptor_G_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec,
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
static auto
MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec)
{
constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0);
constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1);
constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N
assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
const auto to_tuple = [&](auto& vec, auto start, auto end) {
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
};
const auto c_gs_ms_ns_lengths =
to_tuple(c_gs_ms_ns_lengths_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
const auto c_gs_ms_ns_strides =
to_tuple(c_gs_ms_ns_strides_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
// dimension Ids for G0, G1, ...
constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{};
// dimension Ids for M0, M1, ...
constexpr auto mDimIds =
typename arithmetic_sequence_gen<NumDimG, NumDimG + NumDimM, 1>::type{};
// dimension Ids for N0, N1, ...
constexpr auto nDimIds = typename arithmetic_sequence_gen<NumDimG + NumDimM,
NumDimG + NumDimM + NumDimN,
1>::type{};
// lengths for G0, G1, ...
const auto gLengths = get_container_subset(c_gs_ms_ns_lengths, gDimIds);
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(c_gs_ms_ns_lengths, mDimIds);
// lengths for K0, K1, ...
const auto nLengths = get_container_subset(c_gs_ms_ns_lengths, nDimIds);
// naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
const auto c_grid_desc_gs_ms_ns =
make_naive_tensor_descriptor(c_gs_ms_ns_lengths, c_gs_ms_ns_strides);
// transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
const auto c_grid_desc_g_mraw_nraw =
transform_tensor_descriptor(c_grid_desc_gs_ms_ns,
make_tuple(make_merge_transform(gLengths),
make_merge_transform(mLengths),
make_merge_transform(nLengths)),
make_tuple(gDimIds, mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// this desc is only for calculating batch offset so no padding needed
return c_grid_desc_g_mraw_nraw;
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec,
b1_gs_gemm1ns_gemm1ks_strides_vec),
Number<B1K1>{});
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N({}, {}));
using CGridDesc_G_M_N = decltype(MakeCGridDescriptor_G_M_N({}, {}));
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
// to track the points which need to be set to -inf on C0
// Note: no need to reset M padding value, because they will not be stored out.
struct C0MatrixMask
constexpr static auto make_MaskOutPredicate()
{
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {}
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled)
{
return n >= NRaw_;
return MaskDisabledPredicate{};
}
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
{
return IsUpperTriangle(m, n) || IsNOutOfBound(n);
return MaskOutUpperTrianglePredicate{};
}
private:
// index_t MRaw_;
index_t NRaw_;
};
}
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
struct ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideB1,
CGridDesc_G_M_N c_grid_desc_g_m_n)
: BatchStrideA_(BatchStrideA),
BatchStrideB_(BatchStrideB),
BatchStrideB1_(BatchStrideB1),
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n)
{
}
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
......@@ -457,9 +317,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
}
private:
index_t BatchStrideA_;
index_t BatchStrideB_;
index_t BatchStrideB1_;
AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
};
......@@ -523,47 +383,59 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
matrix_padder.PadN,
MaskOutUpperTriangle>;
Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>;
// Argument
// FIXME: constness
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
const B1DataType* p_b1_grid,
CDataType* p_c_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t Gemm1NRaw, // = ORaw
index_t Batch,
std::vector<index_t> c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
std::vector<index_t> c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
index_t StrideA,
index_t StrideB,
index_t StrideB1,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideB1,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op)
Argument(
const ADataType* p_a_grid,
const BDataType* p_b_grid,
const B1DataType* p_b1_grid,
CDataType* p_c_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
b1_grid_desc_bk0_n_bk1_{
DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(NRaw, Gemm1NRaw, StrideB1)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
c_grid_desc_g_m_n_{DeviceOp::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
a_grid_desc_g_m_k_{
Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_g_n_k_{
Transform::MakeB0GridDescriptor_G_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
b1_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op},
......@@ -571,14 +443,31 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op},
c_element_op_{c_element_op},
batch_count_(Batch),
c0_matrix_mask_{b_grid_desc_g_n_k_.GetLength(I1)},
raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1],
b1_gs_gemm1ns_gemm1ks_lengths[NumDimG + NumDimO - 1]},
a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]},
b_nz_kz_strides_{b_gs_ns_ks_strides[NumDimG + NumDimN - 1],
b_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]},
b1_nz_kz_strides_{b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO - 1],
b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]},
c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)},
compute_base_ptr_of_batch_{
BatchStrideA, BatchStrideB, BatchStrideB1, c_grid_desc_g_m_n_},
c0_matrix_mask_{NRaw},
raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw},
c_extent_lowest_{c_gs_ms_gemm1ns_lengths.back()},
c_stride_lowest_{c_gs_ms_gemm1ns_strides.back()}
a_grid_desc_g_m_k_, b_grid_desc_g_n_k_, b1_grid_desc_g_n_k_, c_grid_desc_g_m_n_}
{
// TODO ANT: implement bias addition
ignore = p_acc0_biases;
ignore = p_acc1_biases;
ignore = acc0_biases_gs_ms_ns_lengths;
ignore = acc0_biases_gs_ms_ns_strides;
ignore = acc1_biases_gs_ms_gemm1ns_lengths;
ignore = acc1_biases_gs_ms_gemm1ns_strides;
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
b1_grid_desc_bk0_n_bk1_,
......@@ -591,34 +480,66 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
}
}
// private:
void Print() const
{
std::cout << "a_grid_desc_g_m_k_: " << a_grid_desc_g_m_k_.GetLength(I0) << ", "
<< a_grid_desc_g_m_k_.GetLength(I1) << ", "
<< a_grid_desc_g_m_k_.GetLength(I2) << '\n';
// a_grid_desc_g_m_k_.Print();
std::cout << "b_grid_desc_g_n_k_: " << b_grid_desc_g_n_k_.GetLength(I0) << ", "
<< b_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b_grid_desc_g_n_k_.GetLength(I2) << '\n';
// b_grid_desc_g_n_k_.Print();
std::cout << "b1_grid_desc_g_n_k_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I2) << '\n';
// b1_grid_desc_g_n_k_.Print();
std::cout << "c_grid_desc_g_m_n_: " << c_grid_desc_g_m_n_.GetLength(I0) << ", "
<< c_grid_desc_g_m_n_.GetLength(I1) << ", "
<< c_grid_desc_g_m_n_.GetLength(I2) << '\n';
// c_grid_desc_g_m_n_.Print();
}
// pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
const B1DataType* p_b1_grid_;
CDataType* p_c_grid_;
// tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
AccElementwiseOperation acc_element_op_;
B1ElementwiseOperation b1_element_op_;
CElementwiseOperation c_element_op_;
index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
// check C0 masking and padding
C0MatrixMask c0_matrix_mask_;
// For robust IsSupportedArgument() check
std::vector<index_t> raw_lengths_m_n_k_o_;
index_t c_extent_lowest_;
index_t c_stride_lowest_;
std::vector<index_t> raw_lengths_mz_nz_kz_gemm1nz_;
std::vector<index_t> a_mz_kz_strides_;
std::vector<index_t> b_nz_kz_strides_;
std::vector<index_t> b1_nz_kz_strides_;
std::vector<index_t> c_mz_gemm1nz_strides_;
index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
};
// Invoker
......@@ -628,13 +549,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
if(!DeviceOp::IsSupportedArgument(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
throw std::runtime_error("wrong! unsupported argument");
}
const index_t grid_size =
......@@ -719,17 +636,24 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
#if 0
arg.Print();
#endif
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{
return false;
}
// TODO ANT: Check if tensor specialization & strides mismatch
// Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0);
const index_t c_gemm1n = arg.c_grid_desc_m_n_.GetLength(I1);
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n))
{
return false;
......@@ -737,19 +661,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
const auto MRaw = arg.raw_lengths_m_n_k_o_[0];
const auto NRaw = arg.raw_lengths_m_n_k_o_[1];
const auto KRaw = arg.raw_lengths_m_n_k_o_[2];
const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3];
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
const auto MzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[0];
const auto NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[1];
const auto KzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[2];
const auto Gemm1NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[3];
// Check scalar per vector requirement
const auto a_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
const auto b_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
const auto b1_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
const auto c_extent_lowest = arg.c_extent_lowest_;
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw;
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? NzRaw : Gemm1NzRaw;
const auto c_extent_lowest = Gemm1NzRaw;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
......@@ -759,8 +681,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return false;
}
// Check vector store requirement; assumes last dimension in N to be contiguous
if(arg.c_stride_lowest_ != 1)
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
const auto b_stride_lowest =
BBlockTransferSrcVectorDim == 2 ? arg.b_nz_kz_strides_[1] : arg.b_nz_kz_strides_[0];
const auto b1_stride_lowest =
B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_kz_strides_[1] : arg.b1_nz_kz_strides_[0];
const auto c_stride_lowest =
arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be contiguous
if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 ||
c_stride_lowest == 1))
{
return false;
}
......@@ -778,46 +710,51 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
const B1DataType* p_b1,
CDataType* p_c,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t Gemm1NRaw,
index_t Batch,
std::vector<index_t> c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
std::vector<index_t> c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
index_t StrideA,
index_t StrideB,
index_t StrideB1,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideB1,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op)
static auto MakeArgument(
const ADataType* p_a,
const BDataType* p_b,
const B1DataType* p_b1,
CDataType* p_c,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a,
p_b,
p_b1,
p_c,
MRaw,
NRaw,
KRaw,
Gemm1NRaw,
Batch,
c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides,
StrideA,
StrideB,
StrideB1,
BatchStrideA,
BatchStrideB,
BatchStrideB1,
p_acc0_biases,
p_acc1_biases,
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
a_element_op,
b_element_op,
acc_element_op,
......@@ -829,47 +766,51 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// polymorphic
// FIXME: constness
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const void* p_b1,
void* p_c,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t Gemm1NRaw,
index_t Batch,
std::vector<index_t> c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
std::vector<index_t> c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
index_t StrideA,
index_t StrideB,
index_t StrideB1,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideB1,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) override
std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b,
const void* p_b1,
void* p_c,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<const B1DataType*>(p_b1),
static_cast<CDataType*>(p_c),
MRaw,
NRaw,
KRaw,
Gemm1NRaw,
Batch,
c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides,
StrideA,
StrideB,
StrideB1,
BatchStrideA,
BatchStrideB,
BatchStrideB1,
p_acc0_biases, // cast in struct Argument
p_acc1_biases, // cast in struct Argument
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths,
acc1_biases_gs_ms_gemm1ns_strides,
a_element_op,
b_element_op,
acc_element_op,
......@@ -901,7 +842,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<< Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", "
<< B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ">";
<< getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
<< "B0Spec" << getTensorSpecializationString(BSpec) << ", "
<< "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
<< "CSpec" << getTensorSpecializationString(CSpec) << ", "
<< getMaskingSpecializationString(MaskingSpec) << ">";
// clang-format on
return str.str();
......
......@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
......@@ -196,7 +197,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation>
CElementwiseOperation,
MaskOutUpperTriangle>
{
using DeviceOp = DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle;
......@@ -315,29 +317,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
}
// to track the points which need to be set to -inf on C0
// Note: no need to reset M padding value, because they will not be stored out.
struct C0MatrixMask
{
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {}
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
{
return n >= NRaw_;
}
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
{
return IsUpperTriangle(m, n) || IsNOutOfBound(n);
}
private:
// index_t MRaw_;
index_t NRaw_;
};
struct ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
......@@ -383,6 +362,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using C0MatrixMask = conditional_t<MaskOutUpperTriangle,
C0MatrixMask_impl<MaskOutUpperTrianglePredicate>,
C0MatrixMask_impl<MaskDisabledPredicate>>;
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/welford_helper.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim,
bool UseMultiblockInK,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XSrcYDstVectorDim,
index_t XSrcVectorSize,
index_t YDstVectorSize,
index_t ScaleSrcVectorSize,
index_t BiasSrcVectorSize,
index_t MeanVarSrcDstVectorSize>
struct DeviceBatchNormFwdImpl
: public DeviceBatchNormFwd<Rank, NumBatchNormReduceDim, YElementwiseOp>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!");
static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
(XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeXY2dDescriptor(const std::array<index_t, Rank>& xyLengths,
const std::array<index_t, Rank>& xyStrides,
int blkGroupSize,
int numBlockTileIteration)
{
const auto tupleXYLengths =
generate_tuple([&](auto I) { return xyLengths[I]; }, Number<Rank>{});
const auto tupleXYStrides =
generate_tuple([&](auto I) { return xyStrides[I]; }, Number<Rank>{});
const auto raw_grid_desc = make_naive_tensor_descriptor(tupleXYLengths, tupleXYStrides);
const auto grid_desc_m_k = [&]() {
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths =
generate_tuple([&](auto I) { return xyLengths[NumInvariantDim + I]; },
Number<NumBatchNormReduceDim>{});
const auto invariantDimLengths =
generate_tuple([&](auto I) { return xyLengths[I]; }, Number<NumInvariantDim>{});
return transform_tensor_descriptor(raw_grid_desc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}();
const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{});
const int workSizePerBlock = K_BlockTileSize * numBlockTileIteration;
const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto kPad = workSizePerBlock * blkGroupSize - reduceLength;
auto grid_desc_m_k_padded =
transform_tensor_descriptor(grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, mPad),
make_right_pad_transform(reduceLength, kPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (grid_desc_m_k_padded);
};
static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
{
const auto grid_desc_m_g =
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, blkGroupSize));
const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto grid_desc_m_g_padded =
transform_tensor_descriptor(grid_desc_m_g,
make_tuple(make_right_pad_transform(invariantLength, mPad),
make_pass_through_transform(blkGroupSize)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (grid_desc_m_g_padded);
};
static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize)
{
const auto reduceLength = blkGroupSize;
const auto grid_desc_m_k =
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, reduceLength));
const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto kPad =
math::integer_least_multiple(reduceLength, KThreadClusterSize) - reduceLength;
auto grid_desc_m_k_padded =
transform_tensor_descriptor(grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, mPad),
make_right_pad_transform(reduceLength, kPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (grid_desc_m_k_padded);
};
static auto
MakeScaleBiasMeanVar1dDescriptor(const std::array<index_t, NumInvariantDim>& lengths,
const std::array<index_t, NumInvariantDim>& strides)
{
const auto tupleLengths =
generate_tuple([&](auto I) { return lengths[I]; }, Number<NumInvariantDim>{});
const auto tupleStrides =
generate_tuple([&](auto I) { return strides[I]; }, Number<NumInvariantDim>{});
auto raw_grid_desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides);
auto grid_desc_m = transform_tensor_descriptor(
raw_grid_desc,
make_tuple(make_merge_transform(tupleLengths)),
make_tuple(typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto invariantLength = grid_desc_m.GetLength(Number<0>{});
const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto grid_desc_m_padded =
transform_tensor_descriptor(grid_desc_m,
make_tuple(make_right_pad_transform(invariantLength, mPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return (grid_desc_m_padded);
};
using XYGridDesc_M_K = decltype(MakeXY2dDescriptor({1}, {1}, 1, 1));
using ScaleBiasMeanVarGridDesc_M = decltype(MakeScaleBiasMeanVar1dDescriptor({1}, {1}));
struct Argument : public BaseArgument
{
Argument(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
const XDataType* p_x,
const ScaleDataType* p_scale,
const BiasDataType* p_bias,
const YElementwiseOp y_elementwise_op,
double epsilon,
YDataType* p_y,
MeanVarDataType* resultSaveMean,
MeanVarDataType* resultSaveInvVariance,
double averageFactor,
MeanVarDataType* resultRunningMean,
MeanVarDataType* resultRunningVariance)
: bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
bnScaleStrides_(bnScaleStrides),
bnBiasStrides_(bnBiasStrides),
bnMeanVarStrides_(bnMeanVarStrides),
p_x_(p_x),
p_scale_(p_scale),
p_bias_(p_bias),
y_elementwise_op_(y_elementwise_op),
p_y_(p_y),
resultSaveMean_(resultSaveMean),
resultSaveInvVariance_(resultSaveInvVariance),
resultRunningMean_(resultRunningMean),
resultRunningVariance_(resultRunningVariance)
{
xyLengths_ =
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(xyLengths, reduceDims);
xStrides_ =
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(xStrides, reduceDims);
yStrides_ =
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(yStrides, reduceDims);
std::tie(invariant_length_, reduce_length_) =
get_2d_lengths<Rank, NumBatchNormReduceDim>(xyLengths_);
epsilon_ = type_convert<AccDataType>(epsilon);
averageFactor_ = type_convert<AccDataType>(averageFactor);
updateMovingAverage_ =
(resultRunningMean != nullptr && resultRunningVariance != nullptr);
saveMeanInvVariance_ = (resultSaveMean != nullptr && resultSaveInvVariance_ != nullptr);
if(UseMultiblockInK)
{
int iterations = 1;
while(true)
{
int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
// we want the blkGroupSize be not more than 128
if(testBlkGroupSize <= 128)
break;
iterations++;
};
blkGroupSize_ = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
numBlockTileIteration_ = iterations;
}
else
{
blkGroupSize_ = 1;
numBlockTileIteration_ = (reduce_length_ + K_BlockTileSize - 1) / K_BlockTileSize;
};
gridSize_ = (invariant_length_ + M_BlockTileSize - 1) / M_BlockTileSize * blkGroupSize_;
x_grid_desc_m_k_ =
MakeXY2dDescriptor(xyLengths_, xStrides_, blkGroupSize_, numBlockTileIteration_);
y_grid_desc_m_k_ =
MakeXY2dDescriptor(xyLengths_, yStrides_, blkGroupSize_, numBlockTileIteration_);
scale_grid_desc_m_ =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides_);
bias_grid_desc_m_ =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnBiasStrides_);
mean_var_grid_desc_m_ =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides_);
}
AccDataType epsilon_;
AccDataType averageFactor_;
bool updateMovingAverage_;
bool saveMeanInvVariance_;
std::array<index_t, Rank> xyLengths_;
std::array<index_t, Rank> xStrides_;
std::array<index_t, Rank> yStrides_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides_;
const XDataType* p_x_;
const ScaleDataType* p_scale_;
const BiasDataType* p_bias_;
const YElementwiseOp y_elementwise_op_;
YDataType* p_y_;
MeanVarDataType* resultSaveMean_;
MeanVarDataType* resultSaveInvVariance_;
MeanVarDataType* resultRunningMean_;
MeanVarDataType* resultRunningVariance_;
long_index_t invariant_length_;
long_index_t reduce_length_;
int blkGroupSize_;
int numBlockTileIteration_;
size_t gridSize_;
XYGridDesc_M_K x_grid_desc_m_k_;
XYGridDesc_M_K y_grid_desc_m_k_;
ScaleBiasMeanVarGridDesc_M scale_grid_desc_m_;
ScaleBiasMeanVarGridDesc_M bias_grid_desc_m_;
ScaleBiasMeanVarGridDesc_M mean_var_grid_desc_m_;
void* workspace_mean_;
void* workspace_variance_;
void* workspace_count_;
};
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
{
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
size_t workspace_size = 0;
if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
{
// workspace for welford intermediate mean
workspace_size +=
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64;
// workspace for welford intermediate variance
workspace_size +=
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64;
// workspace for welford intermediate count
workspace_size +=
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64;
}
return (workspace_size);
};
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
{
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
pArg_->p_workspace_ = p_workspace;
if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
{
// setup buffer used for intermediate welford mean
pArg_->workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
index_t mean_space_sz =
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType);
mean_space_sz = math::integer_least_multiple(mean_space_sz, 64);
// setup buffer used for intermediate welford varirance
pArg_->workspace_variance_ =
reinterpret_cast<char*>(pArg_->workspace_mean_) + mean_space_sz;
index_t variance_space_sz =
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType);
variance_space_sz = math::integer_least_multiple(variance_space_sz, 64);
// setup buffer used for intermediate welfor count
pArg_->workspace_count_ =
reinterpret_cast<char*>(pArg_->workspace_variance_) + variance_space_sz;
};
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float avg_time = 0;
if(UseMultiblockInK && arg.blkGroupSize_ > 1)
{
using GetReduceCountPerThreadFunctor =
GetReduceCountPerThreadForMultiblockWelford<K_BlockTileSize, KThreadSliceSize>;
GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
arg.blkGroupSize_, arg.numBlockTileIteration_, arg.reduce_length_);
const auto mean_var_count_grid_desc_m_g =
DeviceBatchNormFwdImpl::MakeMeanVarCountOutputMG2dDescriptor(
arg.invariant_length_, arg.blkGroupSize_);
const auto mean_var_count_grid_desc_m_k =
DeviceBatchNormFwdImpl::MakeMeanVarCountInputMK2dDescriptor(
arg.invariant_length_, arg.blkGroupSize_);
using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g);
using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k);
using GridwiseMultiblockWelfordFirstHalf_ =
GridwiseMultiblockWelfordFirstHalf<XDataType,
AccDataType,
MeanVarDataType,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_G,
GetReduceCountPerThreadFunctor,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XSrcYDstVectorDim,
XSrcVectorSize>;
using GridwiseWelfordSecondHalfBatchNormForwardFinal_ =
GridwiseWelfordSecondHalfBatchNormForwardFinal<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XSrcYDstVectorDim,
XSrcVectorSize,
YDstVectorSize,
ScaleSrcVectorSize,
BiasSrcVectorSize,
MeanVarSrcDstVectorSize>;
index_t numMeanVarCountBlockTileIteration =
(arg.blkGroupSize_ + KThreadClusterSize - 1) / KThreadClusterSize;
const auto kern_multiblock_welford_first_half =
kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_,
XDataType,
MeanVarDataType,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_G,
GetReduceCountPerThreadFunctor>;
const auto kern_welford_second_half_batchnorm_forward_final =
kernel_welford_second_half_batchnorm_forward_final<
GridwiseWelfordSecondHalfBatchNormForwardFinal_,
XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M>;
avg_time +=
launch_and_time_kernel(stream_config,
kern_multiblock_welford_first_half,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
mean_var_count_grid_desc_m_g,
get_reduce_count_per_thread,
arg.numBlockTileIteration_,
arg.p_x_,
static_cast<MeanVarDataType*>(arg.workspace_mean_),
static_cast<MeanVarDataType*>(arg.workspace_variance_),
static_cast<int32_t*>(arg.workspace_count_));
avg_time +=
launch_and_time_kernel(stream_config,
kern_welford_second_half_batchnorm_forward_final,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
arg.y_grid_desc_m_k_,
mean_var_count_grid_desc_m_k,
arg.scale_grid_desc_m_,
arg.bias_grid_desc_m_,
arg.mean_var_grid_desc_m_,
arg.blkGroupSize_,
arg.numBlockTileIteration_,
numMeanVarCountBlockTileIteration,
arg.epsilon_,
static_cast<MeanVarDataType*>(arg.workspace_mean_),
static_cast<MeanVarDataType*>(arg.workspace_variance_),
static_cast<int32_t*>(arg.workspace_count_),
arg.p_x_,
arg.p_scale_,
arg.p_bias_,
arg.y_elementwise_op_,
arg.p_y_,
arg.updateMovingAverage_,
arg.averageFactor_,
arg.resultRunningMean_,
arg.resultRunningVariance_,
arg.saveMeanInvVariance_,
arg.resultSaveMean_,
arg.resultSaveInvVariance_);
}
else
{
using GetReduceCountPerThreadFunctor =
GetReduceCountPerThreadForBlockwiseWelford<K_BlockTileSize, KThreadSliceSize>;
GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
arg.numBlockTileIteration_, arg.reduce_length_);
using GridwiseBatchNormForwardWithBlockwiseWelford_ =
GridwiseBatchNormForwardWithBlockwiseWelford<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M,
GetReduceCountPerThreadFunctor,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XSrcYDstVectorDim,
XSrcVectorSize,
YDstVectorSize,
ScaleSrcVectorSize,
BiasSrcVectorSize,
MeanVarSrcDstVectorSize>;
const auto kern_batchnorm_fwd = kernel_batchnorm_forward_with_blockwise_welford<
GridwiseBatchNormForwardWithBlockwiseWelford_,
XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M,
GetReduceCountPerThreadFunctor>;
avg_time += launch_and_time_kernel(stream_config,
kern_batchnorm_fwd,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
arg.y_grid_desc_m_k_,
arg.scale_grid_desc_m_,
arg.bias_grid_desc_m_,
arg.mean_var_grid_desc_m_,
get_reduce_count_per_thread,
arg.numBlockTileIteration_,
arg.epsilon_,
arg.p_x_,
arg.p_scale_,
arg.p_bias_,
arg.y_elementwise_op_,
arg.p_y_,
arg.updateMovingAverage_, // true or false
arg.averageFactor_,
arg.resultRunningMean_,
arg.resultRunningVariance_,
arg.saveMeanInvVariance_, // true or false
arg.resultSaveMean_,
arg.resultSaveInvVariance_);
};
return (avg_time);
};
float Run(const BaseArgument* pArg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(pArg), stream_config);
};
};
bool IsSupportedArgument(const BaseArgument* pArg) override
{
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
if constexpr(XSrcYDstVectorDim == 0)
{
if(pArg_->xStrides_[NumInvariantDim - 1] != 1 ||
pArg_->yStrides_[NumInvariantDim - 1] != 1)
return false;
if(pArg_->xyLengths_[NumInvariantDim - 1] % XSrcVectorSize != 0 ||
pArg_->xyLengths_[NumInvariantDim - 1] % YDstVectorSize != 0)
return false;
}
else
{
if(pArg_->xStrides_[Rank - 1] != 1 || pArg_->yStrides_[Rank - 1] != 1)
return false;
if(pArg_->xyLengths_[Rank - 1] % XSrcVectorSize != 0 ||
pArg_->xyLengths_[Rank - 1] % YDstVectorSize != 0)
return false;
};
if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1)
return false;
if(pArg_->bnBiasStrides_[NumInvariantDim - 1] != 1 && BiasSrcVectorSize != 1)
return false;
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0)
return false;
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % BiasSrcVectorSize != 0)
return false;
if(pArg_->bnMeanVarStrides_[NumInvariantDim - 1] != 1 && MeanVarSrcDstVectorSize != 1)
return false;
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % MeanVarSrcDstVectorSize != 0)
return false;
bool is_valid = true;
static_for<0, NumInvariantDim, 1>{}([&](auto I) {
if(pArg_->xyLengths_[I] != pArg_->bnScaleBiasMeanVarLengths_[I])
is_valid = false;
});
if(!is_valid)
return false;
return true;
};
std::unique_ptr<BaseArgument> MakeArgumentPointer(
const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
const void* p_x,
const void* p_scale,
const void* p_bias,
double epsilon,
const YElementwiseOp y_elementwise_op,
void* p_y,
void* resultSaveMean,
void* resultSaveInvVariance,
double averageFactor,
void* resultRunningMean,
void* resultRunningVariance) override
{
return std::make_unique<Argument>(xyLengths,
xStrides,
yStrides,
reduceDims,
bnScaleBiasMeanVarLengths,
bnScaleStrides,
bnBiasStrides,
bnMeanVarStrides,
static_cast<const XDataType*>(p_x),
static_cast<const ScaleDataType*>(p_scale),
static_cast<const BiasDataType*>(p_bias),
y_elementwise_op,
epsilon,
static_cast<YDataType*>(p_y),
static_cast<MeanVarDataType*>(resultSaveMean),
static_cast<MeanVarDataType*>(resultSaveInvVariance),
averageFactor,
static_cast<MeanVarDataType*>(resultRunningMean),
static_cast<MeanVarDataType*>(resultRunningVariance));
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchNormFwdImpl<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "XSrcYDstVectorDim_" << XSrcYDstVectorDim << ",";
str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << BiasSrcVectorSize << "_mean_var_" << MeanVarSrcDstVectorSize << "_Y" << YDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -214,6 +214,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
M1PerThread,
N1PerThread,
KPerThread,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
enum struct MaskingSpecialization
{
MaskDisabled,
MaskOutUpperTriangle
};
inline std::string getMaskingSpecializationString(const MaskingSpecialization& s)
{
switch(s)
{
case MaskingSpecialization::MaskDisabled: return "MaskDisabled";
case MaskingSpecialization::MaskOutUpperTriangle: return "MaskOutUpperTriangle";
default: return "Unrecognized specialization!";
}
}
struct MaskDisabledPredicate
{
__host__ __device__ constexpr bool operator()(index_t /*m*/, index_t /*n*/) const
{
return false;
};
__host__ __device__ constexpr bool
IsTileSkippable(index_t /*m*/, index_t /*n*/, index_t /*m_tile*/, index_t /*n_tile*/) const
{
return false;
}
};
struct MaskOutUpperTrianglePredicate
{
__host__ __device__ constexpr bool operator()(index_t m, index_t n) const { return n > m; }
__host__ __device__ constexpr bool
IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t /*n_tile*/) const
{
return operator()(m + m_tile - 1, n);
}
};
// to track the points which need to be set to -inf on C0
// Note: no need to reset M padding value, because they will not be stored out.
template <typename MaskOutPredicate>
struct C0MatrixMask_impl
{
C0MatrixMask_impl(index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) {}
__host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const
{
return n >= NRaw_;
}
__host__ __device__ constexpr bool IsMaskedElement(index_t m, index_t n) const
{
return predicate_(m, n) || IsNOutOfBound(n);
}
__host__ __device__ constexpr bool
IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t n_tile) const
{
return predicate_.IsTileSkippable(m, n, m_tile, n_tile);
}
private:
// index_t MRaw_;
index_t NRaw_;
MaskOutPredicate predicate_;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment