Unverified Commit dc70e3e1 authored by arai713's avatar arai713 Committed by GitHub
Browse files

Merge branch 'develop' into gridwise_2d

parents 10947a54 8ee36118
...@@ -15,13 +15,9 @@ ...@@ -15,13 +15,9 @@
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp" #include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "batchnorm_forward_impl.hpp" #include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
template <typename InOutDataType, typename AccDataType>
using ReferenceBatchNormFwdInstance =
ck::tensor_operation::host::ReferenceBatchNormFwd_Input_N_H_W_C_Output_C<InOutDataType,
AccDataType>;
static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'}, static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'},
{"verify", required_argument, nullptr, 'v'}, {"verify", required_argument, nullptr, 'v'},
...@@ -41,9 +37,10 @@ class BatchNormFwdArg ...@@ -41,9 +37,10 @@ class BatchNormFwdArg
bool updateMovingAverage; bool updateMovingAverage;
bool saveMeanAndInvVariance; bool saveMeanAndInvVariance;
int data_type = 0; int data_type = 0;
int init_method = 2; int init_method = 2;
bool time_kernel = false; bool time_kernel = false;
bool use_multiblock_welford = false;
public: public:
void show_usage(const char* cmd) void show_usage(const char* cmd)
...@@ -68,6 +65,7 @@ class BatchNormFwdArg ...@@ -68,6 +65,7 @@ class BatchNormFwdArg
"value, 3=decimal value)" "value, 3=decimal value)"
<< std::endl; << std::endl;
std::cout << "Arg5: time kernel (0=no, 1=yes)" << std::endl; std::cout << "Arg5: time kernel (0=no, 1=yes)" << std::endl;
std::cout << "Arg6: use multi-block welford (0=n0, 1=yes)" << std::endl;
}; };
int processArgs(int argc, char* argv[]) int processArgs(int argc, char* argv[])
...@@ -110,14 +108,15 @@ class BatchNormFwdArg ...@@ -110,14 +108,15 @@ class BatchNormFwdArg
}; };
}; };
if(optind + 5 > argc) if(optind + 6 > argc)
throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!");
data_type = std::atoi(argv[optind++]); data_type = std::atoi(argv[optind++]);
updateMovingAverage = std::atoi(argv[optind++]); updateMovingAverage = std::atoi(argv[optind++]);
saveMeanAndInvVariance = std::atoi(argv[optind++]); saveMeanAndInvVariance = std::atoi(argv[optind++]);
init_method = std::atoi(argv[optind++]); init_method = std::atoi(argv[optind++]);
time_kernel = static_cast<bool>(std::atoi(argv[optind])); time_kernel = static_cast<bool>(std::atoi(argv[optind++]));
use_multiblock_welford = static_cast<bool>(std::atoi(argv[optind]));
if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6) if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6)
return (-1); return (-1);
...@@ -128,7 +127,7 @@ class BatchNormFwdArg ...@@ -128,7 +127,7 @@ class BatchNormFwdArg
using namespace ck; using namespace ck;
template <typename InOutDataType, typename AccDataType> template <typename InOutDataType, typename AccDataType, bool UseMultiblockInK>
bool bnorm_fwd_nhwc_test(bool do_verification, bool bnorm_fwd_nhwc_test(bool do_verification,
int init_method, int init_method,
bool time_kernel, bool time_kernel,
...@@ -273,73 +272,140 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -273,73 +272,140 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
scaleBiasMeanVarStrides.end(), scaleBiasMeanVarStrides.end(),
i_scaleBiasMeanVarStrides.begin()); i_scaleBiasMeanVarStrides.begin());
int result = 0; using PassThroughOp = ck::tensor_operation::element_wise::PassThrough;
// used for saving meansquare using DeviceBatchNormFwdInstance =
DeviceMem workspace(sizeof(AccDataType) * 2 * resultSaveMean_ref.mDesc.GetElementSpaceSize() + ck::tensor_operation::device::DeviceBatchNormFwdImpl<InOutDataType,
128); InOutDataType,
AccDataType,
void* p_tmp_mean = workspace.GetDeviceBuffer(); AccDataType, // ScaleDataType
void* p_tmp_meansquare = AccDataType, // BiasDataType
static_cast<char*>(p_tmp_mean) + AccDataType, // MeanVarDataType
(sizeof(AccDataType) * resultSaveMean_ref.mDesc.GetElementSpaceSize() + 63) / 64 * 64; PassThroughOp, // YElementwiseOp
Rank,
result = bnorm_fwd<InOutDataType, AccDataType, Rank, NumReduceDim, false>( NumReduceDim,
time_kernel, UseMultiblockInK,
updateMovingAverage, 256,
saveMeanAndInvVariance, 16,
{0, 1, 2}, 16,
1,
2,
0,
1,
1,
1,
1,
1>;
auto batchnorm_fwd = DeviceBatchNormFwdInstance{};
auto argument_ptr = batchnorm_fwd.MakeArgumentPointer(
i_inOutLengths, i_inOutLengths,
i_inOutStrides, i_inOutStrides,
i_inOutStrides, i_inOutStrides,
{0, 1, 2},
i_scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
x_dev.GetDeviceBuffer(), x_dev.GetDeviceBuffer(),
bnScale_dev.GetDeviceBuffer(), bnScale_dev.GetDeviceBuffer(),
bnBias_dev.GetDeviceBuffer(), bnBias_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(),
averageFactor,
updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr,
updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr,
epsilon, epsilon,
PassThroughOp{},
y_dev.GetDeviceBuffer(),
saveMeanAndInvVariance ? resultSaveMean_dev.GetDeviceBuffer() : nullptr, saveMeanAndInvVariance ? resultSaveMean_dev.GetDeviceBuffer() : nullptr,
saveMeanAndInvVariance ? resultSaveInvVariance_dev.GetDeviceBuffer() : nullptr, saveMeanAndInvVariance ? resultSaveInvVariance_dev.GetDeviceBuffer() : nullptr,
p_tmp_mean, averageFactor,
p_tmp_meansquare); updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr,
updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr);
if(result < 0) if(!batchnorm_fwd.IsSupportedArgument(argument_ptr.get()))
{
std::cout << "The runtime parameters seems not supported by the BatchNorm device instance, "
"exiting!"
<< std::endl;
return (false); return (false);
};
size_t workspace_sz = batchnorm_fwd.GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
batchnorm_fwd.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
auto invoker_ptr = batchnorm_fwd.MakeInvokerPointer();
if(time_kernel)
{
float avg_time = 0.0f;
size_t num_bytes = 0;
size_t total_length = inOutLengths[0] * inOutLengths[1] * inOutLengths[2] * inOutLengths[3];
size_t invariant_length = inOutLengths[3];
avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
// inputing of x, scale, bias, outputing of y
num_bytes +=
total_length * sizeof(InOutDataType) * 2 + invariant_length * sizeof(AccDataType) * 2;
// outputing of mean, inv-variance
num_bytes += saveMeanAndInvVariance ? invariant_length * sizeof(AccDataType) * 2 : 0;
// updating of moving mean, variance
num_bytes += updateMovingAverage ? invariant_length * sizeof(AccDataType) * 4 : 0;
float gb_per_sec = num_bytes / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
}
else
(void)invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
bool pass = true; bool pass = true;
if(do_verification) if(do_verification)
{ {
auto batchNormFwd_ref = ReferenceBatchNormFwdInstance<InOutDataType, AccDataType>{};
using ReferenceBatchNormFwdInstance =
ck::tensor_operation::host::ReferenceBatchNormFwd_Input_N_H_W_C_Output_C<InOutDataType,
InOutDataType,
AccDataType,
AccDataType,
AccDataType,
AccDataType,
PassThroughOp>;
auto batchNormFwd_ref = ReferenceBatchNormFwdInstance{};
auto argument_ptr_ref = batchNormFwd_ref.MakeArgumentPointer( auto argument_ptr_ref = batchNormFwd_ref.MakeArgumentPointer(
i_inOutLengths, i_inOutLengths,
i_inOutStrides, i_inOutStrides,
i_inOutStrides, i_inOutStrides,
{0, 1, 2},
i_scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
x.mData.data(), x.mData.data(),
bnScale.mData.data(), bnScale.mData.data(),
bnBias.mData.data(), bnBias.mData.data(),
y_ref.mData.data(),
0.1, // exponentialAverageFactor
updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr, // resultRunningMean
updateMovingAverage ? resultRunningVariance_ref.mData.data()
: nullptr, // resultRunningVariance
epsilon, epsilon,
PassThroughOp{},
y_ref.mData.data(),
saveMeanAndInvVariance ? resultSaveMean_ref.mData.data() : nullptr, saveMeanAndInvVariance ? resultSaveMean_ref.mData.data() : nullptr,
saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr); saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr,
averageFactor,
updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr,
updateMovingAverage ? resultRunningVariance_ref.mData.data() : nullptr);
if(!batchNormFwd_ref.IsSupportedArgument(argument_ptr_ref.get())) if(!batchNormFwd_ref.IsSupportedArgument(argument_ptr_ref.get()))
{ {
std::cout std::cout << "The runtime parameters seems not supported by the BatchNorm reference "
<< "The runtime parameters seems not supported by the BatchNorm instance, exiting!" "instance, exiting!"
<< std::endl; << std::endl;
return (-2); return (false);
}; };
auto invoker_ptr_ref = batchNormFwd_ref.MakeInvokerPointer(); auto invoker_ptr_ref = batchNormFwd_ref.MakeInvokerPointer();
...@@ -365,6 +431,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -365,6 +431,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
if(saveMeanAndInvVariance) if(saveMeanAndInvVariance)
{ {
using ck::host_common::dumpBufferToFile;
Tensor<AccDataType> resultSaveMean(scaleBiasMeanVarLengths); Tensor<AccDataType> resultSaveMean(scaleBiasMeanVarLengths);
Tensor<AccDataType> resultSaveInvVariance(scaleBiasMeanVarLengths); Tensor<AccDataType> resultSaveInvVariance(scaleBiasMeanVarLengths);
...@@ -396,70 +464,129 @@ int main(int argc, char* argv[]) ...@@ -396,70 +464,129 @@ int main(int argc, char* argv[])
if(arg.data_type == 0) if(arg.data_type == 0)
{ {
pass = bnorm_fwd_nhwc_test<ck::half_t, float>(arg.do_verification, if(arg.use_multiblock_welford)
arg.init_method, pass = bnorm_fwd_nhwc_test<ck::half_t, float, true>(arg.do_verification,
arg.time_kernel, arg.init_method,
arg.inOutLengths, arg.time_kernel,
arg.updateMovingAverage, arg.inOutLengths,
arg.saveMeanAndInvVariance, arg.updateMovingAverage,
averageFactor, arg.saveMeanAndInvVariance,
epsilon); averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<ck::half_t, float, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
} }
else if(arg.data_type == 1) else if(arg.data_type == 1)
{ {
pass = bnorm_fwd_nhwc_test<float, float>(arg.do_verification, if(arg.use_multiblock_welford)
arg.init_method, pass = bnorm_fwd_nhwc_test<float, float, true>(arg.do_verification,
arg.time_kernel, arg.init_method,
arg.inOutLengths, arg.time_kernel,
arg.updateMovingAverage, arg.inOutLengths,
arg.saveMeanAndInvVariance, arg.updateMovingAverage,
averageFactor, arg.saveMeanAndInvVariance,
epsilon); averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<float, float, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
} }
else if(arg.data_type == 3) else if(arg.data_type == 3)
{ {
pass = bnorm_fwd_nhwc_test<int8_t, float>(arg.do_verification, if(arg.use_multiblock_welford)
arg.init_method, pass = bnorm_fwd_nhwc_test<int8_t, float, true>(arg.do_verification,
arg.time_kernel, arg.init_method,
arg.inOutLengths, arg.time_kernel,
arg.updateMovingAverage, arg.inOutLengths,
arg.saveMeanAndInvVariance, arg.updateMovingAverage,
averageFactor, arg.saveMeanAndInvVariance,
epsilon); averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<int8_t, float, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
} }
else if(arg.data_type == 5) else if(arg.data_type == 5)
{ {
pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float>(arg.do_verification, if(arg.use_multiblock_welford)
arg.init_method, pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float, true>(arg.do_verification,
arg.time_kernel, arg.init_method,
arg.inOutLengths, arg.time_kernel,
arg.updateMovingAverage, arg.inOutLengths,
arg.saveMeanAndInvVariance, arg.updateMovingAverage,
averageFactor, arg.saveMeanAndInvVariance,
epsilon); averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
} }
else if(arg.data_type == 6) else if(arg.data_type == 6)
{ {
pass = bnorm_fwd_nhwc_test<double, double>(arg.do_verification, if(arg.use_multiblock_welford)
arg.init_method, pass = bnorm_fwd_nhwc_test<double, double, true>(arg.do_verification,
arg.time_kernel, arg.init_method,
arg.inOutLengths, arg.time_kernel,
arg.updateMovingAverage, arg.inOutLengths,
arg.saveMeanAndInvVariance, arg.updateMovingAverage,
averageFactor, arg.saveMeanAndInvVariance,
epsilon); averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<double, double, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
} }
} }
else else
{ {
pass = bnorm_fwd_nhwc_test<ck::half_t, float>(true, pass = bnorm_fwd_nhwc_test<ck::half_t, float, true>(true,
2, 2,
false, // don't time kernel false, // don't time kernel
{128, 16, 16, 1024}, {128, 16, 6, 512},
true, true,
false, true,
averageFactor, averageFactor,
epsilon); epsilon);
pass = pass && bnorm_fwd_nhwc_test<ck::half_t, float, false>(true,
2,
false, // don't time kernel
{128, 16, 3, 1024},
true,
true,
averageFactor,
epsilon);
}; };
return (pass ? 0 : 1); return (pass ? 0 : 1);
......
...@@ -14,8 +14,12 @@ ...@@ -14,8 +14,12 @@
#include "batchnorm_common.hpp" #include "batchnorm_common.hpp"
template <typename InOutDataType, template <typename XDataType,
typename YDataType,
typename AccDataType, typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
ck::index_t Rank, ck::index_t Rank,
ck::index_t NumBatchNormReduceDim, ck::index_t NumBatchNormReduceDim,
bool fastest_dim_is_reduced = false> bool fastest_dim_is_reduced = false>
...@@ -26,7 +30,9 @@ int bnorm_infer( ...@@ -26,7 +30,9 @@ int bnorm_infer(
const std::array<ck::index_t, Rank> xStrides, const std::array<ck::index_t, Rank> xStrides,
const std::array<ck::index_t, Rank> yStrides, const std::array<ck::index_t, Rank> yStrides,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths, const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides, const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
const void* p_x, const void* p_x,
const void* p_scale, const void* p_scale,
const void* p_bias, const void* p_bias,
...@@ -41,11 +47,11 @@ int bnorm_infer( ...@@ -41,11 +47,11 @@ int bnorm_infer(
"Invalid number of reduced dimensions for batchnorm!"); "Invalid number of reduced dimensions for batchnorm!");
using DeviceNormalizeInstance = ck::tensor_operation::device::DeviceElementwise< using DeviceNormalizeInstance = ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<InOutDataType, AccDataType, AccDataType, AccDataType, AccDataType>, // x, mean, ck::Tuple<XDataType, AccDataType, AccDataType, AccDataType, AccDataType>, // x, mean,
// variance, // variance,
// scale, // scale,
// bias, // bias,
ck::Tuple<InOutDataType>, // y ck::Tuple<YDataType>, // y
NormalizeInInfer, NormalizeInInfer,
Rank, Rank,
2, // MPerthread 2, // MPerthread
...@@ -53,14 +59,18 @@ int bnorm_infer( ...@@ -53,14 +59,18 @@ int bnorm_infer(
ck::Sequence<1>>; // scalarPerVector: y ck::Sequence<1>>; // scalarPerVector: y
auto invariantDims = get_invariant_dims<Rank, NumBatchNormReduceDim>(reduceDims); auto invariantDims = get_invariant_dims<Rank, NumBatchNormReduceDim>(reduceDims);
std::array<ck::index_t, Rank> aligned_scaleBiasMeanVarStrides{0}; std::array<ck::index_t, Rank> aligned_bnScaleStrides{0};
std::array<ck::index_t, Rank> aligned_bnBiasStrides{0};
std::array<ck::index_t, Rank> aligned_bnMeanVarStrides{0};
int i = 0; int i = 0;
for(auto dim : invariantDims) for(auto dim : invariantDims)
{ {
assert(xyLengths[dim] == bnScaleBiasMeanVarLengths[i]); assert(xyLengths[dim] == bnScaleBiasMeanVarLengths[i]);
aligned_scaleBiasMeanVarStrides[dim] = bnScaleBiasMeanVarStrides[i]; aligned_bnScaleStrides[dim] = bnScaleStrides[i];
aligned_bnBiasStrides[dim] = bnBiasStrides[i];
aligned_bnMeanVarStrides[dim] = bnMeanVarStrides[i];
i++; i++;
}; };
...@@ -84,10 +94,10 @@ int bnorm_infer( ...@@ -84,10 +94,10 @@ int bnorm_infer(
auto argument_ptr1 = dev_normalize.MakeArgumentPointer( auto argument_ptr1 = dev_normalize.MakeArgumentPointer(
xyLengths, xyLengths,
{xStrides, {xStrides,
aligned_scaleBiasMeanVarStrides, aligned_bnMeanVarStrides,
aligned_scaleBiasMeanVarStrides, aligned_bnMeanVarStrides,
aligned_scaleBiasMeanVarStrides, aligned_bnScaleStrides,
aligned_scaleBiasMeanVarStrides}, aligned_bnBiasStrides},
{yStrides}, {yStrides},
{p_x, p_estimatedMean, p_estimatedVariance, p_scale, p_bias}, {p_x, p_estimatedMean, p_estimatedVariance, p_scale, p_bias},
{p_y}, {p_y},
...@@ -105,8 +115,10 @@ int bnorm_infer( ...@@ -105,8 +115,10 @@ int bnorm_infer(
avg_time += invoker_ptr1->Run(argument_ptr1.get(), StreamConfig{nullptr, time_kernel}); avg_time += invoker_ptr1->Run(argument_ptr1.get(), StreamConfig{nullptr, time_kernel});
num_bytes += (total_length * (1 * sizeof(InOutDataType) + 4 * sizeof(AccDataType)) + num_bytes += total_length * sizeof(XDataType) +
total_length * sizeof(InOutDataType)); invariantLength *
(sizeof(ScaleDataType) + sizeof(BiasDataType) + 2 * sizeof(MeanVarDataType)) +
total_length * sizeof(YDataType);
if(time_kernel) if(time_kernel)
{ {
......
...@@ -18,11 +18,6 @@ ...@@ -18,11 +18,6 @@
#include "batchnorm_infer_impl.hpp" #include "batchnorm_infer_impl.hpp"
template <typename InOutDataType, typename AccDataType>
using ReferenceBatchNormInferInstance =
ck::tensor_operation::host::ReferenceBatchNormInfer_Input_N_H_W_C_Output_C<InOutDataType,
AccDataType>;
static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'}, static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'},
{"verify", required_argument, nullptr, 'v'}, {"verify", required_argument, nullptr, 'v'},
{"help", no_argument, nullptr, '?'}, {"help", no_argument, nullptr, '?'},
...@@ -236,21 +231,30 @@ bool bnorm_infer_nhwc_test(bool do_verification, ...@@ -236,21 +231,30 @@ bool bnorm_infer_nhwc_test(bool do_verification,
int result = 0; int result = 0;
result = bnorm_infer<InOutDataType, AccDataType, Rank, NumReduceDim, false>( result = bnorm_infer<InOutDataType,
time_kernel, InOutDataType,
{0, 1, 2}, AccDataType,
i_inOutLengths, AccDataType,
i_inOutStrides, AccDataType,
i_inOutStrides, AccDataType,
i_scaleBiasMeanVarLengths, Rank,
i_scaleBiasMeanVarStrides, NumReduceDim,
x_dev.GetDeviceBuffer(), false>(time_kernel,
bnScale_dev.GetDeviceBuffer(), {0, 1, 2},
bnBias_dev.GetDeviceBuffer(), i_inOutLengths,
epsilon, i_inOutStrides,
estimatedMean_dev.GetDeviceBuffer(), i_inOutStrides,
estimatedVariance_dev.GetDeviceBuffer(), i_scaleBiasMeanVarLengths,
y_dev.GetDeviceBuffer()); i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
x_dev.GetDeviceBuffer(),
bnScale_dev.GetDeviceBuffer(),
bnBias_dev.GetDeviceBuffer(),
epsilon,
estimatedMean_dev.GetDeviceBuffer(),
estimatedVariance_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer());
if(result < 0) if(result < 0)
return (false); return (false);
...@@ -259,7 +263,15 @@ bool bnorm_infer_nhwc_test(bool do_verification, ...@@ -259,7 +263,15 @@ bool bnorm_infer_nhwc_test(bool do_verification,
if(do_verification) if(do_verification)
{ {
auto batchNormInfer_ref = ReferenceBatchNormInferInstance<InOutDataType, AccDataType>{}; using ReferenceBatchNormInferInstance =
ck::tensor_operation::host::ReferenceBatchNormInfer_Input_N_H_W_C_Output_C<
InOutDataType,
InOutDataType,
AccDataType,
AccDataType,
AccDataType,
AccDataType>;
auto batchNormInfer_ref = ReferenceBatchNormInferInstance{};
auto argument_ptr_ref = auto argument_ptr_ref =
batchNormInfer_ref.MakeArgumentPointer(i_inOutLengths, batchNormInfer_ref.MakeArgumentPointer(i_inOutLengths,
...@@ -267,6 +279,8 @@ bool bnorm_infer_nhwc_test(bool do_verification, ...@@ -267,6 +279,8 @@ bool bnorm_infer_nhwc_test(bool do_verification,
i_inOutStrides, i_inOutStrides,
i_scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
x.mData.data(), x.mData.data(),
bnScale.mData.data(), bnScale.mData.data(),
bnBias.mData.data(), bnBias.mData.data(),
......
...@@ -159,6 +159,11 @@ ...@@ -159,6 +159,11 @@
// tuning parameter // tuning parameter
#define CK_WORKAROUND_SWDEV_325164 0 #define CK_WORKAROUND_SWDEV_325164 0
// workaround: disable broken fused attention kernel instance that does not pass validation
// issue found on mi100/#10738 combo when irregular KPerBlock attention kernel has acc0 scaling
// enabled
#define CK_WORKAROUND_DISABLE_BROKEN_ATTN_KERNEL_INSTANCE 1
namespace ck { namespace ck {
enum struct InMemoryDataOperationEnum enum struct InMemoryDataOperationEnum
......
...@@ -14,7 +14,8 @@ namespace ck { ...@@ -14,7 +14,8 @@ namespace ck {
template <typename TensorLengths, template <typename TensorLengths,
typename DimAccessOrder, typename DimAccessOrder,
typename ScalarsPerAccess> // # of scalars per access in each dimension typename ScalarsPerAccess,
bool SnakeCurved = true> // # of scalars per access in each dimension
struct SpaceFillingCurve struct SpaceFillingCurve
{ {
static constexpr index_t nDim = TensorLengths::Size(); static constexpr index_t nDim = TensorLengths::Size();
...@@ -136,9 +137,10 @@ struct SpaceFillingCurve ...@@ -136,9 +137,10 @@ struct SpaceFillingCurve
Index ordered_idx; Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto idim) { static_for<0, nDim, 1>{}([&](auto idim) {
ordered_idx(idim) = forward_sweep[idim] ? ordered_access_idx[idim] ordered_idx(idim) =
: ordered_access_lengths[idim] - 1 - !SnakeCurved || forward_sweep[idim]
ordered_access_idx[idim]; ? ordered_access_idx[idim]
: ordered_access_lengths[idim] - 1 - ordered_access_idx[idim];
}); });
return container_reorder_given_old2new(ordered_idx, dim_access_order) * return container_reorder_given_old2new(ordered_idx, dim_access_order) *
......
...@@ -151,6 +151,27 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -151,6 +151,27 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
return make_tuple(c_thread_m, c_thread_n); return make_tuple(c_thread_m, c_thread_n);
} }
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
return make_tuple(Number<m0>{},
Number<n0>{},
waveId_m,
waveId_n,
blk_idx[I0],
blk_idx[I1],
blk_idx[I2],
blk_idx[I3]);
}
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1() __host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
{ {
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
...@@ -724,6 +745,21 @@ struct BlockwiseGemmXdlops_v2 ...@@ -724,6 +745,21 @@ struct BlockwiseGemmXdlops_v2
return make_tuple(c_thread_m, c_thread_n); return make_tuple(c_thread_m, c_thread_n);
} }
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
return make_tuple(
m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
}
using Tuple4 = decltype(CalculateAThreadOriginDataIndex()); using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin = CalculateAThreadOriginDataIndex(), __host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
......
...@@ -24,7 +24,8 @@ template <typename ALayout, ...@@ -24,7 +24,8 @@ template <typename ALayout,
typename B0ElementwiseOperation, typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation, typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation,
bool MaskOutUpperTriangle> // TODO: enum for mask type
struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
......
...@@ -7,49 +7,60 @@ ...@@ -7,49 +7,60 @@
#include <vector> #include <vector>
#include "device_base.hpp" #include "device_base.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename ALayout, template <index_t NumDimG,
typename B0Layout, index_t NumDimM,
typename B1Layout, index_t NumDimN,
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<> index_t NumDimK,
index_t NumDimO,
typename ADataType, typename ADataType,
typename B0DataType, typename B0DataType,
typename B1DataType, typename B1DataType,
typename CDataType, typename CDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename B0ElementwiseOperation, typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation, typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
MakeArgumentPointer(const void* p_a, static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
const void* p_b0,
const void* p_b1, virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
void* p_c, const void* p_a,
ck::index_t M, const void* p_b0,
ck::index_t N, const void* p_b1,
ck::index_t K, void* p_c,
ck::index_t O, const std::array<void*, NumAcc0Bias> p_acc0_biases,
ck::index_t Batch, const std::array<void*, NumAcc1Bias> p_acc1_biases,
std::vector<index_t> c_gs_ms_os_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
std::vector<index_t> c_gs_ms_os_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
ck::index_t StrideA, const std::vector<index_t>& b_gs_ns_ks_lengths,
ck::index_t StrideB0, const std::vector<index_t>& b_gs_ns_ks_strides,
ck::index_t StrideB1, const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
ck::index_t BatchStrideA, const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
ck::index_t BatchStrideB0, const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
ck::index_t BatchStrideB1, const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
AElementwiseOperation a_element_op, const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
B0ElementwiseOperation b0_element_op, const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
Acc0ElementwiseOperation acc0_element_op, const std::array<std::vector<index_t>, NumAcc1Bias>
B1ElementwiseOperation b1_element_op, acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
CElementwiseOperation c_element_op) = 0; const std::array<std::vector<index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
...@@ -13,31 +13,36 @@ namespace ck { ...@@ -13,31 +13,36 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <index_t Rank, index_t NumBatchNormReduceDim> template <index_t Rank, index_t NumBatchNormReduceDim, typename YElementwiseOp>
struct DeviceBatchNormFwd : public BaseOperator struct DeviceBatchNormFwd : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer( virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const std::array<index_t, Rank> xyLengths, const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides, const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides, const std::array<index_t, Rank> yStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths, const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides, const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
const void* p_x, const void* p_x,
const void* bnScale, const void* bnScale,
const void* bnBias, const void* bnBias,
double epsilon,
const YElementwiseOp y_elementwise_op,
void* p_y, void* p_y,
void* resultSaveMean,
void* resultSaveInvVariance,
double exponentialAverageFactor, double exponentialAverageFactor,
void* resultRunningMean, void* resultRunningMean,
void* resultRunningVariance, void* resultRunningVariance) = 0;
double epsilon,
void* resultSaveMean,
void* resultSaveInvVariance) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <index_t Rank, index_t NumBatchNormReduceDim> template <index_t Rank, index_t NumBatchNormReduceDim, typename YElementwiseOp>
using DeviceBatchNormFwdPtr = std::unique_ptr<DeviceBatchNormFwd<Rank, NumBatchNormReduceDim>>; using DeviceBatchNormFwdPtr =
std::unique_ptr<DeviceBatchNormFwd<Rank, NumBatchNormReduceDim, YElementwiseOp>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -21,7 +21,9 @@ struct DeviceBatchNormInfer : public BaseOperator ...@@ -21,7 +21,9 @@ struct DeviceBatchNormInfer : public BaseOperator
const std::array<index_t, Rank> xStrides, const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides, const std::array<index_t, Rank> yStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths, const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides, const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
const void* p_x, const void* p_x,
const void* bnScale, const void* bnScale,
const void* bnBias, const void* bnBias,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// Convolution Forward:
// input : input image A[G, N, C, Hi, Wi],
// input : weight B[G, K, C, Y, X],
// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
// output : output image E[G, N, K, Ho, Wo]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGroupedConvFwd : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, // input image
const void* p_b, // weight
void* p_c, // output image
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <functional>
#include <iostream>
#include <iterator>
#include <numeric>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace {
struct ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC)
: BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC)
{
}
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideC_);
}
index_t BatchStrideA_;
index_t BatchStrideB_;
index_t BatchStrideC_;
};
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
* limitations.
*
* \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
* returns the 2D index of the tile that it computes. \see
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
*
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
* device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
* pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
* \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
*
*/
template <typename GridwiseGemm,
typename ABDataType,
typename CDataType,
typename AGridDesc_K0_M0_M1_K1,
typename BGridDesc_K0_N0_N1_K1,
typename CGridDesc_M0_M10_M11_N0_N10_N11,
typename Block2CTileMap,
typename ComputePtrOffsetOfBatch,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_conv_fwd_dl(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
CDataType* __restrict__ p_c_grid,
const index_t batch_count,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType);
__shared__ ABDataType p_shared[shared_block_size];
GridwiseGemm::Run(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_grid_desc_k0_m0_m1_k1,
b_grid_desc_k0_n0_n1_k1,
c_grid_desc_m0_m10_m11_n0_n10_n11,
block_2_ctile_map,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = batch_count;
ignore = a_grid_desc_k0_m0_m1_k1;
ignore = b_grid_desc_k0_n0_n1_k1;
ignore = c_grid_desc_m0_m10_m11_n0_n10_n11;
ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map;
compute_ptr_offset_of_batch.GetAPtrOffset(0);
compute_ptr_offset_of_batch.GetBPtrOffset(0);
compute_ptr_offset_of_batch.GetCPtrOffset(0);
#endif
}
} // namespace
//
// @brief Device Convolution operation.
//
// Supports:
// @li Forward convolution with up to 3 spatial dimentions
// @li Input tensor in GNWC data format
// @li Weight tensor in GKXC data format
// @li Output tensor in GNWK data format
//
// 1D:
// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
// 2D:
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
// 3D:
// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
//
template <
index_t NDimSpatial,
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
typename M1N1ThreadClusterM1Xs,
typename M1N1ThreadClusterN1Xs,
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false>
struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimSpatial,
ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
using DeviceOp = DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock};
template <typename ALay>
static auto
MakeAGridDescriptor_AK0_M_AK1(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
c_g_n_k_wos_lengths,
c_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
const auto M = in_gemmm_gemmk_desc.GetLength(I0);
const auto K = in_gemmm_gemmk_desc.GetLength(I1);
const auto AK0 = K / K1;
return transform_tensor_descriptor(
in_gemmm_gemmk_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0, K1)), make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <typename BLay>
static auto
MakeBGridDescriptor_BK0_N_BK1(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides);
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
const auto N = wei_gemmn_gemmk_desc.GetLength(I0);
const auto K = wei_gemmn_gemmk_desc.GetLength(I1);
const auto BK0 = K / K1;
return transform_tensor_descriptor(
wei_gemmn_gemmk_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, K1)), make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <typename CLay>
static auto
MakeCGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>(c_g_n_k_wos_lengths,
c_g_n_k_wos_strides);
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
return out_gemmm_gemmn_desc;
}
// desc for problem definition
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
MakeAGridDescriptor_AK0_M_AK1<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>;
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N<CLayout>({}, {}))>;
// GridwiseGemm
using GridwiseGemm =
GridwiseGemmDl_km_kn_mn_v1r3<BlockSize,
ADataType,
AccDataType,
CDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
using AGridDesc_K0_M0_M1_K1 =
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_AK0_M_AK1{}));
using BGridDesc_K0_N0_N1_K1 =
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_BK0_N_BK1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
using DefaultBlock2CTileMap =
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
// Argument
struct Argument : public BaseArgument
{
Argument(const void* p_a,
const void* p_b,
void* p_c,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a)},
p_b_grid_{static_cast<const BDataType*>(p_b)},
p_c_grid_{static_cast<CDataType*>(p_c)},
num_group_{a_g_n_c_wis_lengths[0]},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
c_g_n_k_wos_lengths,
c_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1<BLayout>(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N<CLayout>(c_g_n_k_wos_lengths,
c_g_n_k_wos_strides)},
a_grid_desc_k0_m0_m1_k1_{},
b_grid_desc_k0_n0_n1_k1_{},
c_grid_desc_m0_m10_m11_n0_n10_n11_{},
block_2_ctile_map_{},
compute_ptr_offset_of_batch_{
a_g_n_c_wis_strides[0], b_g_k_c_xs_strides[0], c_g_n_k_wos_strides[0]},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
c_g_n_k_wos_lengths_{c_g_n_k_wos_lengths},
c_g_n_k_wos_strides_{c_g_n_k_wos_strides},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
// A/B/E Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_batch_.BatchStrideC_ = c_g_n_k_wos_strides[0];
// populate desc for Ds/E
if(GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_))
{
a_grid_desc_k0_m0_m1_k1_ =
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_ak0_m_ak1_);
b_grid_desc_k0_n0_n1_k1_ =
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_bk0_n_bk1_);
c_grid_desc_m0_m10_m11_n0_n10_n11_ =
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_);
}
}
void Print() const
{
std::cout << "A[K0, M, K1]: " << a_grid_desc_ak0_m_ak1_ << std::endl;
std::cout << "B[K0, N, K1]: " << b_grid_desc_bk0_n_bk1_ << std::endl;
std::cout << "C[M, N]: " << c_grid_desc_m_n_ << std::endl;
std::cout << "num_group: " << num_group_ << std::endl;
std::cout << "A[k0, m0, m1, k1]: " << a_grid_desc_k0_m0_m1_k1_ << std::endl;
std::cout << "B[k0, n0, n1, k1]: " << b_grid_desc_k0_n0_n1_k1_ << std::endl;
std::cout << "A[m0, m10, m11, n0, n10, n11]: " << c_grid_desc_m0_m10_m11_n0_n10_n11_
<< std::endl;
}
// private:
// pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
// tensor descriptors for problem definiton
index_t num_group_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_;
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_;
CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_;
// block-to-e-tile map
DefaultBlock2CTileMap block_2_ctile_map_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
// for checking IsSupportedArgument()
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
std::array<index_t, NDimSpatial + 3> c_g_n_k_wos_lengths_;
std::array<index_t, NDimSpatial + 3> c_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial> conv_filter_strides_;
std::array<index_t, NDimSpatial> conv_filter_dilations_;
std::array<index_t, NDimSpatial> input_left_pads_;
std::array<index_t, NDimSpatial> input_right_pads_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
// if(stream_config.log_level_ > 0)
{
arg.Print();
}
if(!GridwiseGemm::CheckValidity(
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_))
{
throw std::runtime_error(
"wrong! DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK has invalid setting");
}
const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_.GetLength(I0),
arg.c_grid_desc_m_n_.GetLength(I1)) *
arg.num_group_;
auto launch_kernel = [&](auto has_main_k_block_loop,
auto has_double_tail_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
constexpr bool has_double_loop = has_double_tail_k_block_loop;
const auto kernel =
kernel_grouped_conv_fwd_dl<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
DeviceOp::AGridDesc_K0_M0_M1_K1,
DeviceOp::BGridDesc_K0_N0_N1_K1,
DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11,
DefaultBlock2CTileMap,
ComputePtrOffsetOfStridedBatch,
has_main_loop,
has_double_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_g_n_c_wis_lengths_[0], // Group count
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_,
arg.compute_ptr_offset_of_batch_);
};
const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0);
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
const bool has_double_tail_k_block_loop =
GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
return launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
return launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
return launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
namespace ctc = tensor_layout::convolution;
// check device
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030"))
{
return false;
}
// check ConvolutionForwardSpecialization
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t ConvStride = arg.conv_filter_strides_[i];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
{
std::cout << "Filter1x1Stride1Pad0 check: i = " << i << " X = " << X
<< " ConvStride = " << ConvStride << " LeftPad = " << LeftPad
<< " RightPad = " << RightPad << std::endl;
return false;
}
}
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// check if it's 1x1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
if(!(X == 1 && LeftPad == 0 && RightPad == 0))
{
std::cout << "Filter1x1Stride1Pad0 check: i = " << i << " X = " << X
<< " LeftPad = " << LeftPad << " RightPad = " << RightPad
<< std::endl;
return false;
}
}
}
// check vector access of A
// FIXME: layout
if constexpr(is_same_v<ALayout, ctc::G_NW_C> || is_same_v<ALayout, ctc::G_NHW_C> ||
is_same_v<ALayout, ctc::G_NDHW_C> || is_same_v<ALayout, ctc::GNWC> ||
is_same_v<ALayout, ctc::GNHWC> || is_same_v<ALayout, ctc::GNDHWC> ||
is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
is_same_v<ALayout, ctc::NDHWGC>)
{
auto srcVectorLengths = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1{};
if(srcVectorLengths[I1] != 1 || srcVectorLengths[I2] != 1)
{
return false;
}
if(K1 % srcVectorLengths[I3] != 0 || K0PerBlock % srcVectorLengths[I0] != 0)
{
return false;
}
const index_t C = arg.a_g_n_c_wis_lengths_[2];
if(C % (srcVectorLengths[I0] * srcVectorLengths[I3]) != 0)
{
return false;
}
}
else
{
return false;
}
// check vector access of B
// FIXME: layout
if constexpr(is_same_v<BLayout, ctc::G_K_X_C> || is_same_v<BLayout, ctc::G_K_YX_C> ||
is_same_v<BLayout, ctc::G_K_ZYX_C> || is_same_v<BLayout, ctc::GKXC> ||
is_same_v<BLayout, ctc::GKYXC> || is_same_v<BLayout, ctc::GKZYXC> ||
is_same_v<BLayout, ctc::KXGC> || is_same_v<BLayout, ctc::KYXGC> ||
is_same_v<BLayout, ctc::KZYXGC>)
{
auto srcVectorLengths = BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1{};
if(srcVectorLengths[I1] != 1 || srcVectorLengths[I2] != 1)
{
return false;
}
if(K1 % srcVectorLengths[I3] != 0 || K0PerBlock % srcVectorLengths[I0] != 0)
{
return false;
}
const index_t C = arg.b_g_k_c_xs_lengths_[2];
if(C % (srcVectorLengths[I0] * srcVectorLengths[I3]) != 0)
{
return false;
}
}
else
{
return false;
}
// check vector access of C
if constexpr(is_same_v<CLayout, ctc::G_NW_K> || is_same_v<CLayout, ctc::G_NHW_K> ||
is_same_v<CLayout, ctc::G_NDHW_K> || is_same_v<CLayout, ctc::GNWK> ||
is_same_v<CLayout, ctc::GNHWK> || is_same_v<CLayout, ctc::GNDHWK> ||
is_same_v<CLayout, ctc::NWGK> || is_same_v<CLayout, ctc::NHWGK> ||
is_same_v<CLayout, ctc::NDHWGK>)
{
const index_t K = arg.c_g_n_k_wos_lengths_[2];
if(!(K % CThreadTransferDstScalarPerVector == 0 && CThreadTransferSrcDstVectorDim == 5))
{
return false;
}
}
else
{
return false;
}
// check Gridwise GEMM
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_a,
const void* p_b,
void* p_c,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op)
{
return Argument{p_a,
p_b,
p_c,
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
c_g_n_k_wos_lengths,
c_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op) override
{
return std::make_unique<Argument>(p_a,
p_b,
p_c,
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
c_g_n_k_wos_lengths,
c_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_element_op,
b_element_op,
c_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization)
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -7,46 +7,50 @@ ...@@ -7,46 +7,50 @@
#include <vector> #include <vector>
#include "device_base.hpp" #include "device_base.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename ALayout, template <index_t NumDimG,
typename B0Layout, index_t NumDimM,
typename B1Layout, index_t NumDimN,
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<> index_t NumDimK,
index_t NumDimO,
typename ADataType, typename ADataType,
typename B0DataType, typename B0DataType,
typename B1DataType, typename B1DataType,
typename CDataType, typename CDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename B0ElementwiseOperation, typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation, typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator
{ {
struct ProblemDesc struct ProblemDesc
{ {
// Overall problem shape std::vector<index_t> a_gs_ms_ks_lengths;
index_t M; std::vector<index_t> a_gs_ms_ks_strides;
index_t N;
index_t K;
index_t O;
index_t Batch;
// Stride for A/B0/B1; layout determined by template args std::vector<index_t> b0_gs_ns_ks_lengths;
index_t StrideA; std::vector<index_t> b0_gs_ns_ks_strides;
index_t StrideB0;
index_t StrideB1; std::vector<index_t> b1_gs_os_ns_lengths;
index_t BatchStrideA; std::vector<index_t> b1_gs_os_ns_strides;
index_t BatchStrideB0;
index_t BatchStrideB1;
// Lengths and strides for output C
std::vector<index_t> c_gs_ms_os_lengths; std::vector<index_t> c_gs_ms_os_lengths;
std::vector<index_t> c_gs_ms_os_strides; std::vector<index_t> c_gs_ms_os_strides;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_lengths;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_strides;
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_lengths;
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_strides;
}; };
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
...@@ -54,6 +58,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator ...@@ -54,6 +58,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator
std::vector<const void*> p_b0_vec, std::vector<const void*> p_b0_vec,
std::vector<const void*> p_b1_vec, std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec, std::vector<void*> p_c_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
std::vector<ProblemDesc> problem_desc_vec, std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op, B0ElementwiseOperation b0_element_op,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -54,9 +55,8 @@ __global__ void ...@@ -54,9 +55,8 @@ __global__ void
index_t right = group_count; index_t right = group_count;
index_t group_id = index_t((left + right) / 2); index_t group_id = index_t((left + right) / 2);
while((!(block_id >= arg_ptr[group_id].block_start_ && while(
block_id < arg_ptr[group_id].block_end_)) && (!(block_id >= arg_ptr[group_id].block_start_ && block_id < arg_ptr[group_id].block_end_)))
left <= right)
{ {
if(block_id < arg_ptr[group_id].block_start_) if(block_id < arg_ptr[group_id].block_start_)
{ {
...@@ -114,14 +114,17 @@ __global__ void ...@@ -114,14 +114,17 @@ __global__ void
// Computes C = A * B0 * B1 // Computes C = A * B0 * B1
// ^^^^^^ (Acc0) // ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1) // ^^^^^^^^^^^ (Acc1)
template <typename ALayout, template <index_t NumDimG,
typename BLayout, // B0Layout index_t NumDimM,
typename B1Layout, index_t NumDimN,
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<NumDimG, NumDimM, NumDimGemm1N> index_t NumDimK,
index_t NumDimO, // NumDimGemm1N
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename B1DataType, typename B1DataType,
typename CDataType, typename CDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename GemmAccDataType, typename GemmAccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
...@@ -130,6 +133,10 @@ template <typename ALayout, ...@@ -130,6 +133,10 @@ template <typename ALayout,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
TensorSpecialization ASpec,
TensorSpecialization BSpec,
TensorSpecialization B1Spec,
TensorSpecialization CSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -170,297 +177,152 @@ template <typename ALayout, ...@@ -170,297 +177,152 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
bool MaskOutUpperTriangle, MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
: public DeviceGroupedGemmSoftmaxGemmPermute<ALayout, : public DeviceGroupedGemmSoftmaxGemmPermute<NumDimG,
BLayout, NumDimM,
B1Layout, NumDimN,
CPermuteNumDims_G_M_Gemm1N, NumDimK,
NumDimO,
ADataType, ADataType,
BDataType, BDataType,
B1DataType, B1DataType,
CDataType, CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation> CElementwiseOperation,
MaskingSpec>
{ {
using DeviceOp = DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle; static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
using ProblemDesc = "Number of dimension must be greater than 0");
typename DeviceGroupedGemmSoftmaxGemmPermute<ALayout,
BLayout, static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
B1Layout, static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
CPermuteNumDims_G_M_Gemm1N,
ADataType, // TODO ANT: implement bias combination
BDataType, static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
B1DataType,
CDataType, #if 0
AElementwiseOperation, // TODO ANT: use alias
BElementwiseOperation, static constexpr index_t NumDimGemm0M = NumDimM;
AccElementwiseOperation, static constexpr index_t NumDimGemm0N = NumDimN;
B1ElementwiseOperation, static constexpr index_t NumDimGemm0K = NumDimK;
CElementwiseOperation>::ProblemDesc; static constexpr index_t NumDimGemm1M = NumDimM;
static constexpr index_t NumDimGemm1N = NumDimO;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using DeviceOp = DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle;
using ProblemDesc = typename DeviceGroupedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
BDataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
MaskingSpec>::ProblemDesc;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto matrix_padder = using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{ Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock}; Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
GemmSpec,
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) ASpec,
{ BSpec,
const auto a_grid_desc_mraw_kraw = [&]() { B1Spec,
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>) CSpec>;
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
make_tuple(StrideA, I1)); const std::vector<index_t>& a_gs_ms_ks_strides_vec)
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
}
}();
const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1
static auto MakeB1GridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
{ {
const auto b1_grid_desc_nraw_kraw = [&]() { return Transform::MakeAGridDescriptor_AK0_M_AK1(
if constexpr(is_same<tensor_layout::gemm::RowMajor, B1Layout>::value) Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
{ Number<AK1>{});
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, B1Layout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw);
const auto N = b1_grid_desc_n_k.GetLength(I0);
const auto K = b1_grid_desc_n_k.GetLength(I1);
const auto B1K0 = K / B1K1;
return transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
static auto MakeCGridDescriptor_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec, const std::vector<index_t>& b_gs_ns_ks_strides_vec)
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
{ {
constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0); return Transform::MakeB0GridDescriptor_BK0_N_BK1(
constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1); Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec),
constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N Number<BK1>{});
assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
const auto to_tuple = [&](auto& vec, auto start, auto end) {
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
};
const auto c_ms_ns_lengths = to_tuple(
c_gs_ms_ns_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
const auto c_ms_ns_strides = to_tuple(
c_gs_ms_ns_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
// dimension Ids for M0, M1, ...
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
// dimension Ids for N0, N1, ...
constexpr auto nDimIds =
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimN, 1>::type{};
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(c_ms_ns_lengths, mDimIds);
// lengths for K0, K1, ...
const auto nLengths = get_container_subset(c_ms_ns_lengths, nDimIds);
// naive tensor C[M0, M1, M2, ..., N0, N1, N2...]
const auto c_grid_desc_ms_ns =
make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides);
// transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
const auto c_grid_desc_mraw_nraw = transform_tensor_descriptor(
c_grid_desc_ms_ns,
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
make_tuple(mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
} }
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] static auto
static auto MakeCGridDescriptor_G_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec, MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
const std::vector<index_t>& c_gs_ms_ns_strides_vec) const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec)
{ {
constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0); return Transform::MakeB1GridDescriptor_BK0_N_BK1(
constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1); Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec,
constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N b1_gs_gemm1ns_gemm1ks_strides_vec),
Number<B1K1>{});
assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
const auto to_tuple = [&](auto& vec, auto start, auto end) {
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
};
const auto c_gs_ms_ns_lengths =
to_tuple(c_gs_ms_ns_lengths_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
const auto c_gs_ms_ns_strides =
to_tuple(c_gs_ms_ns_strides_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
// dimension Ids for G0, G1, ...
constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{};
// dimension Ids for M0, M1, ...
constexpr auto mDimIds =
typename arithmetic_sequence_gen<NumDimG, NumDimG + NumDimM, 1>::type{};
// dimension Ids for N0, N1, ...
constexpr auto nDimIds = typename arithmetic_sequence_gen<NumDimG + NumDimM,
NumDimG + NumDimM + NumDimN,
1>::type{};
// lengths for G0, G1, ...
const auto gLengths = get_container_subset(c_gs_ms_ns_lengths, gDimIds);
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(c_gs_ms_ns_lengths, mDimIds);
// lengths for K0, K1, ...
const auto nLengths = get_container_subset(c_gs_ms_ns_lengths, nDimIds);
// naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
const auto c_grid_desc_gs_ms_ns =
make_naive_tensor_descriptor(c_gs_ms_ns_lengths, c_gs_ms_ns_strides);
// transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
const auto c_grid_desc_g_mraw_nraw =
transform_tensor_descriptor(c_grid_desc_gs_ms_ns,
make_tuple(make_merge_transform(gLengths),
make_merge_transform(mLengths),
make_merge_transform(nLengths)),
make_tuple(gDimIds, mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// this desc is only for calculating batch offset so no padding needed
return c_grid_desc_g_mraw_nraw;
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1)); using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N({}, {})); using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using CGridDesc_G_M_N = decltype(MakeCGridDescriptor_G_M_N({}, {})); using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
// to track the points which need to be set to -inf on C0 constexpr static auto make_MaskOutPredicate()
// Note: no need to reset M padding value, because they will not be stored out.
struct C0MatrixMask
{ {
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {} if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled)
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
{ {
return n >= NRaw_; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
{ {
return IsUpperTriangle(m, n) || IsNOutOfBound(n); return MaskOutUpperTrianglePredicate{};
} }
}
private: using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
// index_t MRaw_;
index_t NRaw_;
};
struct ComputeBasePtrOfStridedBatch struct ComputeBasePtrOfStridedBatch
{ {
ComputeBasePtrOfStridedBatch(index_t BatchStrideA, ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
index_t BatchStrideB, const BGridDesc_G_N_K& b_grid_desc_g_n_k,
index_t BatchStrideB1, const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
CGridDesc_G_M_N c_grid_desc_g_m_n) const CGridDesc_G_M_N& c_grid_desc_g_m_n)
: BatchStrideA_(BatchStrideA), : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
BatchStrideB_(BatchStrideB), b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
BatchStrideB1_(BatchStrideB1), b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n) c_grid_desc_g_m_n_(c_grid_desc_g_m_n)
{ {
} }
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideA_); return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideB_); return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideB1_); return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
...@@ -469,9 +331,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -469,9 +331,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
} }
private: private:
index_t BatchStrideA_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
index_t BatchStrideB_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
index_t BatchStrideB1_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
}; };
...@@ -535,8 +397,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -535,8 +397,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskOutUpperTriangle>; MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>; using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
...@@ -570,16 +432,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -570,16 +432,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
struct GroupDeviceArg struct GroupDeviceArg
{ {
// problem definiton // lengths for the last dimensions of overall problem for sanity check of vector load/store
index_t M; std::vector<index_t> raw_lengths_mz_nz_kz_gemm1nz_;
index_t N;
index_t K;
index_t O;
// Strides for the last dimensions of C for sanity check of vector load/store // strides for the last dimensions of each tensor for sanity check of vector load/store
index_t c_extent_lowest_; std::vector<index_t> a_mz_kz_strides_;
index_t c_stride_lowest_; std::vector<index_t> b_nz_kz_strides_;
std::vector<index_t> b1_nz_kz_strides_;
std::vector<index_t> c_mz_gemm1nz_strides_;
// for gridwise gemm check
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
}; };
...@@ -591,6 +453,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -591,6 +453,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
std::vector<const void*> p_b_vec, std::vector<const void*> p_b_vec,
std::vector<const void*> p_b1_vec, std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec, std::vector<void*> p_c_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
std::vector<ProblemDesc> problem_desc_vec, std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -603,6 +467,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -603,6 +467,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op}
{ {
// TODO ANT: implement bias addition
group_count_ = problem_desc_vec.size(); group_count_ = problem_desc_vec.size();
if(!(group_count_ == p_a_vec.size() && group_count_ == p_b_vec.size() && if(!(group_count_ == p_a_vec.size() && group_count_ == p_b_vec.size() &&
...@@ -611,6 +476,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -611,6 +476,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size"); throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size");
} }
if(!(p_acc0_biases_vec.size() == p_acc1_biases_vec.size()))
{
throw std::runtime_error("wrong! acc0_bias_vec.size != acc1_bias_vec.size");
}
grid_size_ = 0; grid_size_ = 0;
for(std::size_t i = 0; i < group_count_; i++) for(std::size_t i = 0; i < group_count_; i++)
...@@ -620,14 +490,25 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -620,14 +490,25 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]); const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]);
const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]); const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]);
const auto a_grid_desc_ak0_m_ak1 = DeviceOp::MakeAGridDescriptor_AK0_M_AK1( const auto& problem_desc = problem_desc_vec[i];
problem_desc_vec[i].M, problem_desc_vec[i].K, problem_desc_vec[i].StrideA);
const auto b_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1( const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem_desc_vec[i].K, problem_desc_vec[i].N, problem_desc_vec[i].StrideB0); problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b1_grid_desc_bk0_n_bk1 = DeviceOp::MakeB1GridDescriptor_BK0_N_BK1( const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem_desc_vec[i].N, problem_desc_vec[i].O, problem_desc_vec[i].StrideB1); problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides);
const auto c_grid_desc_m_n = DeviceOp::MakeCGridDescriptor_M_N( const auto b1_grid_desc_bk0_n_bk1 = MakeB1GridDescriptor_BK0_N_BK1(
problem_desc_vec[i].c_gs_ms_os_lengths, problem_desc_vec[i].c_gs_ms_os_strides); problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
const auto c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(
problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
const auto a_grid_desc_g_m_k = Transform::MakeAGridDescriptor_G_M_K(
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_g_n_k = Transform::MakeB0GridDescriptor_G_N_K(
problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides);
const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
...@@ -635,25 +516,32 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -635,25 +516,32 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
const index_t BlockStart = grid_size_; const index_t BlockStart = grid_size_;
const auto block_2_ctile_map = Block2CTileMap(c_grid_desc_m_n, BlockStart); const auto block_2_ctile_map = Block2CTileMap(c_grid_desc_m_n, BlockStart);
const index_t grid_size_grp = block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) * const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0);
problem_desc_vec[i].Batch; const index_t grid_size_grp =
block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) * batch_count;
const index_t BlockEnd = grid_size_ + grid_size_grp; const index_t BlockEnd = grid_size_ + grid_size_grp;
// batch stride // batch stride
// TODO ANT: only keep batch stride in tensor desc to reduce scalar cache pressure const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch(
const auto c_grid_desc_g_m_n = DeviceOp::MakeCGridDescriptor_G_M_N( a_grid_desc_g_m_k, b_grid_desc_g_n_k, b1_grid_desc_g_n_k, c_grid_desc_g_m_n);
problem_desc_vec[i].c_gs_ms_os_lengths, problem_desc_vec[i].c_gs_ms_os_strides);
const auto compute_base_ptr_of_batch =
ComputeBasePtrOfStridedBatch(problem_desc_vec[i].BatchStrideA,
problem_desc_vec[i].BatchStrideB0,
problem_desc_vec[i].BatchStrideB1,
c_grid_desc_g_m_n);
// C0 mask // C0 mask
const auto c0_matrix_mask = C0MatrixMask(problem_desc_vec[i].N); const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1));
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
// for each group, make sure acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias and
// so on
if(!(problem_desc.acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias &&
problem_desc.acc0_biases_gs_ms_ns_strides.size() == NumAcc0Bias &&
problem_desc.acc1_biases_gs_ms_os_lengths.size() == NumAcc1Bias &&
problem_desc.acc1_biases_gs_ms_os_strides.size() == NumAcc1Bias))
{
throw std::runtime_error(
"wrong! number of biases in function argument does not "
"match that in template argument");
}
group_kernel_args_.push_back({p_a_grid, group_kernel_args_.push_back({p_a_grid,
p_b_grid, p_b_grid,
p_b1_grid, p_b1_grid,
...@@ -669,13 +557,20 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -669,13 +557,20 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
BlockStart, BlockStart,
BlockEnd}); BlockEnd});
group_device_args_.push_back({problem_desc_vec[i].M, group_device_args_.push_back(
problem_desc_vec[i].N, {{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
problem_desc_vec[i].K, problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
problem_desc_vec[i].O, problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1],
problem_desc_vec[i].c_gs_ms_os_lengths.back(), problem_desc.b1_gs_os_ns_lengths[NumDimG + NumDimO - 1]},
problem_desc_vec[i].c_gs_ms_os_strides.back(), {problem_desc.a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
c_grid_desc_m_n}); problem_desc.a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]},
{problem_desc.b0_gs_ns_ks_strides[NumDimG + NumDimN - 1],
problem_desc.b0_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]},
{problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO - 1],
problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO + NumDimN - 1]},
{problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1],
problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]},
c_grid_desc_m_n});
} }
} }
...@@ -788,6 +683,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -788,6 +683,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
return false; return false;
} }
// TODO ANT: Check if tensor specialization & strides mismatch
bool all_has_main_k_block_loop = true; bool all_has_main_k_block_loop = true;
bool some_has_main_k_block_loop = false; bool some_has_main_k_block_loop = false;
...@@ -815,19 +712,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -815,19 +712,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
// Note: we need raw lengths since threadwise copy can not handle vector load when // Note: we need raw lengths since threadwise copy can not handle vector load when
// part of vector is out of bounds // part of vector is out of bounds
const auto MRaw = device_arg.M; const auto MzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[0];
const auto NRaw = device_arg.N; const auto NzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[1];
const auto KRaw = device_arg.K; const auto KzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[2];
const auto Gemm1NRaw = device_arg.O; const auto Gemm1NzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[3];
// Check scalar per vector requirement // Check scalar per vector requirement
const auto a_extent_lowest = const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw; const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw;
const auto b_extent_lowest = const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? NzRaw : Gemm1NzRaw;
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw; const auto c_extent_lowest = Gemm1NzRaw;
const auto b1_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
const auto c_extent_lowest = device_arg.c_extent_lowest_;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 && b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
...@@ -837,8 +731,22 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -837,8 +731,22 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
return false; return false;
} }
// Check vector store requirement; assumes last dimension in N to be contiguous // Check vector load/store requirement
if(device_arg.c_stride_lowest_ != 1) const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
? device_arg.a_mz_kz_strides_[1]
: device_arg.a_mz_kz_strides_[0];
const auto b_stride_lowest = BBlockTransferSrcVectorDim == 2
? device_arg.b_nz_kz_strides_[1]
: device_arg.b_nz_kz_strides_[0];
const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2
? device_arg.b1_nz_kz_strides_[1]
: device_arg.b1_nz_kz_strides_[0];
const auto c_stride_lowest =
device_arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be
// contiguous
if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 ||
c_stride_lowest == 1))
{ {
return false; return false;
} }
...@@ -873,6 +781,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -873,6 +781,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
std::vector<const void*> p_b_vec, std::vector<const void*> p_b_vec,
std::vector<const void*> p_b1_vec, std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec, std::vector<void*> p_c_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
std::vector<ProblemDesc> problem_desc_vec, std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -884,6 +794,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -884,6 +794,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
p_b_vec, p_b_vec,
p_b1_vec, p_b1_vec,
p_c_vec, p_c_vec,
p_acc0_biases_vec,
p_acc1_biases_vec,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -895,21 +807,26 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -895,21 +807,26 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*> p_a_vec, std::unique_ptr<BaseArgument>
std::vector<const void*> p_b_vec, MakeArgumentPointer(std::vector<const void*> p_a_vec,
std::vector<const void*> p_b1_vec, std::vector<const void*> p_b_vec,
std::vector<void*> p_c_vec, std::vector<const void*> p_b1_vec,
std::vector<ProblemDesc> problem_desc_vec, std::vector<void*> p_c_vec,
AElementwiseOperation a_element_op, std::vector<std::vector<const void*>> p_acc0_biases_vec,
BElementwiseOperation b_element_op, std::vector<std::vector<const void*>> p_acc1_biases_vec,
AccElementwiseOperation acc_element_op, std::vector<ProblemDesc> problem_desc_vec,
B1ElementwiseOperation b1_element_op, AElementwiseOperation a_element_op,
CElementwiseOperation c_element_op) override BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) override
{ {
return std::make_unique<Argument>(p_a_vec, return std::make_unique<Argument>(p_a_vec,
p_b_vec, p_b_vec,
p_b1_vec, p_b1_vec,
p_c_vec, p_c_vec,
p_acc0_biases_vec,
p_acc1_biases_vec,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -942,7 +859,12 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -942,7 +859,12 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
<< Gemm1NPerBlock << ", " << Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", " << Gemm1KPerBlock << ", "
<< B1K1 << ", " << B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ">"; << getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
<< "B0Spec" << getTensorSpecializationString(BSpec) << ", "
<< "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
<< "CSpec" << getTensorSpecializationString(CSpec) << ", "
<< getMaskingSpecializationString(MaskingSpec) << ">";
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -130,8 +130,11 @@ namespace device { ...@@ -130,8 +130,11 @@ namespace device {
// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] // D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] // E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// FIXME: TensorSpecialization::Packed specialization does not cover all packed tensor cases, it // NOTE: TensorSpecialization::Packed specialized tensor is "packed" in a sense that each inner
// merely degenerates into TensorSpecialization::Default with NumDimG/M/N/K = 1 // dimension in a dimension group (eg [G0, G1] in Gs, [M0, M1, M2] in Ms, etc.) are contiguous and
// ordered. Not in a sense that the tensor [G0, G1, ..., M0, M1, ..., N0, N1...] can be permuted
// while still being a contiguous, unpadded tensor. In other words, it merely degenerates into
// TensorSpecialization::Default with NumDimG/M/N/K = 1
// //
// Detail- Packed tensor satisfies // Detail- Packed tensor satisfies
// stride_0 = 1 // stride_0 = 1
...@@ -147,7 +150,7 @@ namespace device { ...@@ -147,7 +150,7 @@ namespace device {
// essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1. // essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1.
// //
// Might need to expose dimension order to the interface to fully support // Might need to expose dimension order to the interface to fully support
// TensorSpecialization::Packed. // TensorSpecialization::Packed in a traditional sense of "packed" tensor
template <index_t NumDimG, template <index_t NumDimG,
index_t NumDimM, index_t NumDimM,
index_t NumDimN, index_t NumDimN,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -116,14 +117,17 @@ __global__ void ...@@ -116,14 +117,17 @@ __global__ void
// Computes C = A * B0 * B1 // Computes C = A * B0 * B1
// ^^^^^^ (Acc0) // ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1) // ^^^^^^^^^^^ (Acc1)
template <typename ALayout, template <index_t NumDimG,
typename BLayout, // B0Layout index_t NumDimM,
typename B1Layout, index_t NumDimN,
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<NumDimG, NumDimM, NumDimGemm1N> index_t NumDimK,
index_t NumDimO, // NumDimGemm1N
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename B1DataType, typename B1DataType,
typename CDataType, typename CDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename GemmAccDataType, typename GemmAccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
...@@ -132,6 +136,10 @@ template <typename ALayout, ...@@ -132,6 +136,10 @@ template <typename ALayout,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
TensorSpecialization ASpec,
TensorSpecialization BSpec,
TensorSpecialization B1Spec,
TensorSpecialization CSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -172,283 +180,135 @@ template <typename ALayout, ...@@ -172,283 +180,135 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
bool MaskOutUpperTriangle, MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
: public DeviceBatchedGemmSoftmaxGemmPermute<ALayout, : public DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
BLayout, NumDimM,
B1Layout, NumDimN,
CPermuteNumDims_G_M_Gemm1N, NumDimK,
NumDimO,
ADataType, ADataType,
BDataType, BDataType,
B1DataType, B1DataType,
CDataType, CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation> CElementwiseOperation,
MaskingSpec>
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0");
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
// TODO ANT: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
#if 0
// TODO ANT: use alias
static constexpr index_t NumDimGemm0M = NumDimM;
static constexpr index_t NumDimGemm0N = NumDimN;
static constexpr index_t NumDimGemm0K = NumDimK;
static constexpr index_t NumDimGemm1M = NumDimM;
static constexpr index_t NumDimGemm1N = NumDimO;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle; using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto matrix_padder = using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{ Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock}; Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
GemmSpec,
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) ASpec,
BSpec,
B1Spec,
CSpec>;
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
{ {
const auto a_grid_desc_mraw_kraw = [&]() { return Transform::MakeAGridDescriptor_AK0_M_AK1(
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>) Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
{ Number<AK1>{});
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
}
}();
const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1
static auto MakeB1GridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b1_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, B1Layout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, B1Layout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw);
const auto N = b1_grid_desc_n_k.GetLength(I0);
const auto K = b1_grid_desc_n_k.GetLength(I1);
const auto B1K0 = K / B1K1;
return transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
static auto MakeCGridDescriptor_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec, const std::vector<index_t>& b_gs_ns_ks_strides_vec)
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
{ {
constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0); return Transform::MakeB0GridDescriptor_BK0_N_BK1(
constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1); Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec),
constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N Number<BK1>{});
assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
const auto to_tuple = [&](auto& vec, auto start, auto end) {
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
};
const auto c_ms_ns_lengths = to_tuple(
c_gs_ms_ns_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
const auto c_ms_ns_strides = to_tuple(
c_gs_ms_ns_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
// dimension Ids for M0, M1, ...
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
// dimension Ids for N0, N1, ...
constexpr auto nDimIds =
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimN, 1>::type{};
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(c_ms_ns_lengths, mDimIds);
// lengths for K0, K1, ...
const auto nLengths = get_container_subset(c_ms_ns_lengths, nDimIds);
// naive tensor C[M0, M1, M2, ..., N0, N1, N2...]
const auto c_grid_desc_ms_ns =
make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides);
// transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
const auto c_grid_desc_mraw_nraw = transform_tensor_descriptor(
c_grid_desc_ms_ns,
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
make_tuple(mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
} }
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] static auto
static auto MakeCGridDescriptor_G_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec, MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
const std::vector<index_t>& c_gs_ms_ns_strides_vec) const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec)
{ {
constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0); return Transform::MakeB1GridDescriptor_BK0_N_BK1(
constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1); Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec,
constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N b1_gs_gemm1ns_gemm1ks_strides_vec),
Number<B1K1>{});
assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
const auto to_tuple = [&](auto& vec, auto start, auto end) {
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
};
const auto c_gs_ms_ns_lengths =
to_tuple(c_gs_ms_ns_lengths_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
const auto c_gs_ms_ns_strides =
to_tuple(c_gs_ms_ns_strides_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
// dimension Ids for G0, G1, ...
constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{};
// dimension Ids for M0, M1, ...
constexpr auto mDimIds =
typename arithmetic_sequence_gen<NumDimG, NumDimG + NumDimM, 1>::type{};
// dimension Ids for N0, N1, ...
constexpr auto nDimIds = typename arithmetic_sequence_gen<NumDimG + NumDimM,
NumDimG + NumDimM + NumDimN,
1>::type{};
// lengths for G0, G1, ...
const auto gLengths = get_container_subset(c_gs_ms_ns_lengths, gDimIds);
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(c_gs_ms_ns_lengths, mDimIds);
// lengths for K0, K1, ...
const auto nLengths = get_container_subset(c_gs_ms_ns_lengths, nDimIds);
// naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
const auto c_grid_desc_gs_ms_ns =
make_naive_tensor_descriptor(c_gs_ms_ns_lengths, c_gs_ms_ns_strides);
// transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
const auto c_grid_desc_g_mraw_nraw =
transform_tensor_descriptor(c_grid_desc_gs_ms_ns,
make_tuple(make_merge_transform(gLengths),
make_merge_transform(mLengths),
make_merge_transform(nLengths)),
make_tuple(gDimIds, mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// this desc is only for calculating batch offset so no padding needed
return c_grid_desc_g_mraw_nraw;
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1)); using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N({}, {})); using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using CGridDesc_G_M_N = decltype(MakeCGridDescriptor_G_M_N({}, {})); using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
// to track the points which need to be set to -inf on C0 constexpr static auto make_MaskOutPredicate()
// Note: no need to reset M padding value, because they will not be stored out.
struct C0MatrixMask
{ {
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {} if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled)
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
{ {
return n >= NRaw_; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
{ {
return IsUpperTriangle(m, n) || IsNOutOfBound(n); return MaskOutUpperTrianglePredicate{};
} }
}
private: using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
// index_t MRaw_;
index_t NRaw_;
};
struct ComputeBasePtrOfStridedBatch struct ComputeBasePtrOfStridedBatch
{ {
ComputeBasePtrOfStridedBatch(index_t BatchStrideA, ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
index_t BatchStrideB, const BGridDesc_G_N_K& b_grid_desc_g_n_k,
index_t BatchStrideB1, const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
CGridDesc_G_M_N c_grid_desc_g_m_n) const CGridDesc_G_M_N& c_grid_desc_g_m_n)
: BatchStrideA_(BatchStrideA), : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
BatchStrideB_(BatchStrideB), b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
BatchStrideB1_(BatchStrideB1), b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n) c_grid_desc_g_m_n_(c_grid_desc_g_m_n)
{ {
} }
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideA_); return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideB_); return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideB1_); return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
...@@ -457,9 +317,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -457,9 +317,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
} }
private: private:
index_t BatchStrideA_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
index_t BatchStrideB_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
index_t BatchStrideB1_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
}; };
...@@ -523,47 +383,59 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -523,47 +383,59 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskOutUpperTriangle>; MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>;
// Argument // Argument
// FIXME: constness // FIXME: constness
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const ADataType* p_a_grid, Argument(
const BDataType* p_b_grid, const ADataType* p_a_grid,
const B1DataType* p_b1_grid, const BDataType* p_b_grid,
CDataType* p_c_grid, const B1DataType* p_b1_grid,
index_t MRaw, CDataType* p_c_grid,
index_t NRaw, const std::array<void*, NumAcc0Bias> p_acc0_biases,
index_t KRaw, const std::array<void*, NumAcc1Bias> p_acc1_biases,
index_t Gemm1NRaw, // = ORaw const std::vector<index_t>& a_gs_ms_ks_lengths,
index_t Batch, const std::vector<index_t>& a_gs_ms_ks_strides,
std::vector<index_t> c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& b_gs_ns_ks_lengths,
std::vector<index_t> c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& b_gs_ns_ks_strides,
index_t StrideA, const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
index_t StrideB, const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
index_t StrideB1, const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
index_t BatchStrideA, const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
index_t BatchStrideB, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
index_t BatchStrideB1, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
AElementwiseOperation a_element_op, const std::array<std::vector<ck::index_t>, NumAcc1Bias>
BElementwiseOperation b_element_op, acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
AccElementwiseOperation acc_element_op, const std::array<std::vector<ck::index_t>, NumAcc1Bias>
B1ElementwiseOperation b1_element_op, acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
CElementwiseOperation c_element_op) AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_b1_grid_{p_b1_grid}, p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, a_grid_desc_ak0_m_ak1_{
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b1_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(NRaw, Gemm1NRaw, StrideB1)}, DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(
c_gs_ms_gemm1ns_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_g_m_n_{DeviceOp::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths, c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
a_grid_desc_g_m_k_{
Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_g_n_k_{
Transform::MakeB0GridDescriptor_G_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
b1_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
...@@ -571,14 +443,31 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -571,14 +443,31 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
acc_element_op_{acc_element_op}, acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
batch_count_(Batch), c0_matrix_mask_{b_grid_desc_g_n_k_.GetLength(I1)},
raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1],
b1_gs_gemm1ns_gemm1ks_lengths[NumDimG + NumDimO - 1]},
a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]},
b_nz_kz_strides_{b_gs_ns_ks_strides[NumDimG + NumDimN - 1],
b_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]},
b1_nz_kz_strides_{b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO - 1],
b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]},
c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)},
compute_base_ptr_of_batch_{ compute_base_ptr_of_batch_{
BatchStrideA, BatchStrideB, BatchStrideB1, c_grid_desc_g_m_n_}, a_grid_desc_g_m_k_, b_grid_desc_g_n_k_, b1_grid_desc_g_n_k_, c_grid_desc_g_m_n_}
c0_matrix_mask_{NRaw},
raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw},
c_extent_lowest_{c_gs_ms_gemm1ns_lengths.back()},
c_stride_lowest_{c_gs_ms_gemm1ns_strides.back()}
{ {
// TODO ANT: implement bias addition
ignore = p_acc0_biases;
ignore = p_acc1_biases;
ignore = acc0_biases_gs_ms_ns_lengths;
ignore = acc0_biases_gs_ms_ns_strides;
ignore = acc1_biases_gs_ms_gemm1ns_lengths;
ignore = acc1_biases_gs_ms_gemm1ns_strides;
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
b1_grid_desc_bk0_n_bk1_, b1_grid_desc_bk0_n_bk1_,
...@@ -591,34 +480,66 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -591,34 +480,66 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
} }
} }
// private: void Print() const
{
std::cout << "a_grid_desc_g_m_k_: " << a_grid_desc_g_m_k_.GetLength(I0) << ", "
<< a_grid_desc_g_m_k_.GetLength(I1) << ", "
<< a_grid_desc_g_m_k_.GetLength(I2) << '\n';
// a_grid_desc_g_m_k_.Print();
std::cout << "b_grid_desc_g_n_k_: " << b_grid_desc_g_n_k_.GetLength(I0) << ", "
<< b_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b_grid_desc_g_n_k_.GetLength(I2) << '\n';
// b_grid_desc_g_n_k_.Print();
std::cout << "b1_grid_desc_g_n_k_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I2) << '\n';
// b1_grid_desc_g_n_k_.Print();
std::cout << "c_grid_desc_g_m_n_: " << c_grid_desc_g_m_n_.GetLength(I0) << ", "
<< c_grid_desc_g_m_n_.GetLength(I1) << ", "
<< c_grid_desc_g_m_n_.GetLength(I2) << '\n';
// c_grid_desc_g_m_n_.Print();
}
// pointers
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
const B1DataType* p_b1_grid_; const B1DataType* p_b1_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
// tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
// element-wise op
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
AccElementwiseOperation acc_element_op_; AccElementwiseOperation acc_element_op_;
B1ElementwiseOperation b1_element_op_; B1ElementwiseOperation b1_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
// check C0 masking and padding // check C0 masking and padding
C0MatrixMask c0_matrix_mask_; C0MatrixMask c0_matrix_mask_;
// For robust IsSupportedArgument() check // For robust IsSupportedArgument() check
std::vector<index_t> raw_lengths_m_n_k_o_; std::vector<index_t> raw_lengths_mz_nz_kz_gemm1nz_;
index_t c_extent_lowest_; std::vector<index_t> a_mz_kz_strides_;
index_t c_stride_lowest_; std::vector<index_t> b_nz_kz_strides_;
std::vector<index_t> b1_nz_kz_strides_;
std::vector<index_t> c_mz_gemm1nz_strides_;
index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
}; };
// Invoker // Invoker
...@@ -628,13 +549,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -628,13 +549,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, if(!DeviceOp::IsSupportedArgument(arg))
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{ {
throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); throw std::runtime_error("wrong! unsupported argument");
} }
const index_t grid_size = const index_t grid_size =
...@@ -719,17 +636,24 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -719,17 +636,24 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
#if 0
arg.Print();
#endif
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{ {
return false; return false;
} }
// TODO ANT: Check if tensor specialization & strides mismatch
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0); const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0);
const index_t c_gemm1n = arg.c_grid_desc_m_n_.GetLength(I1); const index_t c_gemm1n = arg.c_grid_desc_m_n_.GetLength(I1);
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1); const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n)) if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n))
{ {
return false; return false;
...@@ -737,19 +661,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -737,19 +661,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// Note: we need raw lengths since threadwise copy can not handle vector load when part of // Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds // vector is out of bounds
const auto MRaw = arg.raw_lengths_m_n_k_o_[0]; // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
const auto NRaw = arg.raw_lengths_m_n_k_o_[1]; const auto MzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[0];
const auto KRaw = arg.raw_lengths_m_n_k_o_[2]; const auto NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[1];
const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3]; const auto KzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[2];
const auto Gemm1NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[3];
// Check scalar per vector requirement // Check scalar per vector requirement
const auto a_extent_lowest = const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw; const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw;
const auto b_extent_lowest = const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? NzRaw : Gemm1NzRaw;
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw; const auto c_extent_lowest = Gemm1NzRaw;
const auto b1_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
const auto c_extent_lowest = arg.c_extent_lowest_;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 && b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
...@@ -759,8 +681,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -759,8 +681,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return false; return false;
} }
// Check vector store requirement; assumes last dimension in N to be contiguous // Check vector load/store requirement
if(arg.c_stride_lowest_ != 1) const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
const auto b_stride_lowest =
BBlockTransferSrcVectorDim == 2 ? arg.b_nz_kz_strides_[1] : arg.b_nz_kz_strides_[0];
const auto b1_stride_lowest =
B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_kz_strides_[1] : arg.b1_nz_kz_strides_[0];
const auto c_stride_lowest =
arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be contiguous
if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 ||
c_stride_lowest == 1))
{ {
return false; return false;
} }
...@@ -778,46 +710,51 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -778,46 +710,51 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(
const BDataType* p_b, const ADataType* p_a,
const B1DataType* p_b1, const BDataType* p_b,
CDataType* p_c, const B1DataType* p_b1,
index_t MRaw, CDataType* p_c,
index_t NRaw, const std::array<void*, NumAcc0Bias> p_acc0_biases,
index_t KRaw, const std::array<void*, NumAcc1Bias> p_acc1_biases,
index_t Gemm1NRaw, const std::vector<index_t>& a_gs_ms_ks_lengths,
index_t Batch, const std::vector<index_t>& a_gs_ms_ks_strides,
std::vector<index_t> c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& b_gs_ns_ks_lengths,
std::vector<index_t> c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& b_gs_ns_ks_strides,
index_t StrideA, const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
index_t StrideB, const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
index_t StrideB1, const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
index_t BatchStrideA, const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
index_t BatchStrideB, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
index_t BatchStrideB1, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
AElementwiseOperation a_element_op, const std::array<std::vector<ck::index_t>, NumAcc1Bias>
BElementwiseOperation b_element_op, acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
AccElementwiseOperation acc_element_op, const std::array<std::vector<ck::index_t>, NumAcc1Bias>
B1ElementwiseOperation b1_element_op, acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
CElementwiseOperation c_element_op) AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
p_b1, p_b1,
p_c, p_c,
MRaw, p_acc0_biases,
NRaw, p_acc1_biases,
KRaw, a_gs_ms_ks_lengths,
Gemm1NRaw, a_gs_ms_ks_strides,
Batch, b_gs_ns_ks_lengths,
c_gs_ms_gemm1ns_lengths, b_gs_ns_ks_strides,
c_gs_ms_gemm1ns_strides, b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
StrideA, b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
StrideB, c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
StrideB1, c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
BatchStrideA, acc0_biases_gs_ms_ns_lengths,
BatchStrideB, acc0_biases_gs_ms_ns_strides,
BatchStrideB1, acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
a_element_op, a_element_op,
b_element_op, b_element_op,
acc_element_op, acc_element_op,
...@@ -829,47 +766,51 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -829,47 +766,51 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// polymorphic // polymorphic
// FIXME: constness // FIXME: constness
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument> MakeArgumentPointer(
MakeArgumentPointer(const void* p_a, const void* p_a,
const void* p_b, const void* p_b,
const void* p_b1, const void* p_b1,
void* p_c, void* p_c,
index_t MRaw, const std::array<void*, NumAcc0Bias> p_acc0_biases,
index_t NRaw, const std::array<void*, NumAcc1Bias> p_acc1_biases,
index_t KRaw, const std::vector<index_t>& a_gs_ms_ks_lengths,
index_t Gemm1NRaw, const std::vector<index_t>& a_gs_ms_ks_strides,
index_t Batch, const std::vector<index_t>& b_gs_ns_ks_lengths,
std::vector<index_t> c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& b_gs_ns_ks_strides,
std::vector<index_t> c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
index_t StrideA, const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
index_t StrideB, const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
index_t StrideB1, const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
index_t BatchStrideA, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
index_t BatchStrideB, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
index_t BatchStrideB1, const std::array<std::vector<ck::index_t>, NumAcc1Bias>
AElementwiseOperation a_element_op, acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
BElementwiseOperation b_element_op, const std::array<std::vector<ck::index_t>, NumAcc1Bias>
AccElementwiseOperation acc_element_op, acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
B1ElementwiseOperation b1_element_op, AElementwiseOperation a_element_op,
CElementwiseOperation c_element_op) override BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
static_cast<const B1DataType*>(p_b1), static_cast<const B1DataType*>(p_b1),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
MRaw, p_acc0_biases, // cast in struct Argument
NRaw, p_acc1_biases, // cast in struct Argument
KRaw, a_gs_ms_ks_lengths,
Gemm1NRaw, a_gs_ms_ks_strides,
Batch, b_gs_ns_ks_lengths,
c_gs_ms_gemm1ns_lengths, b_gs_ns_ks_strides,
c_gs_ms_gemm1ns_strides, b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
StrideA, b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
StrideB, c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
StrideB1, c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
BatchStrideA, acc0_biases_gs_ms_ns_lengths,
BatchStrideB, acc0_biases_gs_ms_ns_strides,
BatchStrideB1, acc1_biases_gs_ms_gemm1ns_lengths,
acc1_biases_gs_ms_gemm1ns_strides,
a_element_op, a_element_op,
b_element_op, b_element_op,
acc_element_op, acc_element_op,
...@@ -901,7 +842,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -901,7 +842,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<< Gemm1NPerBlock << ", " << Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", " << Gemm1KPerBlock << ", "
<< B1K1 << ", " << B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ">"; << getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
<< "B0Spec" << getTensorSpecializationString(BSpec) << ", "
<< "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
<< "CSpec" << getTensorSpecializationString(CSpec) << ", "
<< getMaskingSpecializationString(MaskingSpec) << ">";
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
...@@ -196,7 +197,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -196,7 +197,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation> CElementwiseOperation,
MaskOutUpperTriangle>
{ {
using DeviceOp = DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle; using DeviceOp = DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle;
...@@ -315,29 +317,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -315,29 +317,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw); return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
} }
// to track the points which need to be set to -inf on C0
// Note: no need to reset M padding value, because they will not be stored out.
struct C0MatrixMask
{
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {}
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
{
return n >= NRaw_;
}
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
{
return IsUpperTriangle(m, n) || IsNOutOfBound(n);
}
private:
// index_t MRaw_;
index_t NRaw_;
};
struct ComputeBasePtrOfStridedBatch struct ComputeBasePtrOfStridedBatch
{ {
ComputeBasePtrOfStridedBatch(index_t BatchStrideA, ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
...@@ -383,6 +362,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -383,6 +362,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1)); using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using C0MatrixMask = conditional_t<MaskOutUpperTriangle,
C0MatrixMask_impl<MaskOutUpperTrianglePredicate>,
C0MatrixMask_impl<MaskDisabledPredicate>>;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/welford_helper.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim,
bool UseMultiblockInK,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XSrcYDstVectorDim,
index_t XSrcVectorSize,
index_t YDstVectorSize,
index_t ScaleSrcVectorSize,
index_t BiasSrcVectorSize,
index_t MeanVarSrcDstVectorSize>
struct DeviceBatchNormFwdImpl
: public DeviceBatchNormFwd<Rank, NumBatchNormReduceDim, YElementwiseOp>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!");
static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
(XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeXY2dDescriptor(const std::array<index_t, Rank>& xyLengths,
const std::array<index_t, Rank>& xyStrides,
int blkGroupSize,
int numBlockTileIteration)
{
const auto tupleXYLengths =
generate_tuple([&](auto I) { return xyLengths[I]; }, Number<Rank>{});
const auto tupleXYStrides =
generate_tuple([&](auto I) { return xyStrides[I]; }, Number<Rank>{});
const auto raw_grid_desc = make_naive_tensor_descriptor(tupleXYLengths, tupleXYStrides);
const auto grid_desc_m_k = [&]() {
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths =
generate_tuple([&](auto I) { return xyLengths[NumInvariantDim + I]; },
Number<NumBatchNormReduceDim>{});
const auto invariantDimLengths =
generate_tuple([&](auto I) { return xyLengths[I]; }, Number<NumInvariantDim>{});
return transform_tensor_descriptor(raw_grid_desc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}();
const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{});
const int workSizePerBlock = K_BlockTileSize * numBlockTileIteration;
const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto kPad = workSizePerBlock * blkGroupSize - reduceLength;
auto grid_desc_m_k_padded =
transform_tensor_descriptor(grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, mPad),
make_right_pad_transform(reduceLength, kPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (grid_desc_m_k_padded);
};
static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
{
const auto grid_desc_m_g =
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, blkGroupSize));
const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto grid_desc_m_g_padded =
transform_tensor_descriptor(grid_desc_m_g,
make_tuple(make_right_pad_transform(invariantLength, mPad),
make_pass_through_transform(blkGroupSize)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (grid_desc_m_g_padded);
};
static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize)
{
const auto reduceLength = blkGroupSize;
const auto grid_desc_m_k =
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, reduceLength));
const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto kPad =
math::integer_least_multiple(reduceLength, KThreadClusterSize) - reduceLength;
auto grid_desc_m_k_padded =
transform_tensor_descriptor(grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, mPad),
make_right_pad_transform(reduceLength, kPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (grid_desc_m_k_padded);
};
static auto
MakeScaleBiasMeanVar1dDescriptor(const std::array<index_t, NumInvariantDim>& lengths,
const std::array<index_t, NumInvariantDim>& strides)
{
const auto tupleLengths =
generate_tuple([&](auto I) { return lengths[I]; }, Number<NumInvariantDim>{});
const auto tupleStrides =
generate_tuple([&](auto I) { return strides[I]; }, Number<NumInvariantDim>{});
auto raw_grid_desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides);
auto grid_desc_m = transform_tensor_descriptor(
raw_grid_desc,
make_tuple(make_merge_transform(tupleLengths)),
make_tuple(typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto invariantLength = grid_desc_m.GetLength(Number<0>{});
const auto mPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto grid_desc_m_padded =
transform_tensor_descriptor(grid_desc_m,
make_tuple(make_right_pad_transform(invariantLength, mPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return (grid_desc_m_padded);
};
using XYGridDesc_M_K = decltype(MakeXY2dDescriptor({1}, {1}, 1, 1));
using ScaleBiasMeanVarGridDesc_M = decltype(MakeScaleBiasMeanVar1dDescriptor({1}, {1}));
struct Argument : public BaseArgument
{
Argument(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
const XDataType* p_x,
const ScaleDataType* p_scale,
const BiasDataType* p_bias,
const YElementwiseOp y_elementwise_op,
double epsilon,
YDataType* p_y,
MeanVarDataType* resultSaveMean,
MeanVarDataType* resultSaveInvVariance,
double averageFactor,
MeanVarDataType* resultRunningMean,
MeanVarDataType* resultRunningVariance)
: bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
bnScaleStrides_(bnScaleStrides),
bnBiasStrides_(bnBiasStrides),
bnMeanVarStrides_(bnMeanVarStrides),
p_x_(p_x),
p_scale_(p_scale),
p_bias_(p_bias),
y_elementwise_op_(y_elementwise_op),
p_y_(p_y),
resultSaveMean_(resultSaveMean),
resultSaveInvVariance_(resultSaveInvVariance),
resultRunningMean_(resultRunningMean),
resultRunningVariance_(resultRunningVariance)
{
xyLengths_ =
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(xyLengths, reduceDims);
xStrides_ =
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(xStrides, reduceDims);
yStrides_ =
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(yStrides, reduceDims);
std::tie(invariant_length_, reduce_length_) =
get_2d_lengths<Rank, NumBatchNormReduceDim>(xyLengths_);
epsilon_ = type_convert<AccDataType>(epsilon);
averageFactor_ = type_convert<AccDataType>(averageFactor);
updateMovingAverage_ =
(resultRunningMean != nullptr && resultRunningVariance != nullptr);
saveMeanInvVariance_ = (resultSaveMean != nullptr && resultSaveInvVariance_ != nullptr);
if(UseMultiblockInK)
{
int iterations = 1;
while(true)
{
int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
// we want the blkGroupSize be not more than 128
if(testBlkGroupSize <= 128)
break;
iterations++;
};
blkGroupSize_ = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
numBlockTileIteration_ = iterations;
}
else
{
blkGroupSize_ = 1;
numBlockTileIteration_ = (reduce_length_ + K_BlockTileSize - 1) / K_BlockTileSize;
};
gridSize_ = (invariant_length_ + M_BlockTileSize - 1) / M_BlockTileSize * blkGroupSize_;
x_grid_desc_m_k_ =
MakeXY2dDescriptor(xyLengths_, xStrides_, blkGroupSize_, numBlockTileIteration_);
y_grid_desc_m_k_ =
MakeXY2dDescriptor(xyLengths_, yStrides_, blkGroupSize_, numBlockTileIteration_);
scale_grid_desc_m_ =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides_);
bias_grid_desc_m_ =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnBiasStrides_);
mean_var_grid_desc_m_ =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides_);
}
AccDataType epsilon_;
AccDataType averageFactor_;
bool updateMovingAverage_;
bool saveMeanInvVariance_;
std::array<index_t, Rank> xyLengths_;
std::array<index_t, Rank> xStrides_;
std::array<index_t, Rank> yStrides_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides_;
const XDataType* p_x_;
const ScaleDataType* p_scale_;
const BiasDataType* p_bias_;
const YElementwiseOp y_elementwise_op_;
YDataType* p_y_;
MeanVarDataType* resultSaveMean_;
MeanVarDataType* resultSaveInvVariance_;
MeanVarDataType* resultRunningMean_;
MeanVarDataType* resultRunningVariance_;
long_index_t invariant_length_;
long_index_t reduce_length_;
int blkGroupSize_;
int numBlockTileIteration_;
size_t gridSize_;
XYGridDesc_M_K x_grid_desc_m_k_;
XYGridDesc_M_K y_grid_desc_m_k_;
ScaleBiasMeanVarGridDesc_M scale_grid_desc_m_;
ScaleBiasMeanVarGridDesc_M bias_grid_desc_m_;
ScaleBiasMeanVarGridDesc_M mean_var_grid_desc_m_;
void* workspace_mean_;
void* workspace_variance_;
void* workspace_count_;
};
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
{
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
size_t workspace_size = 0;
if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
{
// workspace for welford intermediate mean
workspace_size +=
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64;
// workspace for welford intermediate variance
workspace_size +=
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64;
// workspace for welford intermediate count
workspace_size +=
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64;
}
return (workspace_size);
};
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
{
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
pArg_->p_workspace_ = p_workspace;
if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
{
// setup buffer used for intermediate welford mean
pArg_->workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
index_t mean_space_sz =
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType);
mean_space_sz = math::integer_least_multiple(mean_space_sz, 64);
// setup buffer used for intermediate welford varirance
pArg_->workspace_variance_ =
reinterpret_cast<char*>(pArg_->workspace_mean_) + mean_space_sz;
index_t variance_space_sz =
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType);
variance_space_sz = math::integer_least_multiple(variance_space_sz, 64);
// setup buffer used for intermediate welfor count
pArg_->workspace_count_ =
reinterpret_cast<char*>(pArg_->workspace_variance_) + variance_space_sz;
};
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float avg_time = 0;
if(UseMultiblockInK && arg.blkGroupSize_ > 1)
{
using GetReduceCountPerThreadFunctor =
GetReduceCountPerThreadForMultiblockWelford<K_BlockTileSize, KThreadSliceSize>;
GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
arg.blkGroupSize_, arg.numBlockTileIteration_, arg.reduce_length_);
const auto mean_var_count_grid_desc_m_g =
DeviceBatchNormFwdImpl::MakeMeanVarCountOutputMG2dDescriptor(
arg.invariant_length_, arg.blkGroupSize_);
const auto mean_var_count_grid_desc_m_k =
DeviceBatchNormFwdImpl::MakeMeanVarCountInputMK2dDescriptor(
arg.invariant_length_, arg.blkGroupSize_);
using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g);
using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k);
using GridwiseMultiblockWelfordFirstHalf_ =
GridwiseMultiblockWelfordFirstHalf<XDataType,
AccDataType,
MeanVarDataType,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_G,
GetReduceCountPerThreadFunctor,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XSrcYDstVectorDim,
XSrcVectorSize>;
using GridwiseWelfordSecondHalfBatchNormForwardFinal_ =
GridwiseWelfordSecondHalfBatchNormForwardFinal<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XSrcYDstVectorDim,
XSrcVectorSize,
YDstVectorSize,
ScaleSrcVectorSize,
BiasSrcVectorSize,
MeanVarSrcDstVectorSize>;
index_t numMeanVarCountBlockTileIteration =
(arg.blkGroupSize_ + KThreadClusterSize - 1) / KThreadClusterSize;
const auto kern_multiblock_welford_first_half =
kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_,
XDataType,
MeanVarDataType,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_G,
GetReduceCountPerThreadFunctor>;
const auto kern_welford_second_half_batchnorm_forward_final =
kernel_welford_second_half_batchnorm_forward_final<
GridwiseWelfordSecondHalfBatchNormForwardFinal_,
XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
MeanVarCountGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M>;
avg_time +=
launch_and_time_kernel(stream_config,
kern_multiblock_welford_first_half,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
mean_var_count_grid_desc_m_g,
get_reduce_count_per_thread,
arg.numBlockTileIteration_,
arg.p_x_,
static_cast<MeanVarDataType*>(arg.workspace_mean_),
static_cast<MeanVarDataType*>(arg.workspace_variance_),
static_cast<int32_t*>(arg.workspace_count_));
avg_time +=
launch_and_time_kernel(stream_config,
kern_welford_second_half_batchnorm_forward_final,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
arg.y_grid_desc_m_k_,
mean_var_count_grid_desc_m_k,
arg.scale_grid_desc_m_,
arg.bias_grid_desc_m_,
arg.mean_var_grid_desc_m_,
arg.blkGroupSize_,
arg.numBlockTileIteration_,
numMeanVarCountBlockTileIteration,
arg.epsilon_,
static_cast<MeanVarDataType*>(arg.workspace_mean_),
static_cast<MeanVarDataType*>(arg.workspace_variance_),
static_cast<int32_t*>(arg.workspace_count_),
arg.p_x_,
arg.p_scale_,
arg.p_bias_,
arg.y_elementwise_op_,
arg.p_y_,
arg.updateMovingAverage_,
arg.averageFactor_,
arg.resultRunningMean_,
arg.resultRunningVariance_,
arg.saveMeanInvVariance_,
arg.resultSaveMean_,
arg.resultSaveInvVariance_);
}
else
{
using GetReduceCountPerThreadFunctor =
GetReduceCountPerThreadForBlockwiseWelford<K_BlockTileSize, KThreadSliceSize>;
GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
arg.numBlockTileIteration_, arg.reduce_length_);
using GridwiseBatchNormForwardWithBlockwiseWelford_ =
GridwiseBatchNormForwardWithBlockwiseWelford<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M,
GetReduceCountPerThreadFunctor,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XSrcYDstVectorDim,
XSrcVectorSize,
YDstVectorSize,
ScaleSrcVectorSize,
BiasSrcVectorSize,
MeanVarSrcDstVectorSize>;
const auto kern_batchnorm_fwd = kernel_batchnorm_forward_with_blockwise_welford<
GridwiseBatchNormForwardWithBlockwiseWelford_,
XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
YElementwiseOp,
XYGridDesc_M_K,
ScaleBiasMeanVarGridDesc_M,
ScaleBiasMeanVarGridDesc_M,
GetReduceCountPerThreadFunctor>;
avg_time += launch_and_time_kernel(stream_config,
kern_batchnorm_fwd,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.x_grid_desc_m_k_,
arg.y_grid_desc_m_k_,
arg.scale_grid_desc_m_,
arg.bias_grid_desc_m_,
arg.mean_var_grid_desc_m_,
get_reduce_count_per_thread,
arg.numBlockTileIteration_,
arg.epsilon_,
arg.p_x_,
arg.p_scale_,
arg.p_bias_,
arg.y_elementwise_op_,
arg.p_y_,
arg.updateMovingAverage_, // true or false
arg.averageFactor_,
arg.resultRunningMean_,
arg.resultRunningVariance_,
arg.saveMeanInvVariance_, // true or false
arg.resultSaveMean_,
arg.resultSaveInvVariance_);
};
return (avg_time);
};
float Run(const BaseArgument* pArg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(pArg), stream_config);
};
};
bool IsSupportedArgument(const BaseArgument* pArg) override
{
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
if constexpr(XSrcYDstVectorDim == 0)
{
if(pArg_->xStrides_[NumInvariantDim - 1] != 1 ||
pArg_->yStrides_[NumInvariantDim - 1] != 1)
return false;
if(pArg_->xyLengths_[NumInvariantDim - 1] % XSrcVectorSize != 0 ||
pArg_->xyLengths_[NumInvariantDim - 1] % YDstVectorSize != 0)
return false;
}
else
{
if(pArg_->xStrides_[Rank - 1] != 1 || pArg_->yStrides_[Rank - 1] != 1)
return false;
if(pArg_->xyLengths_[Rank - 1] % XSrcVectorSize != 0 ||
pArg_->xyLengths_[Rank - 1] % YDstVectorSize != 0)
return false;
};
if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1)
return false;
if(pArg_->bnBiasStrides_[NumInvariantDim - 1] != 1 && BiasSrcVectorSize != 1)
return false;
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0)
return false;
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % BiasSrcVectorSize != 0)
return false;
if(pArg_->bnMeanVarStrides_[NumInvariantDim - 1] != 1 && MeanVarSrcDstVectorSize != 1)
return false;
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % MeanVarSrcDstVectorSize != 0)
return false;
bool is_valid = true;
static_for<0, NumInvariantDim, 1>{}([&](auto I) {
if(pArg_->xyLengths_[I] != pArg_->bnScaleBiasMeanVarLengths_[I])
is_valid = false;
});
if(!is_valid)
return false;
return true;
};
std::unique_ptr<BaseArgument> MakeArgumentPointer(
const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
const void* p_x,
const void* p_scale,
const void* p_bias,
double epsilon,
const YElementwiseOp y_elementwise_op,
void* p_y,
void* resultSaveMean,
void* resultSaveInvVariance,
double averageFactor,
void* resultRunningMean,
void* resultRunningVariance) override
{
return std::make_unique<Argument>(xyLengths,
xStrides,
yStrides,
reduceDims,
bnScaleBiasMeanVarLengths,
bnScaleStrides,
bnBiasStrides,
bnMeanVarStrides,
static_cast<const XDataType*>(p_x),
static_cast<const ScaleDataType*>(p_scale),
static_cast<const BiasDataType*>(p_bias),
y_elementwise_op,
epsilon,
static_cast<YDataType*>(p_y),
static_cast<MeanVarDataType*>(resultSaveMean),
static_cast<MeanVarDataType*>(resultSaveInvVariance),
averageFactor,
static_cast<MeanVarDataType*>(resultRunningMean),
static_cast<MeanVarDataType*>(resultRunningVariance));
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchNormFwdImpl<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "XSrcYDstVectorDim_" << XSrcYDstVectorDim << ",";
str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << BiasSrcVectorSize << "_mean_var_" << MeanVarSrcDstVectorSize << "_Y" << YDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -214,6 +214,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -214,6 +214,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
K0PerBlock, K0PerBlock,
K1,
M1PerThread, M1PerThread,
N1PerThread, N1PerThread,
KPerThread, KPerThread,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
enum struct MaskingSpecialization
{
MaskDisabled,
MaskOutUpperTriangle
};
inline std::string getMaskingSpecializationString(const MaskingSpecialization& s)
{
switch(s)
{
case MaskingSpecialization::MaskDisabled: return "MaskDisabled";
case MaskingSpecialization::MaskOutUpperTriangle: return "MaskOutUpperTriangle";
default: return "Unrecognized specialization!";
}
}
struct MaskDisabledPredicate
{
__host__ __device__ constexpr bool operator()(index_t /*m*/, index_t /*n*/) const
{
return false;
};
__host__ __device__ constexpr bool
IsTileSkippable(index_t /*m*/, index_t /*n*/, index_t /*m_tile*/, index_t /*n_tile*/) const
{
return false;
}
};
struct MaskOutUpperTrianglePredicate
{
__host__ __device__ constexpr bool operator()(index_t m, index_t n) const { return n > m; }
__host__ __device__ constexpr bool
IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t /*n_tile*/) const
{
return operator()(m + m_tile - 1, n);
}
};
// to track the points which need to be set to -inf on C0
// Note: no need to reset M padding value, because they will not be stored out.
template <typename MaskOutPredicate>
struct C0MatrixMask_impl
{
C0MatrixMask_impl(index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) {}
__host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const
{
return n >= NRaw_;
}
__host__ __device__ constexpr bool IsMaskedElement(index_t m, index_t n) const
{
return predicate_(m, n) || IsNOutOfBound(n);
}
__host__ __device__ constexpr bool
IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t n_tile) const
{
return predicate_.IsTileSkippable(m, n, m_tile, n_tile);
}
private:
// index_t MRaw_;
index_t NRaw_;
MaskOutPredicate predicate_;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment