Commit 00af2988 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Merge branch 'develop' into feature/restruct-ckprofiler

parents 9a2607d6 ad541ad6
add_executable(client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp)
target_link_libraries(client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_operations)
add_executable(client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp) add_executable(client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp)
target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_operations) target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_operations)
add_executable(client_conv2d_fwd_perchannel_quantization conv2d_fwd_perchannel_quantization.cpp)
target_link_libraries(client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_operations)
add_executable(client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp) add_executable(client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp)
target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_operations) target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_operations)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iostream>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using InDataType = int8_t;
using WeiDataType = int8_t;
using BiasDataType = int32_t;
using RequantScaleDataType = float;
using OutDataType = int8_t;
using InLayout = ck::tensor_layout::convolution::GNHWC;
using WeiLayout = ck::tensor_layout::convolution::GKYXC;
using BiasLayout = ck::tensor_layout::convolution::G_K;
using RequantScaleLayout = ck::tensor_layout::convolution::G_K;
using OutLayout = ck::tensor_layout::convolution::GNHWK;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ActivationOp = ck::tensor_operation::element_wise::Relu;
using OutElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp<ActivationOp>;
static constexpr ck::index_t NumDimSpatial = 2;
static constexpr ck::index_t G = 1;
static constexpr ck::index_t N = 4;
static constexpr ck::index_t K = 64;
static constexpr ck::index_t C = 32;
static constexpr ck::index_t Y = 3;
static constexpr ck::index_t X = 3;
static constexpr ck::index_t Hi = 71;
static constexpr ck::index_t Wi = 71;
static constexpr ck::index_t Ho = 36;
static constexpr ck::index_t Wo = 36;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int main(int argc, char* argv[])
{
std::array<ck::index_t, 5> in_lengths{G, N, C, Hi, Wi};
std::array<ck::index_t, 5> in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C};
std::array<ck::index_t, 5> weight_lengths{G, K, C, Y, X};
std::array<ck::index_t, 5> weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C};
std::array<ck::index_t, 5> bias_lengths{G, N, K, Ho, Wo};
std::array<ck::index_t, 5> bias_strides{K, 0, 1, 0, 0};
std::array<ck::index_t, 5> requant_scale_lengths{G, N, K, Ho, Wo};
std::array<ck::index_t, 5> requant_scale_strides{K, 0, 1, 0, 0};
std::array<ck::index_t, 5> out_lengths{G, N, C, Ho, Wo};
std::array<ck::index_t, 5> out_strides{N * Ho * Wo * C, Ho * Wo * C, 1, Wo * C, C};
std::array<ck::index_t, 2> in_left_pad{1, 1};
std::array<ck::index_t, 2> in_right_pad{1, 1};
std::array<ck::index_t, 2> conv_strides{2, 2};
std::array<ck::index_t, 2> conv_dilations{1, 1};
SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C);
SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C);
SimpleDeviceMem bias(sizeof(BiasDataType) * K * Y * X * C);
SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K * Y * X * C);
SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K);
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
NumDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<BiasLayout, RequantScaleLayout>,
OutLayout,
InDataType,
WeiDataType,
ck::Tuple<BiasDataType, RequantScaleDataType>,
OutDataType,
PassThrough,
PassThrough,
OutElementOp>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
int best_op_id = -1;
float best_avg_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
float best_tflops = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr =
op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
wei.GetDeviceBuffer(),
{bias.GetDeviceBuffer(), requant_scale.GetDeviceBuffer()},
out.GetDeviceBuffer(),
in_lengths,
in_strides,
weight_lengths,
weight_strides,
{bias_lengths, requant_scale_lengths},
{bias_strides, requant_scale_strides},
out_lengths,
out_strides,
conv_strides,
conv_dilations,
in_left_pad,
in_right_pad,
PassThrough{},
PassThrough{},
OutElementOp{ActivationOp{}});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = G * 2 * N * K * C * Ho * Wo * Y * X;
std::size_t num_bytes = G * sizeof(InDataType) * N * Hi * Wi * C +
G * sizeof(WeiDataType) * K * Y * X * C +
G * sizeof(OutDataType) * N * Ho * Wo * K;
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_bytes / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_op_id = i;
best_op_name = op_name;
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
best_tflops = tflops;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops
<< " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
// run the best intance
{
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr =
op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
wei.GetDeviceBuffer(),
{bias.GetDeviceBuffer(), requant_scale.GetDeviceBuffer()},
out.GetDeviceBuffer(),
in_lengths,
in_strides,
weight_lengths,
weight_strides,
{bias_lengths, requant_scale_lengths},
{bias_strides, requant_scale_strides},
out_lengths,
out_strides,
conv_strides,
conv_dilations,
in_left_pad,
in_right_pad,
PassThrough{},
PassThrough{},
OutElementOp{ActivationOp{}});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return 0;
}
\ No newline at end of file
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <vector> #include <vector>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_bias_forward_perlayer_quantization.hpp" #include "ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" #include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iostream>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using InDataType = int8_t;
using WeiDataType = int8_t;
using RequantScaleDataType = float;
using OutDataType = int8_t;
using InLayout = ck::tensor_layout::convolution::GNHWC;
using WeiLayout = ck::tensor_layout::convolution::GKYXC;
using RequantScaleLayout = ck::tensor_layout::convolution::G_K;
using OutLayout = ck::tensor_layout::convolution::GNHWK;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ActivationOp = PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp<ActivationOp>;
static constexpr ck::index_t NumDimSpatial = 2;
static constexpr ck::index_t G = 1;
static constexpr ck::index_t N = 4;
static constexpr ck::index_t K = 64;
static constexpr ck::index_t C = 32;
static constexpr ck::index_t Y = 3;
static constexpr ck::index_t X = 3;
static constexpr ck::index_t Hi = 71;
static constexpr ck::index_t Wi = 71;
static constexpr ck::index_t Ho = 36;
static constexpr ck::index_t Wo = 36;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int main(int argc, char* argv[])
{
std::array<ck::index_t, 5> in_lengths{G, N, C, Hi, Wi};
std::array<ck::index_t, 5> in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C};
std::array<ck::index_t, 5> weight_lengths{G, K, C, Y, X};
std::array<ck::index_t, 5> weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C};
std::array<ck::index_t, 5> requant_scale_lengths{G, N, K, Ho, Wo};
std::array<ck::index_t, 5> requant_scale_strides{K, 0, 1, 0, 0};
std::array<ck::index_t, 5> out_lengths{G, N, C, Ho, Wo};
std::array<ck::index_t, 5> out_strides{N * Ho * Wo * C, Ho * Wo * C, 1, Wo * C, C};
std::array<ck::index_t, 2> in_left_pad{1, 1};
std::array<ck::index_t, 2> in_right_pad{1, 1};
std::array<ck::index_t, 2> conv_strides{2, 2};
std::array<ck::index_t, 2> conv_dilations{1, 1};
SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C);
SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C);
SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K * Y * X * C);
SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K);
using DeviceOp =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<NumDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<RequantScaleLayout>,
OutLayout,
InDataType,
WeiDataType,
ck::Tuple<RequantScaleDataType>,
OutDataType,
PassThrough,
PassThrough,
OutElementOp>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
int best_op_id = -1;
float best_avg_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
float best_tflops = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
wei.GetDeviceBuffer(),
{requant_scale.GetDeviceBuffer()},
out.GetDeviceBuffer(),
in_lengths,
in_strides,
weight_lengths,
weight_strides,
{requant_scale_lengths},
{requant_scale_strides},
out_lengths,
out_strides,
conv_strides,
conv_dilations,
in_left_pad,
in_right_pad,
PassThrough{},
PassThrough{},
OutElementOp{ActivationOp{}});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = G * 2 * N * K * C * Ho * Wo * Y * X;
std::size_t num_bytes = G * sizeof(InDataType) * N * Hi * Wi * C +
G * sizeof(WeiDataType) * K * Y * X * C +
G * sizeof(OutDataType) * N * Ho * Wo * K;
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_bytes / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_op_id = i;
best_op_name = op_name;
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
best_tflops = tflops;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops
<< " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
// run the best intance
{
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
wei.GetDeviceBuffer(),
{},
out.GetDeviceBuffer(),
in_lengths,
in_strides,
weight_lengths,
weight_strides,
{},
{},
out_lengths,
out_strides,
conv_strides,
conv_dilations,
in_left_pad,
in_right_pad,
PassThrough{},
PassThrough{},
OutElementOp{ActivationOp{}});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return 0;
}
\ No newline at end of file
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <vector> #include <vector>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_perlayer_quantization.hpp" #include "ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perlayer_quantization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" #include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
......
add_executable(client_batchnorm_fwd_nhwc batchnorm_fwd_nhwc.cpp) add_executable(client_batchnorm_fwd_nhwc batchnorm_fwd_nhwc.cpp)
add_executable(client_batchnorm_bwd_nhwc batchnorm_bwd_nhwc.cpp)
target_link_libraries(client_batchnorm_fwd_nhwc PRIVATE composable_kernel::device_operations) target_link_libraries(client_batchnorm_fwd_nhwc PRIVATE composable_kernel::device_operations)
target_link_libraries(client_batchnorm_bwd_nhwc PRIVATE composable_kernel::device_operations)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <functional>
#include <numeric>
#include <iomanip>
#include <iostream>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp"
using XDataType = ck::half_t;
using DxDataType = float;
using DyDataType = float;
using AccDataType = float;
using ScaleDataType = ck::half_t;
using DscaleDbiasDataType = float;
using MeanVarDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 4;
constexpr int NumBatchNormReduceDim = 3;
const double epsilon = std::numeric_limits<float>::epsilon();
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int main(int argc, char* argv[])
{
std::array<ck::index_t, Rank> xyLengths{16, 8, 128, 256};
std::array<ck::index_t, Rank> xyStrides{8 * 128 * 256, 128 * 256, 256, 1};
std::array<ck::index_t, Rank - NumBatchNormReduceDim> scaleBiasMeanVarLengths{256};
std::array<ck::index_t, Rank - NumBatchNormReduceDim> scaleBiasMeanVarStrides{1};
std::array<int, NumBatchNormReduceDim> reduceDims{0, 1, 2};
ck::index_t numXYElement =
std::accumulate(xyLengths.begin(), xyLengths.end(), 1, std::multiplies<ck::index_t>());
ck::index_t numScaleBiasMeanVarElement = std::accumulate(scaleBiasMeanVarLengths.begin(),
scaleBiasMeanVarLengths.end(),
1,
std::multiplies<ck::index_t>());
SimpleDeviceMem x(sizeof(XDataType) * numXYElement);
SimpleDeviceMem dy(sizeof(DyDataType) * numXYElement);
SimpleDeviceMem scale(sizeof(ScaleDataType) * numScaleBiasMeanVarElement);
SimpleDeviceMem mean(sizeof(MeanVarDataType) * numScaleBiasMeanVarElement);
SimpleDeviceMem invVariance(sizeof(MeanVarDataType) * numScaleBiasMeanVarElement);
SimpleDeviceMem dx(sizeof(DxDataType) * numXYElement);
SimpleDeviceMem dscale(sizeof(DscaleDbiasDataType) * numScaleBiasMeanVarElement);
SimpleDeviceMem dbias(sizeof(DscaleDbiasDataType) * numScaleBiasMeanVarElement);
using DeviceOp = ck::tensor_operation::device::DeviceBatchNormBwd<XDataType,
DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
PassThrough,
Rank,
NumBatchNormReduceDim>;
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
bool found = false;
int best_op_id = -1;
float best_ave_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(xyLengths,
xyStrides,
xyStrides,
xyStrides,
reduceDims,
scaleBiasMeanVarLengths,
scaleBiasMeanVarStrides,
scaleBiasMeanVarStrides,
scaleBiasMeanVarStrides,
x.GetDeviceBuffer(),
dy.GetDeviceBuffer(),
scale.GetDeviceBuffer(),
mean.GetDeviceBuffer(),
invVariance.GetDeviceBuffer(),
epsilon,
PassThrough{},
dx.GetDeviceBuffer(),
dscale.GetDeviceBuffer(),
dbias.GetDeviceBuffer());
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
SimpleDeviceMem workspace(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer());
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t num_bytes =
numXYElement * (sizeof(XDataType) + sizeof(DyDataType) + sizeof(DxDataType)) +
numScaleBiasMeanVarElement *
(sizeof(ScaleDataType) + sizeof(DscaleDbiasDataType) * 2 +
sizeof(MeanVarDataType) * 2);
float gb_per_sec = num_bytes / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, "
<< op_name << std::endl;
if(ave_time < best_ave_time)
{
found = true;
best_op_id = i;
best_op_name = op_name;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
if(found)
{
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, "
<< best_op_name << std::endl;
// run the best intance
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(xyLengths,
xyStrides,
xyStrides,
xyStrides,
reduceDims,
scaleBiasMeanVarLengths,
scaleBiasMeanVarStrides,
scaleBiasMeanVarStrides,
scaleBiasMeanVarStrides,
x.GetDeviceBuffer(),
dy.GetDeviceBuffer(),
scale.GetDeviceBuffer(),
mean.GetDeviceBuffer(),
invVariance.GetDeviceBuffer(),
epsilon,
PassThrough{},
dx.GetDeviceBuffer(),
dscale.GetDeviceBuffer(),
dbias.GetDeviceBuffer());
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return 0;
}
add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp)
add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp)
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using I8 = int8_t;
using I32 = int32_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ActivationOp = ck::tensor_operation::element_wise::Relu;
using CDEElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<ActivationOp>;
using ADataType = I8;
using BDataType = I8;
using AccDataType = I32;
using CShuffleDataType = I32;
using BiasDataType = I32;
using DsDataType = ck::Tuple<BiasDataType>;
using EDataType = I8;
using ALayout = Row;
using BLayout = Col;
using BiasLayout = Row;
using DsLayout = ck::Tuple<BiasLayout>;
using ELayout = Row;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle<
ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
PassThrough, // AElementwiseOperation,
PassThrough, // BElementwiseOperation,
CDEElementOp, // CDEElementwiseOperation,
GemmDefault, // GemmSpecialization GemmSpec,
1, // NumGemmKPrefetchStage,
256, // BlockSize,
256, // MPerBlock,
128, // NPerBlock,
64, // KPerBlock,
16, // AK1,
16, // BK1,
32, // MPerXDL,
32, // NPerXDL,
4, // MXdlPerWave,
2, // NXdlPerWave,
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1,
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder,
S<1, 0, 2>, // ABlockTransferSrcAccessOrder,
2, // index_t ABlockTransferSrcVectorDim,
16, // index_t ABlockTransferSrcScalarPerVector,
16, // index_t ABlockTransferDstScalarPerVector_AK1,
1, // bool ABlockLdsExtraM,
S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder,
S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder,
2, // index_t BBlockTransferSrcVectorDim,
8, // index_t BBlockTransferSrcScalarPerVector,
8, // index_t BBlockTransferDstScalarPerVector_BK1,
1, // bool BBlockLdsExtraN,
1, // index_t CShuffleMXdlPerWavePerShuffle,
1, // index_t CShuffleNXdlPerWavePerShuffle,
S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
AccDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
int main()
{
bool do_verification = true;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t K = 1024;
ck::index_t StrideA = 1024;
ck::index_t StrideB = 1024;
ck::index_t StrideBias = 0;
ck::index_t StrideE = 1024;
float requant_scale = 0.03;
auto f_host_tensor_descriptor2d =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1_uz}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1_uz, stride}));
}
};
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({len}),
std::vector<std::size_t>({stride}));
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
Tensor<BiasDataType> bias_n(f_host_tensor_descriptor1d(N, 1));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "bias_n: " << bias_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-128, 127});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-128, 127});
bias_n.GenerateTensorValue(GeneratorTensor_2<BiasDataType>{-128, 127});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem bias_device_buf(sizeof(BiasDataType) * bias_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
bias_device_buf.ToDevice(bias_n.mData.data());
auto a_element_op = PassThrough{};
auto b_element_op = PassThrough{};
auto cde_element_op = CDEElementOp{requant_scale, ActivationOp{}};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
{bias_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
{StrideBias},
StrideE,
a_element_op,
b_element_op,
cde_element_op);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification)
{
Tensor<AccDataType> c_m_n(HostTensorDescriptor{M, N});
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), bias_n(n));
}
}
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
}
return 0;
}
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
...@@ -22,50 +22,59 @@ ...@@ -22,50 +22,59 @@
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using I8 = int8_t;
using I32 = int32_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ActivationOp = ck::tensor_operation::element_wise::Relu; using ActivationOp = PassThrough;
using CElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<ActivationOp>; using CDEElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<ActivationOp>;
using ADataType = int8_t; using ADataType = I8;
using BDataType = int8_t; using BDataType = I8;
using CDataType = int8_t; using AccDataType = I32;
using AccDataType = int32_t; using CShuffleDataType = I32;
using CShuffleDataType = float; using DsDataType = ck::Tuple<>;
using EDataType = I8;
using ALayout = ck::tensor_layout::gemm::RowMajor; using ALayout = Row;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = Col;
using CLayout = ck::tensor_layout::gemm::RowMajor; using DsLayout = ck::Tuple<>;
using ELayout = Row;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle< using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle<
ALayout, // typename ALayout, ALayout,
BLayout, // typename BLayout, BLayout,
CLayout, // typename CLayout, DsLayout,
ADataType, // typename ADataType, ELayout,
BDataType, // typename BDataType, ADataType,
CDataType, // typename CDataType, BDataType,
AccDataType, // typename GemmAccDataType, AccDataType,
CShuffleDataType, // typename CShuffleDataType, CShuffleDataType,
PassThrough, // typename AElementwiseOperation, DsDataType,
PassThrough, // typename BElementwiseOperation, EDataType,
CElementOp, // typename CElementwiseOperation, PassThrough, // AElementwiseOperation,
PassThrough, // BElementwiseOperation,
CDEElementOp, // CDEElementwiseOperation,
GemmDefault, // GemmSpecialization GemmSpec, GemmDefault, // GemmSpecialization GemmSpec,
1, // index_t NumGemmKPrefetchStage, 1, // NumGemmKPrefetchStage,
256, // index_t BlockSize, 256, // BlockSize,
256, // index_t MPerBlock, 256, // MPerBlock,
128, // index_t NPerBlock, 128, // NPerBlock,
64, // index_t KPerBlock, 64, // KPerBlock,
16, // index_t AK1, 16, // AK1,
16, // index_t BK1, 16, // BK1,
32, // index_t MPerXDL, 32, // MPerXDL,
32, // index_t NPerXDL, 32, // NPerXDL,
4, // index_t MXdlPerWave, 4, // MXdlPerWave,
2, // index_t NXdlPerWave, 2, // NXdlPerWave,
S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1, S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1,
S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder, S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder,
S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder, S<1, 0, 2>, // ABlockTransferSrcAccessOrder,
2, // index_t ABlockTransferSrcVectorDim, 2, // index_t ABlockTransferSrcVectorDim,
16, // index_t ABlockTransferSrcScalarPerVector, 16, // index_t ABlockTransferSrcScalarPerVector,
16, // index_t ABlockTransferDstScalarPerVector_AK1, 16, // index_t ABlockTransferDstScalarPerVector_AK1,
...@@ -84,53 +93,23 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle ...@@ -84,53 +93,23 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, float, PassThrough, PassThrough, CElementOp>; ReferenceGemm<ADataType, BDataType, EDataType, float, PassThrough, PassThrough, CDEElementOp>;
int main(int argc, char* argv[]) int main()
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1;
bool time_kernel = false; bool time_kernel = false;
// GEMM shape // GEMM shape
ck::index_t M = 3840; ck::index_t M = 1024;
ck::index_t N = 4096; ck::index_t N = 1024;
ck::index_t K = 4096; ck::index_t K = 1024;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
float quant_multiplier = 0.03;
if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); ck::index_t StrideA = 1024;
N = std::stoi(argv[5]); ck::index_t StrideB = 1024;
K = std::stoi(argv[6]); ck::index_t StrideE = 1024;
StrideA = std::stoi(argv[7]); float requant_scale = 0.03;
StrideB = std::stoi(argv[8]);
StrideC = std::stoi(argv[9]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0);
}
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
...@@ -138,61 +117,56 @@ int main(int argc, char* argv[]) ...@@ -138,61 +117,56 @@ int main(int argc, char* argv[])
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value) if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{ {
return HostTensorDescriptor({row, col}, {stride, 1_uz}); return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1_uz}));
} }
else else
{ {
return HostTensorDescriptor({row, col}, {1_uz, stride}); return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1_uz, stride}));
} }
}; };
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(init_method) a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-128, 127});
{ b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-128, 127});
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
}
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = PassThrough{}; auto a_element_op = PassThrough{};
auto b_element_op = PassThrough{}; auto b_element_op = PassThrough{};
auto c_element_op = CElementOp{quant_multiplier, ActivationOp{}}; auto cde_element_op = CDEElementOp{requant_scale, ActivationOp{}};
// do GEMM // do GEMM
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()), b_device_buf.GetDeviceBuffer(),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()), {},
e_device_buf.GetDeviceBuffer(),
M, M,
N, N,
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, {},
StrideE,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); cde_element_op);
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
...@@ -205,7 +179,7 @@ int main(int argc, char* argv[]) ...@@ -205,7 +179,7 @@ int main(int argc, char* argv[])
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -214,7 +188,7 @@ int main(int argc, char* argv[]) ...@@ -214,7 +188,7 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl; << gemm.GetTypeString() << std::endl;
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification) if(do_verification)
{ {
...@@ -222,11 +196,11 @@ int main(int argc, char* argv[]) ...@@ -222,11 +196,11 @@ int main(int argc, char* argv[])
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument( auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); a_m_k, b_k_n, e_m_n_host_result, a_element_op, b_element_op, cde_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result) ? 0 : 1; return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
} }
return 0; return 0;
......
add_example_executable(example_gemm_xdl_relu_quantization_int8 gemm_xdl_relu_quantization_int8.cpp)
\ No newline at end of file
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp" #include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward_nhwc_c.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp"
static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'}, static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'},
...@@ -106,7 +106,7 @@ class BatchNormBwdArg ...@@ -106,7 +106,7 @@ class BatchNormBwdArg
using namespace ck; using namespace ck;
template <typename InOutDataType, typename AccDataType, bool UseMultiblockInK> template <typename XDataType, typename AccDataType, bool UseMultiblockInK>
bool bnorm_bwd_nhwc_test(bool do_verification, bool bnorm_bwd_nhwc_test(bool do_verification,
int init_method, int init_method,
bool time_kernel, bool time_kernel,
...@@ -118,13 +118,15 @@ bool bnorm_bwd_nhwc_test(bool do_verification, ...@@ -118,13 +118,15 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
constexpr index_t Rank = 4; constexpr index_t Rank = 4;
constexpr index_t NumReduceDim = 3; constexpr index_t NumReduceDim = 3;
using ScaleDataType = XDataType;
const std::vector<size_t> scaleBiasMeanVarLengths = {inOutLengths[3]}; const std::vector<size_t> scaleBiasMeanVarLengths = {inOutLengths[3]};
// input data of the batchnorm backward algorithm // input data of the batchnorm backward algorithm
Tensor<InOutDataType> x(inOutLengths); Tensor<XDataType> x(inOutLengths);
Tensor<InOutDataType> dy(inOutLengths); Tensor<AccDataType> dy(inOutLengths);
Tensor<AccDataType> bnScale(scaleBiasMeanVarLengths); Tensor<ScaleDataType> bnScale(scaleBiasMeanVarLengths);
Tensor<AccDataType> savedMean(scaleBiasMeanVarLengths); Tensor<AccDataType> savedMean(scaleBiasMeanVarLengths);
Tensor<AccDataType> savedInvVar(scaleBiasMeanVarLengths); Tensor<AccDataType> savedInvVar(scaleBiasMeanVarLengths);
...@@ -132,8 +134,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification, ...@@ -132,8 +134,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
Tensor<AccDataType> savedVariance(scaleBiasMeanVarLengths); Tensor<AccDataType> savedVariance(scaleBiasMeanVarLengths);
// output data of the batchnorm backward algorithm // output data of the batchnorm backward algorithm
Tensor<InOutDataType> dx_ref(inOutLengths); Tensor<AccDataType> dx_ref(inOutLengths);
Tensor<InOutDataType> dx(inOutLengths); Tensor<AccDataType> dx(inOutLengths);
Tensor<AccDataType> dscale(scaleBiasMeanVarLengths); Tensor<AccDataType> dscale(scaleBiasMeanVarLengths);
Tensor<AccDataType> dbias(scaleBiasMeanVarLengths); Tensor<AccDataType> dbias(scaleBiasMeanVarLengths);
...@@ -153,7 +155,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification, ...@@ -153,7 +155,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
const float noise_stddev = 0.0001f; const float noise_stddev = 0.0001f;
// input data in normal distribution // input data in normal distribution
x.GenerateTensorValue(GeneratorTensor_4<InOutDataType>{x_mean, x_stddev}, num_thread); x.GenerateTensorValue(GeneratorTensor_4<XDataType>{x_mean, x_stddev}, num_thread);
// initialize the savedMean to be values with tiny variation to the mean of the x values // initialize the savedMean to be values with tiny variation to the mean of the x values
savedMean.GenerateTensorValue(GeneratorTensor_4<AccDataType>{x_mean, noise_stddev}, savedMean.GenerateTensorValue(GeneratorTensor_4<AccDataType>{x_mean, noise_stddev},
...@@ -182,7 +184,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification, ...@@ -182,7 +184,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
const float x_stddev = 1.0f; const float x_stddev = 1.0f;
// input data in normal distribution // input data in normal distribution
x.GenerateTensorValue(GeneratorTensor_4<InOutDataType>{x_mean, x_stddev}, num_thread); x.GenerateTensorValue(GeneratorTensor_4<XDataType>{x_mean, x_stddev}, num_thread);
}; };
if(do_verification) if(do_verification)
...@@ -190,34 +192,34 @@ bool bnorm_bwd_nhwc_test(bool do_verification, ...@@ -190,34 +192,34 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
switch(init_method) switch(init_method)
{ {
case 0: case 0:
dy.GenerateTensorValue(GeneratorTensor_0<InOutDataType>{}, num_thread); dy.GenerateTensorValue(GeneratorTensor_0<AccDataType>{}, num_thread);
bnScale.GenerateTensorValue(GeneratorTensor_0<InOutDataType>{}, num_thread); bnScale.GenerateTensorValue(GeneratorTensor_0<ScaleDataType>{}, num_thread);
break; break;
case 1: case 1:
dy.GenerateTensorValue(GeneratorTensor_1<InOutDataType>{1}, num_thread); dy.GenerateTensorValue(GeneratorTensor_1<AccDataType>{1}, num_thread);
bnScale.GenerateTensorValue(GeneratorTensor_1<InOutDataType>{1}, num_thread); bnScale.GenerateTensorValue(GeneratorTensor_1<ScaleDataType>{1}, num_thread);
break; break;
case 2: case 2:
dy.GenerateTensorValue(GeneratorTensor_2<InOutDataType>{-5, 5}, num_thread); dy.GenerateTensorValue(GeneratorTensor_2<AccDataType>{-2, 2}, num_thread);
bnScale.GenerateTensorValue(GeneratorTensor_2<InOutDataType>{-5, 5}, num_thread); bnScale.GenerateTensorValue(GeneratorTensor_2<ScaleDataType>{-5, 5}, num_thread);
break; break;
default: default:
dy.GenerateTensorValue(GeneratorTensor_3<InOutDataType>{-0.2f, 0.2f}, num_thread); dy.GenerateTensorValue(GeneratorTensor_3<AccDataType>{-0.2f, 0.2f}, num_thread);
bnScale.GenerateTensorValue(GeneratorTensor_3<InOutDataType>{-0.5f, 0.5f}, num_thread); bnScale.GenerateTensorValue(GeneratorTensor_3<ScaleDataType>{-0.5f, 0.5f}, num_thread);
} }
}; };
// input data of the batchnorm backward algorithm // input data of the batchnorm backward algorithm
DeviceMem x_dev(sizeof(InOutDataType) * x.mDesc.GetElementSpaceSize()); DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem dy_dev(sizeof(InOutDataType) * dy.mDesc.GetElementSpaceSize()); DeviceMem dy_dev(sizeof(AccDataType) * dy.mDesc.GetElementSpaceSize());
DeviceMem bnScale_dev(sizeof(AccDataType) * bnScale.mDesc.GetElementSpaceSize()); DeviceMem bnScale_dev(sizeof(ScaleDataType) * bnScale.mDesc.GetElementSpaceSize());
DeviceMem savedMean_dev(sizeof(AccDataType) * savedMean.mDesc.GetElementSpaceSize()); DeviceMem savedMean_dev(sizeof(AccDataType) * savedMean.mDesc.GetElementSpaceSize());
DeviceMem savedInvVar_dev(sizeof(AccDataType) * savedInvVar.mDesc.GetElementSpaceSize()); DeviceMem savedInvVar_dev(sizeof(AccDataType) * savedInvVar.mDesc.GetElementSpaceSize());
// output data of the batchnorm backward algorithm // output data of the batchnorm backward algorithm
DeviceMem dx_dev(sizeof(InOutDataType) * dx.mDesc.GetElementSpaceSize()); DeviceMem dx_dev(sizeof(AccDataType) * dx.mDesc.GetElementSpaceSize());
DeviceMem dscale_dev(sizeof(AccDataType) * dscale.mDesc.GetElementSpaceSize()); DeviceMem dscale_dev(sizeof(AccDataType) * dscale.mDesc.GetElementSpaceSize());
DeviceMem dbias_dev(sizeof(AccDataType) * dbias.mDesc.GetElementSpaceSize()); DeviceMem dbias_dev(sizeof(AccDataType) * dbias.mDesc.GetElementSpaceSize());
...@@ -249,13 +251,13 @@ bool bnorm_bwd_nhwc_test(bool do_verification, ...@@ -249,13 +251,13 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
using PassThroughOp = ck::tensor_operation::element_wise::PassThrough; using PassThroughOp = ck::tensor_operation::element_wise::PassThrough;
using DeviceBatchNormBwdInstance = using DeviceBatchNormBwdInstance =
ck::tensor_operation::device::DeviceBatchNormBwdImpl<InOutDataType, ck::tensor_operation::device::DeviceBatchNormBwdImpl<XDataType,
InOutDataType, AccDataType,
InOutDataType, AccDataType,
AccDataType, AccDataType,
AccDataType, // ScaleDataType ScaleDataType, // ScaleDataType
AccDataType, // BiasDataType AccDataType, // DscaleDbiasDataType
AccDataType, // MeanVarDataType AccDataType, // MeanVarDataType
PassThroughOp, PassThroughOp,
Rank, Rank,
NumReduceDim, NumReduceDim,
...@@ -269,8 +271,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification, ...@@ -269,8 +271,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
1, // XSrcVectorSize 1, // XSrcVectorSize
1, // DySrcVectorSize 1, // DySrcVectorSize
1, // DxDstVectorSize 1, // DxDstVectorSize
1, // ScaleSrcDstVectorSize 1, // ScaleSrcVectorSize
1, // BiasDstVectorSize 1, // DscaleDbiasDstVectorSize
1>; // MeanVarSrcVectorSize 1>; // MeanVarSrcVectorSize
auto batchnorm_bwd = DeviceBatchNormBwdInstance{}; auto batchnorm_bwd = DeviceBatchNormBwdInstance{};
...@@ -324,7 +326,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification, ...@@ -324,7 +326,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
// inputing of x, dy, scale, outputing of dx, dscale, dbias // inputing of x, dy, scale, outputing of dx, dscale, dbias
num_bytes += num_bytes +=
total_length * sizeof(InOutDataType) * 3 + invariant_length * sizeof(AccDataType) * 3; total_length * sizeof(XDataType) * 3 + invariant_length * sizeof(AccDataType) * 3;
// outputing of mean, inv-variance // outputing of mean, inv-variance
num_bytes += haveSavedMeanInvVar ? invariant_length * sizeof(AccDataType) * 2 : 0; num_bytes += haveSavedMeanInvVar ? invariant_length * sizeof(AccDataType) * 2 : 0;
...@@ -341,14 +343,16 @@ bool bnorm_bwd_nhwc_test(bool do_verification, ...@@ -341,14 +343,16 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
if(do_verification) if(do_verification)
{ {
using ReferenceBatchNormBwdInstance = using ReferenceBatchNormBwdInstance =
ck::tensor_operation::host::ReferenceBatchNormBwd_Input_N_H_W_C_Output_C<InOutDataType, ck::tensor_operation::host::ReferenceBatchNormBwd<XDataType,
InOutDataType, AccDataType,
InOutDataType, AccDataType,
AccDataType, AccDataType,
AccDataType, ScaleDataType, // ScaleDataType
AccDataType, AccDataType,
AccDataType, AccDataType,
PassThroughOp>; PassThroughOp,
Rank,
NumReduceDim>;
auto batchNormBwd_ref = ReferenceBatchNormBwdInstance{}; auto batchNormBwd_ref = ReferenceBatchNormBwdInstance{};
...@@ -390,8 +394,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification, ...@@ -390,8 +394,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
dbias_dev.FromDevice(dbias.data()); dbias_dev.FromDevice(dbias.data());
// clang-format off // clang-format off
pass = pass && ck::utils::check_err(dbias.mData, dbias_ref.mData, "dBias result:", 1e-5, 1e-5); pass = pass && ck::utils::check_err(dbias.mData, dbias_ref.mData, "dBias result:", 2e-4, 2e-4);
pass = pass && ck::utils::check_err(dscale.mData, dscale_ref.mData, "dScale result:", 1e-5, 2e-4); pass = pass && ck::utils::check_err(dscale.mData, dscale_ref.mData, "dScale result:", 2e-4, 2e-4);
pass = pass && ck::utils::check_err(dx.mData, dx_ref.mData, "dx result:"); pass = pass && ck::utils::check_err(dx.mData, dx_ref.mData, "dx result:");
// clang-format on // clang-format on
}; };
......
add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp)
add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp) add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp)
add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp) add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
using InDataType = int8_t;
using WeiDataType = int8_t;
using BiasDataType = int32_t;
using RequantScaleDataType = float;
using AccDataType = int32_t;
using CShuffleDataType = int32_t;
using OutDataType = int8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using ActivationOp = ck::tensor_operation::element_wise::Relu;
using OutElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp<ActivationOp>;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename BiasLayout,
typename RequantScaleLayout,
typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<
NDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<BiasLayout, RequantScaleLayout>,
OutLayout,
InDataType,
WeiDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<BiasDataType, RequantScaleDataType>,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, //
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
64, // KPerBlock
16, // AK1
16, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
16, // ABlockTransferSrcScalarPerVector
16, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
16, // BBlockTransferSrcScalarPerVector
16, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1,
1,
S<1, 64, 1, 4>,
8>;
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementOp,
typename WeiElementOp,
typename OutElementOp,
typename DeviceConvNDFwdInstance>
bool run_grouped_conv_fwd(bool do_verification,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
const HostTensorDescriptor& in_g_n_c_wis_desc,
const HostTensorDescriptor& wei_g_k_c_xs_desc,
const HostTensorDescriptor& bias_g_k_desc,
const HostTensorDescriptor& requant_scale_g_k_desc,
const HostTensorDescriptor& out_g_n_k_wos_desc,
const InElementOp& in_element_op,
const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op)
{
Tensor<InDataType> in(in_g_n_c_wis_desc);
Tensor<WeiDataType> wei(wei_g_k_c_xs_desc);
Tensor<BiasDataType> bias(bias_g_k_desc);
Tensor<RequantScaleDataType> requant_scale(requant_scale_g_k_desc);
Tensor<OutDataType> out_host(out_g_n_k_wos_desc);
Tensor<OutDataType> out_device(out_g_n_k_wos_desc);
std::cout << "in: " << in.mDesc << std::endl;
std::cout << "wei: " << wei.mDesc << std::endl;
std::cout << "bias: " << bias.mDesc << std::endl;
std::cout << "requant_scale: " << requant_scale.mDesc << std::endl;
std::cout << "out: " << out_host.mDesc << std::endl;
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-128, 127});
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-128, 127});
bias.GenerateTensorValue(GeneratorTensor_2<BiasDataType>{-128, 127});
requant_scale.GenerateTensorValue(GeneratorTensor_2<RequantScaleDataType>{0, 1});
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize());
DeviceMem bias_device_buf(sizeof(BiasDataType) * bias.mDesc.GetElementSpaceSize());
DeviceMem requant_scale_device_buf(sizeof(RequantScaleDataType) *
requant_scale.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(in.mData.data());
wei_device_buf.ToDevice(wei.mData.data());
bias_device_buf.ToDevice(bias.mData.data());
requant_scale_device_buf.ToDevice(requant_scale.mData.data());
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{};
std::array<ck::index_t, NDimSpatial + 3> d0_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> d0_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial + 3> d1_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> d1_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{};
std::array<ck::index_t, NDimSpatial> input_right_pads{};
auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths);
copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides);
copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths);
copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
copy(bias_g_k_desc.GetLengths(), d0_g_n_k_wos_lengths);
copy(bias_g_k_desc.GetStrides(), d0_g_n_k_wos_strides);
copy(requant_scale_g_k_desc.GetLengths(), d1_g_n_k_wos_lengths);
copy(requant_scale_g_k_desc.GetStrides(), d1_g_n_k_wos_strides);
copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths);
copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads);
copy(conv_param.input_right_pads_, input_right_pads);
// do Conv
auto conv = DeviceConvNDFwdInstance{};
auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(
in_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(),
{bias_device_buf.GetDeviceBuffer(), requant_scale_device_buf.GetDeviceBuffer()},
out_device_buf.GetDeviceBuffer(),
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
{d0_g_n_k_wos_lengths, d1_g_n_k_wos_lengths},
{d0_g_n_k_wos_strides, d1_g_n_k_wos_strides},
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
if(!conv.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem");
}
float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = conv_param.GetFlops();
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< conv.GetTypeString() << std::endl;
bool pass = true;
if(do_verification)
{
Tensor<CShuffleDataType> c_host(out_g_n_k_wos_desc);
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InDataType,
WeiDataType,
CShuffleDataType,
InElementOp,
WeiElementOp,
PassThrough>();
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in,
wei,
c_host,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_,
in_element_op,
wei_element_op,
PassThrough{});
ref_invoker.Run(ref_argument);
// TODO: implement elementwise operation for host
out_host.ForEach([&](auto&, auto idx) {
out_element_op(out_host(idx), c_host(idx), bias(idx), requant_scale(idx));
});
out_device_buf.FromDevice(out_device.mData.data());
pass &=
ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
}
return (pass ? 0 : 1);
}
int main()
{
bool do_verification = true;
bool time_kernel = true;
const ck::index_t ndim_spatial = 2;
ck::utils::conv::ConvParam conv_param{
ndim_spatial, // n_dim
1, // group
4, // batch
64, // output channels
32, // input chanels
{3, 3}, // weight HW
{71, 71}, // x HW
{2, 2}, // strides
{1, 1}, // dilations
{1, 1}, // left_pads
{1, 1} // right_pads
};
const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{};
const auto out_element_op = OutElementOp{ActivationOp{}};
using InLayout = ck::tensor_layout::convolution::GNHWC;
using WeiLayout = ck::tensor_layout::convolution::GKYXC;
using BiasLayout = ck::tensor_layout::convolution::G_K;
using RequantScaleLayout = ck::tensor_layout::convolution::G_K;
using OutLayout = ck::tensor_layout::convolution::GNHWK;
const auto in_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
const auto wei_g_k_c_xs_desc =
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
// TODO - make_bias_host_tensor_descriptor_g_n_k_wos_packed()
const auto bias_g_k_desc = HostTensorDescriptor({conv_param.G_,
conv_param.N_,
conv_param.K_,
conv_param.output_spatial_lengths_[0],
conv_param.output_spatial_lengths_[1]},
{
conv_param.K_, // g
0, // n
1, // k
0, // ho
0 // wo
});
const auto requant_scale_g_k_desc = bias_g_k_desc;
const auto out_g_n_k_wos_desc =
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
std::cout << out_g_n_k_wos_desc << std::endl;
using deviceOp = DeviceGroupedConvNDFwdInstance<ndim_spatial,
InLayout,
WeiLayout,
BiasLayout,
RequantScaleLayout,
OutLayout>;
return run_grouped_conv_fwd<ndim_spatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
deviceOp>(do_verification,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
bias_g_k_desc,
requant_scale_g_k_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op,
out_element_op);
}
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
...@@ -163,26 +164,25 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -163,26 +164,25 @@ bool run_grouped_conv_fwd(bool do_verification,
// do Conv // do Conv
auto conv = DeviceConvNDFwdInstance{}; auto conv = DeviceConvNDFwdInstance{};
auto invoker = conv.MakeInvoker(); auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument( auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(),
in_device_buf.GetDeviceBuffer(), wei_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(), {bias_device_buf.GetDeviceBuffer()},
std::array<const void*, 1>{bias_device_buf.GetDeviceBuffer()}, out_device_buf.GetDeviceBuffer(),
out_device_buf.GetDeviceBuffer(), a_g_n_c_wis_lengths,
a_g_n_c_wis_lengths, a_g_n_c_wis_strides,
a_g_n_c_wis_strides, b_g_k_c_xs_lengths,
b_g_k_c_xs_lengths, b_g_k_c_xs_strides,
b_g_k_c_xs_strides, {d0_g_n_k_wos_lengths},
std::array<std::array<ck::index_t, NDimSpatial + 3>, 1>{{d0_g_n_k_wos_lengths}}, {d0_g_n_k_wos_strides},
std::array<std::array<ck::index_t, NDimSpatial + 3>, 1>{{d0_g_n_k_wos_strides}}, e_g_n_k_wos_lengths,
e_g_n_k_wos_lengths, e_g_n_k_wos_strides,
e_g_n_k_wos_strides, conv_filter_strides,
conv_filter_strides, conv_filter_dilations,
conv_filter_dilations, input_left_pads,
input_left_pads, input_right_pads,
input_right_pads, in_element_op,
in_element_op, wei_element_op,
wei_element_op, out_element_op);
out_element_op);
if(!conv.IsSupportedArgument(argument)) if(!conv.IsSupportedArgument(argument))
{ {
...@@ -235,8 +235,8 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -235,8 +235,8 @@ bool run_grouped_conv_fwd(bool do_verification,
out_device_buf.FromDevice(out_device.mData.data()); out_device_buf.FromDevice(out_device.mData.data());
pass &= ck::utils::check_err( pass &=
out_device.mData, out_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f); ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
} }
return (pass ? 0 : 1); return (pass ? 0 : 1);
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
...@@ -150,14 +151,14 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -150,14 +151,14 @@ bool run_grouped_conv_fwd(bool do_verification,
auto invoker = conv.MakeInvoker(); auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(), wei_device_buf.GetDeviceBuffer(),
std::array<const void*, 0>{}, {},
out_device_buf.GetDeviceBuffer(), out_device_buf.GetDeviceBuffer(),
a_g_n_c_wis_lengths, a_g_n_c_wis_lengths,
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
std::array<std::array<ck::index_t, NDimSpatial + 3>, 0>{{}}, {},
std::array<std::array<ck::index_t, NDimSpatial + 3>, 0>{{}}, {},
e_g_n_k_wos_lengths, e_g_n_k_wos_lengths,
e_g_n_k_wos_strides, e_g_n_k_wos_strides,
conv_filter_strides, conv_filter_strides,
...@@ -213,8 +214,8 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -213,8 +214,8 @@ bool run_grouped_conv_fwd(bool do_verification,
out_device_buf.FromDevice(out_device.mData.data()); out_device_buf.FromDevice(out_device.mData.data());
pass &= ck::utils::check_err( pass &=
out_device.mData, out_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f); ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
} }
return (pass ? 0 : 1); return (pass ? 0 : 1);
......
...@@ -13,7 +13,16 @@ namespace ck { ...@@ -13,7 +13,16 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <index_t Rank, index_t NumBatchNormReduceDim, typename DyElementwiseOp> template <typename XDataType,
typename DxDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
struct DeviceBatchNormBwd : public BaseOperator struct DeviceBatchNormBwd : public BaseOperator
{ {
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim; static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
...@@ -26,7 +35,7 @@ struct DeviceBatchNormBwd : public BaseOperator ...@@ -26,7 +35,7 @@ struct DeviceBatchNormBwd : public BaseOperator
const std::array<int, NumBatchNormReduceDim> reduceDims, const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths, const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides, const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
const std::array<ck::index_t, NumInvariantDim> bnBiasStrides, const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides, const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
const void* p_x, const void* p_x,
const void* p_dy, const void* p_dy,
...@@ -42,9 +51,26 @@ struct DeviceBatchNormBwd : public BaseOperator ...@@ -42,9 +51,26 @@ struct DeviceBatchNormBwd : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <index_t Rank, index_t NumBatchNormReduceDim, typename DyElementwiseOp> template <typename XDataType,
using DeviceBatchNormBwdPtr = typename DxDataType,
std::unique_ptr<DeviceBatchNormBwd<Rank, NumBatchNormReduceDim, DyElementwiseOp>>; typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
using DeviceBatchNormBwdPtr = std::unique_ptr<DeviceBatchNormBwd<XDataType,
DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
Rank,
NumBatchNormReduceDim>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -27,7 +27,7 @@ template <typename XDataType, ...@@ -27,7 +27,7 @@ template <typename XDataType,
typename DyDataType, typename DyDataType,
typename AccDataType, typename AccDataType,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename DscaleDbiasDataType,
typename MeanVarDataType, typename MeanVarDataType,
typename DyElementwiseOp, typename DyElementwiseOp,
index_t Rank, index_t Rank,
...@@ -42,11 +42,19 @@ template <typename XDataType, ...@@ -42,11 +42,19 @@ template <typename XDataType,
index_t XSrcVectorSize, index_t XSrcVectorSize,
index_t DySrcVectorSize, index_t DySrcVectorSize,
index_t DxDstVectorSize, index_t DxDstVectorSize,
index_t ScaleSrcDstVectorSize, index_t ScaleSrcVectorSize,
index_t BiasDstVectorSize, index_t DscaleDbiasDstVectorSize,
index_t MeanVarSrcVectorSize> index_t MeanVarSrcVectorSize>
struct DeviceBatchNormBwdImpl struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<XDataType,
: public DeviceBatchNormBwd<Rank, NumBatchNormReduceDim, DyElementwiseOp> DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
Rank,
NumBatchNormReduceDim>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
...@@ -194,7 +202,7 @@ struct DeviceBatchNormBwdImpl ...@@ -194,7 +202,7 @@ struct DeviceBatchNormBwdImpl
const std::array<int, NumBatchNormReduceDim> reduceDims, const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths, const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides, const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
const std::array<ck::index_t, NumInvariantDim> bnBiasStrides, const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides, const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
const XDataType* p_x, const XDataType* p_x,
const DyDataType* p_dy, const DyDataType* p_dy,
...@@ -204,11 +212,11 @@ struct DeviceBatchNormBwdImpl ...@@ -204,11 +212,11 @@ struct DeviceBatchNormBwdImpl
const DyElementwiseOp dy_elementwise_op, const DyElementwiseOp dy_elementwise_op,
double epsilon, double epsilon,
DxDataType* p_dx, DxDataType* p_dx,
ScaleDataType* p_dscale, DscaleDbiasDataType* p_dscale,
BiasDataType* p_dbias) DscaleDbiasDataType* p_dbias)
: bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths), : bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
bnScaleStrides_(bnScaleStrides), bnScaleStrides_(bnScaleStrides),
bnBiasStrides_(bnBiasStrides), bnDscaleDbiasStrides_(bnDscaleDbiasStrides),
bnMeanVarStrides_(bnMeanVarStrides), bnMeanVarStrides_(bnMeanVarStrides),
p_x_(p_x), p_x_(p_x),
p_dy_(p_dy), p_dy_(p_dy),
...@@ -272,8 +280,8 @@ struct DeviceBatchNormBwdImpl ...@@ -272,8 +280,8 @@ struct DeviceBatchNormBwdImpl
MakeXY2dDescriptor(xyLengths_, dxStrides_, blkGroupSize, numBlockTileIteration); MakeXY2dDescriptor(xyLengths_, dxStrides_, blkGroupSize, numBlockTileIteration);
scale_grid_desc_m = scale_grid_desc_m =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides); MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides);
bias_grid_desc_m = dscale_dbias_grid_desc_m =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnBiasStrides); MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnDscaleDbiasStrides);
mean_var_grid_desc_m = mean_var_grid_desc_m =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides); MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides);
} }
...@@ -289,7 +297,7 @@ struct DeviceBatchNormBwdImpl ...@@ -289,7 +297,7 @@ struct DeviceBatchNormBwdImpl
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths_; std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides_; std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides_; std::array<index_t, Rank - NumBatchNormReduceDim> bnDscaleDbiasStrides_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides_; std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides_;
const XDataType* p_x_; const XDataType* p_x_;
...@@ -299,8 +307,8 @@ struct DeviceBatchNormBwdImpl ...@@ -299,8 +307,8 @@ struct DeviceBatchNormBwdImpl
const MeanVarDataType* p_savedInvVar_; const MeanVarDataType* p_savedInvVar_;
const DyElementwiseOp dy_elementwise_op_; const DyElementwiseOp dy_elementwise_op_;
DxDataType* p_dx_; DxDataType* p_dx_;
ScaleDataType* p_dscale_; DscaleDbiasDataType* p_dscale_;
BiasDataType* p_dbias_; DscaleDbiasDataType* p_dbias_;
long_index_t invariant_length; long_index_t invariant_length;
long_index_t reduce_length; long_index_t reduce_length;
...@@ -313,7 +321,7 @@ struct DeviceBatchNormBwdImpl ...@@ -313,7 +321,7 @@ struct DeviceBatchNormBwdImpl
XYGridDesc_M_K dy_grid_desc_m_k; XYGridDesc_M_K dy_grid_desc_m_k;
XYGridDesc_M_K dx_grid_desc_m_k; XYGridDesc_M_K dx_grid_desc_m_k;
ScaleBiasGridDesc_M scale_grid_desc_m; ScaleBiasGridDesc_M scale_grid_desc_m;
ScaleBiasGridDesc_M bias_grid_desc_m; ScaleBiasGridDesc_M dscale_dbias_grid_desc_m;
MeanVarGridDesc_M mean_var_grid_desc_m; MeanVarGridDesc_M mean_var_grid_desc_m;
void* workspace_mean; void* workspace_mean;
...@@ -337,11 +345,11 @@ struct DeviceBatchNormBwdImpl ...@@ -337,11 +345,11 @@ struct DeviceBatchNormBwdImpl
{ {
// workspace for the partial reduced result for dscale // workspace for the partial reduced result for dscale
workspace_size += workspace_size +=
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(ScaleDataType) + 64; pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType) + 64;
// workspace for the partial reduced result for dbias // workspace for the partial reduced result for dbias
workspace_size += workspace_size +=
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(BiasDataType) + 64; pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType) + 64;
if(!pArg_->haveSavedMeanInvVar_) if(!pArg_->haveSavedMeanInvVar_)
{ {
...@@ -379,7 +387,7 @@ struct DeviceBatchNormBwdImpl ...@@ -379,7 +387,7 @@ struct DeviceBatchNormBwdImpl
// setup buffer for the partial reduced result for dscale // setup buffer for the partial reduced result for dscale
pArg_->workspace_reduce_dscale = pArg_->p_workspace_; pArg_->workspace_reduce_dscale = pArg_->p_workspace_;
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(ScaleDataType); space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType);
space_sz = math::integer_least_multiple(space_sz, 64); space_sz = math::integer_least_multiple(space_sz, 64);
// setup buffer for the partial reduced result for dbias // setup buffer for the partial reduced result for dbias
...@@ -388,7 +396,7 @@ struct DeviceBatchNormBwdImpl ...@@ -388,7 +396,7 @@ struct DeviceBatchNormBwdImpl
if(UseMultiblockInK && pArg_->blkGroupSize > 1) if(UseMultiblockInK && pArg_->blkGroupSize > 1)
{ {
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(BiasDataType); space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType);
space_sz = math::integer_least_multiple(space_sz, 64); space_sz = math::integer_least_multiple(space_sz, 64);
// setup buffer for welford intermediate mean // setup buffer for welford intermediate mean
...@@ -454,7 +462,7 @@ struct DeviceBatchNormBwdImpl ...@@ -454,7 +462,7 @@ struct DeviceBatchNormBwdImpl
DyDataType, DyDataType,
AccDataType, AccDataType,
ScaleDataType, ScaleDataType,
BiasDataType, DscaleDbiasDataType,
MeanVarDataType, MeanVarDataType,
DyElementwiseOp, DyElementwiseOp,
XYGridDesc_M_K, XYGridDesc_M_K,
...@@ -477,7 +485,7 @@ struct DeviceBatchNormBwdImpl ...@@ -477,7 +485,7 @@ struct DeviceBatchNormBwdImpl
DxDataType, DxDataType,
AccDataType, AccDataType,
ScaleDataType, ScaleDataType,
BiasDataType, DscaleDbiasDataType,
MeanVarDataType, MeanVarDataType,
DyElementwiseOp, DyElementwiseOp,
XYGridDesc_M_K, XYGridDesc_M_K,
...@@ -493,8 +501,8 @@ struct DeviceBatchNormBwdImpl ...@@ -493,8 +501,8 @@ struct DeviceBatchNormBwdImpl
XSrcVectorSize, XSrcVectorSize,
DySrcVectorSize, DySrcVectorSize,
DxDstVectorSize, DxDstVectorSize,
ScaleSrcDstVectorSize, ScaleSrcVectorSize,
BiasDstVectorSize, DscaleDbiasDstVectorSize,
MeanVarSrcVectorSize>; MeanVarSrcVectorSize>;
if(UseMultiblockInK && arg.blkGroupSize > 1) if(UseMultiblockInK && arg.blkGroupSize > 1)
...@@ -553,7 +561,7 @@ struct DeviceBatchNormBwdImpl ...@@ -553,7 +561,7 @@ struct DeviceBatchNormBwdImpl
DyDataType, DyDataType,
AccDataType, AccDataType,
ScaleDataType, ScaleDataType,
BiasDataType, DscaleDbiasDataType,
MeanVarDataType, MeanVarDataType,
DyElementwiseOp, DyElementwiseOp,
XYGridDesc_M_K, XYGridDesc_M_K,
...@@ -568,7 +576,7 @@ struct DeviceBatchNormBwdImpl ...@@ -568,7 +576,7 @@ struct DeviceBatchNormBwdImpl
DyDataType, DyDataType,
DxDataType, DxDataType,
ScaleDataType, ScaleDataType,
BiasDataType, DscaleDbiasDataType,
MeanVarDataType, MeanVarDataType,
DyElementwiseOp, DyElementwiseOp,
XYGridDesc_M_K, XYGridDesc_M_K,
...@@ -614,8 +622,8 @@ struct DeviceBatchNormBwdImpl ...@@ -614,8 +622,8 @@ struct DeviceBatchNormBwdImpl
: static_cast<MeanVarDataType*>(arg.workspace_savedInvVar), : static_cast<MeanVarDataType*>(arg.workspace_savedInvVar),
arg.p_x_, arg.p_x_,
arg.p_dy_, arg.p_dy_,
static_cast<ScaleDataType*>(arg.workspace_reduce_dscale), static_cast<DscaleDbiasDataType*>(arg.workspace_reduce_dscale),
static_cast<BiasDataType*>(arg.workspace_reduce_dbias)); static_cast<DscaleDbiasDataType*>(arg.workspace_reduce_dbias));
avg_time += launch_and_time_kernel( avg_time += launch_and_time_kernel(
stream_config, stream_config,
...@@ -629,13 +637,13 @@ struct DeviceBatchNormBwdImpl ...@@ -629,13 +637,13 @@ struct DeviceBatchNormBwdImpl
dscale_dbias_grid_desc_m_k, dscale_dbias_grid_desc_m_k,
arg.mean_var_grid_desc_m, arg.mean_var_grid_desc_m,
arg.scale_grid_desc_m, arg.scale_grid_desc_m,
arg.bias_grid_desc_m, arg.dscale_dbias_grid_desc_m,
arg.blkGroupSize, arg.blkGroupSize,
arg.reduce_length, arg.reduce_length,
arg.numBlockTileIteration, arg.numBlockTileIteration,
numDscaleDbiasBlockTileIteration, numDscaleDbiasBlockTileIteration,
static_cast<const ScaleDataType*>(arg.workspace_reduce_dscale), static_cast<const DscaleDbiasDataType*>(arg.workspace_reduce_dscale),
static_cast<const BiasDataType*>(arg.workspace_reduce_dbias), static_cast<const DscaleDbiasDataType*>(arg.workspace_reduce_dbias),
arg.haveSavedMeanInvVar_ arg.haveSavedMeanInvVar_
? arg.p_savedMean_ ? arg.p_savedMean_
: static_cast<const MeanVarDataType*>(arg.workspace_savedMean), : static_cast<const MeanVarDataType*>(arg.workspace_savedMean),
...@@ -664,7 +672,7 @@ struct DeviceBatchNormBwdImpl ...@@ -664,7 +672,7 @@ struct DeviceBatchNormBwdImpl
DxDataType, DxDataType,
AccDataType, AccDataType,
ScaleDataType, ScaleDataType,
BiasDataType, DscaleDbiasDataType,
MeanVarDataType, MeanVarDataType,
DyElementwiseOp, DyElementwiseOp,
XYGridDesc_M_K, XYGridDesc_M_K,
...@@ -680,8 +688,8 @@ struct DeviceBatchNormBwdImpl ...@@ -680,8 +688,8 @@ struct DeviceBatchNormBwdImpl
XSrcVectorSize, XSrcVectorSize,
DySrcVectorSize, DySrcVectorSize,
DxDstVectorSize, DxDstVectorSize,
ScaleSrcDstVectorSize, ScaleSrcVectorSize,
BiasDstVectorSize, DscaleDbiasDstVectorSize,
MeanVarSrcVectorSize>; MeanVarSrcVectorSize>;
const auto kern_batchnorm_bwd = kernel_batchnorm_backward_with_blockwise_welford< const auto kern_batchnorm_bwd = kernel_batchnorm_backward_with_blockwise_welford<
...@@ -691,7 +699,7 @@ struct DeviceBatchNormBwdImpl ...@@ -691,7 +699,7 @@ struct DeviceBatchNormBwdImpl
DxDataType, DxDataType,
AccDataType, AccDataType,
ScaleDataType, ScaleDataType,
BiasDataType, DscaleDbiasDataType,
MeanVarDataType, MeanVarDataType,
DyElementwiseOp, DyElementwiseOp,
XYGridDesc_M_K, XYGridDesc_M_K,
...@@ -708,7 +716,7 @@ struct DeviceBatchNormBwdImpl ...@@ -708,7 +716,7 @@ struct DeviceBatchNormBwdImpl
arg.dy_grid_desc_m_k, arg.dy_grid_desc_m_k,
arg.dx_grid_desc_m_k, arg.dx_grid_desc_m_k,
arg.scale_grid_desc_m, arg.scale_grid_desc_m,
arg.bias_grid_desc_m, arg.dscale_dbias_grid_desc_m,
arg.mean_var_grid_desc_m, arg.mean_var_grid_desc_m,
get_reduce_count_per_thread, get_reduce_count_per_thread,
arg.reduce_length, arg.reduce_length,
...@@ -764,16 +772,16 @@ struct DeviceBatchNormBwdImpl ...@@ -764,16 +772,16 @@ struct DeviceBatchNormBwdImpl
return false; return false;
}; };
if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcDstVectorSize != 1) if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1)
return false; return false;
if(pArg_->bnBiasStrides_[NumInvariantDim - 1] != 1 && BiasDstVectorSize != 1) if(pArg_->bnDscaleDbiasStrides_[NumInvariantDim - 1] != 1 && DscaleDbiasDstVectorSize != 1)
return false; return false;
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcDstVectorSize != 0) if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0)
return false; return false;
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % BiasDstVectorSize != 0) if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % DscaleDbiasDstVectorSize != 0)
return false; return false;
if(pArg_->haveSavedMeanInvVar_) if(pArg_->haveSavedMeanInvVar_)
...@@ -806,7 +814,7 @@ struct DeviceBatchNormBwdImpl ...@@ -806,7 +814,7 @@ struct DeviceBatchNormBwdImpl
const std::array<int, NumBatchNormReduceDim> reduceDims, const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths, const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides, const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
const std::array<ck::index_t, NumInvariantDim> bnBiasStrides, const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides, const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
const void* p_x, const void* p_x,
const void* p_dy, const void* p_dy,
...@@ -826,7 +834,7 @@ struct DeviceBatchNormBwdImpl ...@@ -826,7 +834,7 @@ struct DeviceBatchNormBwdImpl
reduceDims, reduceDims,
bnScaleBiasMeanVarLengths, bnScaleBiasMeanVarLengths,
bnScaleStrides, bnScaleStrides,
bnBiasStrides, bnDscaleDbiasStrides,
bnMeanVarStrides, bnMeanVarStrides,
static_cast<const XDataType*>(p_x), static_cast<const XDataType*>(p_x),
static_cast<const DyDataType*>(p_dy), static_cast<const DyDataType*>(p_dy),
...@@ -836,8 +844,8 @@ struct DeviceBatchNormBwdImpl ...@@ -836,8 +844,8 @@ struct DeviceBatchNormBwdImpl
dy_elementwise_op, dy_elementwise_op,
epsilon, epsilon,
static_cast<DxDataType*>(p_dx), static_cast<DxDataType*>(p_dx),
static_cast<ScaleDataType*>(p_dscale), static_cast<DscaleDbiasDataType*>(p_dscale),
static_cast<BiasDataType*>(p_dbias)); static_cast<DscaleDbiasDataType*>(p_dbias));
}; };
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
...@@ -854,7 +862,7 @@ struct DeviceBatchNormBwdImpl ...@@ -854,7 +862,7 @@ struct DeviceBatchNormBwdImpl
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "XDyDxVectorDim_" << XDyDxVectorDim << ","; str << "XDyDxVectorDim_" << XDyDxVectorDim << ",";
str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcDstVectorSize << "_bias_" << BiasDstVectorSize << "_mean_var_" << MeanVarSrcVectorSize << "_Dx_" << DxDstVectorSize << ">"; str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << DscaleDbiasDstVectorSize << "_mean_var_" << MeanVarSrcVectorSize << "_Dx_" << DxDstVectorSize << ">";
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -10,8 +10,8 @@ namespace element_wise { ...@@ -10,8 +10,8 @@ namespace element_wise {
template <typename Activation> template <typename Activation>
struct Activation_Mul_Clamp struct Activation_Mul_Clamp
{ {
Activation_Mul_Clamp(float multiplier, Activation activationOp) Activation_Mul_Clamp(float requantScale, Activation activationOp)
: multiplier_(multiplier), activationOp_(activationOp) : requantScale_(requantScale), activationOp_(activationOp)
{ {
} }
...@@ -19,7 +19,7 @@ struct Activation_Mul_Clamp ...@@ -19,7 +19,7 @@ struct Activation_Mul_Clamp
{ {
float x_fp32 = ck::type_convert<float>(x); float x_fp32 = ck::type_convert<float>(x);
activationOp_(x_fp32, x_fp32); activationOp_(x_fp32, x_fp32);
float y_fp32 = math::clamp(multiplier_ * x_fp32, -128.f, 127.f); float y_fp32 = math::clamp(requantScale_ * x_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32); y = ck::type_convert<int8_t>(y_fp32);
} }
...@@ -28,10 +28,29 @@ struct Activation_Mul_Clamp ...@@ -28,10 +28,29 @@ struct Activation_Mul_Clamp
// We might type_convert to int8 after lambda in someplace // We might type_convert to int8 after lambda in someplace
float x_fp32 = ck::type_convert<float>(x); float x_fp32 = ck::type_convert<float>(x);
activationOp_(x_fp32, x_fp32); activationOp_(x_fp32, x_fp32);
y = math::clamp(multiplier_ * x_fp32, -128.f, 127.f); y = math::clamp(requantScale_ * x_fp32, -128.f, 127.f);
}
float requantScale_;
Activation activationOp_;
};
// Conv Perchannel quantization + Activation function which is piecewise linear function, such as
// relu, leaky relu ...etc
template <typename Activation>
struct Activation_Mul2_Clamp
{
Activation_Mul2_Clamp(Activation activationOp) : activationOp_(activationOp) {}
__host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x, const float& requantScale) const
{
float y_fp32 = ck::type_convert<float>(x);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
} }
float multiplier_;
Activation activationOp_; Activation activationOp_;
}; };
...@@ -39,21 +58,40 @@ struct Activation_Mul_Clamp ...@@ -39,21 +58,40 @@ struct Activation_Mul_Clamp
template <typename Activation> template <typename Activation>
struct Add_Activation_Mul_Clamp struct Add_Activation_Mul_Clamp
{ {
Add_Activation_Mul_Clamp(float multiplier, Activation activationOp) Add_Activation_Mul_Clamp(float requantScale, Activation activationOp)
: multiplier_(multiplier), activationOp_(activationOp) : requantScale_(requantScale), activationOp_(activationOp)
{ {
} }
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x1, const int32_t& x2) const operator()(int8_t& y, const int32_t& x, const int32_t& bias) const
{
float y_fp32 = ck::type_convert<float>(x + bias);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
float requantScale_;
Activation activationOp_;
};
// Conv Perchannel quantization + Activation function which is piecewise linear function, such as
// relu, leaky relu ...etc
template <typename Activation>
struct Add_Activation_Mul2_Clamp
{
Add_Activation_Mul2_Clamp(Activation activationOp) : activationOp_(activationOp) {}
__host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x, const int32_t& bias, const float& requantScale) const
{ {
float y_fp32 = ck::type_convert<float>(x1 + x2); float y_fp32 = ck::type_convert<float>(x + bias);
activationOp_(y_fp32, y_fp32); activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(multiplier_ * y_fp32, -128.f, 127.f); y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32); y = ck::type_convert<int8_t>(y_fp32);
} }
float multiplier_;
Activation activationOp_; Activation activationOp_;
}; };
...@@ -61,23 +99,23 @@ struct Add_Activation_Mul_Clamp ...@@ -61,23 +99,23 @@ struct Add_Activation_Mul_Clamp
template <typename Activation> template <typename Activation>
struct Add_Mul_Activation_Mul_Clamp struct Add_Mul_Activation_Mul_Clamp
{ {
Add_Mul_Activation_Mul_Clamp(float multiplier1, float multiplier2, Activation activationOp) Add_Mul_Activation_Mul_Clamp(float requantScale1, float requantScale2, Activation activationOp)
: multiplier1_(multiplier1), multiplier2_(multiplier2), activationOp_(activationOp) : requantScale1_(requantScale1), requantScale2_(requantScale2), activationOp_(activationOp)
{ {
} }
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x1, const int32_t& x2) const operator()(int8_t& y, const int32_t& x, const int32_t& bias) const
{ {
float y_fp32 = ck::type_convert<float>(x1 + x2); float y_fp32 = ck::type_convert<float>(x + bias);
y_fp32 = multiplier1_ * y_fp32; y_fp32 = requantScale1_ * y_fp32;
activationOp_(y_fp32, y_fp32); activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(multiplier2_ * y_fp32, -128.f, 127.f); y_fp32 = math::clamp(requantScale2_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32); y = ck::type_convert<int8_t>(y_fp32);
} }
float multiplier1_; float requantScale1_;
float multiplier2_; float requantScale2_;
Activation activationOp_; Activation activationOp_;
}; };
......
...@@ -16,7 +16,7 @@ template <typename GridwiseReduceSecondHalfBatchNormBackwardFinal_, ...@@ -16,7 +16,7 @@ template <typename GridwiseReduceSecondHalfBatchNormBackwardFinal_,
typename DyDataType, typename DyDataType,
typename DxDataType, typename DxDataType,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename DscaleDbiasDataType,
typename MeanVarDataType, typename MeanVarDataType,
typename DyElementwiseOp, typename DyElementwiseOp,
typename XYGridDesc_M_K, typename XYGridDesc_M_K,
...@@ -35,8 +35,8 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final( ...@@ -35,8 +35,8 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
long_index_t reduce_size, long_index_t reduce_size,
index_t num_xy_k_block_tile_iteration, index_t num_xy_k_block_tile_iteration,
index_t num_dscale_dbias_k_block_tile_iteration, index_t num_dscale_dbias_k_block_tile_iteration,
const ScaleDataType* const __restrict__ p_reduce_dscale, const DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
const BiasDataType* const __restrict__ p_reduce_dbias, const DscaleDbiasDataType* const __restrict__ p_reduce_dbias,
const MeanVarDataType* const __restrict__ p_mean, const MeanVarDataType* const __restrict__ p_mean,
const MeanVarDataType* const __restrict__ p_inv_var, const MeanVarDataType* const __restrict__ p_inv_var,
const XDataType* const __restrict__ p_x, const XDataType* const __restrict__ p_x,
...@@ -44,8 +44,8 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final( ...@@ -44,8 +44,8 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
const ScaleDataType* const __restrict__ p_scale, const ScaleDataType* const __restrict__ p_scale,
const DyElementwiseOp dy_elementwise_op, const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx, DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale, DscaleDbiasDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias) DscaleDbiasDataType* const __restrict__ p_dbias)
{ {
GridwiseReduceSecondHalfBatchNormBackwardFinal_::Run(x_grid_desc_m_k, GridwiseReduceSecondHalfBatchNormBackwardFinal_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k, dy_grid_desc_m_k,
...@@ -76,7 +76,7 @@ template <typename XDataType, ...@@ -76,7 +76,7 @@ template <typename XDataType,
typename DxDataType, typename DxDataType,
typename AccDataType, typename AccDataType,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename DscaleDbiasDataType,
typename MeanVarDataType, typename MeanVarDataType,
typename DyElementwiseOp, typename DyElementwiseOp,
typename XYGridDesc_M_K, typename XYGridDesc_M_K,
...@@ -92,8 +92,8 @@ template <typename XDataType, ...@@ -92,8 +92,8 @@ template <typename XDataType,
index_t XSrcVectorSize, index_t XSrcVectorSize,
index_t DySrcVectorSize, index_t DySrcVectorSize,
index_t DxDstVectorSize, index_t DxDstVectorSize,
index_t ScaleSrcDstVectorSize, index_t ScaleSrcVectorSize,
index_t BiasDstVectorSize, index_t DscaleDbiasDstVectorSize,
index_t MeanVarSrcVectorSize> index_t MeanVarSrcVectorSize>
struct GridwiseReduceSecondHalfBatchNormBackwardFinal struct GridwiseReduceSecondHalfBatchNormBackwardFinal
{ {
...@@ -155,13 +155,13 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -155,13 +155,13 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
const DscaleDbiasGridDesc_M_K& dscale_dbias_grid_desc_m_k, const DscaleDbiasGridDesc_M_K& dscale_dbias_grid_desc_m_k,
const MeanVarGridDesc_M& mean_var_grid_desc_m, const MeanVarGridDesc_M& mean_var_grid_desc_m,
const ScaleBiasGridDesc_M& scale_grid_desc_m, const ScaleBiasGridDesc_M& scale_grid_desc_m,
const ScaleBiasGridDesc_M& bias_grid_desc_m, const ScaleBiasGridDesc_M& dscale_dbias_grid_desc_m,
index_t blkgroup_size, index_t blkgroup_size,
long_index_t reduce_size, long_index_t reduce_size,
index_t num_xy_k_block_tile_iteration, index_t num_xy_k_block_tile_iteration,
index_t num_dscale_dbias_k_block_tile_iteration, index_t num_dscale_dbias_k_block_tile_iteration,
const ScaleDataType* const __restrict__ p_reduce_dscale, const DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
const BiasDataType* const __restrict__ p_reduce_dbias, const DscaleDbiasDataType* const __restrict__ p_reduce_dbias,
const MeanVarDataType* const __restrict__ p_mean, const MeanVarDataType* const __restrict__ p_mean,
const MeanVarDataType* const __restrict__ p_inv_var, const MeanVarDataType* const __restrict__ p_inv_var,
const XDataType* const __restrict__ p_x, const XDataType* const __restrict__ p_x,
...@@ -169,8 +169,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -169,8 +169,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
const ScaleDataType* const __restrict__ p_scale, const ScaleDataType* const __restrict__ p_scale,
const DyElementwiseOp dy_elementwise_op, const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx, DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale, DscaleDbiasDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias) DscaleDbiasDataType* const __restrict__ p_dbias)
{ {
__shared__ AccDataType p_reduce_work_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_buffer[BlockSize];
...@@ -222,24 +222,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -222,24 +222,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
// Step 1: do final reduction of dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance) // Step 1: do final reduction of dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
// clang-format on // clang-format on
auto threadwise_dscale_load_m_k = auto threadwise_dscale_dbias_load_m_k =
ThreadwiseTensorSliceTransfer_v2<ScaleDataType, ThreadwiseTensorSliceTransfer_v2<DscaleDbiasDataType,
AccDataType,
DscaleDbiasGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
dscale_dbias_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
auto threadwise_dbias_load_m_k =
ThreadwiseTensorSliceTransfer_v2<BiasDataType,
AccDataType, AccDataType,
DscaleDbiasGridDesc_M_K, DscaleDbiasGridDesc_M_K,
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
...@@ -254,38 +238,20 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -254,38 +238,20 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1)); thread_k_cluster_id * 1));
auto threadwise_dscale_store_m = auto threadwise_dscale_dbias_store_m =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ScaleDataType, DscaleDbiasDataType,
decltype(thread_buffer_desc_m), decltype(thread_buffer_desc_m),
ScaleBiasGridDesc_M, ScaleBiasGridDesc_M,
PassThroughOp, PassThroughOp,
ThreadBufferLengths_M, ThreadBufferLengths_M,
Sequence<0>, Sequence<0>,
0, 0,
ScaleSrcDstVectorSize, DscaleDbiasDstVectorSize,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
scale_grid_desc_m, dscale_dbias_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_dbias_store_m =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
BiasDataType,
decltype(thread_buffer_desc_m),
ScaleBiasGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
BiasDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
bias_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize + make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize), thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{}); PassThroughOp{});
...@@ -297,10 +263,10 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -297,10 +263,10 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
p_reduce_dbias, dscale_dbias_grid_desc_m_k.GetElementSpaceSize()); p_reduce_dbias, dscale_dbias_grid_desc_m_k.GetElementSpaceSize());
auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dscale, scale_grid_desc_m.GetElementSpaceSize()); p_dscale, dscale_dbias_grid_desc_m.GetElementSpaceSize());
auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dbias, bias_grid_desc_m.GetElementSpaceSize()); p_dbias, dscale_dbias_grid_desc_m.GetElementSpaceSize());
constexpr auto dscale_dbias_thread_copy_step_m_k = constexpr auto dscale_dbias_thread_copy_step_m_k =
make_multi_index(0, KThreadClusterSize * 1); make_multi_index(0, KThreadClusterSize * 1);
...@@ -313,25 +279,23 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -313,25 +279,23 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
for(index_t reducedTiles = 0; reducedTiles < num_dscale_dbias_k_block_tile_iteration; for(index_t reducedTiles = 0; reducedTiles < num_dscale_dbias_k_block_tile_iteration;
++reducedTiles) ++reducedTiles)
{ {
threadwise_dscale_load_m_k.Run(dscale_dbias_grid_desc_m_k, threadwise_dscale_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k,
reduce_dscale_global_buf, reduce_dscale_global_buf,
thread_buffer_desc_m_1, thread_buffer_desc_m_1,
make_tuple(I0, I0), make_tuple(I0, I0),
reduce_dscale_thread_buf); reduce_dscale_thread_buf);
threadwise_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k, threadwise_dscale_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k,
reduce_dbias_global_buf, reduce_dbias_global_buf,
thread_buffer_desc_m_1, thread_buffer_desc_m_1,
make_tuple(I0, I0), make_tuple(I0, I0),
reduce_dbias_thread_buf); reduce_dbias_thread_buf);
ThreadwiseReduce::Reduce(reduce_dscale_thread_buf, dscale_thread_buf); ThreadwiseReduce::Reduce(reduce_dscale_thread_buf, dscale_thread_buf);
ThreadwiseReduce::Reduce(reduce_dbias_thread_buf, dbias_thread_buf); ThreadwiseReduce::Reduce(reduce_dbias_thread_buf, dbias_thread_buf);
threadwise_dscale_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k, threadwise_dscale_dbias_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k,
dscale_dbias_thread_copy_step_m_k); dscale_dbias_thread_copy_step_m_k);
threadwise_dbias_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k,
dscale_dbias_thread_copy_step_m_k);
} }
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
...@@ -343,17 +307,17 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -343,17 +307,17 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
BlockwiseReduce::Reduce(reduce_work_buf, dbias_thread_buf(I)); BlockwiseReduce::Reduce(reduce_work_buf, dbias_thread_buf(I));
}); });
threadwise_dscale_store_m.Run(thread_buffer_desc_m, threadwise_dscale_dbias_store_m.Run(thread_buffer_desc_m,
make_tuple(I0), make_tuple(I0),
dscale_thread_buf, dscale_thread_buf,
scale_grid_desc_m, dscale_dbias_grid_desc_m,
dscale_global_buf); dscale_global_buf);
threadwise_dbias_store_m.Run(thread_buffer_desc_m, threadwise_dscale_dbias_store_m.Run(thread_buffer_desc_m,
make_tuple(I0), make_tuple(I0),
dbias_thread_buf, dbias_thread_buf,
bias_grid_desc_m, dscale_dbias_grid_desc_m,
dbias_global_buf); dbias_global_buf);
// clang-format off // clang-format off
// Step 2: calculate dx = 1/N * inv-variance * scale * (N * dy - dbias - dscale * (x - mean) * inv-variance) // Step 2: calculate dx = 1/N * inv-variance * scale * (N * dy - dbias - dscale * (x - mean) * inv-variance)
...@@ -418,7 +382,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -418,7 +382,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
ThreadBufferLengths_M, ThreadBufferLengths_M,
Sequence<0>, Sequence<0>,
0, 0,
ScaleSrcDstVectorSize, ScaleSrcVectorSize,
1, 1,
true>( true>(
scale_grid_desc_m, scale_grid_desc_m,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment