"torchvision/vscode:/vscode.git/clone" did not exist on "de31e4b8bf9b4a7e0668d19059a5ac4760dceee1"
Unverified Commit 7bcaf2a7 authored by Adam Osewski's avatar Adam Osewski Committed by GitHub
Browse files

Merge branch 'develop' into wavelet_model

parents e59daa22 0345963e
add_example_executable(example_gemm_xdl_relu_quantization_int8 gemm_xdl_relu_quantization_int8.cpp)
\ No newline at end of file
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp" #include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward_nhwc_c.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp"
static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'}, static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'},
...@@ -106,7 +106,7 @@ class BatchNormBwdArg ...@@ -106,7 +106,7 @@ class BatchNormBwdArg
using namespace ck; using namespace ck;
template <typename InOutDataType, typename AccDataType, bool UseMultiblockInK> template <typename XDataType, typename AccDataType, bool UseMultiblockInK>
bool bnorm_bwd_nhwc_test(bool do_verification, bool bnorm_bwd_nhwc_test(bool do_verification,
int init_method, int init_method,
bool time_kernel, bool time_kernel,
...@@ -118,13 +118,15 @@ bool bnorm_bwd_nhwc_test(bool do_verification, ...@@ -118,13 +118,15 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
constexpr index_t Rank = 4; constexpr index_t Rank = 4;
constexpr index_t NumReduceDim = 3; constexpr index_t NumReduceDim = 3;
using ScaleDataType = XDataType;
const std::vector<size_t> scaleBiasMeanVarLengths = {inOutLengths[3]}; const std::vector<size_t> scaleBiasMeanVarLengths = {inOutLengths[3]};
// input data of the batchnorm backward algorithm // input data of the batchnorm backward algorithm
Tensor<InOutDataType> x(inOutLengths); Tensor<XDataType> x(inOutLengths);
Tensor<InOutDataType> dy(inOutLengths); Tensor<AccDataType> dy(inOutLengths);
Tensor<AccDataType> bnScale(scaleBiasMeanVarLengths); Tensor<ScaleDataType> bnScale(scaleBiasMeanVarLengths);
Tensor<AccDataType> savedMean(scaleBiasMeanVarLengths); Tensor<AccDataType> savedMean(scaleBiasMeanVarLengths);
Tensor<AccDataType> savedInvVar(scaleBiasMeanVarLengths); Tensor<AccDataType> savedInvVar(scaleBiasMeanVarLengths);
...@@ -132,8 +134,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification, ...@@ -132,8 +134,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
Tensor<AccDataType> savedVariance(scaleBiasMeanVarLengths); Tensor<AccDataType> savedVariance(scaleBiasMeanVarLengths);
// output data of the batchnorm backward algorithm // output data of the batchnorm backward algorithm
Tensor<InOutDataType> dx_ref(inOutLengths); Tensor<AccDataType> dx_ref(inOutLengths);
Tensor<InOutDataType> dx(inOutLengths); Tensor<AccDataType> dx(inOutLengths);
Tensor<AccDataType> dscale(scaleBiasMeanVarLengths); Tensor<AccDataType> dscale(scaleBiasMeanVarLengths);
Tensor<AccDataType> dbias(scaleBiasMeanVarLengths); Tensor<AccDataType> dbias(scaleBiasMeanVarLengths);
...@@ -153,7 +155,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification, ...@@ -153,7 +155,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
const float noise_stddev = 0.0001f; const float noise_stddev = 0.0001f;
// input data in normal distribution // input data in normal distribution
x.GenerateTensorValue(GeneratorTensor_4<InOutDataType>{x_mean, x_stddev}, num_thread); x.GenerateTensorValue(GeneratorTensor_4<XDataType>{x_mean, x_stddev}, num_thread);
// initialize the savedMean to be values with tiny variation to the mean of the x values // initialize the savedMean to be values with tiny variation to the mean of the x values
savedMean.GenerateTensorValue(GeneratorTensor_4<AccDataType>{x_mean, noise_stddev}, savedMean.GenerateTensorValue(GeneratorTensor_4<AccDataType>{x_mean, noise_stddev},
...@@ -182,7 +184,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification, ...@@ -182,7 +184,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
const float x_stddev = 1.0f; const float x_stddev = 1.0f;
// input data in normal distribution // input data in normal distribution
x.GenerateTensorValue(GeneratorTensor_4<InOutDataType>{x_mean, x_stddev}, num_thread); x.GenerateTensorValue(GeneratorTensor_4<XDataType>{x_mean, x_stddev}, num_thread);
}; };
if(do_verification) if(do_verification)
...@@ -190,34 +192,34 @@ bool bnorm_bwd_nhwc_test(bool do_verification, ...@@ -190,34 +192,34 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
switch(init_method) switch(init_method)
{ {
case 0: case 0:
dy.GenerateTensorValue(GeneratorTensor_0<InOutDataType>{}, num_thread); dy.GenerateTensorValue(GeneratorTensor_0<AccDataType>{}, num_thread);
bnScale.GenerateTensorValue(GeneratorTensor_0<InOutDataType>{}, num_thread); bnScale.GenerateTensorValue(GeneratorTensor_0<ScaleDataType>{}, num_thread);
break; break;
case 1: case 1:
dy.GenerateTensorValue(GeneratorTensor_1<InOutDataType>{1}, num_thread); dy.GenerateTensorValue(GeneratorTensor_1<AccDataType>{1}, num_thread);
bnScale.GenerateTensorValue(GeneratorTensor_1<InOutDataType>{1}, num_thread); bnScale.GenerateTensorValue(GeneratorTensor_1<ScaleDataType>{1}, num_thread);
break; break;
case 2: case 2:
dy.GenerateTensorValue(GeneratorTensor_2<InOutDataType>{-5, 5}, num_thread); dy.GenerateTensorValue(GeneratorTensor_2<AccDataType>{-2, 2}, num_thread);
bnScale.GenerateTensorValue(GeneratorTensor_2<InOutDataType>{-5, 5}, num_thread); bnScale.GenerateTensorValue(GeneratorTensor_2<ScaleDataType>{-5, 5}, num_thread);
break; break;
default: default:
dy.GenerateTensorValue(GeneratorTensor_3<InOutDataType>{-0.2f, 0.2f}, num_thread); dy.GenerateTensorValue(GeneratorTensor_3<AccDataType>{-0.2f, 0.2f}, num_thread);
bnScale.GenerateTensorValue(GeneratorTensor_3<InOutDataType>{-0.5f, 0.5f}, num_thread); bnScale.GenerateTensorValue(GeneratorTensor_3<ScaleDataType>{-0.5f, 0.5f}, num_thread);
} }
}; };
// input data of the batchnorm backward algorithm // input data of the batchnorm backward algorithm
DeviceMem x_dev(sizeof(InOutDataType) * x.mDesc.GetElementSpaceSize()); DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem dy_dev(sizeof(InOutDataType) * dy.mDesc.GetElementSpaceSize()); DeviceMem dy_dev(sizeof(AccDataType) * dy.mDesc.GetElementSpaceSize());
DeviceMem bnScale_dev(sizeof(AccDataType) * bnScale.mDesc.GetElementSpaceSize()); DeviceMem bnScale_dev(sizeof(ScaleDataType) * bnScale.mDesc.GetElementSpaceSize());
DeviceMem savedMean_dev(sizeof(AccDataType) * savedMean.mDesc.GetElementSpaceSize()); DeviceMem savedMean_dev(sizeof(AccDataType) * savedMean.mDesc.GetElementSpaceSize());
DeviceMem savedInvVar_dev(sizeof(AccDataType) * savedInvVar.mDesc.GetElementSpaceSize()); DeviceMem savedInvVar_dev(sizeof(AccDataType) * savedInvVar.mDesc.GetElementSpaceSize());
// output data of the batchnorm backward algorithm // output data of the batchnorm backward algorithm
DeviceMem dx_dev(sizeof(InOutDataType) * dx.mDesc.GetElementSpaceSize()); DeviceMem dx_dev(sizeof(AccDataType) * dx.mDesc.GetElementSpaceSize());
DeviceMem dscale_dev(sizeof(AccDataType) * dscale.mDesc.GetElementSpaceSize()); DeviceMem dscale_dev(sizeof(AccDataType) * dscale.mDesc.GetElementSpaceSize());
DeviceMem dbias_dev(sizeof(AccDataType) * dbias.mDesc.GetElementSpaceSize()); DeviceMem dbias_dev(sizeof(AccDataType) * dbias.mDesc.GetElementSpaceSize());
...@@ -249,12 +251,12 @@ bool bnorm_bwd_nhwc_test(bool do_verification, ...@@ -249,12 +251,12 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
using PassThroughOp = ck::tensor_operation::element_wise::PassThrough; using PassThroughOp = ck::tensor_operation::element_wise::PassThrough;
using DeviceBatchNormBwdInstance = using DeviceBatchNormBwdInstance =
ck::tensor_operation::device::DeviceBatchNormBwdImpl<InOutDataType, ck::tensor_operation::device::DeviceBatchNormBwdImpl<XDataType,
InOutDataType, AccDataType,
InOutDataType, AccDataType,
AccDataType, AccDataType,
AccDataType, // ScaleDataType ScaleDataType, // ScaleDataType
AccDataType, // BiasDataType AccDataType, // DscaleDbiasDataType
AccDataType, // MeanVarDataType AccDataType, // MeanVarDataType
PassThroughOp, PassThroughOp,
Rank, Rank,
...@@ -269,8 +271,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification, ...@@ -269,8 +271,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
1, // XSrcVectorSize 1, // XSrcVectorSize
1, // DySrcVectorSize 1, // DySrcVectorSize
1, // DxDstVectorSize 1, // DxDstVectorSize
1, // ScaleSrcDstVectorSize 1, // ScaleSrcVectorSize
1, // BiasDstVectorSize 1, // DscaleDbiasDstVectorSize
1>; // MeanVarSrcVectorSize 1>; // MeanVarSrcVectorSize
auto batchnorm_bwd = DeviceBatchNormBwdInstance{}; auto batchnorm_bwd = DeviceBatchNormBwdInstance{};
...@@ -324,7 +326,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification, ...@@ -324,7 +326,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
// inputing of x, dy, scale, outputing of dx, dscale, dbias // inputing of x, dy, scale, outputing of dx, dscale, dbias
num_bytes += num_bytes +=
total_length * sizeof(InOutDataType) * 3 + invariant_length * sizeof(AccDataType) * 3; total_length * sizeof(XDataType) * 3 + invariant_length * sizeof(AccDataType) * 3;
// outputing of mean, inv-variance // outputing of mean, inv-variance
num_bytes += haveSavedMeanInvVar ? invariant_length * sizeof(AccDataType) * 2 : 0; num_bytes += haveSavedMeanInvVar ? invariant_length * sizeof(AccDataType) * 2 : 0;
...@@ -341,14 +343,16 @@ bool bnorm_bwd_nhwc_test(bool do_verification, ...@@ -341,14 +343,16 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
if(do_verification) if(do_verification)
{ {
using ReferenceBatchNormBwdInstance = using ReferenceBatchNormBwdInstance =
ck::tensor_operation::host::ReferenceBatchNormBwd_Input_N_H_W_C_Output_C<InOutDataType, ck::tensor_operation::host::ReferenceBatchNormBwd<XDataType,
InOutDataType,
InOutDataType,
AccDataType, AccDataType,
AccDataType, AccDataType,
AccDataType, AccDataType,
ScaleDataType, // ScaleDataType
AccDataType, AccDataType,
PassThroughOp>; AccDataType,
PassThroughOp,
Rank,
NumReduceDim>;
auto batchNormBwd_ref = ReferenceBatchNormBwdInstance{}; auto batchNormBwd_ref = ReferenceBatchNormBwdInstance{};
...@@ -390,8 +394,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification, ...@@ -390,8 +394,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
dbias_dev.FromDevice(dbias.data()); dbias_dev.FromDevice(dbias.data());
// clang-format off // clang-format off
pass = pass && ck::utils::check_err(dbias.mData, dbias_ref.mData, "dBias result:", 1e-5, 1e-5); pass = pass && ck::utils::check_err(dbias.mData, dbias_ref.mData, "dBias result:", 2e-4, 2e-4);
pass = pass && ck::utils::check_err(dscale.mData, dscale_ref.mData, "dScale result:", 1e-5, 2e-4); pass = pass && ck::utils::check_err(dscale.mData, dscale_ref.mData, "dScale result:", 2e-4, 2e-4);
pass = pass && ck::utils::check_err(dx.mData, dx_ref.mData, "dx result:"); pass = pass && ck::utils::check_err(dx.mData, dx_ref.mData, "dx result:");
// clang-format on // clang-format on
}; };
......
add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp)
add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp) add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp)
add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp) add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
using InDataType = int8_t;
using WeiDataType = int8_t;
using BiasDataType = int32_t;
using RequantScaleDataType = float;
using AccDataType = int32_t;
using CShuffleDataType = int32_t;
using OutDataType = int8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using ActivationOp = ck::tensor_operation::element_wise::Relu;
using OutElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp<ActivationOp>;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename BiasLayout,
typename RequantScaleLayout,
typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<
NDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<BiasLayout, RequantScaleLayout>,
OutLayout,
InDataType,
WeiDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<BiasDataType, RequantScaleDataType>,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, //
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
64, // KPerBlock
16, // AK1
16, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
16, // ABlockTransferSrcScalarPerVector
16, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
16, // BBlockTransferSrcScalarPerVector
16, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1,
1,
S<1, 64, 1, 4>,
8>;
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementOp,
typename WeiElementOp,
typename OutElementOp,
typename DeviceConvNDFwdInstance>
bool run_grouped_conv_fwd(bool do_verification,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
const HostTensorDescriptor& in_g_n_c_wis_desc,
const HostTensorDescriptor& wei_g_k_c_xs_desc,
const HostTensorDescriptor& bias_g_k_desc,
const HostTensorDescriptor& requant_scale_g_k_desc,
const HostTensorDescriptor& out_g_n_k_wos_desc,
const InElementOp& in_element_op,
const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op)
{
Tensor<InDataType> in(in_g_n_c_wis_desc);
Tensor<WeiDataType> wei(wei_g_k_c_xs_desc);
Tensor<BiasDataType> bias(bias_g_k_desc);
Tensor<RequantScaleDataType> requant_scale(requant_scale_g_k_desc);
Tensor<OutDataType> out_host(out_g_n_k_wos_desc);
Tensor<OutDataType> out_device(out_g_n_k_wos_desc);
std::cout << "in: " << in.mDesc << std::endl;
std::cout << "wei: " << wei.mDesc << std::endl;
std::cout << "bias: " << bias.mDesc << std::endl;
std::cout << "requant_scale: " << requant_scale.mDesc << std::endl;
std::cout << "out: " << out_host.mDesc << std::endl;
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-128, 127});
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-128, 127});
bias.GenerateTensorValue(GeneratorTensor_2<BiasDataType>{-128, 127});
requant_scale.GenerateTensorValue(GeneratorTensor_2<RequantScaleDataType>{0, 1});
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize());
DeviceMem bias_device_buf(sizeof(BiasDataType) * bias.mDesc.GetElementSpaceSize());
DeviceMem requant_scale_device_buf(sizeof(RequantScaleDataType) *
requant_scale.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(in.mData.data());
wei_device_buf.ToDevice(wei.mData.data());
bias_device_buf.ToDevice(bias.mData.data());
requant_scale_device_buf.ToDevice(requant_scale.mData.data());
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{};
std::array<ck::index_t, NDimSpatial + 3> d0_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> d0_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial + 3> d1_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> d1_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{};
std::array<ck::index_t, NDimSpatial> input_right_pads{};
auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths);
copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides);
copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths);
copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
copy(bias_g_k_desc.GetLengths(), d0_g_n_k_wos_lengths);
copy(bias_g_k_desc.GetStrides(), d0_g_n_k_wos_strides);
copy(requant_scale_g_k_desc.GetLengths(), d1_g_n_k_wos_lengths);
copy(requant_scale_g_k_desc.GetStrides(), d1_g_n_k_wos_strides);
copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths);
copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads);
copy(conv_param.input_right_pads_, input_right_pads);
// do Conv
auto conv = DeviceConvNDFwdInstance{};
auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(
in_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(),
{bias_device_buf.GetDeviceBuffer(), requant_scale_device_buf.GetDeviceBuffer()},
out_device_buf.GetDeviceBuffer(),
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
{d0_g_n_k_wos_lengths, d1_g_n_k_wos_lengths},
{d0_g_n_k_wos_strides, d1_g_n_k_wos_strides},
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
if(!conv.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem");
}
float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = conv_param.GetFlops();
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< conv.GetTypeString() << std::endl;
bool pass = true;
if(do_verification)
{
Tensor<CShuffleDataType> c_host(out_g_n_k_wos_desc);
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InDataType,
WeiDataType,
CShuffleDataType,
InElementOp,
WeiElementOp,
PassThrough>();
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in,
wei,
c_host,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_,
in_element_op,
wei_element_op,
PassThrough{});
ref_invoker.Run(ref_argument);
// TODO: implement elementwise operation for host
out_host.ForEach([&](auto&, auto idx) {
out_element_op(out_host(idx), c_host(idx), bias(idx), requant_scale(idx));
});
out_device_buf.FromDevice(out_device.mData.data());
pass &=
ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
}
return (pass ? 0 : 1);
}
int main()
{
bool do_verification = true;
bool time_kernel = true;
const ck::index_t ndim_spatial = 2;
ck::utils::conv::ConvParam conv_param{
ndim_spatial, // n_dim
1, // group
4, // batch
64, // output channels
32, // input chanels
{3, 3}, // weight HW
{71, 71}, // x HW
{2, 2}, // strides
{1, 1}, // dilations
{1, 1}, // left_pads
{1, 1} // right_pads
};
const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{};
const auto out_element_op = OutElementOp{ActivationOp{}};
using InLayout = ck::tensor_layout::convolution::GNHWC;
using WeiLayout = ck::tensor_layout::convolution::GKYXC;
using BiasLayout = ck::tensor_layout::convolution::G_K;
using RequantScaleLayout = ck::tensor_layout::convolution::G_K;
using OutLayout = ck::tensor_layout::convolution::GNHWK;
const auto in_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
const auto wei_g_k_c_xs_desc =
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
// TODO - make_bias_host_tensor_descriptor_g_n_k_wos_packed()
const auto bias_g_k_desc = HostTensorDescriptor({conv_param.G_,
conv_param.N_,
conv_param.K_,
conv_param.output_spatial_lengths_[0],
conv_param.output_spatial_lengths_[1]},
{
conv_param.K_, // g
0, // n
1, // k
0, // ho
0 // wo
});
const auto requant_scale_g_k_desc = bias_g_k_desc;
const auto out_g_n_k_wos_desc =
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
std::cout << out_g_n_k_wos_desc << std::endl;
using deviceOp = DeviceGroupedConvNDFwdInstance<ndim_spatial,
InLayout,
WeiLayout,
BiasLayout,
RequantScaleLayout,
OutLayout>;
return run_grouped_conv_fwd<ndim_spatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
deviceOp>(do_verification,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
bias_g_k_desc,
requant_scale_g_k_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op,
out_element_op);
}
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
...@@ -163,17 +164,16 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -163,17 +164,16 @@ bool run_grouped_conv_fwd(bool do_verification,
// do Conv // do Conv
auto conv = DeviceConvNDFwdInstance{}; auto conv = DeviceConvNDFwdInstance{};
auto invoker = conv.MakeInvoker(); auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument( auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(),
in_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(), wei_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{bias_device_buf.GetDeviceBuffer()}, {bias_device_buf.GetDeviceBuffer()},
out_device_buf.GetDeviceBuffer(), out_device_buf.GetDeviceBuffer(),
a_g_n_c_wis_lengths, a_g_n_c_wis_lengths,
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
std::array<std::array<ck::index_t, NDimSpatial + 3>, 1>{{d0_g_n_k_wos_lengths}}, {d0_g_n_k_wos_lengths},
std::array<std::array<ck::index_t, NDimSpatial + 3>, 1>{{d0_g_n_k_wos_strides}}, {d0_g_n_k_wos_strides},
e_g_n_k_wos_lengths, e_g_n_k_wos_lengths,
e_g_n_k_wos_strides, e_g_n_k_wos_strides,
conv_filter_strides, conv_filter_strides,
...@@ -235,8 +235,8 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -235,8 +235,8 @@ bool run_grouped_conv_fwd(bool do_verification,
out_device_buf.FromDevice(out_device.mData.data()); out_device_buf.FromDevice(out_device.mData.data());
pass &= ck::utils::check_err( pass &=
out_device.mData, out_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f); ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
} }
return (pass ? 0 : 1); return (pass ? 0 : 1);
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
...@@ -150,14 +151,14 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -150,14 +151,14 @@ bool run_grouped_conv_fwd(bool do_verification,
auto invoker = conv.MakeInvoker(); auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(), wei_device_buf.GetDeviceBuffer(),
std::array<const void*, 0>{}, {},
out_device_buf.GetDeviceBuffer(), out_device_buf.GetDeviceBuffer(),
a_g_n_c_wis_lengths, a_g_n_c_wis_lengths,
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
std::array<std::array<ck::index_t, NDimSpatial + 3>, 0>{{}}, {},
std::array<std::array<ck::index_t, NDimSpatial + 3>, 0>{{}}, {},
e_g_n_k_wos_lengths, e_g_n_k_wos_lengths,
e_g_n_k_wos_strides, e_g_n_k_wos_strides,
conv_filter_strides, conv_filter_strides,
...@@ -213,8 +214,8 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -213,8 +214,8 @@ bool run_grouped_conv_fwd(bool do_verification,
out_device_buf.FromDevice(out_device.mData.data()); out_device_buf.FromDevice(out_device.mData.data());
pass &= ck::utils::check_err( pass &=
out_device.mData, out_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f); ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
} }
return (pass ? 0 : 1); return (pass ? 0 : 1);
......
add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp) add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp)
add_example_executable(example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp)
#include <iostream>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise_2d.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
using F16 = ck::half_t;
using ADataType = F16;
using BDataType = F16;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceElementwisePermuteInstance =
ck::tensor_operation::device::DeviceElementwise<ck::Tuple<ADataType>,
ck::Tuple<BDataType>,
PassThrough,
3, // NumDim_M
1, // NumDim_N
8,
8,
ck::Sequence<8>,
ck::Sequence<8>>;
template <typename HostTensorA, typename HostTensorB, typename Functor>
void host_elementwise4D(HostTensorB& B_nhwc,
const HostTensorA& A_nchw,
const std::vector<std::size_t>& shape_nchw,
Functor functor)
{
for(std::size_t n = 0; n < shape_nchw[0]; ++n)
for(std::size_t c = 0; c < shape_nchw[1]; ++c)
for(std::size_t h = 0; h < shape_nchw[2]; ++h)
for(std::size_t w = 0; w < shape_nchw[3]; ++w)
{
auto a_val = A_nchw(n, c, h, w);
functor(B_nhwc(n, h, w, c), a_val);
}
}
int main()
{
bool do_verification = true;
bool time_kernel = true;
const int N = 120;
const int C = 128;
const int H = 32;
const int W = 1024;
/**const int N = 120;
const int H = 32;
const int W = 64;
const int C = 128;**/
std::vector<std::size_t> nchw = {N, C, H, W};
std::vector<std::size_t> nhwc = {N, H, W, C};
Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc);
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a.mData.data());
// LogRangeAsType<float>(std::cout << "Tensor a : ", a.mData, ",") << std::endl;
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 4> ab_lengths{N, H, W, C};
std::array<ck::index_t, 4> a_strides = {C * H * W, W, 1, H * W};
std::array<ck::index_t, 4> b_strides = {H * W * C, W * C, C, 1};
auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{});
if(!broadcastPermute.IsSupportedArgument(argument.get()))
{
throw std::runtime_error(
"The runtime parameters seems not supported by the device instance, exiting!");
};
std::cout << "A (nchw): " << a.mDesc << std::endl;
std::cout << "B (nhwc): " << b.mDesc << std::endl;
auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer();
float ave_time =
broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3];
std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) +
sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
bool pass = true;
if(do_verification)
{
b_device_buf.FromDevice(b.mData.data());
// LogRangeAsType<float>(std::cout << "Tensor b : ", b.mData, ",") << std::endl;
Tensor<BDataType> host_b(nhwc);
host_elementwise4D<Tensor<ADataType>, Tensor<BDataType>, PassThrough>(
host_b, a, nchw, PassThrough{});
// LogRangeAsType<float>(std::cout << "Host b : ", host_b.mData, ",") << std::endl;
pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
}
return pass ? 0 : 1;
}
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
// check GPU target // check GPU target
#ifdef __HIP_DEVICE_COMPILE__ #ifdef __HIP_DEVICE_COMPILE__
#if !(defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ #if !(defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx1030__)) defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__))
#error Not supported target #error Not supported target
#endif #endif
#endif #endif
...@@ -43,6 +43,8 @@ ...@@ -43,6 +43,8 @@
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx1030__) // for GPU code #elif defined(__gfx1030__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx1100__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x10020000
#endif #endif
// FMA instruction // FMA instruction
...@@ -67,6 +69,13 @@ ...@@ -67,6 +69,13 @@
#define CK_USE_AMD_MFMA_BF16_1K_OP #define CK_USE_AMD_MFMA_BF16_1K_OP
#endif #endif
// WMMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_WMMA
#elif defined(__gfx1100__) // for GPU code
#define CK_USE_AMD_WMMA
#endif
// buffer load // buffer load
#define CK_USE_AMD_BUFFER_LOAD 1 #define CK_USE_AMD_BUFFER_LOAD 1
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <cmath> #include <cmath>
#include <string> #include <string>
#include <sstream>
#include "ck/stream_config.hpp" #include "ck/stream_config.hpp"
...@@ -46,6 +47,17 @@ struct BaseOperator ...@@ -46,6 +47,17 @@ struct BaseOperator
virtual bool IsSupportedArgument(const BaseArgument*) { return false; } virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
virtual std::string GetTypeString() const { return ""; } virtual std::string GetTypeString() const { return ""; }
virtual std::string GetTypeIdName() const { return typeid(*this).name(); }
virtual std::string GetTypeIdHashCode() const
{
std::ostringstream oss;
oss << std::hex << typeid(*this).hash_code();
return oss.str();
};
virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; } virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const
......
...@@ -13,7 +13,16 @@ namespace ck { ...@@ -13,7 +13,16 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <index_t Rank, index_t NumBatchNormReduceDim, typename DyElementwiseOp> template <typename XDataType,
typename DxDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
struct DeviceBatchNormBwd : public BaseOperator struct DeviceBatchNormBwd : public BaseOperator
{ {
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim; static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
...@@ -26,7 +35,7 @@ struct DeviceBatchNormBwd : public BaseOperator ...@@ -26,7 +35,7 @@ struct DeviceBatchNormBwd : public BaseOperator
const std::array<int, NumBatchNormReduceDim> reduceDims, const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths, const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides, const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
const std::array<ck::index_t, NumInvariantDim> bnBiasStrides, const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides, const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
const void* p_x, const void* p_x,
const void* p_dy, const void* p_dy,
...@@ -42,9 +51,26 @@ struct DeviceBatchNormBwd : public BaseOperator ...@@ -42,9 +51,26 @@ struct DeviceBatchNormBwd : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <index_t Rank, index_t NumBatchNormReduceDim, typename DyElementwiseOp> template <typename XDataType,
using DeviceBatchNormBwdPtr = typename DxDataType,
std::unique_ptr<DeviceBatchNormBwd<Rank, NumBatchNormReduceDim, DyElementwiseOp>>; typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
using DeviceBatchNormBwdPtr = std::unique_ptr<DeviceBatchNormBwd<XDataType,
DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
Rank,
NumBatchNormReduceDim>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise_base.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataTypeTuple,
typename OutDataTypeTuple,
typename ElementwiseOperation,
index_t NumDim_m,
index_t NumDim_n,
index_t MPerThread,
index_t NPerThread,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
OutDataTypeTuple,
ElementwiseOperation,
NumDim_m + NumDim_n>
{
static constexpr index_t NumDim = NumDim_m + NumDim_n;
static constexpr int NumInput = InDataTypeTuple::Size();
static constexpr int NumOutput = OutDataTypeTuple::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
NumOutput == OutScalarPerVectorSeq::Size(),
"Tuple size is inconsistent with the number of in/out!");
static auto GenerateInDataTypePointerTuple()
{
return generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
return static_cast<const DataType*>(nullptr);
},
Number<NumInput>{});
};
static auto GenerateOutDataTypePointerTuple()
{
return generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
return static_cast<DataType*>(nullptr);
},
Number<NumOutput>{});
};
using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple());
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
template <typename Desc_MN>
static auto PadDescriptor_MN_2d(Desc_MN desc_mn,
index_t gridSize,
index_t blockSize,
index_t num_threads_m,
index_t num_threads_n)
{
std::ignore = blockSize;
std::ignore = gridSize;
const auto m = desc_mn.GetLength(I0);
const auto n = desc_mn.GetLength(I1);
const index_t loop_step_m = num_threads_m * MPerThread;
const index_t loop_step_n = num_threads_n * NPerThread;
const auto pad_m = math::integer_least_multiple(m, loop_step_m) - m;
const auto pad_n = math::integer_least_multiple(n, loop_step_n) - n;
const auto desc_mn_pad = transform_tensor_descriptor(
desc_mn,
make_tuple(make_right_pad_transform(m, pad_m), make_right_pad_transform(n, pad_n)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return desc_mn_pad;
}
static auto MakeDescriptor_MN(const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& stride,
index_t gridSize,
index_t blockSize,
index_t num_threads_m,
index_t num_threads_n)
{
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NumDim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<NumDim>{});
// nd desc - [s0, s1, s2, ...]
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDim_m, 1>::type();
constexpr auto nDimIds =
typename arithmetic_sequence_gen<NumDim_m, NumDim_m + NumDim_n, 1>::type();
const auto mLengths = get_container_subset(tupleOfShape, mDimIds);
const auto nLengths = get_container_subset(tupleOfShape, nDimIds);
// merge nd to 2d desc - [s0 * s1 * ...]
if constexpr(NumDim > 2)
{
const auto desc_mn = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
make_tuple(mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadDescriptor_MN_2d(desc_mn, gridSize, blockSize, num_threads_m, num_threads_n);
}
else
return PadDescriptor_MN_2d(desc, gridSize, blockSize, num_threads_m, num_threads_n);
}
template <index_t TupleSize>
static auto GenerateInOutGrid2dDescTuple(Number<TupleSize>)
{
return generate_tuple(
[&](auto) {
if constexpr(NumDim > 2)
{
return MakeDescriptor_MN({1, 1}, {1, 1}, 1, 1, 1, 1);
}
else
{
return MakeDescriptor_MN({1}, {1}, 1, 1, 1, 1);
};
},
Number<TupleSize>{});
};
using OutGrid2dDescTuple = decltype(GenerateInOutGrid2dDescTuple(Number<NumOutput>{}));
using InGrid2dDescTuple = decltype(GenerateInOutGrid2dDescTuple(Number<NumInput>{}));
using GridwiseElementwise = GridwiseElementwise_2D<InGrid2dDescTuple,
OutGrid2dDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
ElementwiseOperation,
MPerThread,
NPerThread,
InScalarPerVectorSeq,
OutScalarPerVectorSeq>;
struct Argument : public BaseArgument
{
Argument(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op)
: lengths_(lengths),
inStridesArray_(inStridesArray),
outStridesArray_(outStridesArray),
elementwise_op_(elementwise_op),
blockSize_(256),
gridSize_(120), // FIXME - Calculate the grid size by number of CU in the future
num_threads_m_((gridSize_ * blockSize_) / 16),
num_threads_n_(16)
{
static_assert(NumDim_m > 0, "");
static_assert(NumDim_n > 0, "");
in_dev_buffers_ = generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
return static_cast<const DataType*>(in_dev_buffers[I.value]);
},
Number<NumInput>{});
out_dev_buffers_ = generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
return static_cast<DataType*>(out_dev_buffers[I.value]);
},
Number<NumOutput>{});
in_grid_2d_desc_tuple_ = generate_tuple(
[&](auto I) {
return MakeDescriptor_MN(lengths,
inStridesArray[I.value],
gridSize_,
blockSize_,
num_threads_m_,
num_threads_n_);
},
Number<NumInput>{});
out_grid_2d_desc_tuple_ = generate_tuple(
[&](auto I) {
return MakeDescriptor_MN(lengths,
outStridesArray[I.value],
gridSize_,
blockSize_,
num_threads_m_,
num_threads_n_);
},
Number<NumOutput>{});
}
InDataTypePointerTuple in_dev_buffers_;
OutDataTypePointerTuple out_dev_buffers_;
InGrid2dDescTuple in_grid_2d_desc_tuple_;
OutGrid2dDescTuple out_grid_2d_desc_tuple_;
std::array<index_t, NumDim> lengths_;
std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_;
std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_;
ElementwiseOperation elementwise_op_;
index_t blockSize_;
index_t gridSize_;
index_t num_threads_m_;
index_t num_threads_n_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto kernel = kernel_elementwise_2d<GridwiseElementwise,
InGrid2dDescTuple,
OutGrid2dDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
ElementwiseOperation>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize_),
dim3(arg.blockSize_),
0,
arg.in_grid_2d_desc_tuple_,
arg.out_grid_2d_desc_tuple_,
arg.in_dev_buffers_,
arg.out_dev_buffers_,
arg.elementwise_op_,
arg.num_threads_m_,
arg.num_threads_n_);
return elapsed_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if(pArg == nullptr)
return false;
if(pArg->lengths_.back() % MPerThread != 0)
return false;
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& strides,
index_t scalarPerVector,
index_t vectorDim) {
if(strides[vectorDim] == 1 &&
(lengths[vectorDim] % scalarPerVector == 0 ||
lengths[vectorDim] % scalarPerVector == lengths[vectorDim]))
{
return true;
}
if(strides[vectorDim] != 1 && scalarPerVector == strides[vectorDim])
{
return true;
}
return false;
};
bool valid = true;
static_for<0, NumInput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid(pArg->lengths_,
pArg->inStridesArray_[I.value],
InScalarPerVectorSeq::At(I),
NumDim_m - 1))
valid = false;
});
static_for<0, NumOutput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid(pArg->lengths_,
pArg->outStridesArray_[I.value],
OutScalarPerVectorSeq::At(I),
NumDim - 1))
valid = false;
});
return valid;
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op) override
{
return std::make_unique<Argument>(lengths,
inStridesArray,
outStridesArray,
in_dev_buffers,
out_dev_buffers,
elementwise_op);
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
}; // namespace device
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -373,7 +373,8 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -373,7 +373,8 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
N01_{N01}, N01_{N01},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op},
kraw_{K}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_, b_grid_desc_k0_n_k1_,
...@@ -401,6 +402,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -401,6 +402,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
index_t kraw_;
}; };
// Invoker // Invoker
...@@ -410,6 +412,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -410,6 +412,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
...@@ -422,6 +425,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -422,6 +425,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
...@@ -528,6 +532,11 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -528,6 +532,11 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(arg.kraw_ % K1 != 0)
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
......
...@@ -27,7 +27,7 @@ template <typename XDataType, ...@@ -27,7 +27,7 @@ template <typename XDataType,
typename DyDataType, typename DyDataType,
typename AccDataType, typename AccDataType,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename DscaleDbiasDataType,
typename MeanVarDataType, typename MeanVarDataType,
typename DyElementwiseOp, typename DyElementwiseOp,
index_t Rank, index_t Rank,
...@@ -42,11 +42,19 @@ template <typename XDataType, ...@@ -42,11 +42,19 @@ template <typename XDataType,
index_t XSrcVectorSize, index_t XSrcVectorSize,
index_t DySrcVectorSize, index_t DySrcVectorSize,
index_t DxDstVectorSize, index_t DxDstVectorSize,
index_t ScaleSrcDstVectorSize, index_t ScaleSrcVectorSize,
index_t BiasDstVectorSize, index_t DscaleDbiasDstVectorSize,
index_t MeanVarSrcVectorSize> index_t MeanVarSrcVectorSize>
struct DeviceBatchNormBwdImpl struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<XDataType,
: public DeviceBatchNormBwd<Rank, NumBatchNormReduceDim, DyElementwiseOp> DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
Rank,
NumBatchNormReduceDim>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
...@@ -194,7 +202,7 @@ struct DeviceBatchNormBwdImpl ...@@ -194,7 +202,7 @@ struct DeviceBatchNormBwdImpl
const std::array<int, NumBatchNormReduceDim> reduceDims, const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths, const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides, const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
const std::array<ck::index_t, NumInvariantDim> bnBiasStrides, const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides, const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
const XDataType* p_x, const XDataType* p_x,
const DyDataType* p_dy, const DyDataType* p_dy,
...@@ -204,11 +212,11 @@ struct DeviceBatchNormBwdImpl ...@@ -204,11 +212,11 @@ struct DeviceBatchNormBwdImpl
const DyElementwiseOp dy_elementwise_op, const DyElementwiseOp dy_elementwise_op,
double epsilon, double epsilon,
DxDataType* p_dx, DxDataType* p_dx,
ScaleDataType* p_dscale, DscaleDbiasDataType* p_dscale,
BiasDataType* p_dbias) DscaleDbiasDataType* p_dbias)
: bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths), : bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
bnScaleStrides_(bnScaleStrides), bnScaleStrides_(bnScaleStrides),
bnBiasStrides_(bnBiasStrides), bnDscaleDbiasStrides_(bnDscaleDbiasStrides),
bnMeanVarStrides_(bnMeanVarStrides), bnMeanVarStrides_(bnMeanVarStrides),
p_x_(p_x), p_x_(p_x),
p_dy_(p_dy), p_dy_(p_dy),
...@@ -272,8 +280,8 @@ struct DeviceBatchNormBwdImpl ...@@ -272,8 +280,8 @@ struct DeviceBatchNormBwdImpl
MakeXY2dDescriptor(xyLengths_, dxStrides_, blkGroupSize, numBlockTileIteration); MakeXY2dDescriptor(xyLengths_, dxStrides_, blkGroupSize, numBlockTileIteration);
scale_grid_desc_m = scale_grid_desc_m =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides); MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides);
bias_grid_desc_m = dscale_dbias_grid_desc_m =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnBiasStrides); MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnDscaleDbiasStrides);
mean_var_grid_desc_m = mean_var_grid_desc_m =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides); MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides);
} }
...@@ -289,7 +297,7 @@ struct DeviceBatchNormBwdImpl ...@@ -289,7 +297,7 @@ struct DeviceBatchNormBwdImpl
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths_; std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides_; std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides_; std::array<index_t, Rank - NumBatchNormReduceDim> bnDscaleDbiasStrides_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides_; std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides_;
const XDataType* p_x_; const XDataType* p_x_;
...@@ -299,8 +307,8 @@ struct DeviceBatchNormBwdImpl ...@@ -299,8 +307,8 @@ struct DeviceBatchNormBwdImpl
const MeanVarDataType* p_savedInvVar_; const MeanVarDataType* p_savedInvVar_;
const DyElementwiseOp dy_elementwise_op_; const DyElementwiseOp dy_elementwise_op_;
DxDataType* p_dx_; DxDataType* p_dx_;
ScaleDataType* p_dscale_; DscaleDbiasDataType* p_dscale_;
BiasDataType* p_dbias_; DscaleDbiasDataType* p_dbias_;
long_index_t invariant_length; long_index_t invariant_length;
long_index_t reduce_length; long_index_t reduce_length;
...@@ -313,7 +321,7 @@ struct DeviceBatchNormBwdImpl ...@@ -313,7 +321,7 @@ struct DeviceBatchNormBwdImpl
XYGridDesc_M_K dy_grid_desc_m_k; XYGridDesc_M_K dy_grid_desc_m_k;
XYGridDesc_M_K dx_grid_desc_m_k; XYGridDesc_M_K dx_grid_desc_m_k;
ScaleBiasGridDesc_M scale_grid_desc_m; ScaleBiasGridDesc_M scale_grid_desc_m;
ScaleBiasGridDesc_M bias_grid_desc_m; ScaleBiasGridDesc_M dscale_dbias_grid_desc_m;
MeanVarGridDesc_M mean_var_grid_desc_m; MeanVarGridDesc_M mean_var_grid_desc_m;
void* workspace_mean; void* workspace_mean;
...@@ -337,11 +345,11 @@ struct DeviceBatchNormBwdImpl ...@@ -337,11 +345,11 @@ struct DeviceBatchNormBwdImpl
{ {
// workspace for the partial reduced result for dscale // workspace for the partial reduced result for dscale
workspace_size += workspace_size +=
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(ScaleDataType) + 64; pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType) + 64;
// workspace for the partial reduced result for dbias // workspace for the partial reduced result for dbias
workspace_size += workspace_size +=
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(BiasDataType) + 64; pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType) + 64;
if(!pArg_->haveSavedMeanInvVar_) if(!pArg_->haveSavedMeanInvVar_)
{ {
...@@ -379,7 +387,7 @@ struct DeviceBatchNormBwdImpl ...@@ -379,7 +387,7 @@ struct DeviceBatchNormBwdImpl
// setup buffer for the partial reduced result for dscale // setup buffer for the partial reduced result for dscale
pArg_->workspace_reduce_dscale = pArg_->p_workspace_; pArg_->workspace_reduce_dscale = pArg_->p_workspace_;
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(ScaleDataType); space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType);
space_sz = math::integer_least_multiple(space_sz, 64); space_sz = math::integer_least_multiple(space_sz, 64);
// setup buffer for the partial reduced result for dbias // setup buffer for the partial reduced result for dbias
...@@ -388,7 +396,7 @@ struct DeviceBatchNormBwdImpl ...@@ -388,7 +396,7 @@ struct DeviceBatchNormBwdImpl
if(UseMultiblockInK && pArg_->blkGroupSize > 1) if(UseMultiblockInK && pArg_->blkGroupSize > 1)
{ {
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(BiasDataType); space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType);
space_sz = math::integer_least_multiple(space_sz, 64); space_sz = math::integer_least_multiple(space_sz, 64);
// setup buffer for welford intermediate mean // setup buffer for welford intermediate mean
...@@ -454,7 +462,7 @@ struct DeviceBatchNormBwdImpl ...@@ -454,7 +462,7 @@ struct DeviceBatchNormBwdImpl
DyDataType, DyDataType,
AccDataType, AccDataType,
ScaleDataType, ScaleDataType,
BiasDataType, DscaleDbiasDataType,
MeanVarDataType, MeanVarDataType,
DyElementwiseOp, DyElementwiseOp,
XYGridDesc_M_K, XYGridDesc_M_K,
...@@ -477,7 +485,7 @@ struct DeviceBatchNormBwdImpl ...@@ -477,7 +485,7 @@ struct DeviceBatchNormBwdImpl
DxDataType, DxDataType,
AccDataType, AccDataType,
ScaleDataType, ScaleDataType,
BiasDataType, DscaleDbiasDataType,
MeanVarDataType, MeanVarDataType,
DyElementwiseOp, DyElementwiseOp,
XYGridDesc_M_K, XYGridDesc_M_K,
...@@ -493,8 +501,8 @@ struct DeviceBatchNormBwdImpl ...@@ -493,8 +501,8 @@ struct DeviceBatchNormBwdImpl
XSrcVectorSize, XSrcVectorSize,
DySrcVectorSize, DySrcVectorSize,
DxDstVectorSize, DxDstVectorSize,
ScaleSrcDstVectorSize, ScaleSrcVectorSize,
BiasDstVectorSize, DscaleDbiasDstVectorSize,
MeanVarSrcVectorSize>; MeanVarSrcVectorSize>;
if(UseMultiblockInK && arg.blkGroupSize > 1) if(UseMultiblockInK && arg.blkGroupSize > 1)
...@@ -553,7 +561,7 @@ struct DeviceBatchNormBwdImpl ...@@ -553,7 +561,7 @@ struct DeviceBatchNormBwdImpl
DyDataType, DyDataType,
AccDataType, AccDataType,
ScaleDataType, ScaleDataType,
BiasDataType, DscaleDbiasDataType,
MeanVarDataType, MeanVarDataType,
DyElementwiseOp, DyElementwiseOp,
XYGridDesc_M_K, XYGridDesc_M_K,
...@@ -568,7 +576,7 @@ struct DeviceBatchNormBwdImpl ...@@ -568,7 +576,7 @@ struct DeviceBatchNormBwdImpl
DyDataType, DyDataType,
DxDataType, DxDataType,
ScaleDataType, ScaleDataType,
BiasDataType, DscaleDbiasDataType,
MeanVarDataType, MeanVarDataType,
DyElementwiseOp, DyElementwiseOp,
XYGridDesc_M_K, XYGridDesc_M_K,
...@@ -614,8 +622,8 @@ struct DeviceBatchNormBwdImpl ...@@ -614,8 +622,8 @@ struct DeviceBatchNormBwdImpl
: static_cast<MeanVarDataType*>(arg.workspace_savedInvVar), : static_cast<MeanVarDataType*>(arg.workspace_savedInvVar),
arg.p_x_, arg.p_x_,
arg.p_dy_, arg.p_dy_,
static_cast<ScaleDataType*>(arg.workspace_reduce_dscale), static_cast<DscaleDbiasDataType*>(arg.workspace_reduce_dscale),
static_cast<BiasDataType*>(arg.workspace_reduce_dbias)); static_cast<DscaleDbiasDataType*>(arg.workspace_reduce_dbias));
avg_time += launch_and_time_kernel( avg_time += launch_and_time_kernel(
stream_config, stream_config,
...@@ -629,13 +637,13 @@ struct DeviceBatchNormBwdImpl ...@@ -629,13 +637,13 @@ struct DeviceBatchNormBwdImpl
dscale_dbias_grid_desc_m_k, dscale_dbias_grid_desc_m_k,
arg.mean_var_grid_desc_m, arg.mean_var_grid_desc_m,
arg.scale_grid_desc_m, arg.scale_grid_desc_m,
arg.bias_grid_desc_m, arg.dscale_dbias_grid_desc_m,
arg.blkGroupSize, arg.blkGroupSize,
arg.reduce_length, arg.reduce_length,
arg.numBlockTileIteration, arg.numBlockTileIteration,
numDscaleDbiasBlockTileIteration, numDscaleDbiasBlockTileIteration,
static_cast<const ScaleDataType*>(arg.workspace_reduce_dscale), static_cast<const DscaleDbiasDataType*>(arg.workspace_reduce_dscale),
static_cast<const BiasDataType*>(arg.workspace_reduce_dbias), static_cast<const DscaleDbiasDataType*>(arg.workspace_reduce_dbias),
arg.haveSavedMeanInvVar_ arg.haveSavedMeanInvVar_
? arg.p_savedMean_ ? arg.p_savedMean_
: static_cast<const MeanVarDataType*>(arg.workspace_savedMean), : static_cast<const MeanVarDataType*>(arg.workspace_savedMean),
...@@ -664,7 +672,7 @@ struct DeviceBatchNormBwdImpl ...@@ -664,7 +672,7 @@ struct DeviceBatchNormBwdImpl
DxDataType, DxDataType,
AccDataType, AccDataType,
ScaleDataType, ScaleDataType,
BiasDataType, DscaleDbiasDataType,
MeanVarDataType, MeanVarDataType,
DyElementwiseOp, DyElementwiseOp,
XYGridDesc_M_K, XYGridDesc_M_K,
...@@ -680,8 +688,8 @@ struct DeviceBatchNormBwdImpl ...@@ -680,8 +688,8 @@ struct DeviceBatchNormBwdImpl
XSrcVectorSize, XSrcVectorSize,
DySrcVectorSize, DySrcVectorSize,
DxDstVectorSize, DxDstVectorSize,
ScaleSrcDstVectorSize, ScaleSrcVectorSize,
BiasDstVectorSize, DscaleDbiasDstVectorSize,
MeanVarSrcVectorSize>; MeanVarSrcVectorSize>;
const auto kern_batchnorm_bwd = kernel_batchnorm_backward_with_blockwise_welford< const auto kern_batchnorm_bwd = kernel_batchnorm_backward_with_blockwise_welford<
...@@ -691,7 +699,7 @@ struct DeviceBatchNormBwdImpl ...@@ -691,7 +699,7 @@ struct DeviceBatchNormBwdImpl
DxDataType, DxDataType,
AccDataType, AccDataType,
ScaleDataType, ScaleDataType,
BiasDataType, DscaleDbiasDataType,
MeanVarDataType, MeanVarDataType,
DyElementwiseOp, DyElementwiseOp,
XYGridDesc_M_K, XYGridDesc_M_K,
...@@ -708,7 +716,7 @@ struct DeviceBatchNormBwdImpl ...@@ -708,7 +716,7 @@ struct DeviceBatchNormBwdImpl
arg.dy_grid_desc_m_k, arg.dy_grid_desc_m_k,
arg.dx_grid_desc_m_k, arg.dx_grid_desc_m_k,
arg.scale_grid_desc_m, arg.scale_grid_desc_m,
arg.bias_grid_desc_m, arg.dscale_dbias_grid_desc_m,
arg.mean_var_grid_desc_m, arg.mean_var_grid_desc_m,
get_reduce_count_per_thread, get_reduce_count_per_thread,
arg.reduce_length, arg.reduce_length,
...@@ -764,16 +772,16 @@ struct DeviceBatchNormBwdImpl ...@@ -764,16 +772,16 @@ struct DeviceBatchNormBwdImpl
return false; return false;
}; };
if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcDstVectorSize != 1) if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1)
return false; return false;
if(pArg_->bnBiasStrides_[NumInvariantDim - 1] != 1 && BiasDstVectorSize != 1) if(pArg_->bnDscaleDbiasStrides_[NumInvariantDim - 1] != 1 && DscaleDbiasDstVectorSize != 1)
return false; return false;
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcDstVectorSize != 0) if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0)
return false; return false;
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % BiasDstVectorSize != 0) if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % DscaleDbiasDstVectorSize != 0)
return false; return false;
if(pArg_->haveSavedMeanInvVar_) if(pArg_->haveSavedMeanInvVar_)
...@@ -806,7 +814,7 @@ struct DeviceBatchNormBwdImpl ...@@ -806,7 +814,7 @@ struct DeviceBatchNormBwdImpl
const std::array<int, NumBatchNormReduceDim> reduceDims, const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths, const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides, const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
const std::array<ck::index_t, NumInvariantDim> bnBiasStrides, const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides, const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
const void* p_x, const void* p_x,
const void* p_dy, const void* p_dy,
...@@ -826,7 +834,7 @@ struct DeviceBatchNormBwdImpl ...@@ -826,7 +834,7 @@ struct DeviceBatchNormBwdImpl
reduceDims, reduceDims,
bnScaleBiasMeanVarLengths, bnScaleBiasMeanVarLengths,
bnScaleStrides, bnScaleStrides,
bnBiasStrides, bnDscaleDbiasStrides,
bnMeanVarStrides, bnMeanVarStrides,
static_cast<const XDataType*>(p_x), static_cast<const XDataType*>(p_x),
static_cast<const DyDataType*>(p_dy), static_cast<const DyDataType*>(p_dy),
...@@ -836,8 +844,8 @@ struct DeviceBatchNormBwdImpl ...@@ -836,8 +844,8 @@ struct DeviceBatchNormBwdImpl
dy_elementwise_op, dy_elementwise_op,
epsilon, epsilon,
static_cast<DxDataType*>(p_dx), static_cast<DxDataType*>(p_dx),
static_cast<ScaleDataType*>(p_dscale), static_cast<DscaleDbiasDataType*>(p_dscale),
static_cast<BiasDataType*>(p_dbias)); static_cast<DscaleDbiasDataType*>(p_dbias));
}; };
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
...@@ -854,7 +862,7 @@ struct DeviceBatchNormBwdImpl ...@@ -854,7 +862,7 @@ struct DeviceBatchNormBwdImpl
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "XDyDxVectorDim_" << XDyDxVectorDim << ","; str << "XDyDxVectorDim_" << XDyDxVectorDim << ",";
str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcDstVectorSize << "_bias_" << BiasDstVectorSize << "_mean_var_" << MeanVarSrcVectorSize << "_Dx_" << DxDstVectorSize << ">"; str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << DscaleDbiasDstVectorSize << "_mean_var_" << MeanVarSrcVectorSize << "_Dx_" << DxDstVectorSize << ">";
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -549,6 +549,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -549,6 +549,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
float ave_time = 0; float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
{ {
#if 0
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_container_{" std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
...@@ -581,6 +582,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -581,6 +582,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I5) << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I5)
<< " ) " << std::endl; << " ) " << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i],
......
...@@ -265,7 +265,8 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -265,7 +265,8 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
N01_{N01}, N01_{N01},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op},
kraw_{K}
{ {
a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
...@@ -299,6 +300,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -299,6 +300,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
index_t kraw_;
}; };
// Invoker // Invoker
...@@ -443,6 +445,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -443,6 +445,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
return false; return false;
} }
if(arg.kraw_ % K1 != 0)
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
......
...@@ -422,7 +422,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -422,7 +422,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op},
kraw_{KRaw}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
...@@ -448,6 +449,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -448,6 +449,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
index_t kraw_;
}; };
// Invoker // Invoker
...@@ -577,6 +579,15 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -577,6 +579,15 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
return false; return false;
} }
if((arg.kraw_ % AK1 != 0 || arg.kraw_ % BK1 != 0) &&
!(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
......
...@@ -373,12 +373,20 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -373,12 +373,20 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
gemm_desc_kernel_arg_.reserve(group_count_); gemm_desc_kernel_arg_.reserve(group_count_);
skipped_group_count_ = 0;
for(std::size_t i = 0; i < gemm_descs.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
const index_t M = gemm_descs[i].M_; const index_t M = gemm_descs[i].M_;
const index_t N = gemm_descs[i].N_; const index_t N = gemm_descs[i].N_;
const index_t K = gemm_descs[i].K_; const index_t K = gemm_descs[i].K_;
if(M == 0)
{
skipped_group_count_++;
continue;
}
const index_t StrideA = gemm_descs[i].stride_A_; const index_t StrideA = gemm_descs[i].stride_A_;
const index_t StrideB = gemm_descs[i].stride_B_; const index_t StrideB = gemm_descs[i].stride_B_;
const index_t StrideC = gemm_descs[i].stride_C_; const index_t StrideC = gemm_descs[i].stride_C_;
...@@ -470,6 +478,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -470,6 +478,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
// private: // private:
index_t group_count_; index_t group_count_;
index_t skipped_group_count_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation c_element_op_; CDEElementwiseOperation c_element_op_;
...@@ -581,7 +591,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -581,7 +591,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_) if((ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) +
arg.skipped_group_count_) != arg.group_count_)
{ {
return false; return false;
} }
......
...@@ -187,6 +187,22 @@ struct AddRelu ...@@ -187,6 +187,22 @@ struct AddRelu
const float a = x0 + type_convert<float>(x1); const float a = x0 + type_convert<float>(x1);
y = a > 0.0f ? a : 0.0f; y = a > 0.0f ? a : 0.0f;
}; };
template <>
__host__ __device__ constexpr void
operator()<int, int, int8_t>(int& y, const int& x0, const int8_t& x1) const
{
const int8_t a = x0 + x1;
y = a > 0 ? a : 0;
};
template <>
__host__ __device__ constexpr void
operator()<int8_t, int8_t, int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
{
const int8_t a = x0 + x1;
y = a > 0 ? a : 0;
};
}; };
struct AddHardswish struct AddHardswish
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment