Commit 1abaedd9 authored by Alan Turner's avatar Alan Turner
Browse files

Merge remote-tracking branch 'origin/develop' into gpu-invoker

parents bd2b3dd7 cb3fac4d
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "ck/utility/reduction_enums.hpp" #include "ck/utility/reduction_enums.hpp"
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
...@@ -216,8 +217,8 @@ int mean_meansquare_dual_reduce_test(size_t n, ...@@ -216,8 +217,8 @@ int mean_meansquare_dual_reduce_test(size_t n,
size_t invariant_total_length = n; size_t invariant_total_length = n;
size_t reduce_total_length = h * w * c; size_t reduce_total_length = h * w * c;
const AccDataType alpha = ck::type_convert<AccDataType>(1.0f); const double alpha = 1.0f;
const AccDataType beta = ck::type_convert<AccDataType>(0.0f); const double beta = 0.0f;
std::size_t num_thread = 1; std::size_t num_thread = 1;
...@@ -253,10 +254,10 @@ int mean_meansquare_dual_reduce_test(size_t n, ...@@ -253,10 +254,10 @@ int mean_meansquare_dual_reduce_test(size_t n,
std::array<ck::index_t, NumOutputDim> i_outLengths; std::array<ck::index_t, NumOutputDim> i_outLengths;
std::array<ck::index_t, NumOutputDim> i_outStrides; std::array<ck::index_t, NumOutputDim> i_outStrides;
std::copy(inLengths.begin(), inLengths.end(), i_inLengths.begin()); ck::ranges::copy(inLengths, i_inLengths.begin());
std::copy(inStrides.begin(), inStrides.end(), i_inStrides.begin()); ck::ranges::copy(inStrides, i_inStrides.begin());
std::copy(outLengths.begin(), outLengths.end(), i_outLengths.begin()); ck::ranges::copy(outLengths, i_outLengths.begin());
std::copy(outStrides.begin(), outStrides.end(), i_outStrides.begin()); ck::ranges::copy(outStrides, i_outStrides.begin());
auto dual_reduce_op = DeviceDualReduce{}; auto dual_reduce_op = DeviceDualReduce{};
...@@ -266,8 +267,8 @@ int mean_meansquare_dual_reduce_test(size_t n, ...@@ -266,8 +267,8 @@ int mean_meansquare_dual_reduce_test(size_t n,
i_outLengths, i_outLengths,
{i_outStrides, i_outStrides}, {i_outStrides, i_outStrides},
reduceDims, reduceDims,
{&alpha, &alpha}, {alpha, alpha},
{&beta, &beta}, {beta, beta},
in_dev.GetDeviceBuffer(), in_dev.GetDeviceBuffer(),
{mean_dev.GetDeviceBuffer(), meansquare_dev.GetDeviceBuffer()}, {mean_dev.GetDeviceBuffer(), meansquare_dev.GetDeviceBuffer()},
ck::make_tuple(InElementwiseOperation_Mean{}, InElementwiseOperation_Meansquare{}), ck::make_tuple(InElementwiseOperation_Mean{}, InElementwiseOperation_Meansquare{}),
...@@ -305,8 +306,8 @@ int mean_meansquare_dual_reduce_test(size_t n, ...@@ -305,8 +306,8 @@ int mean_meansquare_dual_reduce_test(size_t n,
{ {
mean_dev.FromDevice(mean.mData.data()); mean_dev.FromDevice(mean.mData.data());
meansquare_dev.FromDevice(meansquare.mData.data()); meansquare_dev.FromDevice(meansquare.mData.data());
pass = pass && ck::utils::check_err(mean.mData, mean_ref.mData); pass = pass && ck::utils::check_err(mean, mean_ref);
pass = pass && ck::utils::check_err(meansquare.mData, meansquare_ref.mData); pass = pass && ck::utils::check_err(meansquare, meansquare_ref);
}; };
return (pass ? 0 : 1); return (pass ? 0 : 1);
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_multiple_reduce_multiblock.hpp" #include "ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "dual_reduce_common.hpp" #include "dual_reduce_common.hpp"
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_multiple_reduce_threadwise.hpp" #include "ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "dual_reduce_common.hpp" #include "dual_reduce_common.hpp"
......
add_example_executable(example_batchnorm_forward batchnorm_forward_nhwc.cpp) add_example_executable(example_batchnorm_forward_training batchnorm_forward_training_nhwc.cpp)
add_example_executable(example_batchnorm_infer batchnorm_infer_nhwc.cpp) add_example_executable(example_batchnorm_forward_inferring batchnorm_forward_inferring_nhwc.cpp)
add_example_executable(example_batchnorm_backward batchnorm_backward_nhwc.cpp)
...@@ -53,4 +53,29 @@ Start running 10 times... ...@@ -53,4 +53,29 @@ Start running 10 times...
Perf: 1.28235 ms, 523.329 GB/s Perf: 1.28235 ms, 523.329 GB/s
``` ```
## Run ```batchnorm backward nhwc```
```bash
# -D <xxx> : input 4-d tensor lengths
# -v <x> : verification (0=no, 1=yes)
Arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64)
Arg2 -- 1/0 to indicate whether to use saved mean and invVariance
Arg3 -- init method used for dy and bnScale (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
Arg4 -- time kernel (0=no, 1=yes)
Arg5: use multi-block welford (0=n0, 1=yes)
./bin/example_batchnorm_backward -D 128,16,3,1024 -v 1 0 0 3 1 1
```
Result
```
./bin/example_batchnorm_backward -D 128,16,3,1024 -v 1 0 0 3 1 1
launch_and_time_kernel: grid_dim {6144, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Start running 10 times...
launch_and_time_kernel: grid_dim {6144, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Start running 10 times...
launch_and_time_kernel: grid_dim {6144, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Start running 10 times...
Perf: 0.411026 ms, 91.8702 GB/s
```
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <limits>
#include <iostream>
#include <getopt.h>
#include "ck/ck.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp"
static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'},
{"verify", required_argument, nullptr, 'v'},
{"help", no_argument, nullptr, '?'},
{nullptr, 0, nullptr, 0}};
class BatchNormBwdArg
{
private:
int option_index = 0;
public:
std::vector<size_t> inOutLengths;
bool do_verification = false;
bool haveSavedMeanInvVar;
int data_type = 0;
int init_method = 3;
bool time_kernel = false;
bool use_multiblock_welford = false;
public:
void show_usage(const char* cmd)
{
// clang-format off
std::cout << "Usage of " << cmd << std::endl;
std::cout << "--inOutLengths or -D, comma separated list of input tensor dimension lengths, must have 4 integers for nhwc" << std::endl;
std::cout << "--verify or -v, 1/0 to indicate whether to verify the result by comparing with the host-based batch-normalization" << std::endl;
std::cout << "Arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64)" << std::endl;
std::cout << "Arg2 -- 1/0 to indicate whether to use saved mean and invVariance" << std::endl;
std::cout << "Arg3 -- init method used for dy and bnScale (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)" << std::endl;
std::cout << "Arg4 -- time kernel (0=no, 1=yes)" << std::endl;
std::cout << "Arg5: use multi-block welford (0=n0, 1=yes)" << std::endl;
// clang-format on
};
int processArgs(int argc, char* argv[])
{
using ck::host_common::getTypeValuesFromString;
int ch;
while(1)
{
ch = getopt_long(argc, argv, "D:v:", long_options, &option_index);
if(ch == -1)
break;
switch(ch)
{
case 'D':
if(!optarg)
throw std::runtime_error("Invalid option format!");
inOutLengths = getTypeValuesFromString<size_t>(optarg);
if(inOutLengths.size() != 4)
throw std::runtime_error(
"NHWC tensor layout should have 4 length values specified!");
break;
case 'v':
if(!optarg)
throw std::runtime_error("Invalid option format!");
do_verification = static_cast<bool>(std::atoi(optarg));
break;
case '?':
if(std::string(long_options[option_index].name) == "help")
{
show_usage(argv[0]);
return (-1);
};
break;
default: show_usage(argv[0]); return (-1);
};
};
if(optind + 5 > argc)
throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!");
data_type = std::atoi(argv[optind++]);
haveSavedMeanInvVar = std::atoi(argv[optind++]);
init_method = std::atoi(argv[optind++]);
time_kernel = static_cast<bool>(std::atoi(argv[optind++]));
use_multiblock_welford = static_cast<bool>(std::atoi(argv[optind]));
return (0);
};
};
using namespace ck;
template <typename XDataType, typename AccDataType, bool UseMultiblockInK>
bool bnorm_bwd_nhwc_test(bool do_verification,
int init_method,
bool time_kernel,
const std::vector<size_t> inOutLengths,
bool haveSavedMeanInvVar,
double epsilon)
{
// for NHWC BatchNorm calculation of mean and meansquare
constexpr index_t Rank = 4;
constexpr index_t NumReduceDim = 3;
using ScaleDataType = XDataType;
const std::vector<size_t> scaleBiasMeanVarLengths = {inOutLengths[3]};
// input data of the batchnorm backward algorithm
Tensor<XDataType> x(inOutLengths);
Tensor<AccDataType> dy(inOutLengths);
Tensor<ScaleDataType> bnScale(scaleBiasMeanVarLengths);
Tensor<AccDataType> savedMean(scaleBiasMeanVarLengths);
Tensor<AccDataType> savedInvVar(scaleBiasMeanVarLengths);
// savedVariance is only used for initializing savedInvVar
Tensor<AccDataType> savedVariance(scaleBiasMeanVarLengths);
// output data of the batchnorm backward algorithm
Tensor<AccDataType> dx_ref(inOutLengths);
Tensor<AccDataType> dx(inOutLengths);
Tensor<AccDataType> dscale(scaleBiasMeanVarLengths);
Tensor<AccDataType> dbias(scaleBiasMeanVarLengths);
Tensor<AccDataType> dscale_ref(scaleBiasMeanVarLengths);
Tensor<AccDataType> dbias_ref(scaleBiasMeanVarLengths);
auto inOutStrides = dy.mDesc.GetStrides();
auto scaleBiasMeanVarStrides = dscale.mDesc.GetStrides();
std::size_t num_thread = std::thread::hardware_concurrency();
if(haveSavedMeanInvVar)
{
const float x_mean = 0.0f;
const float x_stddev = 1.0f;
const float noise_stddev = 0.0001f;
// input data in normal distribution
x.GenerateTensorValue(GeneratorTensor_4<XDataType>{x_mean, x_stddev}, num_thread);
// initialize the savedMean to be values with tiny variation to the mean of the x values
savedMean.GenerateTensorValue(GeneratorTensor_4<AccDataType>{x_mean, noise_stddev},
num_thread);
// initialize the variance to be values with tiny variation to the variance of the x values
savedVariance.GenerateTensorValue(
GeneratorTensor_4<AccDataType>{x_stddev * x_stddev, noise_stddev}, num_thread);
auto it_src = savedVariance.mData.begin();
auto it_dst = savedInvVar.mData.begin();
float tmp_epsilon = std::numeric_limits<float>::epsilon();
while(it_src != savedVariance.mData.end())
{
*it_dst = type_convert<AccDataType>(
1.0f / std::sqrtf(type_convert<float>(*it_src) + tmp_epsilon));
it_src++;
it_dst++;
};
}
else
{
const float x_mean = 0.0f;
const float x_stddev = 1.0f;
// input data in normal distribution
x.GenerateTensorValue(GeneratorTensor_4<XDataType>{x_mean, x_stddev}, num_thread);
};
if(do_verification)
{
switch(init_method)
{
case 0:
dy.GenerateTensorValue(GeneratorTensor_0<AccDataType>{}, num_thread);
bnScale.GenerateTensorValue(GeneratorTensor_0<ScaleDataType>{}, num_thread);
break;
case 1:
dy.GenerateTensorValue(GeneratorTensor_1<AccDataType>{1}, num_thread);
bnScale.GenerateTensorValue(GeneratorTensor_1<ScaleDataType>{1}, num_thread);
break;
case 2:
dy.GenerateTensorValue(GeneratorTensor_2<AccDataType>{-2, 2}, num_thread);
bnScale.GenerateTensorValue(GeneratorTensor_2<ScaleDataType>{-5, 5}, num_thread);
break;
default:
dy.GenerateTensorValue(GeneratorTensor_3<AccDataType>{-0.2f, 0.2f}, num_thread);
bnScale.GenerateTensorValue(GeneratorTensor_3<ScaleDataType>{-0.5f, 0.5f}, num_thread);
}
};
// input data of the batchnorm backward algorithm
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem dy_dev(sizeof(AccDataType) * dy.mDesc.GetElementSpaceSize());
DeviceMem bnScale_dev(sizeof(ScaleDataType) * bnScale.mDesc.GetElementSpaceSize());
DeviceMem savedMean_dev(sizeof(AccDataType) * savedMean.mDesc.GetElementSpaceSize());
DeviceMem savedInvVar_dev(sizeof(AccDataType) * savedInvVar.mDesc.GetElementSpaceSize());
// output data of the batchnorm backward algorithm
DeviceMem dx_dev(sizeof(AccDataType) * dx.mDesc.GetElementSpaceSize());
DeviceMem dscale_dev(sizeof(AccDataType) * dscale.mDesc.GetElementSpaceSize());
DeviceMem dbias_dev(sizeof(AccDataType) * dbias.mDesc.GetElementSpaceSize());
x_dev.ToDevice(x.mData.data());
dy_dev.ToDevice(dy.mData.data());
bnScale_dev.ToDevice(bnScale.mData.data());
if(haveSavedMeanInvVar)
{
savedMean_dev.ToDevice(savedMean.mData.data());
savedInvVar_dev.ToDevice(savedInvVar.mData.data());
};
std::array<index_t, Rank> i_inOutLengths;
std::array<index_t, Rank> i_inOutStrides;
std::array<index_t, Rank - NumReduceDim> i_scaleBiasMeanVarLengths;
std::array<index_t, Rank - NumReduceDim> i_scaleBiasMeanVarStrides;
std::copy(inOutLengths.begin(), inOutLengths.end(), i_inOutLengths.begin());
std::copy(inOutStrides.begin(), inOutStrides.end(), i_inOutStrides.begin());
std::copy(scaleBiasMeanVarLengths.begin(),
scaleBiasMeanVarLengths.end(),
i_scaleBiasMeanVarLengths.begin());
std::copy(scaleBiasMeanVarStrides.begin(),
scaleBiasMeanVarStrides.end(),
i_scaleBiasMeanVarStrides.begin());
using PassThroughOp = ck::tensor_operation::element_wise::PassThrough;
using DeviceBatchNormBwdInstance =
ck::tensor_operation::device::DeviceBatchNormBwdImpl<XDataType,
AccDataType,
AccDataType,
AccDataType,
ScaleDataType, // ScaleDataType
AccDataType, // DscaleDbiasDataType
AccDataType, // MeanVarDataType
PassThroughOp,
Rank,
NumReduceDim,
UseMultiblockInK,
256,
16,
16,
1,
2,
0,
1, // XSrcVectorSize
1, // DySrcVectorSize
1, // DxDstVectorSize
1, // ScaleSrcVectorSize
1, // DscaleDbiasDstVectorSize
1>; // MeanVarSrcVectorSize
auto batchnorm_bwd = DeviceBatchNormBwdInstance{};
auto argument_ptr = batchnorm_bwd.MakeArgumentPointer(
i_inOutLengths,
i_inOutStrides,
i_inOutStrides,
i_inOutStrides,
{0, 1, 2},
i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
x_dev.GetDeviceBuffer(),
dy_dev.GetDeviceBuffer(),
bnScale_dev.GetDeviceBuffer(),
haveSavedMeanInvVar ? savedMean_dev.GetDeviceBuffer() : nullptr,
haveSavedMeanInvVar ? savedInvVar_dev.GetDeviceBuffer() : nullptr,
epsilon,
PassThroughOp{},
dx_dev.GetDeviceBuffer(),
dscale_dev.GetDeviceBuffer(),
dbias_dev.GetDeviceBuffer());
if(!batchnorm_bwd.IsSupportedArgument(argument_ptr.get()))
{
std::cout << "The runtime parameters seems not supported by the BatchNorm device instance, "
"exiting!"
<< std::endl;
return (false);
};
size_t workspace_sz = batchnorm_bwd.GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
batchnorm_bwd.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
auto invoker_ptr = batchnorm_bwd.MakeInvokerPointer();
if(time_kernel)
{
float avg_time = 0.0f;
size_t num_bytes = 0;
size_t total_length = inOutLengths[0] * inOutLengths[1] * inOutLengths[2] * inOutLengths[3];
size_t invariant_length = inOutLengths[3];
avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
// inputing of x, dy, scale, outputing of dx, dscale, dbias
num_bytes +=
total_length * sizeof(XDataType) * 3 + invariant_length * sizeof(AccDataType) * 3;
// outputing of mean, inv-variance
num_bytes += haveSavedMeanInvVar ? invariant_length * sizeof(AccDataType) * 2 : 0;
float gb_per_sec = num_bytes / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
}
else
(void)invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
bool pass = true;
if(do_verification)
{
using ReferenceBatchNormBwdInstance =
ck::tensor_operation::host::ReferenceBatchNormBwd<XDataType,
AccDataType,
AccDataType,
AccDataType,
ScaleDataType, // ScaleDataType
AccDataType,
AccDataType,
PassThroughOp,
Rank,
NumReduceDim>;
auto batchNormBwd_ref = ReferenceBatchNormBwdInstance{};
auto argument_ptr_ref = batchNormBwd_ref.MakeArgumentPointer(
i_inOutLengths,
i_inOutStrides,
i_inOutStrides,
i_inOutStrides,
{0, 1, 2},
i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
x.mData.data(),
dy.mData.data(),
bnScale.mData.data(),
haveSavedMeanInvVar ? savedMean.mData.data() : nullptr,
haveSavedMeanInvVar ? savedInvVar.mData.data() : nullptr,
epsilon,
PassThroughOp{},
dx_ref.mData.data(),
dscale_ref.mData.data(),
dbias_ref.mData.data());
if(!batchNormBwd_ref.IsSupportedArgument(argument_ptr_ref.get()))
{
std::cout
<< "The runtime parameters seems not supported by the device instance, exiting!"
<< std::endl;
return (false);
};
auto invoker_ptr_ref = batchNormBwd_ref.MakeInvokerPointer();
(void)invoker_ptr_ref->Run(argument_ptr_ref.get());
dx_dev.FromDevice(dx.mData.data());
dscale_dev.FromDevice(dscale.data());
dbias_dev.FromDevice(dbias.data());
// clang-format off
pass = pass && ck::utils::check_err(dbias.mData, dbias_ref.mData, "dBias result:", 2e-4, 2e-4);
pass = pass && ck::utils::check_err(dscale.mData, dscale_ref.mData, "dScale result:", 2e-4, 2e-4);
pass = pass && ck::utils::check_err(dx.mData, dx_ref.mData, "dx result:");
// clang-format on
};
return (pass);
};
static const double epsilon = std::numeric_limits<float>::epsilon();
int main(int argc, char* argv[])
{
bool pass = true;
if(argc > 1)
{
BatchNormBwdArg arg;
if(arg.processArgs(argc, argv) < 0)
return (-1);
if(arg.data_type == 0)
{
if(arg.use_multiblock_welford)
pass = bnorm_bwd_nhwc_test<ck::half_t, float, true>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.haveSavedMeanInvVar,
epsilon);
else
pass = bnorm_bwd_nhwc_test<ck::half_t, float, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.haveSavedMeanInvVar,
epsilon);
}
else if(arg.data_type == 1)
{
if(arg.use_multiblock_welford)
pass = bnorm_bwd_nhwc_test<float, float, true>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.haveSavedMeanInvVar,
epsilon);
else
pass = bnorm_bwd_nhwc_test<float, float, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.haveSavedMeanInvVar,
epsilon);
}
else if(arg.data_type == 5)
{
if(arg.use_multiblock_welford)
pass = bnorm_bwd_nhwc_test<ck::bhalf_t, float, true>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.haveSavedMeanInvVar,
epsilon);
else
pass = bnorm_bwd_nhwc_test<ck::bhalf_t, float, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.haveSavedMeanInvVar,
epsilon);
}
else if(arg.data_type == 6)
{
if(arg.use_multiblock_welford)
pass = bnorm_bwd_nhwc_test<double, double, true>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.haveSavedMeanInvVar,
epsilon);
else
pass = bnorm_bwd_nhwc_test<double, double, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.haveSavedMeanInvVar,
epsilon);
}
}
else
{
pass = bnorm_bwd_nhwc_test<ck::half_t, float, true>(true,
3,
false, // don't time kernel
{128, 16, 6, 512},
false,
epsilon);
pass = pass && bnorm_bwd_nhwc_test<ck::half_t, float, false>(true,
3,
false, // don't time kernel
{128, 16, 3, 1024},
false,
epsilon);
};
return (pass ? 0 : 1);
}
...@@ -10,131 +10,17 @@ ...@@ -10,131 +10,17 @@
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
// binary operation used to calculate invVariance from mean and meansquare
struct InvVariance
{
InvVariance(double epsilon) : epsilon_(epsilon){};
template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& mean, const T& meansquare) const
{
static_assert(std::is_same<T, float>::value || std::is_same<T, double>::value,
"Data type is not supported by this operation!");
using ck::type_convert;
using ck::math::sqrt;
T tmp_epsilon = type_convert<T>(epsilon_);
y = meansquare - mean * mean;
y = 1.0f / sqrt(tmp_epsilon + y);
};
double epsilon_;
};
// (4-in, 2-out) element-wise operation used to update the moving average of mean and variance
struct MovingAverage
{
MovingAverage(double factor) : factor_(factor){};
template <typename T>
__host__ __device__ constexpr void operator()(T& y0,
T& y1,
const T& mean,
const T& runningMean,
const T& meansquare,
const T& runningVariance) const
{
static_assert(std::is_same<T, float>::value || std::is_same<T, double>::value,
"Data type is not supported by this operation!");
using ck::type_convert;
T tmp_factor = type_convert<T>(factor_);
T variance = meansquare - mean * mean;
y0 = runningMean * (type_convert<T>(1.0f) - tmp_factor) + mean * tmp_factor;
y1 = runningVariance * (type_convert<T>(1.0f) - tmp_factor) + variance * tmp_factor;
};
double factor_;
};
struct MovingAverageAndInvVariance
{
MovingAverageAndInvVariance(double epsilon, double factor)
: epsilon_(epsilon), factor_(factor){};
template <typename T>
__host__ __device__ constexpr void operator()(T& y0, // resultRunningMean
T& y1, // resultRunningVariance
T& y2, // saveInvVariance
const T& mean,
const T& runningMean,
const T& meansquare,
const T& runningVariance) const
{
static_assert(std::is_same<T, float>::value || std::is_same<T, double>::value,
"Data type is not supported by this operation!");
using ck::type_convert;
using ck::math::sqrt;
T tmp_epsilon = type_convert<T>(epsilon_);
T tmp_factor = type_convert<T>(factor_);
T variance = meansquare - mean * mean;
y0 = runningMean * (type_convert<T>(1.0f) - tmp_factor) + mean * tmp_factor;
y1 = runningVariance * (type_convert<T>(1.0f) - tmp_factor) + variance * tmp_factor;
y2 = 1.0f / sqrt(tmp_epsilon + variance);
};
double epsilon_;
double factor_;
};
struct NormalizeInInfer struct NormalizeInInfer
{ {
NormalizeInInfer(double epsilon = 1e-4) : epsilon_(epsilon) {} NormalizeInInfer(double epsilon = 1e-4) : epsilon_(epsilon) {}
template <typename T1, typename T2> template <typename T1, typename T2, typename T3, typename T4>
__host__ __device__ constexpr void operator()(T1& y, __host__ __device__ constexpr void operator()(T1& y,
const T1& x, const T1& x,
const T2& mean, const T2& mean,
const T2& variance, const T2& variance,
const T2& gamma, const T3& gamma,
const T2& beta) const const T4& beta) const
{
static_assert(std::is_same<T2, float>::value || std::is_same<T2, double>::value,
"Data type is not supported by this operation!");
using ck::type_convert;
using ck::math::sqrt;
T2 tmp_x, tmp_y;
tmp_x = type_convert<T2>(x);
tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert<T2>(epsilon_))) * gamma + beta;
y = type_convert<T1>(tmp_y);
};
double epsilon_;
};
struct NormalizeInForward
{
NormalizeInForward(double epsilon = 1e-4) : epsilon_(epsilon) {}
template <typename T1, typename T2>
__host__ __device__ constexpr void operator()(T1& y,
const T1& x,
const T2& mean,
const T2& meansquare,
const T2& gamma,
const T2& beta) const
{ {
static_assert(std::is_same<T2, float>::value || std::is_same<T2, double>::value, static_assert(std::is_same<T2, float>::value || std::is_same<T2, double>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
...@@ -143,12 +29,13 @@ struct NormalizeInForward ...@@ -143,12 +29,13 @@ struct NormalizeInForward
using ck::math::sqrt; using ck::math::sqrt;
T2 tmp_x, tmp_y; T2 tmp_x, tmp_y;
T2 variance = meansquare - mean * mean;
tmp_x = type_convert<T2>(x); tmp_x = type_convert<T2>(x);
tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert<T2>(epsilon_))) * gamma + beta; tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert<T2>(epsilon_))) *
y = type_convert<T1>(tmp_y); type_convert<T2>(gamma) +
type_convert<T2>(beta);
y = type_convert<T1>(tmp_y);
}; };
double epsilon_; double epsilon_;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cassert>
#include <vector>
#include "ck/ck.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_multiple_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "batchnorm_common.hpp"
template <typename InOutDataType,
typename AccDataType,
ck::index_t Rank,
ck::index_t NumBatchNormReduceDim,
bool fastest_dim_is_reduced = false>
int bnorm_fwd(bool time_kernel,
bool updateMovingAverage,
bool saveMeanAndInvVariance,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<ck::index_t, Rank> xyLengths,
const std::array<ck::index_t, Rank> xStrides,
const std::array<ck::index_t, Rank> yStrides,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides,
const void* p_x,
const void* p_scale,
const void* p_bias,
void* p_y,
double exponentialAverageFactor,
void* p_runningMean,
void* p_runningVariance,
double epsilon,
void* p_saveMean,
void* p_saveInvVariance,
void* p_tmp_mean,
void* p_tmp_meansquare)
{
static_assert(NumBatchNormReduceDim < Rank,
"Invalid number of reduced dimensions for batchnorm!");
constexpr ck::index_t NumScaleBiasMeanVarDim = Rank - NumBatchNormReduceDim;
using InElementwiseOperation_Mean = ck::tensor_operation::element_wise::PassThrough;
using AccElementwiseOperation_Mean = ck::tensor_operation::element_wise::UnaryDivide;
using InElementwiseOperation_Meansquare = ck::tensor_operation::element_wise::UnarySquare;
using AccElementwiseOperation_Meansquare = ck::tensor_operation::element_wise::UnaryDivide;
using DeviceMeanAndMeansquareInstance =
ck::tensor_operation::device::DeviceMultipleReduceMultiBlock<
2,
InOutDataType,
AccDataType,
ck::Tuple<AccDataType, AccDataType>,
Rank,
NumBatchNormReduceDim,
ck::reduce::Add,
ck::Tuple<InElementwiseOperation_Mean, InElementwiseOperation_Meansquare>,
ck::Tuple<AccElementwiseOperation_Mean, AccElementwiseOperation_Meansquare>,
ck::InMemoryDataOperationEnum::Set,
false, // PropagateNan
256,
16,
16,
1,
1,
fastest_dim_is_reduced ? 1 : 0,
1,
ck::Sequence<1, 1>>;
using DeviceNormalizeInstance = ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<InOutDataType, AccDataType, AccDataType, AccDataType, AccDataType>, // x, mean,
// meansquare,
// scale, bias
ck::Tuple<InOutDataType>, // y
NormalizeInForward,
Rank,
2, // MPerthread
ck::Sequence<1, 1, 1, 1, 1>, // scalarPerVector: x, mean, meansquare, scale, bias
ck::Sequence<1>>; // scalarPerVector: y
using DeviceInvVarianceInstance = ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<AccDataType, AccDataType>, // mean, meansquare
ck::Tuple<AccDataType>, // invVariance
InvVariance,
NumScaleBiasMeanVarDim,
2, // MPerthread
ck::Sequence<1, 1>, // scalarPerVector: mean, meansquare
ck::Sequence<1>>; // scalarPerVector: invVariance
using DeviceMovingAverageInstance = ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<AccDataType, AccDataType, AccDataType, AccDataType>, // old moving mean, new mean,
// old moving variance, new
// meansquare
ck::Tuple<AccDataType, AccDataType>, // updated moving mean, updated moving variance
MovingAverage,
NumScaleBiasMeanVarDim,
4, // MPerthread
ck::Sequence<1, 1, 1, 1>, // scalarPerVector: old moving mean, new mean, old moving
// variance, new meansquare
ck::Sequence<1, 1>>; // scalarPerVector: updated moving mean, updated moving variance
using DeviceMovingAverageAndInvVarianceInstance =
ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<AccDataType, AccDataType, AccDataType, AccDataType>, // old moving mean, new
// mean, old moving
// variance, new
// meansquare
ck::Tuple<AccDataType, AccDataType, AccDataType>, // updated moving mean, updated moving
// variancem, invVariance
MovingAverageAndInvVariance,
NumScaleBiasMeanVarDim,
4, // MPerthread
ck::Sequence<1, 1, 1, 1>, // scalarPerVector: old moving mean, new mean, old moving
// variance, new meansquare
ck::Sequence<1, 1, 1>>; // scalarPerVector: updated moving mean, updated moving variance
auto invariantDims = get_invariant_dims<Rank, NumBatchNormReduceDim>(reduceDims);
std::array<ck::index_t, Rank> aligned_scaleBiasMeanVarStrides{0};
int i = 0;
for(auto dim : invariantDims)
{
assert(xyLengths[dim] == bnScaleBiasMeanVarLengths[i]);
aligned_scaleBiasMeanVarStrides[dim] = bnScaleBiasMeanVarStrides[i];
i++;
};
int32_t reduceLength = 1;
for(auto dim : reduceDims)
reduceLength *= xyLengths[dim];
int32_t invariantLength = 1;
for(auto dim : invariantDims)
invariantLength *= xyLengths[dim];
size_t total_length = static_cast<size_t>(invariantLength) * reduceLength;
float avg_time = 0.0f;
std::size_t num_bytes = 0;
auto dev_mean_and_meansquare = DeviceMeanAndMeansquareInstance{};
void* p_mean = saveMeanAndInvVariance ? p_saveMean : p_tmp_mean;
const AccDataType alpha = ck::type_convert<AccDataType>(1.0f);
const AccDataType beta = ck::type_convert<AccDataType>(0.0f);
auto argument_ptr1 = dev_mean_and_meansquare.MakeArgumentPointer(
xyLengths,
xStrides,
bnScaleBiasMeanVarLengths,
{bnScaleBiasMeanVarStrides, bnScaleBiasMeanVarStrides},
reduceDims,
{&alpha, &alpha},
{&beta, &beta},
p_x,
{p_mean, p_tmp_meansquare},
ck::make_tuple(InElementwiseOperation_Mean{}, InElementwiseOperation_Meansquare{}),
ck::make_tuple(AccElementwiseOperation_Mean{reduceLength},
AccElementwiseOperation_Meansquare{reduceLength}));
auto dev_normalize = DeviceNormalizeInstance{};
auto argument_ptr2 =
dev_normalize.MakeArgumentPointer(xyLengths,
{xStrides,
aligned_scaleBiasMeanVarStrides,
aligned_scaleBiasMeanVarStrides,
aligned_scaleBiasMeanVarStrides,
aligned_scaleBiasMeanVarStrides},
{yStrides},
{p_x, p_mean, p_tmp_meansquare, p_scale, p_bias},
{p_y},
NormalizeInForward{epsilon});
if(!dev_mean_and_meansquare.IsSupportedArgument(argument_ptr1.get()) ||
!dev_normalize.IsSupportedArgument(argument_ptr2.get()))
{
std::cout << "The runtime parameters seems not supported by the Devic, exiting!"
<< std::endl;
return (-1);
};
auto invoker_ptr1 = dev_mean_and_meansquare.MakeInvokerPointer();
auto invoker_ptr2 = dev_normalize.MakeInvokerPointer();
avg_time += invoker_ptr1->Run(argument_ptr1.get(), StreamConfig{nullptr, time_kernel});
avg_time += invoker_ptr2->Run(argument_ptr2.get(), StreamConfig{nullptr, time_kernel});
num_bytes +=
(total_length * sizeof(InOutDataType) + invariantLength * 2 * sizeof(AccDataType)) + // No.1
(total_length * (1 * sizeof(InOutDataType) + 4 * sizeof(AccDataType)) +
total_length * sizeof(InOutDataType)); // No.2
if(saveMeanAndInvVariance && updateMovingAverage)
{
auto dev_moving_average_inv_variance = DeviceMovingAverageAndInvVarianceInstance{};
auto argument_ptr3 = dev_moving_average_inv_variance.MakeArgumentPointer(
bnScaleBiasMeanVarLengths,
{bnScaleBiasMeanVarStrides,
bnScaleBiasMeanVarStrides,
bnScaleBiasMeanVarStrides,
bnScaleBiasMeanVarStrides},
{bnScaleBiasMeanVarStrides, bnScaleBiasMeanVarStrides, bnScaleBiasMeanVarStrides},
{p_mean, p_runningMean, p_tmp_meansquare, p_runningVariance},
{p_runningMean, p_runningVariance, p_saveInvVariance},
MovingAverageAndInvVariance{epsilon, exponentialAverageFactor});
if(!dev_moving_average_inv_variance.IsSupportedArgument(argument_ptr3.get()))
{
std::cout << "Runtime parameters not supported by the Device, exiting!" << std::endl;
return (-1);
};
auto invoker_ptr3 = dev_moving_average_inv_variance.MakeInvokerPointer();
avg_time += invoker_ptr3->Run(argument_ptr3.get(), StreamConfig{nullptr, time_kernel});
num_bytes += invariantLength * (4 + 3) * sizeof(AccDataType) * 2; // No.5
}
else if(saveMeanAndInvVariance)
{
auto dev_inv_variance = DeviceInvVarianceInstance{};
auto argument_ptr3 = dev_inv_variance.MakeArgumentPointer(
bnScaleBiasMeanVarLengths,
{bnScaleBiasMeanVarStrides, bnScaleBiasMeanVarStrides},
{bnScaleBiasMeanVarStrides},
{p_mean, p_tmp_meansquare},
{p_saveInvVariance},
InvVariance{epsilon});
if(!dev_inv_variance.IsSupportedArgument(argument_ptr3.get()))
{
std::cout << "Runtime parameters not supported by the Device, exiting!" << std::endl;
return (-1);
};
auto invoker_ptr3 = dev_inv_variance.MakeInvokerPointer();
avg_time += invoker_ptr3->Run(argument_ptr3.get(), StreamConfig{nullptr, time_kernel});
num_bytes += invariantLength * (2 + 1) * sizeof(AccDataType);
}
else if(updateMovingAverage)
{
auto dev_moving_average = DeviceMovingAverageInstance{};
auto argument_ptr3 = dev_moving_average.MakeArgumentPointer(
bnScaleBiasMeanVarLengths,
{bnScaleBiasMeanVarStrides,
bnScaleBiasMeanVarStrides,
bnScaleBiasMeanVarStrides,
bnScaleBiasMeanVarStrides},
{bnScaleBiasMeanVarStrides, bnScaleBiasMeanVarStrides},
{p_mean, p_runningMean, p_tmp_meansquare, p_runningVariance},
{p_runningMean, p_runningVariance},
MovingAverage{exponentialAverageFactor});
if(!dev_moving_average.IsSupportedArgument(argument_ptr3.get()))
{
std::cout << "Runtime parameters not supported by the Device, exiting!" << std::endl;
return (-1);
};
auto invoker_ptr3 = dev_moving_average.MakeInvokerPointer();
avg_time += invoker_ptr3->Run(argument_ptr3.get(), StreamConfig{nullptr, time_kernel});
num_bytes += invariantLength * (4 + 2) * sizeof(AccDataType) * 2; // No.5
};
if(time_kernel)
{
float gb_per_sec = num_bytes / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
};
return (0);
};
...@@ -9,20 +9,17 @@ ...@@ -9,20 +9,17 @@
#include <getopt.h> #include <getopt.h>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp" #include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer_nhwc_c.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer.hpp"
#include "batchnorm_infer_impl.hpp" #include "batchnorm_infer_impl.hpp"
template <typename InOutDataType, typename AccDataType>
using ReferenceBatchNormInferInstance =
ck::tensor_operation::host::ReferenceBatchNormInfer_Input_N_H_W_C_Output_C<InOutDataType,
AccDataType>;
static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'}, static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'},
{"verify", required_argument, nullptr, 'v'}, {"verify", required_argument, nullptr, 'v'},
{"help", no_argument, nullptr, '?'}, {"help", no_argument, nullptr, '?'},
...@@ -128,6 +125,8 @@ bool bnorm_infer_nhwc_test(bool do_verification, ...@@ -128,6 +125,8 @@ bool bnorm_infer_nhwc_test(bool do_verification,
constexpr int Rank = 4; constexpr int Rank = 4;
constexpr int NumReduceDim = 3; constexpr int NumReduceDim = 3;
// when using lengths[] to create a tensor, lengths[0] is the length of highest dimension
// eg. N of NHWC, so lengths[3] is the dimension C length of NHWC
const std::vector<size_t> scaleBiasMeanVarLengths = {inOutLengths[3]}; const std::vector<size_t> scaleBiasMeanVarLengths = {inOutLengths[3]};
// input data of the batchnorm forward algorithm // input data of the batchnorm forward algorithm
...@@ -225,32 +224,37 @@ bool bnorm_infer_nhwc_test(bool do_verification, ...@@ -225,32 +224,37 @@ bool bnorm_infer_nhwc_test(bool do_verification,
std::array<index_t, Rank - NumReduceDim> i_scaleBiasMeanVarLengths; std::array<index_t, Rank - NumReduceDim> i_scaleBiasMeanVarLengths;
std::array<index_t, Rank - NumReduceDim> i_scaleBiasMeanVarStrides; std::array<index_t, Rank - NumReduceDim> i_scaleBiasMeanVarStrides;
std::copy(inOutLengths.begin(), inOutLengths.end(), i_inOutLengths.begin()); ck::ranges::copy(inOutLengths, i_inOutLengths.begin());
std::copy(inOutStrides.begin(), inOutStrides.end(), i_inOutStrides.begin()); ck::ranges::copy(inOutStrides, i_inOutStrides.begin());
std::copy(scaleBiasMeanVarLengths.begin(), ck::ranges::copy(scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths.begin());
scaleBiasMeanVarLengths.end(), ck::ranges::copy(scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides.begin());
i_scaleBiasMeanVarLengths.begin());
std::copy(scaleBiasMeanVarStrides.begin(),
scaleBiasMeanVarStrides.end(),
i_scaleBiasMeanVarStrides.begin());
int result = 0; int result = 0;
result = bnorm_infer<InOutDataType, AccDataType, Rank, NumReduceDim, false>( result = bnorm_infer<InOutDataType,
time_kernel, InOutDataType,
{0, 1, 2}, AccDataType,
i_inOutLengths, AccDataType,
i_inOutStrides, AccDataType,
i_inOutStrides, AccDataType,
i_scaleBiasMeanVarLengths, Rank,
i_scaleBiasMeanVarStrides, NumReduceDim,
x_dev.GetDeviceBuffer(), false>(time_kernel,
bnScale_dev.GetDeviceBuffer(), {0, 1, 2},
bnBias_dev.GetDeviceBuffer(), i_inOutLengths,
epsilon, i_inOutStrides,
estimatedMean_dev.GetDeviceBuffer(), i_inOutStrides,
estimatedVariance_dev.GetDeviceBuffer(), i_scaleBiasMeanVarLengths,
y_dev.GetDeviceBuffer()); i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
x_dev.GetDeviceBuffer(),
bnScale_dev.GetDeviceBuffer(),
bnBias_dev.GetDeviceBuffer(),
epsilon,
estimatedMean_dev.GetDeviceBuffer(),
estimatedVariance_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer());
if(result < 0) if(result < 0)
return (false); return (false);
...@@ -259,18 +263,34 @@ bool bnorm_infer_nhwc_test(bool do_verification, ...@@ -259,18 +263,34 @@ bool bnorm_infer_nhwc_test(bool do_verification,
if(do_verification) if(do_verification)
{ {
auto batchNormInfer_ref = ReferenceBatchNormInferInstance<InOutDataType, AccDataType>{}; using PassThroughOp = ck::tensor_operation::element_wise::PassThrough;
using ReferenceBatchNormInferInstance =
ck::tensor_operation::host::ReferenceBatchNormInfer<InOutDataType,
InOutDataType,
AccDataType,
AccDataType,
AccDataType,
AccDataType,
PassThroughOp,
Rank,
NumReduceDim>;
auto batchNormInfer_ref = ReferenceBatchNormInferInstance{};
auto argument_ptr_ref = auto argument_ptr_ref =
batchNormInfer_ref.MakeArgumentPointer(i_inOutLengths, batchNormInfer_ref.MakeArgumentPointer(i_inOutLengths,
i_inOutStrides, i_inOutStrides,
i_inOutStrides, i_inOutStrides,
{0, 1, 2},
i_scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
x.mData.data(), x.mData.data(),
bnScale.mData.data(), bnScale.mData.data(),
bnBias.mData.data(), bnBias.mData.data(),
epsilon, epsilon,
PassThroughOp{},
estimatedMean.mData.data(), estimatedMean.mData.data(),
estimatedVariance.mData.data(), estimatedVariance.mData.data(),
y_ref.mData.data()); y_ref.mData.data());
...@@ -288,7 +308,7 @@ bool bnorm_infer_nhwc_test(bool do_verification, ...@@ -288,7 +308,7 @@ bool bnorm_infer_nhwc_test(bool do_verification,
(void)invoker_ptr_ref->Run(argument_ptr_ref.get()); (void)invoker_ptr_ref->Run(argument_ptr_ref.get());
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.mData.data());
pass = pass && ck::utils::check_err(y.mData, y_ref.mData); pass = pass && ck::utils::check_err(y, y_ref);
}; };
return (pass); return (pass);
......
...@@ -9,19 +9,16 @@ ...@@ -9,19 +9,16 @@
#include <getopt.h> #include <getopt.h>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp" #include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "batchnorm_forward_impl.hpp" #include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
template <typename InOutDataType, typename AccDataType>
using ReferenceBatchNormFwdInstance =
ck::tensor_operation::host::ReferenceBatchNormFwd_Input_N_H_W_C_Output_C<InOutDataType,
AccDataType>;
static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'}, static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'},
{"verify", required_argument, nullptr, 'v'}, {"verify", required_argument, nullptr, 'v'},
...@@ -41,9 +38,10 @@ class BatchNormFwdArg ...@@ -41,9 +38,10 @@ class BatchNormFwdArg
bool updateMovingAverage; bool updateMovingAverage;
bool saveMeanAndInvVariance; bool saveMeanAndInvVariance;
int data_type = 0; int data_type = 0;
int init_method = 2; int init_method = 2;
bool time_kernel = false; bool time_kernel = false;
bool use_multiblock_welford = false;
public: public:
void show_usage(const char* cmd) void show_usage(const char* cmd)
...@@ -68,6 +66,7 @@ class BatchNormFwdArg ...@@ -68,6 +66,7 @@ class BatchNormFwdArg
"value, 3=decimal value)" "value, 3=decimal value)"
<< std::endl; << std::endl;
std::cout << "Arg5: time kernel (0=no, 1=yes)" << std::endl; std::cout << "Arg5: time kernel (0=no, 1=yes)" << std::endl;
std::cout << "Arg6: use multi-block welford (0=n0, 1=yes)" << std::endl;
}; };
int processArgs(int argc, char* argv[]) int processArgs(int argc, char* argv[])
...@@ -110,14 +109,15 @@ class BatchNormFwdArg ...@@ -110,14 +109,15 @@ class BatchNormFwdArg
}; };
}; };
if(optind + 5 > argc) if(optind + 6 > argc)
throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!");
data_type = std::atoi(argv[optind++]); data_type = std::atoi(argv[optind++]);
updateMovingAverage = std::atoi(argv[optind++]); updateMovingAverage = std::atoi(argv[optind++]);
saveMeanAndInvVariance = std::atoi(argv[optind++]); saveMeanAndInvVariance = std::atoi(argv[optind++]);
init_method = std::atoi(argv[optind++]); init_method = std::atoi(argv[optind++]);
time_kernel = static_cast<bool>(std::atoi(argv[optind])); time_kernel = static_cast<bool>(std::atoi(argv[optind++]));
use_multiblock_welford = static_cast<bool>(std::atoi(argv[optind]));
if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6) if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6)
return (-1); return (-1);
...@@ -128,7 +128,7 @@ class BatchNormFwdArg ...@@ -128,7 +128,7 @@ class BatchNormFwdArg
using namespace ck; using namespace ck;
template <typename InOutDataType, typename AccDataType> template <typename InOutDataType, typename AccDataType, bool UseMultiblockInK>
bool bnorm_fwd_nhwc_test(bool do_verification, bool bnorm_fwd_nhwc_test(bool do_verification,
int init_method, int init_method,
bool time_kernel, bool time_kernel,
...@@ -142,6 +142,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -142,6 +142,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
constexpr int Rank = 4; constexpr int Rank = 4;
constexpr int NumReduceDim = 3; constexpr int NumReduceDim = 3;
// when using lengths[] to create a tensor, lengths[0] is the length of highest dimension
// eg. N of NHWC, so lengths[3] is the dimension C length of NHWC
const std::vector<size_t> scaleBiasMeanVarLengths = {inOutLengths[3]}; const std::vector<size_t> scaleBiasMeanVarLengths = {inOutLengths[3]};
// input data of the batchnorm forward algorithm // input data of the batchnorm forward algorithm
...@@ -264,82 +266,147 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -264,82 +266,147 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
std::array<index_t, Rank - NumReduceDim> i_scaleBiasMeanVarLengths; std::array<index_t, Rank - NumReduceDim> i_scaleBiasMeanVarLengths;
std::array<index_t, Rank - NumReduceDim> i_scaleBiasMeanVarStrides; std::array<index_t, Rank - NumReduceDim> i_scaleBiasMeanVarStrides;
std::copy(inOutLengths.begin(), inOutLengths.end(), i_inOutLengths.begin()); ck::ranges::copy(inOutLengths, i_inOutLengths.begin());
std::copy(inOutStrides.begin(), inOutStrides.end(), i_inOutStrides.begin()); ck::ranges::copy(inOutStrides, i_inOutStrides.begin());
std::copy(scaleBiasMeanVarLengths.begin(), ck::ranges::copy(scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths.begin());
scaleBiasMeanVarLengths.end(), ck::ranges::copy(scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides.begin());
i_scaleBiasMeanVarLengths.begin());
std::copy(scaleBiasMeanVarStrides.begin(), using PassThroughOp = ck::tensor_operation::element_wise::PassThrough;
scaleBiasMeanVarStrides.end(),
i_scaleBiasMeanVarStrides.begin()); using DeviceBatchNormFwdInstance =
ck::tensor_operation::device::DeviceBatchNormFwdImpl<InOutDataType,
int result = 0; InOutDataType,
AccDataType,
// used for saving meansquare AccDataType, // ScaleDataType
DeviceMem workspace(sizeof(AccDataType) * 2 * resultSaveMean_ref.mDesc.GetElementSpaceSize() + AccDataType, // BiasDataType
128); AccDataType, // MeanVarDataType
PassThroughOp, // YElementwiseOp
void* p_tmp_mean = workspace.GetDeviceBuffer(); Rank,
void* p_tmp_meansquare = NumReduceDim,
static_cast<char*>(p_tmp_mean) + UseMultiblockInK,
(sizeof(AccDataType) * resultSaveMean_ref.mDesc.GetElementSpaceSize() + 63) / 64 * 64; 256,
16,
result = bnorm_fwd<InOutDataType, AccDataType, Rank, NumReduceDim, false>( 16,
time_kernel, 1,
updateMovingAverage, 2,
saveMeanAndInvVariance, 0,
{0, 1, 2}, 1,
1,
1,
1,
1>;
auto batchnorm_fwd = DeviceBatchNormFwdInstance{};
auto argument_ptr = batchnorm_fwd.MakeArgumentPointer(
i_inOutLengths, i_inOutLengths,
i_inOutStrides, i_inOutStrides,
i_inOutStrides, i_inOutStrides,
{0, 1, 2}, // indicates physical indices of reduce dimensions in lengths[] and strides[]
i_scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
x_dev.GetDeviceBuffer(), x_dev.GetDeviceBuffer(),
bnScale_dev.GetDeviceBuffer(), bnScale_dev.GetDeviceBuffer(),
bnBias_dev.GetDeviceBuffer(), bnBias_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(),
averageFactor,
updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr,
updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr,
epsilon, epsilon,
PassThroughOp{},
y_dev.GetDeviceBuffer(),
saveMeanAndInvVariance ? resultSaveMean_dev.GetDeviceBuffer() : nullptr, saveMeanAndInvVariance ? resultSaveMean_dev.GetDeviceBuffer() : nullptr,
saveMeanAndInvVariance ? resultSaveInvVariance_dev.GetDeviceBuffer() : nullptr, saveMeanAndInvVariance ? resultSaveInvVariance_dev.GetDeviceBuffer() : nullptr,
p_tmp_mean, averageFactor,
p_tmp_meansquare); updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr,
updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr);
if(result < 0) if(!batchnorm_fwd.IsSupportedArgument(argument_ptr.get()))
{
std::cout << "The runtime parameters seems not supported by the BatchNorm device instance, "
"exiting!"
<< std::endl;
return (false); return (false);
};
size_t workspace_sz = batchnorm_fwd.GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
batchnorm_fwd.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
auto invoker_ptr = batchnorm_fwd.MakeInvokerPointer();
if(time_kernel)
{
float avg_time = 0.0f;
size_t num_bytes = 0;
size_t total_length = inOutLengths[0] * inOutLengths[1] * inOutLengths[2] * inOutLengths[3];
size_t invariant_length = inOutLengths[3];
avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
// inputing of x, scale, bias, outputing of y
num_bytes +=
total_length * sizeof(InOutDataType) * 2 + invariant_length * sizeof(AccDataType) * 2;
// outputing of mean, inv-variance
num_bytes += saveMeanAndInvVariance ? invariant_length * sizeof(AccDataType) * 2 : 0;
// updating of moving mean, variance
num_bytes += updateMovingAverage ? invariant_length * sizeof(AccDataType) * 4 : 0;
float gb_per_sec = num_bytes / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
}
else
(void)invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
bool pass = true; bool pass = true;
if(do_verification) if(do_verification)
{ {
auto batchNormFwd_ref = ReferenceBatchNormFwdInstance<InOutDataType, AccDataType>{};
using ReferenceBatchNormFwdInstance =
ck::tensor_operation::host::ReferenceBatchNormFwd<InOutDataType,
InOutDataType,
AccDataType,
AccDataType,
AccDataType,
AccDataType,
PassThroughOp,
Rank,
NumReduceDim>;
auto batchNormFwd_ref = ReferenceBatchNormFwdInstance{};
auto argument_ptr_ref = batchNormFwd_ref.MakeArgumentPointer( auto argument_ptr_ref = batchNormFwd_ref.MakeArgumentPointer(
i_inOutLengths, i_inOutLengths,
i_inOutStrides, i_inOutStrides,
i_inOutStrides, i_inOutStrides,
{0, 1, 2}, // indicates physical indices of reduce dimensions in lengths[] and strides[]
i_scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
i_scaleBiasMeanVarStrides,
x.mData.data(), x.mData.data(),
bnScale.mData.data(), bnScale.mData.data(),
bnBias.mData.data(), bnBias.mData.data(),
y_ref.mData.data(),
0.1, // exponentialAverageFactor
updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr, // resultRunningMean
updateMovingAverage ? resultRunningVariance_ref.mData.data()
: nullptr, // resultRunningVariance
epsilon, epsilon,
PassThroughOp{},
y_ref.mData.data(),
saveMeanAndInvVariance ? resultSaveMean_ref.mData.data() : nullptr, saveMeanAndInvVariance ? resultSaveMean_ref.mData.data() : nullptr,
saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr); saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr,
averageFactor,
updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr,
updateMovingAverage ? resultRunningVariance_ref.mData.data() : nullptr);
if(!batchNormFwd_ref.IsSupportedArgument(argument_ptr_ref.get())) if(!batchNormFwd_ref.IsSupportedArgument(argument_ptr_ref.get()))
{ {
std::cout std::cout << "The runtime parameters seems not supported by the BatchNorm reference "
<< "The runtime parameters seems not supported by the BatchNorm instance, exiting!" "instance, exiting!"
<< std::endl; << std::endl;
return (-2); return (false);
}; };
auto invoker_ptr_ref = batchNormFwd_ref.MakeInvokerPointer(); auto invoker_ptr_ref = batchNormFwd_ref.MakeInvokerPointer();
...@@ -347,7 +414,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -347,7 +414,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
(void)invoker_ptr_ref->Run(argument_ptr_ref.get()); (void)invoker_ptr_ref->Run(argument_ptr_ref.get());
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.mData.data());
pass = pass && ck::utils::check_err(y.mData, y_ref.mData); pass = pass && ck::utils::check_err(y, y_ref);
if(updateMovingAverage) if(updateMovingAverage)
{ {
...@@ -357,23 +424,22 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -357,23 +424,22 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
resultRunningMean_dev.FromDevice(resultRunningMean.mData.data()); resultRunningMean_dev.FromDevice(resultRunningMean.mData.data());
resultRunningVariance_dev.FromDevice(resultRunningVariance.mData.data()); resultRunningVariance_dev.FromDevice(resultRunningVariance.mData.data());
pass = pass = pass && ck::utils::check_err(resultRunningMean, resultRunningMean_ref);
pass && ck::utils::check_err(resultRunningMean.mData, resultRunningMean_ref.mData); pass = pass && ck::utils::check_err(resultRunningVariance, resultRunningVariance_ref);
pass = pass && ck::utils::check_err(resultRunningVariance.mData,
resultRunningVariance_ref.mData);
}; };
if(saveMeanAndInvVariance) if(saveMeanAndInvVariance)
{ {
using ck::host_common::dumpBufferToFile;
Tensor<AccDataType> resultSaveMean(scaleBiasMeanVarLengths); Tensor<AccDataType> resultSaveMean(scaleBiasMeanVarLengths);
Tensor<AccDataType> resultSaveInvVariance(scaleBiasMeanVarLengths); Tensor<AccDataType> resultSaveInvVariance(scaleBiasMeanVarLengths);
resultSaveMean_dev.FromDevice(resultSaveMean.mData.data()); resultSaveMean_dev.FromDevice(resultSaveMean.mData.data());
resultSaveInvVariance_dev.FromDevice(resultSaveInvVariance.mData.data()); resultSaveInvVariance_dev.FromDevice(resultSaveInvVariance.mData.data());
pass = pass && ck::utils::check_err(resultSaveMean.mData, resultSaveMean_ref.mData); pass = pass && ck::utils::check_err(resultSaveMean, resultSaveMean_ref);
pass = pass && ck::utils::check_err(resultSaveInvVariance.mData, pass = pass && ck::utils::check_err(resultSaveInvVariance, resultSaveInvVariance_ref);
resultSaveInvVariance_ref.mData);
}; };
}; };
...@@ -396,70 +462,129 @@ int main(int argc, char* argv[]) ...@@ -396,70 +462,129 @@ int main(int argc, char* argv[])
if(arg.data_type == 0) if(arg.data_type == 0)
{ {
pass = bnorm_fwd_nhwc_test<ck::half_t, float>(arg.do_verification, if(arg.use_multiblock_welford)
arg.init_method, pass = bnorm_fwd_nhwc_test<ck::half_t, float, true>(arg.do_verification,
arg.time_kernel, arg.init_method,
arg.inOutLengths, arg.time_kernel,
arg.updateMovingAverage, arg.inOutLengths,
arg.saveMeanAndInvVariance, arg.updateMovingAverage,
averageFactor, arg.saveMeanAndInvVariance,
epsilon); averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<ck::half_t, float, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
} }
else if(arg.data_type == 1) else if(arg.data_type == 1)
{ {
pass = bnorm_fwd_nhwc_test<float, float>(arg.do_verification, if(arg.use_multiblock_welford)
arg.init_method, pass = bnorm_fwd_nhwc_test<float, float, true>(arg.do_verification,
arg.time_kernel, arg.init_method,
arg.inOutLengths, arg.time_kernel,
arg.updateMovingAverage, arg.inOutLengths,
arg.saveMeanAndInvVariance, arg.updateMovingAverage,
averageFactor, arg.saveMeanAndInvVariance,
epsilon); averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<float, float, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
} }
else if(arg.data_type == 3) else if(arg.data_type == 3)
{ {
pass = bnorm_fwd_nhwc_test<int8_t, float>(arg.do_verification, if(arg.use_multiblock_welford)
arg.init_method, pass = bnorm_fwd_nhwc_test<int8_t, float, true>(arg.do_verification,
arg.time_kernel, arg.init_method,
arg.inOutLengths, arg.time_kernel,
arg.updateMovingAverage, arg.inOutLengths,
arg.saveMeanAndInvVariance, arg.updateMovingAverage,
averageFactor, arg.saveMeanAndInvVariance,
epsilon); averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<int8_t, float, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
} }
else if(arg.data_type == 5) else if(arg.data_type == 5)
{ {
pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float>(arg.do_verification, if(arg.use_multiblock_welford)
arg.init_method, pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float, true>(arg.do_verification,
arg.time_kernel, arg.init_method,
arg.inOutLengths, arg.time_kernel,
arg.updateMovingAverage, arg.inOutLengths,
arg.saveMeanAndInvVariance, arg.updateMovingAverage,
averageFactor, arg.saveMeanAndInvVariance,
epsilon); averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
} }
else if(arg.data_type == 6) else if(arg.data_type == 6)
{ {
pass = bnorm_fwd_nhwc_test<double, double>(arg.do_verification, if(arg.use_multiblock_welford)
arg.init_method, pass = bnorm_fwd_nhwc_test<double, double, true>(arg.do_verification,
arg.time_kernel, arg.init_method,
arg.inOutLengths, arg.time_kernel,
arg.updateMovingAverage, arg.inOutLengths,
arg.saveMeanAndInvVariance, arg.updateMovingAverage,
averageFactor, arg.saveMeanAndInvVariance,
epsilon); averageFactor,
epsilon);
else
pass = bnorm_fwd_nhwc_test<double, double, false>(arg.do_verification,
arg.init_method,
arg.time_kernel,
arg.inOutLengths,
arg.updateMovingAverage,
arg.saveMeanAndInvVariance,
averageFactor,
epsilon);
} }
} }
else else
{ {
pass = bnorm_fwd_nhwc_test<ck::half_t, float>(true, pass = bnorm_fwd_nhwc_test<ck::half_t, float, true>(true,
2, 2,
false, // don't time kernel false, // don't time kernel
{128, 16, 16, 1024}, {128, 16, 6, 512},
true, true,
false, true,
averageFactor, averageFactor,
epsilon); epsilon);
pass = pass && bnorm_fwd_nhwc_test<ck::half_t, float, false>(true,
2,
false, // don't time kernel
{128, 16, 3, 1024},
true,
true,
averageFactor,
epsilon);
}; };
return (pass ? 0 : 1); return (pass ? 0 : 1);
......
...@@ -10,12 +10,16 @@ ...@@ -10,12 +10,16 @@
#include "ck/utility/sequence.hpp" #include "ck/utility/sequence.hpp"
#include "ck/utility/tuple.hpp" #include "ck/utility/tuple.hpp"
#include "ck/utility/reduction_operator.hpp" #include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp"
#include "batchnorm_common.hpp" #include "batchnorm_common.hpp"
template <typename InOutDataType, template <typename XDataType,
typename YDataType,
typename AccDataType, typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
ck::index_t Rank, ck::index_t Rank,
ck::index_t NumBatchNormReduceDim, ck::index_t NumBatchNormReduceDim,
bool fastest_dim_is_reduced = false> bool fastest_dim_is_reduced = false>
...@@ -26,7 +30,9 @@ int bnorm_infer( ...@@ -26,7 +30,9 @@ int bnorm_infer(
const std::array<ck::index_t, Rank> xStrides, const std::array<ck::index_t, Rank> xStrides,
const std::array<ck::index_t, Rank> yStrides, const std::array<ck::index_t, Rank> yStrides,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths, const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides, const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
const void* p_x, const void* p_x,
const void* p_scale, const void* p_scale,
const void* p_bias, const void* p_bias,
...@@ -40,12 +46,12 @@ int bnorm_infer( ...@@ -40,12 +46,12 @@ int bnorm_infer(
static_assert(NumBatchNormReduceDim < Rank, static_assert(NumBatchNormReduceDim < Rank,
"Invalid number of reduced dimensions for batchnorm!"); "Invalid number of reduced dimensions for batchnorm!");
using DeviceNormalizeInstance = ck::tensor_operation::device::DeviceElementwise< using DeviceNormalizeInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
ck::Tuple<InOutDataType, AccDataType, AccDataType, AccDataType, AccDataType>, // x, mean, ck::Tuple<XDataType, AccDataType, AccDataType, AccDataType, AccDataType>, // x, mean,
// variance, // variance,
// scale, // scale,
// bias, // bias,
ck::Tuple<InOutDataType>, // y ck::Tuple<YDataType>, // y
NormalizeInInfer, NormalizeInInfer,
Rank, Rank,
2, // MPerthread 2, // MPerthread
...@@ -53,14 +59,18 @@ int bnorm_infer( ...@@ -53,14 +59,18 @@ int bnorm_infer(
ck::Sequence<1>>; // scalarPerVector: y ck::Sequence<1>>; // scalarPerVector: y
auto invariantDims = get_invariant_dims<Rank, NumBatchNormReduceDim>(reduceDims); auto invariantDims = get_invariant_dims<Rank, NumBatchNormReduceDim>(reduceDims);
std::array<ck::index_t, Rank> aligned_scaleBiasMeanVarStrides{0}; std::array<ck::index_t, Rank> aligned_bnScaleStrides{0};
std::array<ck::index_t, Rank> aligned_bnBiasStrides{0};
std::array<ck::index_t, Rank> aligned_bnMeanVarStrides{0};
int i = 0; int i = 0;
for(auto dim : invariantDims) for(auto dim : invariantDims)
{ {
assert(xyLengths[dim] == bnScaleBiasMeanVarLengths[i]); assert(xyLengths[dim] == bnScaleBiasMeanVarLengths[i]);
aligned_scaleBiasMeanVarStrides[dim] = bnScaleBiasMeanVarStrides[i]; aligned_bnScaleStrides[dim] = bnScaleStrides[i];
aligned_bnBiasStrides[dim] = bnBiasStrides[i];
aligned_bnMeanVarStrides[dim] = bnMeanVarStrides[i];
i++; i++;
}; };
...@@ -84,10 +94,10 @@ int bnorm_infer( ...@@ -84,10 +94,10 @@ int bnorm_infer(
auto argument_ptr1 = dev_normalize.MakeArgumentPointer( auto argument_ptr1 = dev_normalize.MakeArgumentPointer(
xyLengths, xyLengths,
{xStrides, {xStrides,
aligned_scaleBiasMeanVarStrides, aligned_bnMeanVarStrides,
aligned_scaleBiasMeanVarStrides, aligned_bnMeanVarStrides,
aligned_scaleBiasMeanVarStrides, aligned_bnScaleStrides,
aligned_scaleBiasMeanVarStrides}, aligned_bnBiasStrides},
{yStrides}, {yStrides},
{p_x, p_estimatedMean, p_estimatedVariance, p_scale, p_bias}, {p_x, p_estimatedMean, p_estimatedVariance, p_scale, p_bias},
{p_y}, {p_y},
...@@ -105,8 +115,10 @@ int bnorm_infer( ...@@ -105,8 +115,10 @@ int bnorm_infer(
avg_time += invoker_ptr1->Run(argument_ptr1.get(), StreamConfig{nullptr, time_kernel}); avg_time += invoker_ptr1->Run(argument_ptr1.get(), StreamConfig{nullptr, time_kernel});
num_bytes += (total_length * (1 * sizeof(InOutDataType) + 4 * sizeof(AccDataType)) + num_bytes += total_length * sizeof(XDataType) +
total_length * sizeof(InOutDataType)); invariantLength *
(sizeof(ScaleDataType) + sizeof(BiasDataType) + 2 * sizeof(MeanVarDataType)) +
total_length * sizeof(YDataType);
if(time_kernel) if(time_kernel)
{ {
......
...@@ -34,15 +34,15 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con ...@@ -34,15 +34,15 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value) if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({row, col}), return HostTensorDescriptor({row, col}, {stride, 1_uz});
std::vector<std::size_t>({stride, 1}));
} }
else else
{ {
return HostTensorDescriptor(std::vector<std::size_t>({row, col}), return HostTensorDescriptor({row, col}, {1_uz, stride});
std::vector<std::size_t>({1, stride}));
} }
}; };
...@@ -146,15 +146,12 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con ...@@ -146,15 +146,12 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
if(std::is_same<CDataType, ck::half_t>::value) if(std::is_same<CDataType, ck::half_t>::value)
{ {
pass &= ck::utils::check_err(c_m_n_device_result.mData, pass &= ck::utils::check_err(
c_m_n_host_result.mData, c_m_n_device_result, c_m_n_host_result, "fp16 incorrect result", 3e-3, 1e-3);
"fp16 incorrect result",
3e-3,
1e-3);
} }
else else
{ {
pass &= ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); pass &= ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
} }
} }
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
......
...@@ -9,7 +9,8 @@ ...@@ -9,7 +9,8 @@
#include <ctime> #include <ctime>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_sparse_embedding3_forward_layernorm.hpp" #include "ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
...@@ -18,53 +19,26 @@ ...@@ -18,53 +19,26 @@
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_sparse_embedding3_forward_layernorm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_sparse_embedding3_forward_layernorm.hpp"
// using EmbType = float; // clang-format off
// using IndexType = int64_t;
// using GammaDataType = float;
// using BetaDataType = float;
// using AccDataType = float;
// using OutType = float;
using EmbType = ck::half_t; using EmbType = ck::half_t;
using IndexType = int64_t; using IndexType = int64_t;
using GammaDataType = ck::half_t; using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using OutType = ck::half_t; using OutType = ck::half_t;
using EmbElementwiseOperation = ck::tensor_operation::element_wise::AddAdd;
// clang-format off using DeviceInstance_fp16_e256 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, EmbElementwiseOperation, 256, 1, 256, 1, 256, 1, 1, 3>;
// BlockSize, DimClusterSize, RowClusterSize, DimPerBlock, RowPerBlock, DimThreadSize, RowVectorSize using DeviceInstance_fp16_e512 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, EmbElementwiseOperation, 256, 1, 256, 1, 512, 1, 2, 3>;
using DeviceInstance_fp32_e256 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 256, 1, 1>; using DeviceInstance_fp16_e768 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, EmbElementwiseOperation, 256, 1, 256, 1, 768, 1, 1, 3>;
using DeviceInstance_fp32_e512 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 512, 1, 1>; using DeviceInstance_fp16_e1024 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, EmbElementwiseOperation, 256, 1, 256, 1, 1024, 1, 2, 3>;
using DeviceInstance_fp32_e768 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 768, 1, 1>; using DeviceInstance_fp16_e1536 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, EmbElementwiseOperation, 256, 1, 256, 1, 1536, 1, 2, 3>;
using DeviceInstance_fp32_e1024 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 1024, 1, 1>; using DeviceInstance_fp16_e2048 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, EmbElementwiseOperation, 256, 1, 256, 1, 2048, 1, 2, 3>;
using DeviceInstance_fp32_e1536 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 1536, 1, 1>; using DeviceInstance_fp16_e4096 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, EmbElementwiseOperation, 256, 1, 256, 1, 4096, 1, 8, 3>;
using DeviceInstance_fp32_e2048 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 2048, 1, 4>; using DeviceInstance_fp16_e8192 = ck::tensor_operation::device::DeviceSparseEmbeddingsForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, EmbElementwiseOperation, 256, 1, 256, 1, 8192, 1, 8, 3>;
using DeviceInstance_fp32_e4096 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 4096, 1, 4>;
using DeviceInstance_fp32_e8192 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 8192, 1, 4>;
using DeviceInstance_fp32_e16384 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 16384, 1, 4>;
using DeviceInstance_fp16_e256 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 256, 1, 1>;
using DeviceInstance_fp16_e512 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 512, 1, 2>;
using DeviceInstance_fp16_e768 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 768, 1, 1>;
using DeviceInstance_fp16_e1024 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 1024, 1, 2>;
using DeviceInstance_fp16_e1536 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 1536, 1, 2>;
using DeviceInstance_fp16_e2048 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 2048, 1, 2>;
using DeviceInstance_fp16_e4096 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 4096, 1, 8>;
using DeviceInstance_fp16_e8192 = ck::tensor_operation::device::DeviceSparseEmbedding3ForwardLayernorm<EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, 256, 1, 256, 1, 8192, 1, 8>;
template<typename emb_type, ck::index_t dim> struct emb_kernel{}; template<typename emb_type, ck::index_t dim> struct emb_kernel{};
template<> struct emb_kernel<float, 256> { using kernel_type = DeviceInstance_fp32_e256; };
template<> struct emb_kernel<float, 512> { using kernel_type = DeviceInstance_fp32_e512; };
template<> struct emb_kernel<float, 768> { using kernel_type = DeviceInstance_fp32_e768; };
template<> struct emb_kernel<float, 1024> { using kernel_type = DeviceInstance_fp32_e1024;};
template<> struct emb_kernel<float, 1536> { using kernel_type = DeviceInstance_fp32_e1536;};
template<> struct emb_kernel<float, 2048> { using kernel_type = DeviceInstance_fp32_e2048;};
template<> struct emb_kernel<float, 4096> { using kernel_type = DeviceInstance_fp32_e4096;};
template<> struct emb_kernel<float, 8192> { using kernel_type = DeviceInstance_fp32_e8192;};
template<> struct emb_kernel<float, 16384>{ using kernel_type = DeviceInstance_fp32_e16384;};
template<> struct emb_kernel<ck::half_t, 256> { using kernel_type = DeviceInstance_fp16_e256; }; template<> struct emb_kernel<ck::half_t, 256> { using kernel_type = DeviceInstance_fp16_e256; };
template<> struct emb_kernel<ck::half_t, 512> { using kernel_type = DeviceInstance_fp16_e512; }; template<> struct emb_kernel<ck::half_t, 512> { using kernel_type = DeviceInstance_fp16_e512; };
template<> struct emb_kernel<ck::half_t, 768> { using kernel_type = DeviceInstance_fp16_e768; }; template<> struct emb_kernel<ck::half_t, 768> { using kernel_type = DeviceInstance_fp16_e768; };
...@@ -86,12 +60,10 @@ int main() ...@@ -86,12 +60,10 @@ int main()
constexpr auto index_length = 2048; constexpr auto index_length = 2048;
constexpr AccDataType epsilon = 1e-4; constexpr AccDataType epsilon = 1e-4;
auto f_host_tensor_desc_1d = [](std::size_t len_) { auto f_host_tensor_desc_1d = [](std::size_t len_) { return HostTensorDescriptor({len_}); };
return HostTensorDescriptor(std::vector<std::size_t>({len_}));
};
auto f_host_tensor_desc_2d = [](std::size_t rows_, std::size_t cols_) { auto f_host_tensor_desc_2d = [](std::size_t rows_, std::size_t cols_) {
return HostTensorDescriptor(std::vector<std::size_t>({rows_, cols_})); return HostTensorDescriptor({rows_, cols_});
}; };
using ReferenceInstance = using ReferenceInstance =
...@@ -154,19 +126,20 @@ int main() ...@@ -154,19 +126,20 @@ int main()
beta_dev.ToDevice(beta.mData.data()); beta_dev.ToDevice(beta.mData.data());
auto device_instance = typename emb_kernel<EmbType, current_dim>::kernel_type{}; auto device_instance = typename emb_kernel<EmbType, current_dim>::kernel_type{};
auto argument_ptr = device_instance.MakeArgumentPointer(out_dev.GetDeviceBuffer(), auto argument_ptr = device_instance.MakeArgumentPointer(
emb_a_dev.GetDeviceBuffer(), out_dev.GetDeviceBuffer(),
emb_b_dev.GetDeviceBuffer(), {ck::type_convert<EmbType*>(emb_a_dev.GetDeviceBuffer()),
emb_c_dev.GetDeviceBuffer(), ck::type_convert<EmbType*>(emb_b_dev.GetDeviceBuffer()),
index_a_dev.GetDeviceBuffer(), ck::type_convert<EmbType*>(emb_c_dev.GetDeviceBuffer())},
index_b_dev.GetDeviceBuffer(), {ck::type_convert<IndexType*>(index_a_dev.GetDeviceBuffer()),
index_c_dev.GetDeviceBuffer(), ck::type_convert<IndexType*>(index_b_dev.GetDeviceBuffer()),
gamma_dev.GetDeviceBuffer(), ck::type_convert<IndexType*>(index_c_dev.GetDeviceBuffer())},
beta_dev.GetDeviceBuffer(), gamma_dev.GetDeviceBuffer(),
num_rows, beta_dev.GetDeviceBuffer(),
current_dim, current_dim,
index_length, index_length,
epsilon); epsilon,
EmbElementwiseOperation{});
std::cout << "Dim:" << current_dim << ", kernel:" << device_instance.GetTypeString() std::cout << "Dim:" << current_dim << ", kernel:" << device_instance.GetTypeString()
<< std::endl << std::endl
<< std::flush; << std::flush;
...@@ -203,8 +176,7 @@ int main() ...@@ -203,8 +176,7 @@ int main()
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
out_dev.FromDevice(out_from_dev.mData.data()); out_dev.FromDevice(out_from_dev.mData.data());
pass &= ck::utils::check_err( pass &= ck::utils::check_err(out_from_dev, out, "Error: Incorrect results", 1e-3, 1e-3);
out_from_dev.mData, out.mData, "Error: Incorrect results", 1e-3, 1e-3);
} }
double total_read = current_dim * index_length * 3 * sizeof(EmbType) + double total_read = current_dim * index_length * 3 * sizeof(EmbType) +
......
...@@ -12,13 +12,14 @@ Computes C_m_o = Relu(A0[m, k] * B0[n, k] + D00[m, n] + D01[mn]) * B1[n, o] + D1 ...@@ -12,13 +12,14 @@ Computes C_m_o = Relu(A0[m, k] * B0[n, k] + D00[m, n] + D01[mn]) * B1[n, o] + D1
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
...@@ -314,15 +315,15 @@ int main(int argc, char* argv[]) ...@@ -314,15 +315,15 @@ int main(int argc, char* argv[])
std::size_t stride, std::size_t stride,
std::size_t batch_stride, std::size_t batch_stride,
auto layout) { auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), Row>::value) if(std::is_same<decltype(layout), Row>::value)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz});
std::vector<std::size_t>({batch_stride, stride, 1}));
} }
else else
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride});
std::vector<std::size_t>({batch_stride, 1, stride}));
} }
}; };
...@@ -511,8 +512,7 @@ int main(int argc, char* argv[]) ...@@ -511,8 +512,7 @@ int main(int argc, char* argv[])
cde1_element_op(e1_g_m_o_host_result(idx), c1_g_m_o(idx), d1_g_m_o(idx)); cde1_element_op(e1_g_m_o_host_result(idx), c1_g_m_o(idx), d1_g_m_o(idx));
}); });
return ck::utils::check_err(e1_g_m_o_device_result.mData, e1_g_m_o_host_result.mData) ? 0 return ck::utils::check_err(e1_g_m_o_device_result, e1_g_m_o_host_result) ? 0 : 1;
: 1;
} }
return 0; return 0;
......
add_example_executable(example_grouped_conv_bwd_data_bias_relu_fp16 grouped_conv_bwd_data_bias_relu_fp16.cpp)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment