Commit 4fec5ad3 authored by aska-0096's avatar aska-0096
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/composable_kernel into wmma_op

parents 24faa1fc 87fd1152
...@@ -15,13 +15,9 @@ ...@@ -15,13 +15,9 @@
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp" #include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "batchnorm_forward_impl.hpp" #include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
template <typename InOutDataType, typename AccDataType>
using ReferenceBatchNormFwdInstance =
ck::tensor_operation::host::ReferenceBatchNormFwd_Input_N_H_W_C_Output_C<InOutDataType,
AccDataType>;
static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'}, static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'},
{"verify", required_argument, nullptr, 'v'}, {"verify", required_argument, nullptr, 'v'},
...@@ -44,6 +40,7 @@ class BatchNormFwdArg ...@@ -44,6 +40,7 @@ class BatchNormFwdArg
int data_type = 0; int data_type = 0;
int init_method = 2; int init_method = 2;
bool time_kernel = false; bool time_kernel = false;
bool use_multiblock_welford = false;
public: public:
void show_usage(const char* cmd) void show_usage(const char* cmd)
...@@ -68,6 +65,7 @@ class BatchNormFwdArg ...@@ -68,6 +65,7 @@ class BatchNormFwdArg
"value, 3=decimal value)" "value, 3=decimal value)"
<< std::endl; << std::endl;
std::cout << "Arg5: time kernel (0=no, 1=yes)" << std::endl; std::cout << "Arg5: time kernel (0=no, 1=yes)" << std::endl;
std::cout << "Arg6: use multi-block welford (0=n0, 1=yes)" << std::endl;
}; };
int processArgs(int argc, char* argv[]) int processArgs(int argc, char* argv[])
...@@ -110,14 +108,15 @@ class BatchNormFwdArg ...@@ -110,14 +108,15 @@ class BatchNormFwdArg
}; };
}; };
if(optind + 5 > argc) if(optind + 6 > argc)
throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!");
data_type = std::atoi(argv[optind++]); data_type = std::atoi(argv[optind++]);
updateMovingAverage = std::atoi(argv[optind++]); updateMovingAverage = std::atoi(argv[optind++]);
saveMeanAndInvVariance = std::atoi(argv[optind++]); saveMeanAndInvVariance = std::atoi(argv[optind++]);
init_method = std::atoi(argv[optind++]); init_method = std::atoi(argv[optind++]);
time_kernel = static_cast<bool>(std::atoi(argv[optind])); time_kernel = static_cast<bool>(std::atoi(argv[optind++]));
use_multiblock_welford = static_cast<bool>(std::atoi(argv[optind]));
if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6) if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6)
return (-1); return (-1);
...@@ -128,7 +127,7 @@ class BatchNormFwdArg ...@@ -128,7 +127,7 @@ class BatchNormFwdArg
using namespace ck; using namespace ck;
template <typename InOutDataType, typename AccDataType> template <typename InOutDataType, typename AccDataType, bool UseMultiblockInK>
bool bnorm_fwd_nhwc_test(bool do_verification, bool bnorm_fwd_nhwc_test(bool do_verification,
int init_method, int init_method,
bool time_kernel, bool time_kernel,
...@@ -273,73 +272,140 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -273,73 +272,140 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
scaleBiasMeanVarStrides.end(), scaleBiasMeanVarStrides.end(),
i_scaleBiasMeanVarStrides.begin()); i_scaleBiasMeanVarStrides.begin());
int result = 0; using PassThroughOp = ck::tensor_operation::element_wise::PassThrough;
// used for saving meansquare using DeviceBatchNormFwdInstance =
DeviceMem workspace(sizeof(AccDataType) * 2 * resultSaveMean_ref.mDesc.GetElementSpaceSize() + ck::tensor_operation::device::DeviceBatchNormFwdImpl<InOutDataType,
128); InOutDataType,
AccDataType,
AccDataType, // ScaleDataType
AccDataType, // BiasDataType
AccDataType, // MeanVarDataType
PassThroughOp, // YElementwiseOp
Rank,
NumReduceDim,
UseMultiblockInK,
256,
16,
16,
1,
2,
0,
1,
1,
1,
1,
1>;
void* p_tmp_mean = workspace.GetDeviceBuffer(); auto batchnorm_fwd = DeviceBatchNormFwdInstance{};
void* p_tmp_meansquare =
static_cast<char*>(p_tmp_mean) +
(sizeof(AccDataType) * resultSaveMean_ref.mDesc.GetElementSpaceSize() + 63) / 64 * 64;
result = bnorm_fwd<InOutDataType, AccDataType, Rank, NumReduceDim, false>( auto argument_ptr = batchnorm_fwd.MakeArgumentPointer(
time_kernel,
updateMovingAverage,
saveMeanAndInvVariance,
{0, 1, 2},
i_inOutLengths, i_inOutLengths,
i_inOutStrides, i_inOutStrides,
i_inOutStrides, i_inOutStrides,
{0, 1, 2},
i_scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
x_dev.GetDeviceBuffer(), x_dev.GetDeviceBuffer(),
bnScale_dev.GetDeviceBuffer(), bnScale_dev.GetDeviceBuffer(),
bnBias_dev.GetDeviceBuffer(), bnBias_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(),
averageFactor,
updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr,
updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr,
epsilon, epsilon,
PassThroughOp{},
y_dev.GetDeviceBuffer(),
saveMeanAndInvVariance ? resultSaveMean_dev.GetDeviceBuffer() : nullptr, saveMeanAndInvVariance ? resultSaveMean_dev.GetDeviceBuffer() : nullptr,
saveMeanAndInvVariance ? resultSaveInvVariance_dev.GetDeviceBuffer() : nullptr, saveMeanAndInvVariance ? resultSaveInvVariance_dev.GetDeviceBuffer() : nullptr,
p_tmp_mean, averageFactor,
p_tmp_meansquare); updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr,
updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr);
if(result < 0) if(!batchnorm_fwd.IsSupportedArgument(argument_ptr.get()))
{
std::cout << "The runtime parameters seems not supported by the BatchNorm device instance, "
"exiting!"
<< std::endl;
return (false); return (false);
};
size_t workspace_sz = batchnorm_fwd.GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
batchnorm_fwd.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
auto invoker_ptr = batchnorm_fwd.MakeInvokerPointer();
if(time_kernel)
{
float avg_time = 0.0f;
size_t num_bytes = 0;
size_t total_length = inOutLengths[0] * inOutLengths[1] * inOutLengths[2] * inOutLengths[3];
size_t invariant_length = inOutLengths[3];
avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
// inputing of x, scale, bias, outputing of y
num_bytes +=
total_length * sizeof(InOutDataType) * 2 + invariant_length * sizeof(AccDataType) * 2;
// outputing of mean, inv-variance
num_bytes += saveMeanAndInvVariance ? invariant_length * sizeof(AccDataType) * 2 : 0;
// updating of moving mean, variance
num_bytes += updateMovingAverage ? invariant_length * sizeof(AccDataType) * 4 : 0;
float gb_per_sec = num_bytes / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
}
else
(void)invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
bool pass = true; bool pass = true;
if(do_verification) if(do_verification)
{ {
auto batchNormFwd_ref = ReferenceBatchNormFwdInstance<InOutDataType, AccDataType>{};
using ReferenceBatchNormFwdInstance =
ck::tensor_operation::host::ReferenceBatchNormFwd_Input_N_H_W_C_Output_C<InOutDataType,
InOutDataType,
AccDataType,
AccDataType,
AccDataType,
AccDataType,
PassThroughOp>;
auto batchNormFwd_ref = ReferenceBatchNormFwdInstance{};
auto argument_ptr_ref = batchNormFwd_ref.MakeArgumentPointer( auto argument_ptr_ref = batchNormFwd_ref.MakeArgumentPointer(
i_inOutLengths, i_inOutLengths,
i_inOutStrides, i_inOutStrides,
i_inOutStrides, i_inOutStrides,
{0, 1, 2},
i_scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
x.mData.data(), x.mData.data(),
bnScale.mData.data(), bnScale.mData.data(),
bnBias.mData.data(), bnBias.mData.data(),
y_ref.mData.data(),
0.1, // exponentialAverageFactor
updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr, // resultRunningMean
updateMovingAverage ? resultRunningVariance_ref.mData.data()
: nullptr, // resultRunningVariance
epsilon, epsilon,
PassThroughOp{},
y_ref.mData.data(),
saveMeanAndInvVariance ? resultSaveMean_ref.mData.data() : nullptr, saveMeanAndInvVariance ? resultSaveMean_ref.mData.data() : nullptr,
saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr); saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr,
averageFactor,
updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr,
updateMovingAverage ? resultRunningVariance_ref.mData.data() : nullptr);
if(!batchNormFwd_ref.IsSupportedArgument(argument_ptr_ref.get())) if(!batchNormFwd_ref.IsSupportedArgument(argument_ptr_ref.get()))
{ {
std::cout std::cout << "The runtime parameters seems not supported by the BatchNorm reference "
<< "The runtime parameters seems not supported by the BatchNorm instance, exiting!" "instance, exiting!"
<< std::endl; << std::endl;
return (-2); return (false);
}; };
auto invoker_ptr_ref = batchNormFwd_ref.MakeInvokerPointer(); auto invoker_ptr_ref = batchNormFwd_ref.MakeInvokerPointer();
...@@ -365,6 +431,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -365,6 +431,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
if(saveMeanAndInvVariance) if(saveMeanAndInvVariance)
{ {
using ck::host_common::dumpBufferToFile;
Tensor<AccDataType> resultSaveMean(scaleBiasMeanVarLengths); Tensor<AccDataType> resultSaveMean(scaleBiasMeanVarLengths);
Tensor<AccDataType> resultSaveInvVariance(scaleBiasMeanVarLengths); Tensor<AccDataType> resultSaveInvVariance(scaleBiasMeanVarLengths);
...@@ -396,7 +464,17 @@ int main(int argc, char* argv[]) ...@@ -396,7 +464,17 @@ int main(int argc, char* argv[])
if(arg.data_type == 0) if(arg.data_type == 0)
{ {
pass = bnorm_fwd_nhwc_test<ck::half_t, float>(arg.do_verification, if(arg.use_multiblock_welford)
pass = bnorm_fwd_nhwc_test<ck::half_t, float, true>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<ck::half_t, float, false>(arg.do_verification,
arg.init_method, arg.init_method,
arg.time_kernel, arg.time_kernel,
arg.inOutLengths, arg.inOutLengths,
...@@ -407,7 +485,17 @@ int main(int argc, char* argv[]) ...@@ -407,7 +485,17 @@ int main(int argc, char* argv[])
} }
else if(arg.data_type == 1) else if(arg.data_type == 1)
{ {
pass = bnorm_fwd_nhwc_test<float, float>(arg.do_verification, if(arg.use_multiblock_welford)
pass = bnorm_fwd_nhwc_test<float, float, true>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<float, float, false>(arg.do_verification,
arg.init_method, arg.init_method,
arg.time_kernel, arg.time_kernel,
arg.inOutLengths, arg.inOutLengths,
...@@ -418,7 +506,17 @@ int main(int argc, char* argv[]) ...@@ -418,7 +506,17 @@ int main(int argc, char* argv[])
} }
else if(arg.data_type == 3) else if(arg.data_type == 3)
{ {
pass = bnorm_fwd_nhwc_test<int8_t, float>(arg.do_verification, if(arg.use_multiblock_welford)
pass = bnorm_fwd_nhwc_test<int8_t, float, true>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<int8_t, float, false>(arg.do_verification,
arg.init_method, arg.init_method,
arg.time_kernel, arg.time_kernel,
arg.inOutLengths, arg.inOutLengths,
...@@ -429,7 +527,17 @@ int main(int argc, char* argv[]) ...@@ -429,7 +527,17 @@ int main(int argc, char* argv[])
} }
else if(arg.data_type == 5) else if(arg.data_type == 5)
{ {
pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float>(arg.do_verification, if(arg.use_multiblock_welford)
pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float, true>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float, false>(arg.do_verification,
arg.init_method, arg.init_method,
arg.time_kernel, arg.time_kernel,
arg.inOutLengths, arg.inOutLengths,
...@@ -440,7 +548,17 @@ int main(int argc, char* argv[]) ...@@ -440,7 +548,17 @@ int main(int argc, char* argv[])
} }
else if(arg.data_type == 6) else if(arg.data_type == 6)
{ {
pass = bnorm_fwd_nhwc_test<double, double>(arg.do_verification, if(arg.use_multiblock_welford)
pass = bnorm_fwd_nhwc_test<double, double, true>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<double, double, false>(arg.do_verification,
arg.init_method, arg.init_method,
arg.time_kernel, arg.time_kernel,
arg.inOutLengths, arg.inOutLengths,
...@@ -452,12 +570,21 @@ int main(int argc, char* argv[]) ...@@ -452,12 +570,21 @@ int main(int argc, char* argv[])
} }
else else
{ {
pass = bnorm_fwd_nhwc_test<ck::half_t, float>(true, pass = bnorm_fwd_nhwc_test<ck::half_t, float, true>(true,
2, 2,
false, // don't time kernel false, // don't time kernel
{128, 16, 16, 1024}, {128, 16, 6, 512},
true,
true,
averageFactor,
epsilon);
pass = pass && bnorm_fwd_nhwc_test<ck::half_t, float, false>(true,
2,
false, // don't time kernel
{128, 16, 3, 1024},
true,
true, true,
false,
averageFactor, averageFactor,
epsilon); epsilon);
}; };
......
...@@ -14,8 +14,12 @@ ...@@ -14,8 +14,12 @@
#include "batchnorm_common.hpp" #include "batchnorm_common.hpp"
template <typename InOutDataType, template <typename XDataType,
typename YDataType,
typename AccDataType, typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
ck::index_t Rank, ck::index_t Rank,
ck::index_t NumBatchNormReduceDim, ck::index_t NumBatchNormReduceDim,
bool fastest_dim_is_reduced = false> bool fastest_dim_is_reduced = false>
...@@ -26,7 +30,9 @@ int bnorm_infer( ...@@ -26,7 +30,9 @@ int bnorm_infer(
const std::array<ck::index_t, Rank> xStrides, const std::array<ck::index_t, Rank> xStrides,
const std::array<ck::index_t, Rank> yStrides, const std::array<ck::index_t, Rank> yStrides,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths, const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides, const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
const void* p_x, const void* p_x,
const void* p_scale, const void* p_scale,
const void* p_bias, const void* p_bias,
...@@ -41,11 +47,11 @@ int bnorm_infer( ...@@ -41,11 +47,11 @@ int bnorm_infer(
"Invalid number of reduced dimensions for batchnorm!"); "Invalid number of reduced dimensions for batchnorm!");
using DeviceNormalizeInstance = ck::tensor_operation::device::DeviceElementwise< using DeviceNormalizeInstance = ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<InOutDataType, AccDataType, AccDataType, AccDataType, AccDataType>, // x, mean, ck::Tuple<XDataType, AccDataType, AccDataType, AccDataType, AccDataType>, // x, mean,
// variance, // variance,
// scale, // scale,
// bias, // bias,
ck::Tuple<InOutDataType>, // y ck::Tuple<YDataType>, // y
NormalizeInInfer, NormalizeInInfer,
Rank, Rank,
2, // MPerthread 2, // MPerthread
...@@ -53,14 +59,18 @@ int bnorm_infer( ...@@ -53,14 +59,18 @@ int bnorm_infer(
ck::Sequence<1>>; // scalarPerVector: y ck::Sequence<1>>; // scalarPerVector: y
auto invariantDims = get_invariant_dims<Rank, NumBatchNormReduceDim>(reduceDims); auto invariantDims = get_invariant_dims<Rank, NumBatchNormReduceDim>(reduceDims);
std::array<ck::index_t, Rank> aligned_scaleBiasMeanVarStrides{0}; std::array<ck::index_t, Rank> aligned_bnScaleStrides{0};
std::array<ck::index_t, Rank> aligned_bnBiasStrides{0};
std::array<ck::index_t, Rank> aligned_bnMeanVarStrides{0};
int i = 0; int i = 0;
for(auto dim : invariantDims) for(auto dim : invariantDims)
{ {
assert(xyLengths[dim] == bnScaleBiasMeanVarLengths[i]); assert(xyLengths[dim] == bnScaleBiasMeanVarLengths[i]);
aligned_scaleBiasMeanVarStrides[dim] = bnScaleBiasMeanVarStrides[i]; aligned_bnScaleStrides[dim] = bnScaleStrides[i];
aligned_bnBiasStrides[dim] = bnBiasStrides[i];
aligned_bnMeanVarStrides[dim] = bnMeanVarStrides[i];
i++; i++;
}; };
...@@ -84,10 +94,10 @@ int bnorm_infer( ...@@ -84,10 +94,10 @@ int bnorm_infer(
auto argument_ptr1 = dev_normalize.MakeArgumentPointer( auto argument_ptr1 = dev_normalize.MakeArgumentPointer(
xyLengths, xyLengths,
{xStrides, {xStrides,
aligned_scaleBiasMeanVarStrides, aligned_bnMeanVarStrides,
aligned_scaleBiasMeanVarStrides, aligned_bnMeanVarStrides,
aligned_scaleBiasMeanVarStrides, aligned_bnScaleStrides,
aligned_scaleBiasMeanVarStrides}, aligned_bnBiasStrides},
{yStrides}, {yStrides},
{p_x, p_estimatedMean, p_estimatedVariance, p_scale, p_bias}, {p_x, p_estimatedMean, p_estimatedVariance, p_scale, p_bias},
{p_y}, {p_y},
...@@ -105,8 +115,10 @@ int bnorm_infer( ...@@ -105,8 +115,10 @@ int bnorm_infer(
avg_time += invoker_ptr1->Run(argument_ptr1.get(), StreamConfig{nullptr, time_kernel}); avg_time += invoker_ptr1->Run(argument_ptr1.get(), StreamConfig{nullptr, time_kernel});
num_bytes += (total_length * (1 * sizeof(InOutDataType) + 4 * sizeof(AccDataType)) + num_bytes += total_length * sizeof(XDataType) +
total_length * sizeof(InOutDataType)); invariantLength *
(sizeof(ScaleDataType) + sizeof(BiasDataType) + 2 * sizeof(MeanVarDataType)) +
total_length * sizeof(YDataType);
if(time_kernel) if(time_kernel)
{ {
......
...@@ -18,11 +18,6 @@ ...@@ -18,11 +18,6 @@
#include "batchnorm_infer_impl.hpp" #include "batchnorm_infer_impl.hpp"
template <typename InOutDataType, typename AccDataType>
using ReferenceBatchNormInferInstance =
ck::tensor_operation::host::ReferenceBatchNormInfer_Input_N_H_W_C_Output_C<InOutDataType,
AccDataType>;
static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'}, static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'},
{"verify", required_argument, nullptr, 'v'}, {"verify", required_argument, nullptr, 'v'},
{"help", no_argument, nullptr, '?'}, {"help", no_argument, nullptr, '?'},
...@@ -236,14 +231,23 @@ bool bnorm_infer_nhwc_test(bool do_verification, ...@@ -236,14 +231,23 @@ bool bnorm_infer_nhwc_test(bool do_verification,
int result = 0; int result = 0;
result = bnorm_infer<InOutDataType, AccDataType, Rank, NumReduceDim, false>( result = bnorm_infer<InOutDataType,
time_kernel, InOutDataType,
AccDataType,
AccDataType,
AccDataType,
AccDataType,
Rank,
NumReduceDim,
false>(time_kernel,
{0, 1, 2}, {0, 1, 2},
i_inOutLengths, i_inOutLengths,
i_inOutStrides, i_inOutStrides,
i_inOutStrides, i_inOutStrides,
i_scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
x_dev.GetDeviceBuffer(), x_dev.GetDeviceBuffer(),
bnScale_dev.GetDeviceBuffer(), bnScale_dev.GetDeviceBuffer(),
bnBias_dev.GetDeviceBuffer(), bnBias_dev.GetDeviceBuffer(),
...@@ -259,7 +263,15 @@ bool bnorm_infer_nhwc_test(bool do_verification, ...@@ -259,7 +263,15 @@ bool bnorm_infer_nhwc_test(bool do_verification,
if(do_verification) if(do_verification)
{ {
auto batchNormInfer_ref = ReferenceBatchNormInferInstance<InOutDataType, AccDataType>{}; using ReferenceBatchNormInferInstance =
ck::tensor_operation::host::ReferenceBatchNormInfer_Input_N_H_W_C_Output_C<
InOutDataType,
InOutDataType,
AccDataType,
AccDataType,
AccDataType,
AccDataType>;
auto batchNormInfer_ref = ReferenceBatchNormInferInstance{};
auto argument_ptr_ref = auto argument_ptr_ref =
batchNormInfer_ref.MakeArgumentPointer(i_inOutLengths, batchNormInfer_ref.MakeArgumentPointer(i_inOutLengths,
...@@ -267,6 +279,8 @@ bool bnorm_infer_nhwc_test(bool do_verification, ...@@ -267,6 +279,8 @@ bool bnorm_infer_nhwc_test(bool do_verification,
i_inOutStrides, i_inOutStrides,
i_scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
x.mData.data(), x.mData.data(),
bnScale.mData.data(), bnScale.mData.data(),
bnBias.mData.data(), bnBias.mData.data(),
......
...@@ -168,6 +168,11 @@ ...@@ -168,6 +168,11 @@
// tuning parameter // tuning parameter
#define CK_WORKAROUND_SWDEV_325164 0 #define CK_WORKAROUND_SWDEV_325164 0
// workaround: disable broken fused attention kernel instance that does not pass validation
// issue found on mi100/#10738 combo when irregular KPerBlock attention kernel has acc0 scaling
// enabled
#define CK_WORKAROUND_DISABLE_BROKEN_ATTN_KERNEL_INSTANCE 1
namespace ck { namespace ck {
enum struct InMemoryDataOperationEnum enum struct InMemoryDataOperationEnum
......
...@@ -14,7 +14,8 @@ namespace ck { ...@@ -14,7 +14,8 @@ namespace ck {
template <typename TensorLengths, template <typename TensorLengths,
typename DimAccessOrder, typename DimAccessOrder,
typename ScalarsPerAccess> // # of scalars per access in each dimension typename ScalarsPerAccess,
bool SnakeCurved = true> // # of scalars per access in each dimension
struct SpaceFillingCurve struct SpaceFillingCurve
{ {
static constexpr index_t nDim = TensorLengths::Size(); static constexpr index_t nDim = TensorLengths::Size();
...@@ -136,9 +137,10 @@ struct SpaceFillingCurve ...@@ -136,9 +137,10 @@ struct SpaceFillingCurve
Index ordered_idx; Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto idim) { static_for<0, nDim, 1>{}([&](auto idim) {
ordered_idx(idim) = forward_sweep[idim] ? ordered_access_idx[idim] ordered_idx(idim) =
: ordered_access_lengths[idim] - 1 - !SnakeCurved || forward_sweep[idim]
ordered_access_idx[idim]; ? ordered_access_idx[idim]
: ordered_access_lengths[idim] - 1 - ordered_access_idx[idim];
}); });
return container_reorder_given_old2new(ordered_idx, dim_access_order) * return container_reorder_given_old2new(ordered_idx, dim_access_order) *
......
...@@ -151,6 +151,27 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -151,6 +151,27 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
return make_tuple(c_thread_m, c_thread_n); return make_tuple(c_thread_m, c_thread_n);
} }
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
return make_tuple(Number<m0>{},
Number<n0>{},
waveId_m,
waveId_n,
blk_idx[I0],
blk_idx[I1],
blk_idx[I2],
blk_idx[I3]);
}
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1() __host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
{ {
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
...@@ -724,6 +745,21 @@ struct BlockwiseGemmXdlops_v2 ...@@ -724,6 +745,21 @@ struct BlockwiseGemmXdlops_v2
return make_tuple(c_thread_m, c_thread_n); return make_tuple(c_thread_m, c_thread_n);
} }
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
return make_tuple(
m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
}
using Tuple4 = decltype(CalculateAThreadOriginDataIndex()); using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin = CalculateAThreadOriginDataIndex(), __host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
......
...@@ -24,7 +24,8 @@ template <typename ALayout, ...@@ -24,7 +24,8 @@ template <typename ALayout,
typename B0ElementwiseOperation, typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation, typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation,
bool MaskOutUpperTriangle> // TODO: enum for mask type
struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
......
...@@ -7,44 +7,55 @@ ...@@ -7,44 +7,55 @@
#include <vector> #include <vector>
#include "device_base.hpp" #include "device_base.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename ALayout, template <index_t NumDimG,
typename B0Layout, index_t NumDimM,
typename B1Layout, index_t NumDimN,
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<> index_t NumDimK,
index_t NumDimO,
typename ADataType, typename ADataType,
typename B0DataType, typename B0DataType,
typename B1DataType, typename B1DataType,
typename CDataType, typename CDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename B0ElementwiseOperation, typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation, typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
MakeArgumentPointer(const void* p_a, static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b0, const void* p_b0,
const void* p_b1, const void* p_b1,
void* p_c, void* p_c,
ck::index_t M, const std::array<void*, NumAcc0Bias> p_acc0_biases,
ck::index_t N, const std::array<void*, NumAcc1Bias> p_acc1_biases,
ck::index_t K, const std::vector<index_t>& a_gs_ms_ks_lengths,
ck::index_t O, const std::vector<index_t>& a_gs_ms_ks_strides,
ck::index_t Batch, const std::vector<index_t>& b_gs_ns_ks_lengths,
std::vector<index_t> c_gs_ms_os_lengths, const std::vector<index_t>& b_gs_ns_ks_strides,
std::vector<index_t> c_gs_ms_os_strides, const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
ck::index_t StrideA, const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
ck::index_t StrideB0, const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
ck::index_t StrideB1, const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
ck::index_t BatchStrideA, const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
ck::index_t BatchStrideB0, const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
ck::index_t BatchStrideB1, const std::array<std::vector<index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op, B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op, Acc0ElementwiseOperation acc0_element_op,
......
...@@ -13,31 +13,36 @@ namespace ck { ...@@ -13,31 +13,36 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <index_t Rank, index_t NumBatchNormReduceDim> template <index_t Rank, index_t NumBatchNormReduceDim, typename YElementwiseOp>
struct DeviceBatchNormFwd : public BaseOperator struct DeviceBatchNormFwd : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer( virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const std::array<index_t, Rank> xyLengths, const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides, const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides, const std::array<index_t, Rank> yStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths, const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides, const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
const void* p_x, const void* p_x,
const void* bnScale, const void* bnScale,
const void* bnBias, const void* bnBias,
double epsilon,
const YElementwiseOp y_elementwise_op,
void* p_y, void* p_y,
void* resultSaveMean,
void* resultSaveInvVariance,
double exponentialAverageFactor, double exponentialAverageFactor,
void* resultRunningMean, void* resultRunningMean,
void* resultRunningVariance, void* resultRunningVariance) = 0;
double epsilon,
void* resultSaveMean,
void* resultSaveInvVariance) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <index_t Rank, index_t NumBatchNormReduceDim> template <index_t Rank, index_t NumBatchNormReduceDim, typename YElementwiseOp>
using DeviceBatchNormFwdPtr = std::unique_ptr<DeviceBatchNormFwd<Rank, NumBatchNormReduceDim>>; using DeviceBatchNormFwdPtr =
std::unique_ptr<DeviceBatchNormFwd<Rank, NumBatchNormReduceDim, YElementwiseOp>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -21,7 +21,9 @@ struct DeviceBatchNormInfer : public BaseOperator ...@@ -21,7 +21,9 @@ struct DeviceBatchNormInfer : public BaseOperator
const std::array<index_t, Rank> xStrides, const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> yStrides, const std::array<index_t, Rank> yStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths, const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides, const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
const void* p_x, const void* p_x,
const void* bnScale, const void* bnScale,
const void* bnBias, const void* bnBias,
......
...@@ -7,46 +7,50 @@ ...@@ -7,46 +7,50 @@
#include <vector> #include <vector>
#include "device_base.hpp" #include "device_base.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename ALayout, template <index_t NumDimG,
typename B0Layout, index_t NumDimM,
typename B1Layout, index_t NumDimN,
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<> index_t NumDimK,
index_t NumDimO,
typename ADataType, typename ADataType,
typename B0DataType, typename B0DataType,
typename B1DataType, typename B1DataType,
typename CDataType, typename CDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename B0ElementwiseOperation, typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation, typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator
{ {
struct ProblemDesc struct ProblemDesc
{ {
// Overall problem shape std::vector<index_t> a_gs_ms_ks_lengths;
index_t M; std::vector<index_t> a_gs_ms_ks_strides;
index_t N;
index_t K;
index_t O;
index_t Batch;
// Stride for A/B0/B1; layout determined by template args std::vector<index_t> b0_gs_ns_ks_lengths;
index_t StrideA; std::vector<index_t> b0_gs_ns_ks_strides;
index_t StrideB0;
index_t StrideB1; std::vector<index_t> b1_gs_os_ns_lengths;
index_t BatchStrideA; std::vector<index_t> b1_gs_os_ns_strides;
index_t BatchStrideB0;
index_t BatchStrideB1;
// Lengths and strides for output C
std::vector<index_t> c_gs_ms_os_lengths; std::vector<index_t> c_gs_ms_os_lengths;
std::vector<index_t> c_gs_ms_os_strides; std::vector<index_t> c_gs_ms_os_strides;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_lengths;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_strides;
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_lengths;
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_strides;
}; };
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
...@@ -54,6 +58,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator ...@@ -54,6 +58,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator
std::vector<const void*> p_b0_vec, std::vector<const void*> p_b0_vec,
std::vector<const void*> p_b1_vec, std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec, std::vector<void*> p_c_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
std::vector<ProblemDesc> problem_desc_vec, std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op, B0ElementwiseOperation b0_element_op,
......
...@@ -3,27 +3,30 @@ ...@@ -3,27 +3,30 @@
#pragma once #pragma once
#include <vector> #include <array>
#include <memory> #include <memory>
#include <iostream>
#include "ck/utility/common_header.hpp" #include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename InElementwiseOperation, typename AccElementwiseOperation> template <index_t Rank,
index_t NumReduceDim,
typename InElementwiseOperation,
typename AccElementwiseOperation>
struct DeviceReduce : public BaseOperator struct DeviceReduce : public BaseOperator
{ {
static constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim;
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> inLengths, MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
const std::vector<index_t> inStrides, const std::array<index_t, Rank> inStrides,
const std::vector<index_t> outLengths, const std::array<index_t, NumOutDim> outLengths,
const std::vector<index_t> outStrides, const std::array<index_t, NumOutDim> outStrides,
const std::vector<int> reduceDims, const std::array<int, NumReduceDim> reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
...@@ -36,9 +39,12 @@ struct DeviceReduce : public BaseOperator ...@@ -36,9 +39,12 @@ struct DeviceReduce : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename InElementwiseOperation, typename AccElementwiseOperation> template <index_t Rank,
using DeviceReducePtr = index_t NumReduceDim,
std::unique_ptr<DeviceReduce<InElementwiseOperation, AccElementwiseOperation>>; typename InElementwiseOperation,
typename AccElementwiseOperation>
using DeviceReducePtr = std::unique_ptr<
DeviceReduce<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -130,8 +130,11 @@ namespace device { ...@@ -130,8 +130,11 @@ namespace device {
// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] // D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] // E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// FIXME: TensorSpecialization::Packed specialization does not cover all packed tensor cases, it // NOTE: TensorSpecialization::Packed specialized tensor is "packed" in a sense that each inner
// merely degenerates into TensorSpecialization::Default with NumDimG/M/N/K = 1 // dimension in a dimension group (eg [G0, G1] in Gs, [M0, M1, M2] in Ms, etc.) are contiguous and
// ordered. Not in a sense that the tensor [G0, G1, ..., M0, M1, ..., N0, N1...] can be permuted
// while still being a contiguous, unpadded tensor. In other words, it merely degenerates into
// TensorSpecialization::Default with NumDimG/M/N/K = 1
// //
// Detail- Packed tensor satisfies // Detail- Packed tensor satisfies
// stride_0 = 1 // stride_0 = 1
...@@ -147,7 +150,7 @@ namespace device { ...@@ -147,7 +150,7 @@ namespace device {
// essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1. // essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1.
// //
// Might need to expose dimension order to the interface to fully support // Might need to expose dimension order to the interface to fully support
// TensorSpecialization::Packed. // TensorSpecialization::Packed in a traditional sense of "packed" tensor
template <index_t NumDimG, template <index_t NumDimG,
index_t NumDimM, index_t NumDimM,
index_t NumDimN, index_t NumDimN,
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
...@@ -196,7 +197,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -196,7 +197,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation> CElementwiseOperation,
MaskOutUpperTriangle>
{ {
using DeviceOp = DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle; using DeviceOp = DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle;
...@@ -315,29 +317,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -315,29 +317,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw); return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
} }
// to track the points which need to be set to -inf on C0
// Note: no need to reset M padding value, because they will not be stored out.
struct C0MatrixMask
{
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {}
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
{
return n >= NRaw_;
}
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
{
return IsUpperTriangle(m, n) || IsNOutOfBound(n);
}
private:
// index_t MRaw_;
index_t NRaw_;
};
struct ComputeBasePtrOfStridedBatch struct ComputeBasePtrOfStridedBatch
{ {
ComputeBasePtrOfStridedBatch(index_t BatchStrideA, ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
...@@ -383,6 +362,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -383,6 +362,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1)); using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using C0MatrixMask = conditional_t<MaskOutUpperTriangle,
C0MatrixMask_impl<MaskOutUpperTrianglePredicate>,
C0MatrixMask_impl<MaskDisabledPredicate>>;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
......
...@@ -5,9 +5,8 @@ ...@@ -5,9 +5,8 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <array>
#include "ck/utility/common_header.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp" #include "ck/tensor_operation/gpu/device/device_reduce.hpp"
...@@ -41,7 +40,8 @@ template <typename InDataType, ...@@ -41,7 +40,8 @@ template <typename InDataType,
index_t InSrcVectorDim, index_t InSrcVectorDim,
index_t InSrcVectorSize, index_t InSrcVectorSize,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccElementwiseOperation> struct DeviceReduceMultiBlock
: public DeviceReduce<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
...@@ -58,8 +58,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -58,8 +58,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank; static constexpr index_t NumSrcDim = Rank;
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0); static constexpr bool reduceAllDim = (NumInvariantDim == 0);
// So far, only AtomicAdd is considered, other Atomic Operation like AtomicMax can be added // So far, only AtomicAdd is considered, other Atomic Operation like AtomicMax can be added
...@@ -81,13 +81,15 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -81,13 +81,15 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths, static auto MakeSrc2dDescriptor(const std::array<index_t, Rank>& inLengths,
const std::vector<index_t>& inStrides, const std::array<index_t, Rank>& inStrides,
int blkGroupSize, int blkGroupSize,
int numBlockTileIteration) int numBlockTileIteration)
{ {
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{}); const auto tupleSrcLengths =
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{}); generate_tuple([&](auto I) { return inLengths[I]; }, Number<Rank>{});
const auto tupleSrcStrides =
generate_tuple([&](auto I) { return inStrides[I]; }, Number<Rank>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
...@@ -97,7 +99,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -97,7 +99,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
const auto one_dim_inDesc = transform_tensor_descriptor( const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc, inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)), make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}), make_tuple(typename arithmetic_sequence_gen<0, NumSrcDim, 1>::type{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc, return transform_tensor_descriptor(one_dim_inDesc,
...@@ -111,10 +113,10 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -111,10 +113,10 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type; using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths = const auto reduceDimLengths = generate_tuple(
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
const auto invariantDimLengths = const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
return transform_tensor_descriptor( return transform_tensor_descriptor(
inDesc, inDesc,
...@@ -143,18 +145,20 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -143,18 +145,20 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
return (in_grid_desc_m_k_padded); return (in_grid_desc_m_k_padded);
}; };
static auto MakeDst1dDescriptor(const std::vector<index_t>& outLengths, static auto MakeDst1dDescriptor(const std::array<index_t, NumDstDim>& outLengths,
const std::vector<index_t>& outStrides) const std::array<index_t, NumDstDim>& outStrides)
{ {
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{}); const auto tupleDstLengths =
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{}); generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumDstDim>{});
const auto tupleDstStrides =
generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor( auto out_grid_desc_m = transform_tensor_descriptor(
outDesc, outDesc,
make_tuple(make_merge_transform(tupleDstLengths)), make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}), make_tuple(typename arithmetic_sequence_gen<0, NumDstDim, 1>::type{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{}); const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
...@@ -170,18 +174,20 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -170,18 +174,20 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
return (out_grid_desc_m_padded); return (out_grid_desc_m_padded);
}; };
static auto MakeDst1dDescriptorForBufferSet(const std::vector<index_t>& outLengths, static auto MakeDst1dDescriptorForBufferSet(const std::array<index_t, NumDstDim>& outLengths,
const std::vector<index_t>& outStrides) const std::array<index_t, NumDstDim>& outStrides)
{ {
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{}); const auto tupleDstLengths =
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{}); generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumDstDim>{});
const auto tupleDstStrides =
generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor( auto out_grid_desc_m = transform_tensor_descriptor(
outDesc, outDesc,
make_tuple(make_merge_transform(tupleDstLengths)), make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}), make_tuple(typename arithmetic_sequence_gen<0, NumDstDim, 1>::type{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto length = out_grid_desc_m.GetLength(Number<0>{}); const auto length = out_grid_desc_m.GetLength(Number<0>{});
...@@ -198,11 +204,11 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -198,11 +204,11 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const std::vector<index_t> inLengths, Argument(const std::array<index_t, Rank> inLengths,
const std::vector<index_t> inStrides, const std::array<index_t, Rank> inStrides,
const std::vector<index_t> outLengths, const std::array<index_t, NumDstDim> outLengths,
const std::vector<index_t> outStrides, const std::array<index_t, NumDstDim> outStrides,
const std::vector<int> reduceDims, const std::array<int, NumReduceDim> reduceDims,
float alpha, float alpha,
float beta, float beta,
const InDataType* in_dev, const InDataType* in_dev,
...@@ -272,10 +278,10 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -272,10 +278,10 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
math::integer_least_multiple(invariant_total_length, BlockSize) / BlockSize; math::integer_least_multiple(invariant_total_length, BlockSize) / BlockSize;
} }
std::vector<index_t> inLengths_; std::array<index_t, Rank> inLengths_;
std::vector<index_t> inStrides_; std::array<index_t, Rank> inStrides_;
std::vector<index_t> outLengths_; std::array<index_t, NumDstDim> outLengths_;
std::vector<index_t> outStrides_; std::array<index_t, NumDstDim> outStrides_;
AccDataType alpha_; AccDataType alpha_;
AccDataType beta_; AccDataType beta_;
...@@ -459,11 +465,11 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -459,11 +465,11 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
}; };
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> inLengths, MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
const std::vector<index_t> inStrides, const std::array<index_t, Rank> inStrides,
const std::vector<index_t> outLengths, const std::array<index_t, NumDstDim> outLengths,
const std::vector<index_t> outStrides, const std::array<index_t, NumDstDim> outStrides,
const std::vector<int> reduceDims, const std::array<int, NumReduceDim> reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <array>
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -34,7 +35,8 @@ template <typename InDataType, ...@@ -34,7 +35,8 @@ template <typename InDataType,
index_t InSrcVectorDim, index_t InSrcVectorDim,
index_t InSrcVectorSize, index_t InSrcVectorSize,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccElementwiseOperation> struct DeviceReduceThreadWise
: public DeviceReduce<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(Rank <= 6, "Bigger Rank size is not supported!");
...@@ -49,18 +51,20 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE ...@@ -49,18 +51,20 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank; static constexpr index_t NumSrcDim = Rank;
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0); static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize; static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize;
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths, static auto MakeSrc2dDescriptor(const std::array<index_t, Rank>& inLengths,
const std::vector<index_t>& inStrides) const std::array<index_t, Rank>& inStrides)
{ {
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{}); const auto tupleSrcLengths =
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{}); generate_tuple([&](auto I) { return inLengths[I]; }, Number<Rank>{});
const auto tupleSrcStrides =
generate_tuple([&](auto I) { return inStrides[I]; }, Number<Rank>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
...@@ -70,7 +74,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE ...@@ -70,7 +74,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
const auto one_dim_inDesc = transform_tensor_descriptor( const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc, inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)), make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}), make_tuple(typename arithmetic_sequence_gen<0, NumSrcDim, 1>::type{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc, return transform_tensor_descriptor(one_dim_inDesc,
...@@ -84,10 +88,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE ...@@ -84,10 +88,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type; using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths = const auto reduceDimLengths = generate_tuple(
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
const auto invariantDimLengths = const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
return transform_tensor_descriptor( return transform_tensor_descriptor(
inDesc, inDesc,
...@@ -116,18 +120,20 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE ...@@ -116,18 +120,20 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
return (in_grid_desc_m_k_padded); return (in_grid_desc_m_k_padded);
}; };
static auto MakeDst1dDescriptor(const std::vector<index_t>& outLengths, static auto MakeDst1dDescriptor(const std::array<index_t, NumDstDim>& outLengths,
const std::vector<index_t>& outStrides) const std::array<index_t, NumDstDim>& outStrides)
{ {
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{}); const auto tupleDstLengths =
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{}); generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumDstDim>{});
const auto tupleDstStrides =
generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor( auto out_grid_desc_m = transform_tensor_descriptor(
outDesc, outDesc,
make_tuple(make_merge_transform(tupleDstLengths)), make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}), make_tuple(typename arithmetic_sequence_gen<0, NumDstDim, 1>::type{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{}); const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
...@@ -145,11 +151,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE ...@@ -145,11 +151,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const std::vector<index_t> inLengths, Argument(const std::array<index_t, Rank> inLengths,
const std::vector<index_t> inStrides, const std::array<index_t, Rank> inStrides,
const std::vector<index_t> outLengths, const std::array<index_t, NumDstDim> outLengths,
const std::vector<index_t> outStrides, const std::array<index_t, NumDstDim> outStrides,
const std::vector<int> reduceDims, const std::array<int, NumReduceDim> reduceDims,
float alpha, float alpha,
float beta, float beta,
const InDataType* in_dev, const InDataType* in_dev,
...@@ -187,10 +193,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE ...@@ -187,10 +193,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
M_BlockTileSize; M_BlockTileSize;
} }
std::vector<index_t> inLengths_; std::array<index_t, Rank> inLengths_;
std::vector<index_t> inStrides_; std::array<index_t, Rank> inStrides_;
std::vector<index_t> outLengths_; std::array<index_t, NumDstDim> outLengths_;
std::vector<index_t> outStrides_; std::array<index_t, NumDstDim> outStrides_;
AccDataType alpha_; AccDataType alpha_;
AccDataType beta_; AccDataType beta_;
...@@ -321,11 +327,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE ...@@ -321,11 +327,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
}; };
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> inLengths, MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
const std::vector<index_t> inStrides, const std::array<index_t, Rank> inStrides,
const std::vector<index_t> outLengths, const std::array<index_t, NumDstDim> outLengths,
const std::vector<index_t> outStrides, const std::array<index_t, NumDstDim> outStrides,
const std::vector<int> reduceDims, const std::array<int, NumReduceDim> reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
......
...@@ -8,12 +8,9 @@ ...@@ -8,12 +8,9 @@
#include "ck/utility/reduction_operator.hpp" #include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp" #include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -50,29 +47,80 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -50,29 +47,80 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
virtual index_t GetNumReduceDim() const override { return kNumReduceDim; } virtual index_t GetNumReduceDim() const override { return kNumReduceDim; }
// Used for freeloading of some handy functions from DeviceReduceMultiBlock static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
using Reduction = DeviceReduceMultiBlock<InDataType,
AccDataType, static constexpr index_t NumSrcDim = Rank;
OutDataType, static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
Rank, static constexpr bool reduceAllDim = (NumInvariantDim == 0);
NumReduceDim,
reduce::Add, static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
InElementwiseOp, static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
AccElementwiseOp,
InMemoryDataOperationEnum::Set, static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
false, // PropagateNan const std::vector<index_t>& inStrides,
false, // OutputIndex int blkGroupSize,
false, // HaveIndexInputIfOutputIndex int numBlockTileIteration)
BlockSize, {
MThreadClusterSize, const auto tupleSrcLengths =
KThreadClusterSize, generate_tuple([&](auto I) { return inLengths[I]; }, Number<Rank>{});
MThreadSliceSize, const auto tupleSrcStrides =
KThreadSliceSize, generate_tuple([&](auto I) { return inStrides[I]; }, Number<Rank>{});
InSrcVectorDim,
InSrcVectorSize, const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
1>; // OutDstVectorSize
const auto in_grid_desc_m_k = [&]() {
if constexpr(reduceAllDim)
{
const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, NumSrcDim, 1>::type{}),
make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc,
make_tuple(make_unmerge_transform(make_tuple(
1, one_dim_inDesc.GetLength(Number<0>{})))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
}
else
{
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1)); const auto reduceDimLengths = generate_tuple(
[&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
const auto invariantDimLengths =
generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
return transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}();
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration;
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded);
};
using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1));
using GridwiseSoftmaxGeneric = GridwiseSoftmax_mk_to_mk<InDataType, using GridwiseSoftmaxGeneric = GridwiseSoftmax_mk_to_mk<InDataType,
OutDataType, OutDataType,
...@@ -102,7 +150,7 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -102,7 +150,7 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
OutDstVectorSize, OutDstVectorSize,
true>; true>;
struct Argument : public Reduction::Argument struct Argument : public BaseArgument
{ {
Argument(const std::vector<index_t> inLengths, Argument(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides, const std::vector<index_t> inStrides,
...@@ -113,42 +161,60 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -113,42 +161,60 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
OutDataType* out_dev, OutDataType* out_dev,
InElementwiseOp in_elementwise_op, InElementwiseOp in_elementwise_op,
AccElementwiseOp acc_elementwise_op) AccElementwiseOp acc_elementwise_op)
: Reduction::Argument(inLengths, : alpha_{alpha},
inStrides, beta_{beta},
{}, in_dev_{in_dev},
{}, out_dev_{out_dev},
reduceDims, in_elementwise_op_{in_elementwise_op},
0.0f, // alpha acc_elementwise_op_{acc_elementwise_op}
0.0f, // beta
in_dev,
nullptr,
out_dev,
nullptr,
in_elementwise_op,
acc_elementwise_op),
// FIXME: The base class DeviceReduceMultiBlock::Argument only supports alpha/beta of
// float32 precision. Make it support any data type so the fields can be removed.
alpha_(alpha),
beta_(beta)
{ {
// std::cout << "blkGroupSize= " << this->blkGroupSize inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
// << ", numBlockTileIteration= " << this->numBlockTileIteration inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
// << ", gridSize=" << this->gridSize
// << ", invariant_total_length=" << this->invariant_total_length << long_index_t invariant_total_length;
// std::endl; long_index_t reduce_total_length;
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
if constexpr(NumInvariantDim == 0)
invariant_lowest_length_ = 1;
else
invariant_lowest_length_ = inLengths_[NumInvariantDim - 1];
blkGroupSize = 1;
numBlockTileIteration = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize;
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize * blkGroupSize;
} }
std::vector<index_t> inLengths_;
std::vector<index_t> inStrides_;
AccDataType alpha_; AccDataType alpha_;
AccDataType beta_; AccDataType beta_;
const InDataType* in_dev_;
OutDataType* out_dev_;
InElementwiseOp in_elementwise_op_;
AccElementwiseOp acc_elementwise_op_;
index_t invariant_lowest_length_;
int blkGroupSize;
int numBlockTileIteration;
size_t gridSize;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
const auto in_grid_desc_m_k = Reduction::MakeSrc2dDescriptor( const auto in_grid_desc_m_k = DeviceSoftmaxImpl::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto out_grid_desc_m_k = Reduction::MakeSrc2dDescriptor( const auto out_grid_desc_m_k = DeviceSoftmaxImpl::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
bool sweep_once = bool sweep_once =
...@@ -195,15 +261,32 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -195,15 +261,32 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
{ {
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg); const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
if(!Reduction::IsSupportedArgument(p_arg_)) if constexpr(InSrcVectorDim == 0)
{
if constexpr(NumInvariantDim == 0)
{ {
return false; return false;
} }
else
if(p_arg_->inLengths_[Rank - 1] % OutDstVectorSize != 0)
{ {
if(p_arg_->inStrides_[NumInvariantDim - 1] != 1)
return false; return false;
if(p_arg_->invariant_lowest_length_ % InSrcVectorSize != 0)
return false;
};
} }
else
{
if(p_arg_->inStrides_[Rank - 1] != 1)
return false;
if(p_arg_->inLengths_[Rank - 1] % InSrcVectorSize != 0)
return false;
};
if(p_arg_->invariant_lowest_length_ % OutDstVectorSize != 0)
return false;
return true; return true;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment