Commit 4fec5ad3 authored by aska-0096's avatar aska-0096
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/composable_kernel into wmma_op

parents 24faa1fc 87fd1152
...@@ -15,13 +15,9 @@ ...@@ -15,13 +15,9 @@
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp" #include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "batchnorm_forward_impl.hpp" #include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
template <typename InOutDataType, typename AccDataType>
using ReferenceBatchNormFwdInstance =
ck::tensor_operation::host::ReferenceBatchNormFwd_Input_N_H_W_C_Output_C<InOutDataType,
AccDataType>;
static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'}, static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'},
{"verify", required_argument, nullptr, 'v'}, {"verify", required_argument, nullptr, 'v'},
...@@ -41,9 +37,10 @@ class BatchNormFwdArg ...@@ -41,9 +37,10 @@ class BatchNormFwdArg
bool updateMovingAverage; bool updateMovingAverage;
bool saveMeanAndInvVariance; bool saveMeanAndInvVariance;
int data_type = 0; int data_type = 0;
int init_method = 2; int init_method = 2;
bool time_kernel = false; bool time_kernel = false;
bool use_multiblock_welford = false;
public: public:
void show_usage(const char* cmd) void show_usage(const char* cmd)
...@@ -68,6 +65,7 @@ class BatchNormFwdArg ...@@ -68,6 +65,7 @@ class BatchNormFwdArg
"value, 3=decimal value)" "value, 3=decimal value)"
<< std::endl; << std::endl;
std::cout << "Arg5: time kernel (0=no, 1=yes)" << std::endl; std::cout << "Arg5: time kernel (0=no, 1=yes)" << std::endl;
std::cout << "Arg6: use multi-block welford (0=n0, 1=yes)" << std::endl;
}; };
int processArgs(int argc, char* argv[]) int processArgs(int argc, char* argv[])
...@@ -110,14 +108,15 @@ class BatchNormFwdArg ...@@ -110,14 +108,15 @@ class BatchNormFwdArg
}; };
}; };
if(optind + 5 > argc) if(optind + 6 > argc)
throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!");
data_type = std::atoi(argv[optind++]); data_type = std::atoi(argv[optind++]);
updateMovingAverage = std::atoi(argv[optind++]); updateMovingAverage = std::atoi(argv[optind++]);
saveMeanAndInvVariance = std::atoi(argv[optind++]); saveMeanAndInvVariance = std::atoi(argv[optind++]);
init_method = std::atoi(argv[optind++]); init_method = std::atoi(argv[optind++]);
time_kernel = static_cast<bool>(std::atoi(argv[optind])); time_kernel = static_cast<bool>(std::atoi(argv[optind++]));
use_multiblock_welford = static_cast<bool>(std::atoi(argv[optind]));
if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6) if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6)
return (-1); return (-1);
...@@ -128,7 +127,7 @@ class BatchNormFwdArg ...@@ -128,7 +127,7 @@ class BatchNormFwdArg
using namespace ck; using namespace ck;
template <typename InOutDataType, typename AccDataType> template <typename InOutDataType, typename AccDataType, bool UseMultiblockInK>
bool bnorm_fwd_nhwc_test(bool do_verification, bool bnorm_fwd_nhwc_test(bool do_verification,
int init_method, int init_method,
bool time_kernel, bool time_kernel,
...@@ -273,73 +272,140 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -273,73 +272,140 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
scaleBiasMeanVarStrides.end(), scaleBiasMeanVarStrides.end(),
i_scaleBiasMeanVarStrides.begin()); i_scaleBiasMeanVarStrides.begin());
int result = 0; using PassThroughOp = ck::tensor_operation::element_wise::PassThrough;
// used for saving meansquare using DeviceBatchNormFwdInstance =
DeviceMem workspace(sizeof(AccDataType) * 2 * resultSaveMean_ref.mDesc.GetElementSpaceSize() + ck::tensor_operation::device::DeviceBatchNormFwdImpl<InOutDataType,
128); InOutDataType,
AccDataType,
void* p_tmp_mean = workspace.GetDeviceBuffer(); AccDataType, // ScaleDataType
void* p_tmp_meansquare = AccDataType, // BiasDataType
static_cast<char*>(p_tmp_mean) + AccDataType, // MeanVarDataType
(sizeof(AccDataType) * resultSaveMean_ref.mDesc.GetElementSpaceSize() + 63) / 64 * 64; PassThroughOp, // YElementwiseOp
Rank,
result = bnorm_fwd<InOutDataType, AccDataType, Rank, NumReduceDim, false>( NumReduceDim,
time_kernel, UseMultiblockInK,
updateMovingAverage, 256,
saveMeanAndInvVariance, 16,
{0, 1, 2}, 16,
1,
2,
0,
1,
1,
1,
1,
1>;
auto batchnorm_fwd = DeviceBatchNormFwdInstance{};
auto argument_ptr = batchnorm_fwd.MakeArgumentPointer(
i_inOutLengths, i_inOutLengths,
i_inOutStrides, i_inOutStrides,
i_inOutStrides, i_inOutStrides,
{0, 1, 2},
i_scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
x_dev.GetDeviceBuffer(), x_dev.GetDeviceBuffer(),
bnScale_dev.GetDeviceBuffer(), bnScale_dev.GetDeviceBuffer(),
bnBias_dev.GetDeviceBuffer(), bnBias_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(),
averageFactor,
updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr,
updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr,
epsilon, epsilon,
PassThroughOp{},
y_dev.GetDeviceBuffer(),
saveMeanAndInvVariance ? resultSaveMean_dev.GetDeviceBuffer() : nullptr, saveMeanAndInvVariance ? resultSaveMean_dev.GetDeviceBuffer() : nullptr,
saveMeanAndInvVariance ? resultSaveInvVariance_dev.GetDeviceBuffer() : nullptr, saveMeanAndInvVariance ? resultSaveInvVariance_dev.GetDeviceBuffer() : nullptr,
p_tmp_mean, averageFactor,
p_tmp_meansquare); updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr,
updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr);
if(result < 0) if(!batchnorm_fwd.IsSupportedArgument(argument_ptr.get()))
{
std::cout << "The runtime parameters seems not supported by the BatchNorm device instance, "
"exiting!"
<< std::endl;
return (false); return (false);
};
size_t workspace_sz = batchnorm_fwd.GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
batchnorm_fwd.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
auto invoker_ptr = batchnorm_fwd.MakeInvokerPointer();
if(time_kernel)
{
float avg_time = 0.0f;
size_t num_bytes = 0;
size_t total_length = inOutLengths[0] * inOutLengths[1] * inOutLengths[2] * inOutLengths[3];
size_t invariant_length = inOutLengths[3];
avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
// inputing of x, scale, bias, outputing of y
num_bytes +=
total_length * sizeof(InOutDataType) * 2 + invariant_length * sizeof(AccDataType) * 2;
// outputing of mean, inv-variance
num_bytes += saveMeanAndInvVariance ? invariant_length * sizeof(AccDataType) * 2 : 0;
// updating of moving mean, variance
num_bytes += updateMovingAverage ? invariant_length * sizeof(AccDataType) * 4 : 0;
float gb_per_sec = num_bytes / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
}
else
(void)invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
bool pass = true; bool pass = true;
if(do_verification) if(do_verification)
{ {
auto batchNormFwd_ref = ReferenceBatchNormFwdInstance<InOutDataType, AccDataType>{};
using ReferenceBatchNormFwdInstance =
ck::tensor_operation::host::ReferenceBatchNormFwd_Input_N_H_W_C_Output_C<InOutDataType,
InOutDataType,
AccDataType,
AccDataType,
AccDataType,
AccDataType,
PassThroughOp>;
auto batchNormFwd_ref = ReferenceBatchNormFwdInstance{};
auto argument_ptr_ref = batchNormFwd_ref.MakeArgumentPointer( auto argument_ptr_ref = batchNormFwd_ref.MakeArgumentPointer(
i_inOutLengths, i_inOutLengths,
i_inOutStrides, i_inOutStrides,
i_inOutStrides, i_inOutStrides,
{0, 1, 2},
i_scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
x.mData.data(), x.mData.data(),
bnScale.mData.data(), bnScale.mData.data(),
bnBias.mData.data(), bnBias.mData.data(),
y_ref.mData.data(),
0.1, // exponentialAverageFactor
updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr, // resultRunningMean
updateMovingAverage ? resultRunningVariance_ref.mData.data()
: nullptr, // resultRunningVariance
epsilon, epsilon,
PassThroughOp{},
y_ref.mData.data(),
saveMeanAndInvVariance ? resultSaveMean_ref.mData.data() : nullptr, saveMeanAndInvVariance ? resultSaveMean_ref.mData.data() : nullptr,
saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr); saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr,
averageFactor,
updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr,
updateMovingAverage ? resultRunningVariance_ref.mData.data() : nullptr);
if(!batchNormFwd_ref.IsSupportedArgument(argument_ptr_ref.get())) if(!batchNormFwd_ref.IsSupportedArgument(argument_ptr_ref.get()))
{ {
std::cout std::cout << "The runtime parameters seems not supported by the BatchNorm reference "
<< "The runtime parameters seems not supported by the BatchNorm instance, exiting!" "instance, exiting!"
<< std::endl; << std::endl;
return (-2); return (false);
}; };
auto invoker_ptr_ref = batchNormFwd_ref.MakeInvokerPointer(); auto invoker_ptr_ref = batchNormFwd_ref.MakeInvokerPointer();
...@@ -365,6 +431,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -365,6 +431,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
if(saveMeanAndInvVariance) if(saveMeanAndInvVariance)
{ {
using ck::host_common::dumpBufferToFile;
Tensor<AccDataType> resultSaveMean(scaleBiasMeanVarLengths); Tensor<AccDataType> resultSaveMean(scaleBiasMeanVarLengths);
Tensor<AccDataType> resultSaveInvVariance(scaleBiasMeanVarLengths); Tensor<AccDataType> resultSaveInvVariance(scaleBiasMeanVarLengths);
...@@ -396,70 +464,129 @@ int main(int argc, char* argv[]) ...@@ -396,70 +464,129 @@ int main(int argc, char* argv[])
if(arg.data_type == 0) if(arg.data_type == 0)
{ {
pass = bnorm_fwd_nhwc_test<ck::half_t, float>(arg.do_verification, if(arg.use_multiblock_welford)
arg.init_method, pass = bnorm_fwd_nhwc_test<ck::half_t, float, true>(arg.do_verification,
arg.time_kernel, arg.init_method,
arg.inOutLengths, arg.time_kernel,
arg.updateMovingAverage, arg.inOutLengths,
arg.saveMeanAndInvVariance, arg.updateMovingAverage,
averageFactor, arg.saveMeanAndInvVariance,
epsilon); averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<ck::half_t, float, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
} }
else if(arg.data_type == 1) else if(arg.data_type == 1)
{ {
pass = bnorm_fwd_nhwc_test<float, float>(arg.do_verification, if(arg.use_multiblock_welford)
arg.init_method, pass = bnorm_fwd_nhwc_test<float, float, true>(arg.do_verification,
arg.time_kernel, arg.init_method,
arg.inOutLengths, arg.time_kernel,
arg.updateMovingAverage, arg.inOutLengths,
arg.saveMeanAndInvVariance, arg.updateMovingAverage,
averageFactor, arg.saveMeanAndInvVariance,
epsilon); averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<float, float, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
} }
else if(arg.data_type == 3) else if(arg.data_type == 3)
{ {
pass = bnorm_fwd_nhwc_test<int8_t, float>(arg.do_verification, if(arg.use_multiblock_welford)
arg.init_method, pass = bnorm_fwd_nhwc_test<int8_t, float, true>(arg.do_verification,
arg.time_kernel, arg.init_method,
arg.inOutLengths, arg.time_kernel,
arg.updateMovingAverage, arg.inOutLengths,
arg.saveMeanAndInvVariance, arg.updateMovingAverage,
averageFactor, arg.saveMeanAndInvVariance,
epsilon); averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<int8_t, float, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
} }
else if(arg.data_type == 5) else if(arg.data_type == 5)
{ {
pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float>(arg.do_verification, if(arg.use_multiblock_welford)
arg.init_method, pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float, true>(arg.do_verification,
arg.time_kernel, arg.init_method,
arg.inOutLengths, arg.time_kernel,
arg.updateMovingAverage, arg.inOutLengths,
arg.saveMeanAndInvVariance, arg.updateMovingAverage,
averageFactor, arg.saveMeanAndInvVariance,
epsilon); averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
} }
else if(arg.data_type == 6) else if(arg.data_type == 6)
{ {
pass = bnorm_fwd_nhwc_test<double, double>(arg.do_verification, if(arg.use_multiblock_welford)
arg.init_method, pass = bnorm_fwd_nhwc_test<double, double, true>(arg.do_verification,
arg.time_kernel, arg.init_method,
arg.inOutLengths, arg.time_kernel,
arg.updateMovingAverage, arg.inOutLengths,
arg.saveMeanAndInvVariance, arg.updateMovingAverage,
averageFactor, arg.saveMeanAndInvVariance,
epsilon); averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<double, double, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
} }
} }
else else
{ {
pass = bnorm_fwd_nhwc_test<ck::half_t, float>(true, pass = bnorm_fwd_nhwc_test<ck::half_t, float, true>(true,
2, 2,
false, // don't time kernel false, // don't time kernel
{128, 16, 16, 1024}, {128, 16, 6, 512},
true, true,
false, true,
averageFactor, averageFactor,
epsilon); epsilon);
pass = pass && bnorm_fwd_nhwc_test<ck::half_t, float, false>(true,
2,
false, // don't time kernel
{128, 16, 3, 1024},
true,
true,
averageFactor,
epsilon);
}; };
return (pass ? 0 : 1); return (pass ? 0 : 1);
......
...@@ -14,8 +14,12 @@ ...@@ -14,8 +14,12 @@
#include "batchnorm_common.hpp" #include "batchnorm_common.hpp"
template <typename InOutDataType, template <typename XDataType,
typename YDataType,
typename AccDataType, typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
ck::index_t Rank, ck::index_t Rank,
ck::index_t NumBatchNormReduceDim, ck::index_t NumBatchNormReduceDim,
bool fastest_dim_is_reduced = false> bool fastest_dim_is_reduced = false>
...@@ -26,7 +30,9 @@ int bnorm_infer( ...@@ -26,7 +30,9 @@ int bnorm_infer(
const std::array<ck::index_t, Rank> xStrides, const std::array<ck::index_t, Rank> xStrides,
const std::array<ck::index_t, Rank> yStrides, const std::array<ck::index_t, Rank> yStrides,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths, const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides, const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
const void* p_x, const void* p_x,
const void* p_scale, const void* p_scale,
const void* p_bias, const void* p_bias,
...@@ -41,11 +47,11 @@ int bnorm_infer( ...@@ -41,11 +47,11 @@ int bnorm_infer(
"Invalid number of reduced dimensions for batchnorm!"); "Invalid number of reduced dimensions for batchnorm!");
using DeviceNormalizeInstance = ck::tensor_operation::device::DeviceElementwise< using DeviceNormalizeInstance = ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<InOutDataType, AccDataType, AccDataType, AccDataType, AccDataType>, // x, mean, ck::Tuple<XDataType, AccDataType, AccDataType, AccDataType, AccDataType>, // x, mean,
// variance, // variance,
// scale, // scale,
// bias, // bias,
ck::Tuple<InOutDataType>, // y ck::Tuple<YDataType>, // y
NormalizeInInfer, NormalizeInInfer,
Rank, Rank,
2, // MPerthread 2, // MPerthread
...@@ -53,14 +59,18 @@ int bnorm_infer( ...@@ -53,14 +59,18 @@ int bnorm_infer(
ck::Sequence<1>>; // scalarPerVector: y ck::Sequence<1>>; // scalarPerVector: y
auto invariantDims = get_invariant_dims<Rank, NumBatchNormReduceDim>(reduceDims); auto invariantDims = get_invariant_dims<Rank, NumBatchNormReduceDim>(reduceDims);
std::array<ck::index_t, Rank> aligned_scaleBiasMeanVarStrides{0}; std::array<ck::index_t, Rank> aligned_bnScaleStrides{0};
std::array<ck::index_t, Rank> aligned_bnBiasStrides{0};
std::array<ck::index_t, Rank> aligned_bnMeanVarStrides{0};
int i = 0; int i = 0;
for(auto dim : invariantDims) for(auto dim : invariantDims)
{ {
assert(xyLengths[dim] == bnScaleBiasMeanVarLengths[i]); assert(xyLengths[dim] == bnScaleBiasMeanVarLengths[i]);
aligned_scaleBiasMeanVarStrides[dim] = bnScaleBiasMeanVarStrides[i]; aligned_bnScaleStrides[dim] = bnScaleStrides[i];
aligned_bnBiasStrides[dim] = bnBiasStrides[i];
aligned_bnMeanVarStrides[dim] = bnMeanVarStrides[i];
i++; i++;
}; };
...@@ -84,10 +94,10 @@ int bnorm_infer( ...@@ -84,10 +94,10 @@ int bnorm_infer(
auto argument_ptr1 = dev_normalize.MakeArgumentPointer( auto argument_ptr1 = dev_normalize.MakeArgumentPointer(
xyLengths, xyLengths,
{xStrides, {xStrides,
aligned_scaleBiasMeanVarStrides, aligned_bnMeanVarStrides,
aligned_scaleBiasMeanVarStrides, aligned_bnMeanVarStrides,
aligned_scaleBiasMeanVarStrides, aligned_bnScaleStrides,
aligned_scaleBiasMeanVarStrides}, aligned_bnBiasStrides},
{yStrides}, {yStrides},
{p_x, p_estimatedMean, p_estimatedVariance, p_scale, p_bias}, {p_x, p_estimatedMean, p_estimatedVariance, p_scale, p_bias},
{p_y}, {p_y},
...@@ -105,8 +115,10 @@ int bnorm_infer( ...@@ -105,8 +115,10 @@ int bnorm_infer(
avg_time += invoker_ptr1->Run(argument_ptr1.get(), StreamConfig{nullptr, time_kernel}); avg_time += invoker_ptr1->Run(argument_ptr1.get(), StreamConfig{nullptr, time_kernel});
num_bytes += (total_length * (1 * sizeof(InOutDataType) + 4 * sizeof(AccDataType)) + num_bytes += total_length * sizeof(XDataType) +
total_length * sizeof(InOutDataType)); invariantLength *
(sizeof(ScaleDataType) + sizeof(BiasDataType) + 2 * sizeof(MeanVarDataType)) +
total_length * sizeof(YDataType);
if(time_kernel) if(time_kernel)
{ {
......
...@@ -18,11 +18,6 @@ ...@@ -18,11 +18,6 @@
#include "batchnorm_infer_impl.hpp" #include "batchnorm_infer_impl.hpp"
template <typename InOutDataType, typename AccDataType>
using ReferenceBatchNormInferInstance =
ck::tensor_operation::host::ReferenceBatchNormInfer_Input_N_H_W_C_Output_C<InOutDataType,
AccDataType>;
static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'}, static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'},
{"verify", required_argument, nullptr, 'v'}, {"verify", required_argument, nullptr, 'v'},
{"help", no_argument, nullptr, '?'}, {"help", no_argument, nullptr, '?'},
...@@ -236,21 +231,30 @@ bool bnorm_infer_nhwc_test(bool do_verification, ...@@ -236,21 +231,30 @@ bool bnorm_infer_nhwc_test(bool do_verification,
int result = 0; int result = 0;
result = bnorm_infer<InOutDataType, AccDataType, Rank, NumReduceDim, false>( result = bnorm_infer<InOutDataType,
time_kernel, InOutDataType,
{0, 1, 2}, AccDataType,
i_inOutLengths, AccDataType,
i_inOutStrides, AccDataType,
i_inOutStrides, AccDataType,
i_scaleBiasMeanVarLengths, Rank,
i_scaleBiasMeanVarStrides, NumReduceDim,
x_dev.GetDeviceBuffer(), false>(time_kernel,
bnScale_dev.GetDeviceBuffer(), {0, 1, 2},
bnBias_dev.GetDeviceBuffer(), i_inOutLengths,
epsilon, i_inOutStrides,
estimatedMean_dev.GetDeviceBuffer(), i_inOutStrides,
estimatedVariance_dev.GetDeviceBuffer(), i_scaleBiasMeanVarLengths,
y_dev.GetDeviceBuffer()); i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
x_dev.GetDeviceBuffer(),
bnScale_dev.GetDeviceBuffer(),
bnBias_dev.GetDeviceBuffer(),
epsilon,
estimatedMean_dev.GetDeviceBuffer(),
estimatedVariance_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer());
if(result < 0) if(result < 0)
return (false); return (false);
...@@ -259,7 +263,15 @@ bool bnorm_infer_nhwc_test(bool do_verification, ...@@ -259,7 +263,15 @@ bool bnorm_infer_nhwc_test(bool do_verification,
if(do_verification) if(do_verification)
{ {
auto batchNormInfer_ref = ReferenceBatchNormInferInstance<InOutDataType, AccDataType>{}; using ReferenceBatchNormInferInstance =
ck::tensor_operation::host::ReferenceBatchNormInfer_Input_N_H_W_C_Output_C<
InOutDataType,
InOutDataType,
AccDataType,
AccDataType,
AccDataType,
AccDataType>;
auto batchNormInfer_ref = ReferenceBatchNormInferInstance{};
auto argument_ptr_ref = auto argument_ptr_ref =
batchNormInfer_ref.MakeArgumentPointer(i_inOutLengths, batchNormInfer_ref.MakeArgumentPointer(i_inOutLengths,
...@@ -267,6 +279,8 @@ bool bnorm_infer_nhwc_test(bool do_verification, ...@@ -267,6 +279,8 @@ bool bnorm_infer_nhwc_test(bool do_verification,
i_inOutStrides, i_inOutStrides,
i_scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
x.mData.data(), x.mData.data(),
bnScale.mData.data(), bnScale.mData.data(),
bnBias.mData.data(), bnBias.mData.data(),
......
...@@ -168,6 +168,11 @@ ...@@ -168,6 +168,11 @@
// tuning parameter // tuning parameter
#define CK_WORKAROUND_SWDEV_325164 0 #define CK_WORKAROUND_SWDEV_325164 0
// workaround: disable broken fused attention kernel instance that does not pass validation
// issue found on mi100/#10738 combo when irregular KPerBlock attention kernel has acc0 scaling
// enabled
#define CK_WORKAROUND_DISABLE_BROKEN_ATTN_KERNEL_INSTANCE 1
namespace ck { namespace ck {
enum struct InMemoryDataOperationEnum enum struct InMemoryDataOperationEnum
......
...@@ -14,7 +14,8 @@ namespace ck { ...@@ -14,7 +14,8 @@ namespace ck {
template <typename TensorLengths, template <typename TensorLengths,
typename DimAccessOrder, typename DimAccessOrder,
typename ScalarsPerAccess> // # of scalars per access in each dimension typename ScalarsPerAccess,
bool SnakeCurved = true> // # of scalars per access in each dimension
struct SpaceFillingCurve struct SpaceFillingCurve
{ {
static constexpr index_t nDim = TensorLengths::Size(); static constexpr index_t nDim = TensorLengths::Size();
...@@ -136,9 +137,10 @@ struct SpaceFillingCurve ...@@ -136,9 +137,10 @@ struct SpaceFillingCurve
Index ordered_idx; Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto idim) { static_for<0, nDim, 1>{}([&](auto idim) {
ordered_idx(idim) = forward_sweep[idim] ? ordered_access_idx[idim] ordered_idx(idim) =
: ordered_access_lengths[idim] - 1 - !SnakeCurved || forward_sweep[idim]
ordered_access_idx[idim]; ? ordered_access_idx[idim]
: ordered_access_lengths[idim] - 1 - ordered_access_idx[idim];
}); });
return container_reorder_given_old2new(ordered_idx, dim_access_order) * return container_reorder_given_old2new(ordered_idx, dim_access_order) *
......
...@@ -151,6 +151,27 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -151,6 +151,27 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
return make_tuple(c_thread_m, c_thread_n); return make_tuple(c_thread_m, c_thread_n);
} }
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
return make_tuple(Number<m0>{},
Number<n0>{},
waveId_m,
waveId_n,
blk_idx[I0],
blk_idx[I1],
blk_idx[I2],
blk_idx[I3]);
}
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1() __host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
{ {
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
...@@ -724,6 +745,21 @@ struct BlockwiseGemmXdlops_v2 ...@@ -724,6 +745,21 @@ struct BlockwiseGemmXdlops_v2
return make_tuple(c_thread_m, c_thread_n); return make_tuple(c_thread_m, c_thread_n);
} }
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
return make_tuple(
m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
}
using Tuple4 = decltype(CalculateAThreadOriginDataIndex()); using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin = CalculateAThreadOriginDataIndex(), __host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
......
...@@ -24,7 +24,8 @@ template <typename ALayout, ...@@ -24,7 +24,8 @@ template <typename ALayout,
typename B0ElementwiseOperation, typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation, typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation,
bool MaskOutUpperTriangle> // TODO: enum for mask type
struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
......
...@@ -7,49 +7,60 @@ ...@@ -7,49 +7,60 @@
#include <vector> #include <vector>
#include "device_base.hpp" #include "device_base.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename ALayout, template <index_t NumDimG,
typename B0Layout, index_t NumDimM,
typename B1Layout, index_t NumDimN,
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<> index_t NumDimK,
index_t NumDimO,
typename ADataType, typename ADataType,
typename B0DataType, typename B0DataType,
typename B1DataType, typename B1DataType,
typename CDataType, typename CDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename B0ElementwiseOperation, typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation, typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
MakeArgumentPointer(const void* p_a, static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
const void* p_b0,
const void* p_b1, virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
void* p_c, const void* p_a,
ck::index_t M, const void* p_b0,
ck::index_t N, const void* p_b1,
ck::index_t K, void* p_c,
ck::index_t O, const std::array<void*, NumAcc0Bias> p_acc0_biases,
ck::index_t Batch, const std::array<void*, NumAcc1Bias> p_acc1_biases,
std::vector<index_t> c_gs_ms_os_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
std::vector<index_t> c_gs_ms_os_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
ck::index_t StrideA, const std::vector<index_t>& b_gs_ns_ks_lengths,
ck::index_t StrideB0, const std::vector<index_t>& b_gs_ns_ks_strides,
ck::index_t StrideB1, const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
ck::index_t BatchStrideA, const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
ck::index_t BatchStrideB0, const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
ck::index_t BatchStrideB1, const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
AElementwiseOperation a_element_op, const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
B0ElementwiseOperation b0_element_op, const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
Acc0ElementwiseOperation acc0_element_op, const std::array<std::vector<index_t>, NumAcc1Bias>
B1ElementwiseOperation b1_element_op, acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
CElementwiseOperation c_element_op) = 0; const std::array<std::vector<index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
...@@ -13,31 +13,36 @@ namespace ck { ...@@ -13,31 +13,36 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <index_t Rank, index_t NumBatchNormReduceDim> template <index_t Rank, index_t NumBatchNormReduceDim, typename YElementwiseOp>
struct DeviceBatchNormFwd : public BaseOperator struct DeviceBatchNormFwd : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer( virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const std::array<index_t, Rank> xyLengths, const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides, const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides, const std::array<index_t, Rank> yStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths, const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides, const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
const void* p_x, const void* p_x,
const void* bnScale, const void* bnScale,
const void* bnBias, const void* bnBias,
double epsilon,
const YElementwiseOp y_elementwise_op,
void* p_y, void* p_y,
void* resultSaveMean,
void* resultSaveInvVariance,
double exponentialAverageFactor, double exponentialAverageFactor,
void* resultRunningMean, void* resultRunningMean,
void* resultRunningVariance, void* resultRunningVariance) = 0;
double epsilon,
void* resultSaveMean,
void* resultSaveInvVariance) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <index_t Rank, index_t NumBatchNormReduceDim> template <index_t Rank, index_t NumBatchNormReduceDim, typename YElementwiseOp>
using DeviceBatchNormFwdPtr = std::unique_ptr<DeviceBatchNormFwd<Rank, NumBatchNormReduceDim>>; using DeviceBatchNormFwdPtr =
std::unique_ptr<DeviceBatchNormFwd<Rank, NumBatchNormReduceDim, YElementwiseOp>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -21,7 +21,9 @@ struct DeviceBatchNormInfer : public BaseOperator ...@@ -21,7 +21,9 @@ struct DeviceBatchNormInfer : public BaseOperator
const std::array<index_t, Rank> xStrides, const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides, const std::array<index_t, Rank> yStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths, const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides, const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
const void* p_x, const void* p_x,
const void* bnScale, const void* bnScale,
const void* bnBias, const void* bnBias,
......
...@@ -7,46 +7,50 @@ ...@@ -7,46 +7,50 @@
#include <vector> #include <vector>
#include "device_base.hpp" #include "device_base.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename ALayout, template <index_t NumDimG,
typename B0Layout, index_t NumDimM,
typename B1Layout, index_t NumDimN,
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<> index_t NumDimK,
index_t NumDimO,
typename ADataType, typename ADataType,
typename B0DataType, typename B0DataType,
typename B1DataType, typename B1DataType,
typename CDataType, typename CDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename B0ElementwiseOperation, typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation, typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator
{ {
struct ProblemDesc struct ProblemDesc
{ {
// Overall problem shape std::vector<index_t> a_gs_ms_ks_lengths;
index_t M; std::vector<index_t> a_gs_ms_ks_strides;
index_t N;
index_t K;
index_t O;
index_t Batch;
// Stride for A/B0/B1; layout determined by template args std::vector<index_t> b0_gs_ns_ks_lengths;
index_t StrideA; std::vector<index_t> b0_gs_ns_ks_strides;
index_t StrideB0;
index_t StrideB1; std::vector<index_t> b1_gs_os_ns_lengths;
index_t BatchStrideA; std::vector<index_t> b1_gs_os_ns_strides;
index_t BatchStrideB0;
index_t BatchStrideB1;
// Lengths and strides for output C
std::vector<index_t> c_gs_ms_os_lengths; std::vector<index_t> c_gs_ms_os_lengths;
std::vector<index_t> c_gs_ms_os_strides; std::vector<index_t> c_gs_ms_os_strides;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_lengths;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_strides;
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_lengths;
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_strides;
}; };
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
...@@ -54,6 +58,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator ...@@ -54,6 +58,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator
std::vector<const void*> p_b0_vec, std::vector<const void*> p_b0_vec,
std::vector<const void*> p_b1_vec, std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec, std::vector<void*> p_c_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
std::vector<ProblemDesc> problem_desc_vec, std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op, B0ElementwiseOperation b0_element_op,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -54,9 +55,8 @@ __global__ void ...@@ -54,9 +55,8 @@ __global__ void
index_t right = group_count; index_t right = group_count;
index_t group_id = index_t((left + right) / 2); index_t group_id = index_t((left + right) / 2);
while((!(block_id >= arg_ptr[group_id].block_start_ && while(
block_id < arg_ptr[group_id].block_end_)) && (!(block_id >= arg_ptr[group_id].block_start_ && block_id < arg_ptr[group_id].block_end_)))
left <= right)
{ {
if(block_id < arg_ptr[group_id].block_start_) if(block_id < arg_ptr[group_id].block_start_)
{ {
...@@ -114,14 +114,17 @@ __global__ void ...@@ -114,14 +114,17 @@ __global__ void
// Computes C = A * B0 * B1 // Computes C = A * B0 * B1
// ^^^^^^ (Acc0) // ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1) // ^^^^^^^^^^^ (Acc1)
template <typename ALayout, template <index_t NumDimG,
typename BLayout, // B0Layout index_t NumDimM,
typename B1Layout, index_t NumDimN,
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<NumDimG, NumDimM, NumDimGemm1N> index_t NumDimK,
index_t NumDimO, // NumDimGemm1N
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename B1DataType, typename B1DataType,
typename CDataType, typename CDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename GemmAccDataType, typename GemmAccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
...@@ -130,6 +133,10 @@ template <typename ALayout, ...@@ -130,6 +133,10 @@ template <typename ALayout,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
TensorSpecialization ASpec,
TensorSpecialization BSpec,
TensorSpecialization B1Spec,
TensorSpecialization CSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -170,297 +177,152 @@ template <typename ALayout, ...@@ -170,297 +177,152 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
bool MaskOutUpperTriangle, MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
: public DeviceGroupedGemmSoftmaxGemmPermute<ALayout, : public DeviceGroupedGemmSoftmaxGemmPermute<NumDimG,
BLayout, NumDimM,
B1Layout, NumDimN,
CPermuteNumDims_G_M_Gemm1N, NumDimK,
NumDimO,
ADataType, ADataType,
BDataType, BDataType,
B1DataType, B1DataType,
CDataType, CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation> CElementwiseOperation,
MaskingSpec>
{ {
using DeviceOp = DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle; static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
using ProblemDesc = "Number of dimension must be greater than 0");
typename DeviceGroupedGemmSoftmaxGemmPermute<ALayout,
BLayout, static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
B1Layout, static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
CPermuteNumDims_G_M_Gemm1N,
ADataType, // TODO ANT: implement bias combination
BDataType, static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
B1DataType,
CDataType, #if 0
AElementwiseOperation, // TODO ANT: use alias
BElementwiseOperation, static constexpr index_t NumDimGemm0M = NumDimM;
AccElementwiseOperation, static constexpr index_t NumDimGemm0N = NumDimN;
B1ElementwiseOperation, static constexpr index_t NumDimGemm0K = NumDimK;
CElementwiseOperation>::ProblemDesc; static constexpr index_t NumDimGemm1M = NumDimM;
static constexpr index_t NumDimGemm1N = NumDimO;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using DeviceOp = DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle;
using ProblemDesc = typename DeviceGroupedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
BDataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
MaskingSpec>::ProblemDesc;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto matrix_padder = using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{ Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock}; Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
GemmSpec,
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) ASpec,
{ BSpec,
const auto a_grid_desc_mraw_kraw = [&]() { B1Spec,
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>) CSpec>;
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
make_tuple(StrideA, I1)); const std::vector<index_t>& a_gs_ms_ks_strides_vec)
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
}
}();
const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1
static auto MakeB1GridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
{ {
const auto b1_grid_desc_nraw_kraw = [&]() { return Transform::MakeAGridDescriptor_AK0_M_AK1(
if constexpr(is_same<tensor_layout::gemm::RowMajor, B1Layout>::value) Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
{ Number<AK1>{});
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, B1Layout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw);
const auto N = b1_grid_desc_n_k.GetLength(I0);
const auto K = b1_grid_desc_n_k.GetLength(I1);
const auto B1K0 = K / B1K1;
return transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
static auto MakeCGridDescriptor_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec, const std::vector<index_t>& b_gs_ns_ks_strides_vec)
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
{ {
constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0); return Transform::MakeB0GridDescriptor_BK0_N_BK1(
constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1); Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec),
constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N Number<BK1>{});
assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
const auto to_tuple = [&](auto& vec, auto start, auto end) {
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
};
const auto c_ms_ns_lengths = to_tuple(
c_gs_ms_ns_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
const auto c_ms_ns_strides = to_tuple(
c_gs_ms_ns_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
// dimension Ids for M0, M1, ...
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
// dimension Ids for N0, N1, ...
constexpr auto nDimIds =
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimN, 1>::type{};
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(c_ms_ns_lengths, mDimIds);
// lengths for K0, K1, ...
const auto nLengths = get_container_subset(c_ms_ns_lengths, nDimIds);
// naive tensor C[M0, M1, M2, ..., N0, N1, N2...]
const auto c_grid_desc_ms_ns =
make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides);
// transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
const auto c_grid_desc_mraw_nraw = transform_tensor_descriptor(
c_grid_desc_ms_ns,
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
make_tuple(mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
} }
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] static auto
static auto MakeCGridDescriptor_G_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec, MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
const std::vector<index_t>& c_gs_ms_ns_strides_vec) const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec)
{ {
constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0); return Transform::MakeB1GridDescriptor_BK0_N_BK1(
constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1); Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec,
constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N b1_gs_gemm1ns_gemm1ks_strides_vec),
Number<B1K1>{});
assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
const auto to_tuple = [&](auto& vec, auto start, auto end) {
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
};
const auto c_gs_ms_ns_lengths =
to_tuple(c_gs_ms_ns_lengths_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
const auto c_gs_ms_ns_strides =
to_tuple(c_gs_ms_ns_strides_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
// dimension Ids for G0, G1, ...
constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{};
// dimension Ids for M0, M1, ...
constexpr auto mDimIds =
typename arithmetic_sequence_gen<NumDimG, NumDimG + NumDimM, 1>::type{};
// dimension Ids for N0, N1, ...
constexpr auto nDimIds = typename arithmetic_sequence_gen<NumDimG + NumDimM,
NumDimG + NumDimM + NumDimN,
1>::type{};
// lengths for G0, G1, ...
const auto gLengths = get_container_subset(c_gs_ms_ns_lengths, gDimIds);
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(c_gs_ms_ns_lengths, mDimIds);
// lengths for K0, K1, ...
const auto nLengths = get_container_subset(c_gs_ms_ns_lengths, nDimIds);
// naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
const auto c_grid_desc_gs_ms_ns =
make_naive_tensor_descriptor(c_gs_ms_ns_lengths, c_gs_ms_ns_strides);
// transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
const auto c_grid_desc_g_mraw_nraw =
transform_tensor_descriptor(c_grid_desc_gs_ms_ns,
make_tuple(make_merge_transform(gLengths),
make_merge_transform(mLengths),
make_merge_transform(nLengths)),
make_tuple(gDimIds, mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// this desc is only for calculating batch offset so no padding needed
return c_grid_desc_g_mraw_nraw;
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1)); using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N({}, {})); using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using CGridDesc_G_M_N = decltype(MakeCGridDescriptor_G_M_N({}, {})); using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
// to track the points which need to be set to -inf on C0 constexpr static auto make_MaskOutPredicate()
// Note: no need to reset M padding value, because they will not be stored out.
struct C0MatrixMask
{ {
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {} if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled)
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
{ {
return n >= NRaw_; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
{ {
return IsUpperTriangle(m, n) || IsNOutOfBound(n); return MaskOutUpperTrianglePredicate{};
} }
}
private: using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
// index_t MRaw_;
index_t NRaw_;
};
struct ComputeBasePtrOfStridedBatch struct ComputeBasePtrOfStridedBatch
{ {
ComputeBasePtrOfStridedBatch(index_t BatchStrideA, ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
index_t BatchStrideB, const BGridDesc_G_N_K& b_grid_desc_g_n_k,
index_t BatchStrideB1, const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
CGridDesc_G_M_N c_grid_desc_g_m_n) const CGridDesc_G_M_N& c_grid_desc_g_m_n)
: BatchStrideA_(BatchStrideA), : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
BatchStrideB_(BatchStrideB), b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
BatchStrideB1_(BatchStrideB1), b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n) c_grid_desc_g_m_n_(c_grid_desc_g_m_n)
{ {
} }
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideA_); return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideB_); return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideB1_); return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
...@@ -469,9 +331,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -469,9 +331,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
} }
private: private:
index_t BatchStrideA_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
index_t BatchStrideB_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
index_t BatchStrideB1_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
}; };
...@@ -535,8 +397,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -535,8 +397,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskOutUpperTriangle>; MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>; using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
...@@ -570,16 +432,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -570,16 +432,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
struct GroupDeviceArg struct GroupDeviceArg
{ {
// problem definiton // lengths for the last dimensions of overall problem for sanity check of vector load/store
index_t M; std::vector<index_t> raw_lengths_mz_nz_kz_gemm1nz_;
index_t N;
index_t K;
index_t O;
// Strides for the last dimensions of C for sanity check of vector load/store // strides for the last dimensions of each tensor for sanity check of vector load/store
index_t c_extent_lowest_; std::vector<index_t> a_mz_kz_strides_;
index_t c_stride_lowest_; std::vector<index_t> b_nz_kz_strides_;
std::vector<index_t> b1_nz_kz_strides_;
std::vector<index_t> c_mz_gemm1nz_strides_;
// for gridwise gemm check
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
}; };
...@@ -591,6 +453,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -591,6 +453,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
std::vector<const void*> p_b_vec, std::vector<const void*> p_b_vec,
std::vector<const void*> p_b1_vec, std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec, std::vector<void*> p_c_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
std::vector<ProblemDesc> problem_desc_vec, std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -603,6 +467,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -603,6 +467,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op}
{ {
// TODO ANT: implement bias addition
group_count_ = problem_desc_vec.size(); group_count_ = problem_desc_vec.size();
if(!(group_count_ == p_a_vec.size() && group_count_ == p_b_vec.size() && if(!(group_count_ == p_a_vec.size() && group_count_ == p_b_vec.size() &&
...@@ -611,6 +476,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -611,6 +476,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size"); throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size");
} }
if(!(p_acc0_biases_vec.size() == p_acc1_biases_vec.size()))
{
throw std::runtime_error("wrong! acc0_bias_vec.size != acc1_bias_vec.size");
}
grid_size_ = 0; grid_size_ = 0;
for(std::size_t i = 0; i < group_count_; i++) for(std::size_t i = 0; i < group_count_; i++)
...@@ -620,14 +490,25 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -620,14 +490,25 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]); const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]);
const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]); const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]);
const auto a_grid_desc_ak0_m_ak1 = DeviceOp::MakeAGridDescriptor_AK0_M_AK1( const auto& problem_desc = problem_desc_vec[i];
problem_desc_vec[i].M, problem_desc_vec[i].K, problem_desc_vec[i].StrideA);
const auto b_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1( const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem_desc_vec[i].K, problem_desc_vec[i].N, problem_desc_vec[i].StrideB0); problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b1_grid_desc_bk0_n_bk1 = DeviceOp::MakeB1GridDescriptor_BK0_N_BK1( const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem_desc_vec[i].N, problem_desc_vec[i].O, problem_desc_vec[i].StrideB1); problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides);
const auto c_grid_desc_m_n = DeviceOp::MakeCGridDescriptor_M_N( const auto b1_grid_desc_bk0_n_bk1 = MakeB1GridDescriptor_BK0_N_BK1(
problem_desc_vec[i].c_gs_ms_os_lengths, problem_desc_vec[i].c_gs_ms_os_strides); problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
const auto c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(
problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
const auto a_grid_desc_g_m_k = Transform::MakeAGridDescriptor_G_M_K(
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_g_n_k = Transform::MakeB0GridDescriptor_G_N_K(
problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides);
const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
...@@ -635,25 +516,32 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -635,25 +516,32 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
const index_t BlockStart = grid_size_; const index_t BlockStart = grid_size_;
const auto block_2_ctile_map = Block2CTileMap(c_grid_desc_m_n, BlockStart); const auto block_2_ctile_map = Block2CTileMap(c_grid_desc_m_n, BlockStart);
const index_t grid_size_grp = block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) * const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0);
problem_desc_vec[i].Batch; const index_t grid_size_grp =
block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) * batch_count;
const index_t BlockEnd = grid_size_ + grid_size_grp; const index_t BlockEnd = grid_size_ + grid_size_grp;
// batch stride // batch stride
// TODO ANT: only keep batch stride in tensor desc to reduce scalar cache pressure const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch(
const auto c_grid_desc_g_m_n = DeviceOp::MakeCGridDescriptor_G_M_N( a_grid_desc_g_m_k, b_grid_desc_g_n_k, b1_grid_desc_g_n_k, c_grid_desc_g_m_n);
problem_desc_vec[i].c_gs_ms_os_lengths, problem_desc_vec[i].c_gs_ms_os_strides);
const auto compute_base_ptr_of_batch =
ComputeBasePtrOfStridedBatch(problem_desc_vec[i].BatchStrideA,
problem_desc_vec[i].BatchStrideB0,
problem_desc_vec[i].BatchStrideB1,
c_grid_desc_g_m_n);
// C0 mask // C0 mask
const auto c0_matrix_mask = C0MatrixMask(problem_desc_vec[i].N); const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1));
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
// for each group, make sure acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias and
// so on
if(!(problem_desc.acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias &&
problem_desc.acc0_biases_gs_ms_ns_strides.size() == NumAcc0Bias &&
problem_desc.acc1_biases_gs_ms_os_lengths.size() == NumAcc1Bias &&
problem_desc.acc1_biases_gs_ms_os_strides.size() == NumAcc1Bias))
{
throw std::runtime_error(
"wrong! number of biases in function argument does not "
"match that in template argument");
}
group_kernel_args_.push_back({p_a_grid, group_kernel_args_.push_back({p_a_grid,
p_b_grid, p_b_grid,
p_b1_grid, p_b1_grid,
...@@ -669,13 +557,20 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -669,13 +557,20 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
BlockStart, BlockStart,
BlockEnd}); BlockEnd});
group_device_args_.push_back({problem_desc_vec[i].M, group_device_args_.push_back(
problem_desc_vec[i].N, {{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
problem_desc_vec[i].K, problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
problem_desc_vec[i].O, problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1],
problem_desc_vec[i].c_gs_ms_os_lengths.back(), problem_desc.b1_gs_os_ns_lengths[NumDimG + NumDimO - 1]},
problem_desc_vec[i].c_gs_ms_os_strides.back(), {problem_desc.a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
c_grid_desc_m_n}); problem_desc.a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]},
{problem_desc.b0_gs_ns_ks_strides[NumDimG + NumDimN - 1],
problem_desc.b0_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]},
{problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO - 1],
problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO + NumDimN - 1]},
{problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1],
problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]},
c_grid_desc_m_n});
} }
} }
...@@ -788,6 +683,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -788,6 +683,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
return false; return false;
} }
// TODO ANT: Check if tensor specialization & strides mismatch
bool all_has_main_k_block_loop = true; bool all_has_main_k_block_loop = true;
bool some_has_main_k_block_loop = false; bool some_has_main_k_block_loop = false;
...@@ -815,19 +712,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -815,19 +712,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
// Note: we need raw lengths since threadwise copy can not handle vector load when // Note: we need raw lengths since threadwise copy can not handle vector load when
// part of vector is out of bounds // part of vector is out of bounds
const auto MRaw = device_arg.M; const auto MzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[0];
const auto NRaw = device_arg.N; const auto NzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[1];
const auto KRaw = device_arg.K; const auto KzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[2];
const auto Gemm1NRaw = device_arg.O; const auto Gemm1NzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[3];
// Check scalar per vector requirement // Check scalar per vector requirement
const auto a_extent_lowest = const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw; const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw;
const auto b_extent_lowest = const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? NzRaw : Gemm1NzRaw;
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw; const auto c_extent_lowest = Gemm1NzRaw;
const auto b1_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
const auto c_extent_lowest = device_arg.c_extent_lowest_;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 && b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
...@@ -837,8 +731,22 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -837,8 +731,22 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
return false; return false;
} }
// Check vector store requirement; assumes last dimension in N to be contiguous // Check vector load/store requirement
if(device_arg.c_stride_lowest_ != 1) const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
? device_arg.a_mz_kz_strides_[1]
: device_arg.a_mz_kz_strides_[0];
const auto b_stride_lowest = BBlockTransferSrcVectorDim == 2
? device_arg.b_nz_kz_strides_[1]
: device_arg.b_nz_kz_strides_[0];
const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2
? device_arg.b1_nz_kz_strides_[1]
: device_arg.b1_nz_kz_strides_[0];
const auto c_stride_lowest =
device_arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be
// contiguous
if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 ||
c_stride_lowest == 1))
{ {
return false; return false;
} }
...@@ -873,6 +781,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -873,6 +781,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
std::vector<const void*> p_b_vec, std::vector<const void*> p_b_vec,
std::vector<const void*> p_b1_vec, std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec, std::vector<void*> p_c_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
std::vector<ProblemDesc> problem_desc_vec, std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -884,6 +794,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -884,6 +794,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
p_b_vec, p_b_vec,
p_b1_vec, p_b1_vec,
p_c_vec, p_c_vec,
p_acc0_biases_vec,
p_acc1_biases_vec,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -895,21 +807,26 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -895,21 +807,26 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*> p_a_vec, std::unique_ptr<BaseArgument>
std::vector<const void*> p_b_vec, MakeArgumentPointer(std::vector<const void*> p_a_vec,
std::vector<const void*> p_b1_vec, std::vector<const void*> p_b_vec,
std::vector<void*> p_c_vec, std::vector<const void*> p_b1_vec,
std::vector<ProblemDesc> problem_desc_vec, std::vector<void*> p_c_vec,
AElementwiseOperation a_element_op, std::vector<std::vector<const void*>> p_acc0_biases_vec,
BElementwiseOperation b_element_op, std::vector<std::vector<const void*>> p_acc1_biases_vec,
AccElementwiseOperation acc_element_op, std::vector<ProblemDesc> problem_desc_vec,
B1ElementwiseOperation b1_element_op, AElementwiseOperation a_element_op,
CElementwiseOperation c_element_op) override BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) override
{ {
return std::make_unique<Argument>(p_a_vec, return std::make_unique<Argument>(p_a_vec,
p_b_vec, p_b_vec,
p_b1_vec, p_b1_vec,
p_c_vec, p_c_vec,
p_acc0_biases_vec,
p_acc1_biases_vec,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -942,7 +859,12 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -942,7 +859,12 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
<< Gemm1NPerBlock << ", " << Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", " << Gemm1KPerBlock << ", "
<< B1K1 << ", " << B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ">"; << getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
<< "B0Spec" << getTensorSpecializationString(BSpec) << ", "
<< "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
<< "CSpec" << getTensorSpecializationString(CSpec) << ", "
<< getMaskingSpecializationString(MaskingSpec) << ">";
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -3,27 +3,30 @@ ...@@ -3,27 +3,30 @@
#pragma once #pragma once
#include <vector> #include <array>
#include <memory> #include <memory>
#include <iostream>
#include "ck/utility/common_header.hpp" #include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename InElementwiseOperation, typename AccElementwiseOperation> template <index_t Rank,
index_t NumReduceDim,
typename InElementwiseOperation,
typename AccElementwiseOperation>
struct DeviceReduce : public BaseOperator struct DeviceReduce : public BaseOperator
{ {
static constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim;
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> inLengths, MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
const std::vector<index_t> inStrides, const std::array<index_t, Rank> inStrides,
const std::vector<index_t> outLengths, const std::array<index_t, NumOutDim> outLengths,
const std::vector<index_t> outStrides, const std::array<index_t, NumOutDim> outStrides,
const std::vector<int> reduceDims, const std::array<int, NumReduceDim> reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
...@@ -36,9 +39,12 @@ struct DeviceReduce : public BaseOperator ...@@ -36,9 +39,12 @@ struct DeviceReduce : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename InElementwiseOperation, typename AccElementwiseOperation> template <index_t Rank,
using DeviceReducePtr = index_t NumReduceDim,
std::unique_ptr<DeviceReduce<InElementwiseOperation, AccElementwiseOperation>>; typename InElementwiseOperation,
typename AccElementwiseOperation>
using DeviceReducePtr = std::unique_ptr<
DeviceReduce<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -130,8 +130,11 @@ namespace device { ...@@ -130,8 +130,11 @@ namespace device {
// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] // D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] // E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// FIXME: TensorSpecialization::Packed specialization does not cover all packed tensor cases, it // NOTE: TensorSpecialization::Packed specialized tensor is "packed" in a sense that each inner
// merely degenerates into TensorSpecialization::Default with NumDimG/M/N/K = 1 // dimension in a dimension group (eg [G0, G1] in Gs, [M0, M1, M2] in Ms, etc.) are contiguous and
// ordered. Not in a sense that the tensor [G0, G1, ..., M0, M1, ..., N0, N1...] can be permuted
// while still being a contiguous, unpadded tensor. In other words, it merely degenerates into
// TensorSpecialization::Default with NumDimG/M/N/K = 1
// //
// Detail- Packed tensor satisfies // Detail- Packed tensor satisfies
// stride_0 = 1 // stride_0 = 1
...@@ -147,7 +150,7 @@ namespace device { ...@@ -147,7 +150,7 @@ namespace device {
// essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1. // essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1.
// //
// Might need to expose dimension order to the interface to fully support // Might need to expose dimension order to the interface to fully support
// TensorSpecialization::Packed. // TensorSpecialization::Packed in a traditional sense of "packed" tensor
template <index_t NumDimG, template <index_t NumDimG,
index_t NumDimM, index_t NumDimM,
index_t NumDimN, index_t NumDimN,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -116,14 +117,17 @@ __global__ void ...@@ -116,14 +117,17 @@ __global__ void
// Computes C = A * B0 * B1 // Computes C = A * B0 * B1
// ^^^^^^ (Acc0) // ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1) // ^^^^^^^^^^^ (Acc1)
template <typename ALayout, template <index_t NumDimG,
typename BLayout, // B0Layout index_t NumDimM,
typename B1Layout, index_t NumDimN,
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<NumDimG, NumDimM, NumDimGemm1N> index_t NumDimK,
index_t NumDimO, // NumDimGemm1N
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename B1DataType, typename B1DataType,
typename CDataType, typename CDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename GemmAccDataType, typename GemmAccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
...@@ -132,6 +136,10 @@ template <typename ALayout, ...@@ -132,6 +136,10 @@ template <typename ALayout,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
TensorSpecialization ASpec,
TensorSpecialization BSpec,
TensorSpecialization B1Spec,
TensorSpecialization CSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -172,283 +180,135 @@ template <typename ALayout, ...@@ -172,283 +180,135 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
bool MaskOutUpperTriangle, MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
: public DeviceBatchedGemmSoftmaxGemmPermute<ALayout, : public DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
BLayout, NumDimM,
B1Layout, NumDimN,
CPermuteNumDims_G_M_Gemm1N, NumDimK,
NumDimO,
ADataType, ADataType,
BDataType, BDataType,
B1DataType, B1DataType,
CDataType, CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation> CElementwiseOperation,
MaskingSpec>
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0");
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
// TODO ANT: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
#if 0
// TODO ANT: use alias
static constexpr index_t NumDimGemm0M = NumDimM;
static constexpr index_t NumDimGemm0N = NumDimN;
static constexpr index_t NumDimGemm0K = NumDimK;
static constexpr index_t NumDimGemm1M = NumDimM;
static constexpr index_t NumDimGemm1N = NumDimO;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle; using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto matrix_padder = using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{ Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock}; Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
GemmSpec,
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) ASpec,
BSpec,
B1Spec,
CSpec>;
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
{ {
const auto a_grid_desc_mraw_kraw = [&]() { return Transform::MakeAGridDescriptor_AK0_M_AK1(
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>) Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
{ Number<AK1>{});
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
}
}();
const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1
static auto MakeB1GridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b1_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, B1Layout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, B1Layout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw);
const auto N = b1_grid_desc_n_k.GetLength(I0);
const auto K = b1_grid_desc_n_k.GetLength(I1);
const auto B1K0 = K / B1K1;
return transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
static auto MakeCGridDescriptor_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec, const std::vector<index_t>& b_gs_ns_ks_strides_vec)
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
{ {
constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0); return Transform::MakeB0GridDescriptor_BK0_N_BK1(
constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1); Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec),
constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N Number<BK1>{});
assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
const auto to_tuple = [&](auto& vec, auto start, auto end) {
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
};
const auto c_ms_ns_lengths = to_tuple(
c_gs_ms_ns_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
const auto c_ms_ns_strides = to_tuple(
c_gs_ms_ns_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
// dimension Ids for M0, M1, ...
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
// dimension Ids for N0, N1, ...
constexpr auto nDimIds =
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimN, 1>::type{};
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(c_ms_ns_lengths, mDimIds);
// lengths for K0, K1, ...
const auto nLengths = get_container_subset(c_ms_ns_lengths, nDimIds);
// naive tensor C[M0, M1, M2, ..., N0, N1, N2...]
const auto c_grid_desc_ms_ns =
make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides);
// transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
const auto c_grid_desc_mraw_nraw = transform_tensor_descriptor(
c_grid_desc_ms_ns,
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
make_tuple(mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
} }
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] static auto
static auto MakeCGridDescriptor_G_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec, MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
const std::vector<index_t>& c_gs_ms_ns_strides_vec) const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec)
{ {
constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0); return Transform::MakeB1GridDescriptor_BK0_N_BK1(
constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1); Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec,
constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N b1_gs_gemm1ns_gemm1ks_strides_vec),
Number<B1K1>{});
assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
const auto to_tuple = [&](auto& vec, auto start, auto end) {
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
};
const auto c_gs_ms_ns_lengths =
to_tuple(c_gs_ms_ns_lengths_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
const auto c_gs_ms_ns_strides =
to_tuple(c_gs_ms_ns_strides_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
// dimension Ids for G0, G1, ...
constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{};
// dimension Ids for M0, M1, ...
constexpr auto mDimIds =
typename arithmetic_sequence_gen<NumDimG, NumDimG + NumDimM, 1>::type{};
// dimension Ids for N0, N1, ...
constexpr auto nDimIds = typename arithmetic_sequence_gen<NumDimG + NumDimM,
NumDimG + NumDimM + NumDimN,
1>::type{};
// lengths for G0, G1, ...
const auto gLengths = get_container_subset(c_gs_ms_ns_lengths, gDimIds);
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(c_gs_ms_ns_lengths, mDimIds);
// lengths for K0, K1, ...
const auto nLengths = get_container_subset(c_gs_ms_ns_lengths, nDimIds);
// naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
const auto c_grid_desc_gs_ms_ns =
make_naive_tensor_descriptor(c_gs_ms_ns_lengths, c_gs_ms_ns_strides);
// transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
const auto c_grid_desc_g_mraw_nraw =
transform_tensor_descriptor(c_grid_desc_gs_ms_ns,
make_tuple(make_merge_transform(gLengths),
make_merge_transform(mLengths),
make_merge_transform(nLengths)),
make_tuple(gDimIds, mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// this desc is only for calculating batch offset so no padding needed
return c_grid_desc_g_mraw_nraw;
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1)); using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N({}, {})); using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using CGridDesc_G_M_N = decltype(MakeCGridDescriptor_G_M_N({}, {})); using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
// to track the points which need to be set to -inf on C0 constexpr static auto make_MaskOutPredicate()
// Note: no need to reset M padding value, because they will not be stored out.
struct C0MatrixMask
{ {
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {} if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled)
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
{ {
return n >= NRaw_; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
{ {
return IsUpperTriangle(m, n) || IsNOutOfBound(n); return MaskOutUpperTrianglePredicate{};
} }
}
private: using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
// index_t MRaw_;
index_t NRaw_;
};
struct ComputeBasePtrOfStridedBatch struct ComputeBasePtrOfStridedBatch
{ {
ComputeBasePtrOfStridedBatch(index_t BatchStrideA, ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
index_t BatchStrideB, const BGridDesc_G_N_K& b_grid_desc_g_n_k,
index_t BatchStrideB1, const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
CGridDesc_G_M_N c_grid_desc_g_m_n) const CGridDesc_G_M_N& c_grid_desc_g_m_n)
: BatchStrideA_(BatchStrideA), : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
BatchStrideB_(BatchStrideB), b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
BatchStrideB1_(BatchStrideB1), b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n) c_grid_desc_g_m_n_(c_grid_desc_g_m_n)
{ {
} }
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideA_); return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideB_); return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideB1_); return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
...@@ -457,9 +317,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -457,9 +317,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
} }
private: private:
index_t BatchStrideA_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
index_t BatchStrideB_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
index_t BatchStrideB1_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
}; };
...@@ -523,47 +383,59 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -523,47 +383,59 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskOutUpperTriangle>; MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>;
// Argument // Argument
// FIXME: constness // FIXME: constness
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const ADataType* p_a_grid, Argument(
const BDataType* p_b_grid, const ADataType* p_a_grid,
const B1DataType* p_b1_grid, const BDataType* p_b_grid,
CDataType* p_c_grid, const B1DataType* p_b1_grid,
index_t MRaw, CDataType* p_c_grid,
index_t NRaw, const std::array<void*, NumAcc0Bias> p_acc0_biases,
index_t KRaw, const std::array<void*, NumAcc1Bias> p_acc1_biases,
index_t Gemm1NRaw, // = ORaw const std::vector<index_t>& a_gs_ms_ks_lengths,
index_t Batch, const std::vector<index_t>& a_gs_ms_ks_strides,
std::vector<index_t> c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& b_gs_ns_ks_lengths,
std::vector<index_t> c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& b_gs_ns_ks_strides,
index_t StrideA, const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
index_t StrideB, const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
index_t StrideB1, const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
index_t BatchStrideA, const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
index_t BatchStrideB, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
index_t BatchStrideB1, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
AElementwiseOperation a_element_op, const std::array<std::vector<ck::index_t>, NumAcc1Bias>
BElementwiseOperation b_element_op, acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
AccElementwiseOperation acc_element_op, const std::array<std::vector<ck::index_t>, NumAcc1Bias>
B1ElementwiseOperation b1_element_op, acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
CElementwiseOperation c_element_op) AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_b1_grid_{p_b1_grid}, p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, a_grid_desc_ak0_m_ak1_{
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b1_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(NRaw, Gemm1NRaw, StrideB1)}, DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(
c_gs_ms_gemm1ns_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_g_m_n_{DeviceOp::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths, c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
a_grid_desc_g_m_k_{
Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_g_n_k_{
Transform::MakeB0GridDescriptor_G_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
b1_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
...@@ -571,14 +443,31 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -571,14 +443,31 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
acc_element_op_{acc_element_op}, acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
batch_count_(Batch), c0_matrix_mask_{b_grid_desc_g_n_k_.GetLength(I1)},
raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1],
b1_gs_gemm1ns_gemm1ks_lengths[NumDimG + NumDimO - 1]},
a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]},
b_nz_kz_strides_{b_gs_ns_ks_strides[NumDimG + NumDimN - 1],
b_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]},
b1_nz_kz_strides_{b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO - 1],
b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]},
c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)},
compute_base_ptr_of_batch_{ compute_base_ptr_of_batch_{
BatchStrideA, BatchStrideB, BatchStrideB1, c_grid_desc_g_m_n_}, a_grid_desc_g_m_k_, b_grid_desc_g_n_k_, b1_grid_desc_g_n_k_, c_grid_desc_g_m_n_}
c0_matrix_mask_{NRaw},
raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw},
c_extent_lowest_{c_gs_ms_gemm1ns_lengths.back()},
c_stride_lowest_{c_gs_ms_gemm1ns_strides.back()}
{ {
// TODO ANT: implement bias addition
ignore = p_acc0_biases;
ignore = p_acc1_biases;
ignore = acc0_biases_gs_ms_ns_lengths;
ignore = acc0_biases_gs_ms_ns_strides;
ignore = acc1_biases_gs_ms_gemm1ns_lengths;
ignore = acc1_biases_gs_ms_gemm1ns_strides;
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
b1_grid_desc_bk0_n_bk1_, b1_grid_desc_bk0_n_bk1_,
...@@ -591,34 +480,66 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -591,34 +480,66 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
} }
} }
// private: void Print() const
{
std::cout << "a_grid_desc_g_m_k_: " << a_grid_desc_g_m_k_.GetLength(I0) << ", "
<< a_grid_desc_g_m_k_.GetLength(I1) << ", "
<< a_grid_desc_g_m_k_.GetLength(I2) << '\n';
// a_grid_desc_g_m_k_.Print();
std::cout << "b_grid_desc_g_n_k_: " << b_grid_desc_g_n_k_.GetLength(I0) << ", "
<< b_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b_grid_desc_g_n_k_.GetLength(I2) << '\n';
// b_grid_desc_g_n_k_.Print();
std::cout << "b1_grid_desc_g_n_k_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I2) << '\n';
// b1_grid_desc_g_n_k_.Print();
std::cout << "c_grid_desc_g_m_n_: " << c_grid_desc_g_m_n_.GetLength(I0) << ", "
<< c_grid_desc_g_m_n_.GetLength(I1) << ", "
<< c_grid_desc_g_m_n_.GetLength(I2) << '\n';
// c_grid_desc_g_m_n_.Print();
}
// pointers
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
const B1DataType* p_b1_grid_; const B1DataType* p_b1_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
// tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
// element-wise op
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
AccElementwiseOperation acc_element_op_; AccElementwiseOperation acc_element_op_;
B1ElementwiseOperation b1_element_op_; B1ElementwiseOperation b1_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
// check C0 masking and padding // check C0 masking and padding
C0MatrixMask c0_matrix_mask_; C0MatrixMask c0_matrix_mask_;
// For robust IsSupportedArgument() check // For robust IsSupportedArgument() check
std::vector<index_t> raw_lengths_m_n_k_o_; std::vector<index_t> raw_lengths_mz_nz_kz_gemm1nz_;
index_t c_extent_lowest_; std::vector<index_t> a_mz_kz_strides_;
index_t c_stride_lowest_; std::vector<index_t> b_nz_kz_strides_;
std::vector<index_t> b1_nz_kz_strides_;
std::vector<index_t> c_mz_gemm1nz_strides_;
index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
}; };
// Invoker // Invoker
...@@ -628,13 +549,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -628,13 +549,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, if(!DeviceOp::IsSupportedArgument(arg))
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{ {
throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); throw std::runtime_error("wrong! unsupported argument");
} }
const index_t grid_size = const index_t grid_size =
...@@ -719,17 +636,24 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -719,17 +636,24 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
#if 0
arg.Print();
#endif
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{ {
return false; return false;
} }
// TODO ANT: Check if tensor specialization & strides mismatch
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0); const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0);
const index_t c_gemm1n = arg.c_grid_desc_m_n_.GetLength(I1); const index_t c_gemm1n = arg.c_grid_desc_m_n_.GetLength(I1);
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1); const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n)) if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n))
{ {
return false; return false;
...@@ -737,19 +661,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -737,19 +661,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// Note: we need raw lengths since threadwise copy can not handle vector load when part of // Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds // vector is out of bounds
const auto MRaw = arg.raw_lengths_m_n_k_o_[0]; // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
const auto NRaw = arg.raw_lengths_m_n_k_o_[1]; const auto MzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[0];
const auto KRaw = arg.raw_lengths_m_n_k_o_[2]; const auto NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[1];
const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3]; const auto KzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[2];
const auto Gemm1NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[3];
// Check scalar per vector requirement // Check scalar per vector requirement
const auto a_extent_lowest = const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw; const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw;
const auto b_extent_lowest = const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? NzRaw : Gemm1NzRaw;
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw; const auto c_extent_lowest = Gemm1NzRaw;
const auto b1_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
const auto c_extent_lowest = arg.c_extent_lowest_;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 && b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
...@@ -759,8 +681,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -759,8 +681,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return false; return false;
} }
// Check vector store requirement; assumes last dimension in N to be contiguous // Check vector load/store requirement
if(arg.c_stride_lowest_ != 1) const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
const auto b_stride_lowest =
BBlockTransferSrcVectorDim == 2 ? arg.b_nz_kz_strides_[1] : arg.b_nz_kz_strides_[0];
const auto b1_stride_lowest =
B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_kz_strides_[1] : arg.b1_nz_kz_strides_[0];
const auto c_stride_lowest =
arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be contiguous
if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 ||
c_stride_lowest == 1))
{ {
return false; return false;
} }
...@@ -778,46 +710,51 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -778,46 +710,51 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(
const BDataType* p_b, const ADataType* p_a,
const B1DataType* p_b1, const BDataType* p_b,
CDataType* p_c, const B1DataType* p_b1,
index_t MRaw, CDataType* p_c,
index_t NRaw, const std::array<void*, NumAcc0Bias> p_acc0_biases,
index_t KRaw, const std::array<void*, NumAcc1Bias> p_acc1_biases,
index_t Gemm1NRaw, const std::vector<index_t>& a_gs_ms_ks_lengths,
index_t Batch, const std::vector<index_t>& a_gs_ms_ks_strides,
std::vector<index_t> c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& b_gs_ns_ks_lengths,
std::vector<index_t> c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& b_gs_ns_ks_strides,
index_t StrideA, const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
index_t StrideB, const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
index_t StrideB1, const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
index_t BatchStrideA, const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
index_t BatchStrideB, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
index_t BatchStrideB1, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
AElementwiseOperation a_element_op, const std::array<std::vector<ck::index_t>, NumAcc1Bias>
BElementwiseOperation b_element_op, acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
AccElementwiseOperation acc_element_op, const std::array<std::vector<ck::index_t>, NumAcc1Bias>
B1ElementwiseOperation b1_element_op, acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
CElementwiseOperation c_element_op) AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
p_b1, p_b1,
p_c, p_c,
MRaw, p_acc0_biases,
NRaw, p_acc1_biases,
KRaw, a_gs_ms_ks_lengths,
Gemm1NRaw, a_gs_ms_ks_strides,
Batch, b_gs_ns_ks_lengths,
c_gs_ms_gemm1ns_lengths, b_gs_ns_ks_strides,
c_gs_ms_gemm1ns_strides, b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
StrideA, b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
StrideB, c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
StrideB1, c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
BatchStrideA, acc0_biases_gs_ms_ns_lengths,
BatchStrideB, acc0_biases_gs_ms_ns_strides,
BatchStrideB1, acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
a_element_op, a_element_op,
b_element_op, b_element_op,
acc_element_op, acc_element_op,
...@@ -829,47 +766,51 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -829,47 +766,51 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// polymorphic // polymorphic
// FIXME: constness // FIXME: constness
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument> MakeArgumentPointer(
MakeArgumentPointer(const void* p_a, const void* p_a,
const void* p_b, const void* p_b,
const void* p_b1, const void* p_b1,
void* p_c, void* p_c,
index_t MRaw, const std::array<void*, NumAcc0Bias> p_acc0_biases,
index_t NRaw, const std::array<void*, NumAcc1Bias> p_acc1_biases,
index_t KRaw, const std::vector<index_t>& a_gs_ms_ks_lengths,
index_t Gemm1NRaw, const std::vector<index_t>& a_gs_ms_ks_strides,
index_t Batch, const std::vector<index_t>& b_gs_ns_ks_lengths,
std::vector<index_t> c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& b_gs_ns_ks_strides,
std::vector<index_t> c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
index_t StrideA, const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
index_t StrideB, const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
index_t StrideB1, const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
index_t BatchStrideA, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
index_t BatchStrideB, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
index_t BatchStrideB1, const std::array<std::vector<ck::index_t>, NumAcc1Bias>
AElementwiseOperation a_element_op, acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
BElementwiseOperation b_element_op, const std::array<std::vector<ck::index_t>, NumAcc1Bias>
AccElementwiseOperation acc_element_op, acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
B1ElementwiseOperation b1_element_op, AElementwiseOperation a_element_op,
CElementwiseOperation c_element_op) override BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
static_cast<const B1DataType*>(p_b1), static_cast<const B1DataType*>(p_b1),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
MRaw, p_acc0_biases, // cast in struct Argument
NRaw, p_acc1_biases, // cast in struct Argument
KRaw, a_gs_ms_ks_lengths,
Gemm1NRaw, a_gs_ms_ks_strides,
Batch, b_gs_ns_ks_lengths,
c_gs_ms_gemm1ns_lengths, b_gs_ns_ks_strides,
c_gs_ms_gemm1ns_strides, b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
StrideA, b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
StrideB, c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
StrideB1, c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
BatchStrideA, acc0_biases_gs_ms_ns_lengths,
BatchStrideB, acc0_biases_gs_ms_ns_strides,
BatchStrideB1, acc1_biases_gs_ms_gemm1ns_lengths,
acc1_biases_gs_ms_gemm1ns_strides,
a_element_op, a_element_op,
b_element_op, b_element_op,
acc_element_op, acc_element_op,
...@@ -901,7 +842,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -901,7 +842,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<< Gemm1NPerBlock << ", " << Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", " << Gemm1KPerBlock << ", "
<< B1K1 << ", " << B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ">"; << getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
<< "B0Spec" << getTensorSpecializationString(BSpec) << ", "
<< "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
<< "CSpec" << getTensorSpecializationString(CSpec) << ", "
<< getMaskingSpecializationString(MaskingSpec) << ">";
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
...@@ -196,7 +197,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -196,7 +197,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation> CElementwiseOperation,
MaskOutUpperTriangle>
{ {
using DeviceOp = DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle; using DeviceOp = DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle;
...@@ -315,29 +317,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -315,29 +317,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw); return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
} }
// to track the points which need to be set to -inf on C0
// Note: no need to reset M padding value, because they will not be stored out.
struct C0MatrixMask
{
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {}
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
{
return n >= NRaw_;
}
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
{
return IsUpperTriangle(m, n) || IsNOutOfBound(n);
}
private:
// index_t MRaw_;
index_t NRaw_;
};
struct ComputeBasePtrOfStridedBatch struct ComputeBasePtrOfStridedBatch
{ {
ComputeBasePtrOfStridedBatch(index_t BatchStrideA, ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
...@@ -383,6 +362,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -383,6 +362,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1)); using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using C0MatrixMask = conditional_t<MaskOutUpperTriangle,
C0MatrixMask_impl<MaskOutUpperTrianglePredicate>,
C0MatrixMask_impl<MaskDisabledPredicate>>;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/welford_helper.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim,
bool UseMultiblockInK,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XSrcYDstVectorDim,
index_t XSrcVectorSize,
index_t YDstVectorSize,
index_t ScaleSrcVectorSize,
index_t BiasSrcVectorSize,
index_t MeanVarSrcDstVectorSize>
struct DeviceBatchNormFwdImpl
: public DeviceBatchNormFwd<Rank, NumBatchNormReduceDim, YElementwiseOp>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!");
static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
(XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeXY2dDescriptor(const std::array<index_t, Rank>& xyLengths,
const std::array<index_t, Rank>& xyStrides,
int blkGroupSize,
int numBlockTileIteration)
{
const auto tupleXYLengths =
generate_tuple([&](auto I) { return xyLengths[I]; }, Number<Rank>{});
const auto tupleXYStrides =
generate_tuple([&](auto I) { return xyStrides[I]; }, Number<Rank>{});
const auto raw_grid_desc = make_naive_tensor_descriptor(tupleXYLengths, tupleXYStrides);
const auto grid_desc_m_k = [&]() {
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths =
generate_tuple([&](auto I) { return xyLengths[NumInvariantDim + I]; },
Number<NumBatchNormReduceDim>{});
const auto invariantDimLengths =
generate_tuple([&](auto I) { return xyLengths[I]; }, Number<NumInvariantDim>{});
return transform_tensor_descriptor(raw_grid_desc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}();
const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{});
const int workSizePerBlock = K_BlockTileSize * numBlockTileIteration;
const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto kPad = workSizePerBlock * blkGroupSize - reduceLength;
auto grid_desc_m_k_padded =
transform_tensor_descriptor(grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, mPad),
make_right_pad_transform(reduceLength, kPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (grid_desc_m_k_padded);
};
static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
{
const auto grid_desc_m_g =
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, blkGroupSize));
const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto grid_desc_m_g_padded =
transform_tensor_descriptor(grid_desc_m_g,
make_tuple(make_right_pad_transform(invariantLength, mPad),
make_pass_through_transform(blkGroupSize)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (grid_desc_m_g_padded);
};
static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize)
{
const auto reduceLength = blkGroupSize;
const auto grid_desc_m_k =
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, reduceLength));
const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto kPad =
math::integer_least_multiple(reduceLength, KThreadClusterSize) - reduceLength;
auto grid_desc_m_k_padded =
transform_tensor_descriptor(grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, mPad),
make_right_pad_transform(reduceLength, kPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (grid_desc_m_k_padded);
};
static auto
MakeScaleBiasMeanVar1dDescriptor(const std::array<index_t, NumInvariantDim>& lengths,
const std::array<index_t, NumInvariantDim>& strides)
{
const auto tupleLengths =
generate_tuple([&](auto I) { return lengths[I]; }, Number<NumInvariantDim>{});
const auto tupleStrides =
generate_tuple([&](auto I) { return strides[I]; }, Number<NumInvariantDim>{});
auto raw_grid_desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides);
auto grid_desc_m = transform_tensor_descriptor(
raw_grid_desc,
make_tuple(make_merge_transform(tupleLengths)),
make_tuple(typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto invariantLength = grid_desc_m.GetLength(Number<0>{});
const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto grid_desc_m_padded =
transform_tensor_descriptor(grid_desc_m,
make_tuple(make_right_pad_transform(invariantLength, mPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return (grid_desc_m_padded);
};
using XYGridDesc_M_K = decltype(MakeXY2dDescriptor({1}, {1}, 1, 1));
using ScaleBiasMeanVarGridDesc_M = decltype(MakeScaleBiasMeanVar1dDescriptor({1}, {1}));
struct Argument : public BaseArgument
{
Argument(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
const XDataType* p_x,
const ScaleDataType* p_scale,
const BiasDataType* p_bias,
const YElementwiseOp y_elementwise_op,
double epsilon,
YDataType* p_y,
MeanVarDataType* resultSaveMean,
MeanVarDataType* resultSaveInvVariance,
double averageFactor,
MeanVarDataType* resultRunningMean,
MeanVarDataType* resultRunningVariance)
: bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
bnScaleStrides_(bnScaleStrides),
bnBiasStrides_(bnBiasStrides),
bnMeanVarStrides_(bnMeanVarStrides),
p_x_(p_x),
p_scale_(p_scale),
p_bias_(p_bias),
y_elementwise_op_(y_elementwise_op),
p_y_(p_y),
resultSaveMean_(resultSaveMean),
resultSaveInvVariance_(resultSaveInvVariance),
resultRunningMean_(resultRunningMean),
resultRunningVariance_(resultRunningVariance)
{
xyLengths_ =
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(xyLengths, reduceDims);
xStrides_ =
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(xStrides, reduceDims);
yStrides_ =
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(yStrides, reduceDims);
std::tie(invariant_length_, reduce_length_) =
get_2d_lengths<Rank, NumBatchNormReduceDim>(xyLengths_);
epsilon_ = type_convert<AccDataType>(epsilon);
averageFactor_ = type_convert<AccDataType>(averageFactor);
updateMovingAverage_ =
(resultRunningMean != nullptr && resultRunningVariance != nullptr);
saveMeanInvVariance_ = (resultSaveMean != nullptr && resultSaveInvVariance_ != nullptr);
if(UseMultiblockInK)
{
int iterations = 1;
while(true)
{
int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
// we want the blkGroupSize be not more than 128
if(testBlkGroupSize <= 128)
break;
iterations++;
};
blkGroupSize_ = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
numBlockTileIteration_ = iterations;
}
else
{
blkGroupSize_ = 1;
numBlockTileIteration_ = (reduce_length_ + K_BlockTileSize - 1) / K_BlockTileSize;
};
gridSize_ = (invariant_length_ + M_BlockTileSize - 1) / M_BlockTileSize * blkGroupSize_;
x_grid_desc_m_k_ =
MakeXY2dDescriptor(xyLengths_, xStrides_, blkGroupSize_, numBlockTileIteration_);
y_grid_desc_m_k_ =
MakeXY2dDescriptor(xyLengths_, yStrides_, blkGroupSize_, numBlockTileIteration_);
scale_grid_desc_m_ =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides_);
bias_grid_desc_m_ =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnBiasStrides_);
mean_var_grid_desc_m_ =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides_);
}
AccDataType epsilon_;
AccDataType averageFactor_;
bool updateMovingAverage_;
bool saveMeanInvVariance_;
std::array<index_t, Rank> xyLengths_;
std::array<index_t, Rank> xStrides_;
std::array<index_t, Rank> yStrides_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides_;
const XDataType* p_x_;
const ScaleDataType* p_scale_;
const BiasDataType* p_bias_;
const YElementwiseOp y_elementwise_op_;
YDataType* p_y_;
MeanVarDataType* resultSaveMean_;
MeanVarDataType* resultSaveInvVariance_;
MeanVarDataType* resultRunningMean_;
MeanVarDataType* resultRunningVariance_;
long_index_t invariant_length_;
long_index_t reduce_length_;
int blkGroupSize_;
int numBlockTileIteration_;
size_t gridSize_;
XYGridDesc_M_K x_grid_desc_m_k_;
XYGridDesc_M_K y_grid_desc_m_k_;
ScaleBiasMeanVarGridDesc_M scale_grid_desc_m_;
ScaleBiasMeanVarGridDesc_M bias_grid_desc_m_;
ScaleBiasMeanVarGridDesc_M mean_var_grid_desc_m_;
void* workspace_mean_;
void* workspace_variance_;
void* workspace_count_;
};
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
{
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
size_t workspace_size = 0;
if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
{
// workspace for welford intermediate mean
workspace_size +=
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64;
// workspace for welford intermediate variance
workspace_size +=
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64;
// workspace for welford intermediate count
workspace_size +=
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64;
}
return (workspace_size);
};
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
{
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
pArg_->p_workspace_ = p_workspace;
if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
{
// setup buffer used for intermediate welford mean
pArg_->workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
index_t mean_space_sz =
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType);
mean_space_sz = math::integer_least_multiple(mean_space_sz, 64);
// setup buffer used for intermediate welford varirance
pArg_->workspace_variance_ =
reinterpret_cast<char*>(pArg_->workspace_mean_) + mean_space_sz;
index_t variance_space_sz =
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType);
variance_space_sz = math::integer_least_multiple(variance_space_sz, 64);
// setup buffer used for intermediate welfor count
pArg_->workspace_count_ =
reinterpret_cast<char*>(pArg_->workspace_variance_) + variance_space_sz;
};
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float avg_time = 0;
if(UseMultiblockInK && arg.blkGroupSize_ > 1)
{
using GetReduceCountPerThreadFunctor =
GetReduceCountPerThreadForMultiblockWelford<K_BlockTileSize, KThreadSliceSize>;
GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
arg.blkGroupSize_, arg.numBlockTileIteration_, arg.reduce_length_);
const auto mean_var_count_grid_desc_m_g =
DeviceBatchNormFwdImpl::MakeMeanVarCountOutputMG2dDescriptor(
arg.invariant_length_, arg.blkGroupSize_);
const auto mean_var_count_grid_desc_m_k =
DeviceBatchNormFwdImpl::MakeMeanVarCountInputMK2dDescriptor(
arg.invariant_length_, arg.blkGroupSize_);
using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g);
using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k);
using GridwiseMultiblockWelfordFirstHalf_ =
GridwiseMultiblockWelfordFirstHalf<XDataType,
AccDataType,
MeanVarDataType,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_G,
GetReduceCountPerThreadFunctor,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XSrcYDstVectorDim,
XSrcVectorSize>;
using GridwiseWelfordSecondHalfBatchNormForwardFinal_ =
GridwiseWelfordSecondHalfBatchNormForwardFinal<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XSrcYDstVectorDim,
XSrcVectorSize,
YDstVectorSize,
ScaleSrcVectorSize,
BiasSrcVectorSize,
MeanVarSrcDstVectorSize>;
index_t numMeanVarCountBlockTileIteration =
(arg.blkGroupSize_ + KThreadClusterSize - 1) / KThreadClusterSize;
const auto kern_multiblock_welford_first_half =
kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_,
XDataType,
MeanVarDataType,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_G,
GetReduceCountPerThreadFunctor>;
const auto kern_welford_second_half_batchnorm_forward_final =
kernel_welford_second_half_batchnorm_forward_final<
GridwiseWelfordSecondHalfBatchNormForwardFinal_,
XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M>;
avg_time +=
launch_and_time_kernel(stream_config,
kern_multiblock_welford_first_half,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
mean_var_count_grid_desc_m_g,
get_reduce_count_per_thread,
arg.numBlockTileIteration_,
arg.p_x_,
static_cast<MeanVarDataType*>(arg.workspace_mean_),
static_cast<MeanVarDataType*>(arg.workspace_variance_),
static_cast<int32_t*>(arg.workspace_count_));
avg_time +=
launch_and_time_kernel(stream_config,
kern_welford_second_half_batchnorm_forward_final,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
arg.y_grid_desc_m_k_,
mean_var_count_grid_desc_m_k,
arg.scale_grid_desc_m_,
arg.bias_grid_desc_m_,
arg.mean_var_grid_desc_m_,
arg.blkGroupSize_,
arg.numBlockTileIteration_,
numMeanVarCountBlockTileIteration,
arg.epsilon_,
static_cast<MeanVarDataType*>(arg.workspace_mean_),
static_cast<MeanVarDataType*>(arg.workspace_variance_),
static_cast<int32_t*>(arg.workspace_count_),
arg.p_x_,
arg.p_scale_,
arg.p_bias_,
arg.y_elementwise_op_,
arg.p_y_,
arg.updateMovingAverage_,
arg.averageFactor_,
arg.resultRunningMean_,
arg.resultRunningVariance_,
arg.saveMeanInvVariance_,
arg.resultSaveMean_,
arg.resultSaveInvVariance_);
}
else
{
using GetReduceCountPerThreadFunctor =
GetReduceCountPerThreadForBlockwiseWelford<K_BlockTileSize, KThreadSliceSize>;
GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
arg.numBlockTileIteration_, arg.reduce_length_);
using GridwiseBatchNormForwardWithBlockwiseWelford_ =
GridwiseBatchNormForwardWithBlockwiseWelford<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M,
GetReduceCountPerThreadFunctor,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XSrcYDstVectorDim,
XSrcVectorSize,
YDstVectorSize,
ScaleSrcVectorSize,
BiasSrcVectorSize,
MeanVarSrcDstVectorSize>;
const auto kern_batchnorm_fwd = kernel_batchnorm_forward_with_blockwise_welford<
GridwiseBatchNormForwardWithBlockwiseWelford_,
XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M,
GetReduceCountPerThreadFunctor>;
avg_time += launch_and_time_kernel(stream_config,
kern_batchnorm_fwd,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
arg.y_grid_desc_m_k_,
arg.scale_grid_desc_m_,
arg.bias_grid_desc_m_,
arg.mean_var_grid_desc_m_,
get_reduce_count_per_thread,
arg.numBlockTileIteration_,
arg.epsilon_,
arg.p_x_,
arg.p_scale_,
arg.p_bias_,
arg.y_elementwise_op_,
arg.p_y_,
arg.updateMovingAverage_, // true or false
arg.averageFactor_,
arg.resultRunningMean_,
arg.resultRunningVariance_,
arg.saveMeanInvVariance_, // true or false
arg.resultSaveMean_,
arg.resultSaveInvVariance_);
};
return (avg_time);
};
float Run(const BaseArgument* pArg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(pArg), stream_config);
};
};
bool IsSupportedArgument(const BaseArgument* pArg) override
{
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
if constexpr(XSrcYDstVectorDim == 0)
{
if(pArg_->xStrides_[NumInvariantDim - 1] != 1 ||
pArg_->yStrides_[NumInvariantDim - 1] != 1)
return false;
if(pArg_->xyLengths_[NumInvariantDim - 1] % XSrcVectorSize != 0 ||
pArg_->xyLengths_[NumInvariantDim - 1] % YDstVectorSize != 0)
return false;
}
else
{
if(pArg_->xStrides_[Rank - 1] != 1 || pArg_->yStrides_[Rank - 1] != 1)
return false;
if(pArg_->xyLengths_[Rank - 1] % XSrcVectorSize != 0 ||
pArg_->xyLengths_[Rank - 1] % YDstVectorSize != 0)
return false;
};
if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1)
return false;
if(pArg_->bnBiasStrides_[NumInvariantDim - 1] != 1 && BiasSrcVectorSize != 1)
return false;
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0)
return false;
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % BiasSrcVectorSize != 0)
return false;
if(pArg_->bnMeanVarStrides_[NumInvariantDim - 1] != 1 && MeanVarSrcDstVectorSize != 1)
return false;
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % MeanVarSrcDstVectorSize != 0)
return false;
bool is_valid = true;
static_for<0, NumInvariantDim, 1>{}([&](auto I) {
if(pArg_->xyLengths_[I] != pArg_->bnScaleBiasMeanVarLengths_[I])
is_valid = false;
});
if(!is_valid)
return false;
return true;
};
std::unique_ptr<BaseArgument> MakeArgumentPointer(
const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
const void* p_x,
const void* p_scale,
const void* p_bias,
double epsilon,
const YElementwiseOp y_elementwise_op,
void* p_y,
void* resultSaveMean,
void* resultSaveInvVariance,
double averageFactor,
void* resultRunningMean,
void* resultRunningVariance) override
{
return std::make_unique<Argument>(xyLengths,
xStrides,
yStrides,
reduceDims,
bnScaleBiasMeanVarLengths,
bnScaleStrides,
bnBiasStrides,
bnMeanVarStrides,
static_cast<const XDataType*>(p_x),
static_cast<const ScaleDataType*>(p_scale),
static_cast<const BiasDataType*>(p_bias),
y_elementwise_op,
epsilon,
static_cast<YDataType*>(p_y),
static_cast<MeanVarDataType*>(resultSaveMean),
static_cast<MeanVarDataType*>(resultSaveInvVariance),
averageFactor,
static_cast<MeanVarDataType*>(resultRunningMean),
static_cast<MeanVarDataType*>(resultRunningVariance));
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchNormFwdImpl<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "XSrcYDstVectorDim_" << XSrcYDstVectorDim << ",";
str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << BiasSrcVectorSize << "_mean_var_" << MeanVarSrcDstVectorSize << "_Y" << YDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -5,9 +5,8 @@ ...@@ -5,9 +5,8 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <array>
#include "ck/utility/common_header.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp" #include "ck/tensor_operation/gpu/device/device_reduce.hpp"
...@@ -41,7 +40,8 @@ template <typename InDataType, ...@@ -41,7 +40,8 @@ template <typename InDataType,
index_t InSrcVectorDim, index_t InSrcVectorDim,
index_t InSrcVectorSize, index_t InSrcVectorSize,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccElementwiseOperation> struct DeviceReduceMultiBlock
: public DeviceReduce<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
...@@ -58,8 +58,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -58,8 +58,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank; static constexpr index_t NumSrcDim = Rank;
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0); static constexpr bool reduceAllDim = (NumInvariantDim == 0);
// So far, only AtomicAdd is considered, other Atomic Operation like AtomicMax can be added // So far, only AtomicAdd is considered, other Atomic Operation like AtomicMax can be added
...@@ -81,13 +81,15 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -81,13 +81,15 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths, static auto MakeSrc2dDescriptor(const std::array<index_t, Rank>& inLengths,
const std::vector<index_t>& inStrides, const std::array<index_t, Rank>& inStrides,
int blkGroupSize, int blkGroupSize,
int numBlockTileIteration) int numBlockTileIteration)
{ {
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{}); const auto tupleSrcLengths =
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{}); generate_tuple([&](auto I) { return inLengths[I]; }, Number<Rank>{});
const auto tupleSrcStrides =
generate_tuple([&](auto I) { return inStrides[I]; }, Number<Rank>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
...@@ -97,7 +99,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -97,7 +99,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
const auto one_dim_inDesc = transform_tensor_descriptor( const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc, inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)), make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}), make_tuple(typename arithmetic_sequence_gen<0, NumSrcDim, 1>::type{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc, return transform_tensor_descriptor(one_dim_inDesc,
...@@ -111,10 +113,10 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -111,10 +113,10 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type; using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths = const auto reduceDimLengths = generate_tuple(
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
const auto invariantDimLengths = const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
return transform_tensor_descriptor( return transform_tensor_descriptor(
inDesc, inDesc,
...@@ -143,18 +145,20 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -143,18 +145,20 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
return (in_grid_desc_m_k_padded); return (in_grid_desc_m_k_padded);
}; };
static auto MakeDst1dDescriptor(const std::vector<index_t>& outLengths, static auto MakeDst1dDescriptor(const std::array<index_t, NumDstDim>& outLengths,
const std::vector<index_t>& outStrides) const std::array<index_t, NumDstDim>& outStrides)
{ {
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{}); const auto tupleDstLengths =
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{}); generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumDstDim>{});
const auto tupleDstStrides =
generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor( auto out_grid_desc_m = transform_tensor_descriptor(
outDesc, outDesc,
make_tuple(make_merge_transform(tupleDstLengths)), make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}), make_tuple(typename arithmetic_sequence_gen<0, NumDstDim, 1>::type{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{}); const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
...@@ -170,18 +174,20 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -170,18 +174,20 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
return (out_grid_desc_m_padded); return (out_grid_desc_m_padded);
}; };
static auto MakeDst1dDescriptorForBufferSet(const std::vector<index_t>& outLengths, static auto MakeDst1dDescriptorForBufferSet(const std::array<index_t, NumDstDim>& outLengths,
const std::vector<index_t>& outStrides) const std::array<index_t, NumDstDim>& outStrides)
{ {
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{}); const auto tupleDstLengths =
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{}); generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumDstDim>{});
const auto tupleDstStrides =
generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor( auto out_grid_desc_m = transform_tensor_descriptor(
outDesc, outDesc,
make_tuple(make_merge_transform(tupleDstLengths)), make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}), make_tuple(typename arithmetic_sequence_gen<0, NumDstDim, 1>::type{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto length = out_grid_desc_m.GetLength(Number<0>{}); const auto length = out_grid_desc_m.GetLength(Number<0>{});
...@@ -198,11 +204,11 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -198,11 +204,11 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const std::vector<index_t> inLengths, Argument(const std::array<index_t, Rank> inLengths,
const std::vector<index_t> inStrides, const std::array<index_t, Rank> inStrides,
const std::vector<index_t> outLengths, const std::array<index_t, NumDstDim> outLengths,
const std::vector<index_t> outStrides, const std::array<index_t, NumDstDim> outStrides,
const std::vector<int> reduceDims, const std::array<int, NumReduceDim> reduceDims,
float alpha, float alpha,
float beta, float beta,
const InDataType* in_dev, const InDataType* in_dev,
...@@ -272,10 +278,10 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -272,10 +278,10 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
math::integer_least_multiple(invariant_total_length, BlockSize) / BlockSize; math::integer_least_multiple(invariant_total_length, BlockSize) / BlockSize;
} }
std::vector<index_t> inLengths_; std::array<index_t, Rank> inLengths_;
std::vector<index_t> inStrides_; std::array<index_t, Rank> inStrides_;
std::vector<index_t> outLengths_; std::array<index_t, NumDstDim> outLengths_;
std::vector<index_t> outStrides_; std::array<index_t, NumDstDim> outStrides_;
AccDataType alpha_; AccDataType alpha_;
AccDataType beta_; AccDataType beta_;
...@@ -459,11 +465,11 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -459,11 +465,11 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
}; };
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> inLengths, MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
const std::vector<index_t> inStrides, const std::array<index_t, Rank> inStrides,
const std::vector<index_t> outLengths, const std::array<index_t, NumDstDim> outLengths,
const std::vector<index_t> outStrides, const std::array<index_t, NumDstDim> outStrides,
const std::vector<int> reduceDims, const std::array<int, NumReduceDim> reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <array>
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -34,7 +35,8 @@ template <typename InDataType, ...@@ -34,7 +35,8 @@ template <typename InDataType,
index_t InSrcVectorDim, index_t InSrcVectorDim,
index_t InSrcVectorSize, index_t InSrcVectorSize,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccElementwiseOperation> struct DeviceReduceThreadWise
: public DeviceReduce<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(Rank <= 6, "Bigger Rank size is not supported!");
...@@ -49,18 +51,20 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE ...@@ -49,18 +51,20 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank; static constexpr index_t NumSrcDim = Rank;
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0); static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize; static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize;
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths, static auto MakeSrc2dDescriptor(const std::array<index_t, Rank>& inLengths,
const std::vector<index_t>& inStrides) const std::array<index_t, Rank>& inStrides)
{ {
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{}); const auto tupleSrcLengths =
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{}); generate_tuple([&](auto I) { return inLengths[I]; }, Number<Rank>{});
const auto tupleSrcStrides =
generate_tuple([&](auto I) { return inStrides[I]; }, Number<Rank>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
...@@ -70,7 +74,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE ...@@ -70,7 +74,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
const auto one_dim_inDesc = transform_tensor_descriptor( const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc, inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)), make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}), make_tuple(typename arithmetic_sequence_gen<0, NumSrcDim, 1>::type{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc, return transform_tensor_descriptor(one_dim_inDesc,
...@@ -84,10 +88,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE ...@@ -84,10 +88,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type; using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths = const auto reduceDimLengths = generate_tuple(
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
const auto invariantDimLengths = const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
return transform_tensor_descriptor( return transform_tensor_descriptor(
inDesc, inDesc,
...@@ -116,18 +120,20 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE ...@@ -116,18 +120,20 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
return (in_grid_desc_m_k_padded); return (in_grid_desc_m_k_padded);
}; };
static auto MakeDst1dDescriptor(const std::vector<index_t>& outLengths, static auto MakeDst1dDescriptor(const std::array<index_t, NumDstDim>& outLengths,
const std::vector<index_t>& outStrides) const std::array<index_t, NumDstDim>& outStrides)
{ {
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{}); const auto tupleDstLengths =
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{}); generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumDstDim>{});
const auto tupleDstStrides =
generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor( auto out_grid_desc_m = transform_tensor_descriptor(
outDesc, outDesc,
make_tuple(make_merge_transform(tupleDstLengths)), make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}), make_tuple(typename arithmetic_sequence_gen<0, NumDstDim, 1>::type{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{}); const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
...@@ -145,11 +151,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE ...@@ -145,11 +151,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const std::vector<index_t> inLengths, Argument(const std::array<index_t, Rank> inLengths,
const std::vector<index_t> inStrides, const std::array<index_t, Rank> inStrides,
const std::vector<index_t> outLengths, const std::array<index_t, NumDstDim> outLengths,
const std::vector<index_t> outStrides, const std::array<index_t, NumDstDim> outStrides,
const std::vector<int> reduceDims, const std::array<int, NumReduceDim> reduceDims,
float alpha, float alpha,
float beta, float beta,
const InDataType* in_dev, const InDataType* in_dev,
...@@ -187,10 +193,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE ...@@ -187,10 +193,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
M_BlockTileSize; M_BlockTileSize;
} }
std::vector<index_t> inLengths_; std::array<index_t, Rank> inLengths_;
std::vector<index_t> inStrides_; std::array<index_t, Rank> inStrides_;
std::vector<index_t> outLengths_; std::array<index_t, NumDstDim> outLengths_;
std::vector<index_t> outStrides_; std::array<index_t, NumDstDim> outStrides_;
AccDataType alpha_; AccDataType alpha_;
AccDataType beta_; AccDataType beta_;
...@@ -321,11 +327,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE ...@@ -321,11 +327,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
}; };
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> inLengths, MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
const std::vector<index_t> inStrides, const std::array<index_t, Rank> inStrides,
const std::vector<index_t> outLengths, const std::array<index_t, NumDstDim> outLengths,
const std::vector<index_t> outStrides, const std::array<index_t, NumDstDim> outStrides,
const std::vector<int> reduceDims, const std::array<int, NumReduceDim> reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
......
...@@ -8,12 +8,9 @@ ...@@ -8,12 +8,9 @@
#include "ck/utility/reduction_operator.hpp" #include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp" #include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -50,29 +47,80 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -50,29 +47,80 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
virtual index_t GetNumReduceDim() const override { return kNumReduceDim; } virtual index_t GetNumReduceDim() const override { return kNumReduceDim; }
// Used for freeloading of some handy functions from DeviceReduceMultiBlock static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
using Reduction = DeviceReduceMultiBlock<InDataType,
AccDataType, static constexpr index_t NumSrcDim = Rank;
OutDataType, static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
Rank, static constexpr bool reduceAllDim = (NumInvariantDim == 0);
NumReduceDim,
reduce::Add, static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
InElementwiseOp, static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
AccElementwiseOp,
InMemoryDataOperationEnum::Set, static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
false, // PropagateNan const std::vector<index_t>& inStrides,
false, // OutputIndex int blkGroupSize,
false, // HaveIndexInputIfOutputIndex int numBlockTileIteration)
BlockSize, {
MThreadClusterSize, const auto tupleSrcLengths =
KThreadClusterSize, generate_tuple([&](auto I) { return inLengths[I]; }, Number<Rank>{});
MThreadSliceSize, const auto tupleSrcStrides =
KThreadSliceSize, generate_tuple([&](auto I) { return inStrides[I]; }, Number<Rank>{});
InSrcVectorDim,
InSrcVectorSize, const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
1>; // OutDstVectorSize
const auto in_grid_desc_m_k = [&]() {
using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1)); if constexpr(reduceAllDim)
{
const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, NumSrcDim, 1>::type{}),
make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc,
make_tuple(make_unmerge_transform(make_tuple(
1, one_dim_inDesc.GetLength(Number<0>{})))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
}
else
{
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths = generate_tuple(
[&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
const auto invariantDimLengths =
generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
return transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}();
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration;
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded);
};
using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1));
using GridwiseSoftmaxGeneric = GridwiseSoftmax_mk_to_mk<InDataType, using GridwiseSoftmaxGeneric = GridwiseSoftmax_mk_to_mk<InDataType,
OutDataType, OutDataType,
...@@ -102,7 +150,7 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -102,7 +150,7 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
OutDstVectorSize, OutDstVectorSize,
true>; true>;
struct Argument : public Reduction::Argument struct Argument : public BaseArgument
{ {
Argument(const std::vector<index_t> inLengths, Argument(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides, const std::vector<index_t> inStrides,
...@@ -113,42 +161,60 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -113,42 +161,60 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
OutDataType* out_dev, OutDataType* out_dev,
InElementwiseOp in_elementwise_op, InElementwiseOp in_elementwise_op,
AccElementwiseOp acc_elementwise_op) AccElementwiseOp acc_elementwise_op)
: Reduction::Argument(inLengths, : alpha_{alpha},
inStrides, beta_{beta},
{}, in_dev_{in_dev},
{}, out_dev_{out_dev},
reduceDims, in_elementwise_op_{in_elementwise_op},
0.0f, // alpha acc_elementwise_op_{acc_elementwise_op}
0.0f, // beta
in_dev,
nullptr,
out_dev,
nullptr,
in_elementwise_op,
acc_elementwise_op),
// FIXME: The base class DeviceReduceMultiBlock::Argument only supports alpha/beta of
// float32 precision. Make it support any data type so the fields can be removed.
alpha_(alpha),
beta_(beta)
{ {
// std::cout << "blkGroupSize= " << this->blkGroupSize inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
// << ", numBlockTileIteration= " << this->numBlockTileIteration inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
// << ", gridSize=" << this->gridSize
// << ", invariant_total_length=" << this->invariant_total_length << long_index_t invariant_total_length;
// std::endl; long_index_t reduce_total_length;
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
if constexpr(NumInvariantDim == 0)
invariant_lowest_length_ = 1;
else
invariant_lowest_length_ = inLengths_[NumInvariantDim - 1];
blkGroupSize = 1;
numBlockTileIteration = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize;
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize * blkGroupSize;
} }
std::vector<index_t> inLengths_;
std::vector<index_t> inStrides_;
AccDataType alpha_; AccDataType alpha_;
AccDataType beta_; AccDataType beta_;
const InDataType* in_dev_;
OutDataType* out_dev_;
InElementwiseOp in_elementwise_op_;
AccElementwiseOp acc_elementwise_op_;
index_t invariant_lowest_length_;
int blkGroupSize;
int numBlockTileIteration;
size_t gridSize;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
const auto in_grid_desc_m_k = Reduction::MakeSrc2dDescriptor( const auto in_grid_desc_m_k = DeviceSoftmaxImpl::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto out_grid_desc_m_k = Reduction::MakeSrc2dDescriptor( const auto out_grid_desc_m_k = DeviceSoftmaxImpl::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
bool sweep_once = bool sweep_once =
...@@ -195,15 +261,32 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -195,15 +261,32 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
{ {
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg); const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
if(!Reduction::IsSupportedArgument(p_arg_)) if constexpr(InSrcVectorDim == 0)
{ {
return false; if constexpr(NumInvariantDim == 0)
} {
return false;
}
else
{
if(p_arg_->inStrides_[NumInvariantDim - 1] != 1)
return false;
if(p_arg_->inLengths_[Rank - 1] % OutDstVectorSize != 0) if(p_arg_->invariant_lowest_length_ % InSrcVectorSize != 0)
return false;
};
}
else
{ {
if(p_arg_->inStrides_[Rank - 1] != 1)
return false;
if(p_arg_->inLengths_[Rank - 1] % InSrcVectorSize != 0)
return false;
};
if(p_arg_->invariant_lowest_length_ % OutDstVectorSize != 0)
return false; return false;
}
return true; return true;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment