Commit ba816e69 authored by Chao Liu's avatar Chao Liu
Browse files

updating refernce conv

parent d41a59a9
...@@ -91,30 +91,29 @@ template <ck::index_t NDimSpatial, ...@@ -91,30 +91,29 @@ template <ck::index_t NDimSpatial,
typename OutElementOp, typename OutElementOp,
typename DeviceConvNDFwdInstance, typename DeviceConvNDFwdInstance,
typename ReferenceConvNDFwdInstance> typename ReferenceConvNDFwdInstance>
int run_conv_fwd(const ck::tensor_operation::device::ConvParams& params, int run_conv_fwd_nhwc(const ck::tensor_operation::device::ConvParams& params,
bool do_verification, bool do_verification,
int init_method, int init_method,
bool time_kernel) bool time_kernel)
{ {
auto f_nchw_host_tensor_descriptor = auto f_nhwc_host_tensor_descriptor =
[](ck::index_t n, ck::index_t c, std::vector<ck::index_t> spatial_lengths) { [](ck::index_t n, ck::index_t c, std::vector<ck::index_t> spatial_lengths) {
std::vector<std::size_t> nhwc_lengths{static_cast<std::size_t>(n), std::vector<std::size_t> nhwc_lengths{static_cast<std::size_t>(n),
static_cast<std::size_t>(c)}; static_cast<std::size_t>(c)};
nhwc_lengths.insert( nhwc_lengths.insert(
nhwc_lengths.begin() + 1, spatial_lengths.begin(), spatial_lengths.end()); nhwc_lengths.begin() + 1, spatial_lengths.begin(), spatial_lengths.end());
return transpose_host_tensor_descriptor_given_new2old( return HostTensorDescriptor(nhwc_lengths);
HostTensorDescriptor(nhwc_lengths), std::vector<std::size_t>({0, 3, 1, 2}));
}; };
Tensor<InDataType> input( Tensor<InDataType> input(
f_nchw_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_)); f_nhwc_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_));
Tensor<InDataType> weights( Tensor<WeiDataType> weights(
f_nchw_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_)); f_nhwc_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_));
Tensor<InDataType> host_output( Tensor<OutDataType> host_output(
f_nchw_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths())); f_nhwc_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths()));
Tensor<InDataType> device_output( Tensor<OutDataType> device_output(
f_nchw_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths())); f_nhwc_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths()));
std::cout << "input: " << input.mDesc << std::endl; std::cout << "input: " << input.mDesc << std::endl;
std::cout << "weights: " << weights.mDesc << std::endl; std::cout << "weights: " << weights.mDesc << std::endl;
......
...@@ -25,7 +25,8 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwc ...@@ -25,7 +25,8 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwc
OutDataType, // OutDataType, //
AccDataType, // AccDataType, //
InElementOp, // Input Elementwise Operation InElementOp, // Input Elementwise Operation
WeiElementOp, // Weights Elementwise Operation WeiElementOp, // Weights Elementwise Operation =
// ck::tensor_layout::convolution::NKHW,
OutElementOp, // Output Elementwise Operation OutElementOp, // Output Elementwise Operation
ConvFwdDefault, // ConvForwardSpecialization ConvFwdDefault, // ConvForwardSpecialization
NumDimSpatial, // NumDimSpatial NumDimSpatial, // NumDimSpatial
...@@ -56,13 +57,16 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwc ...@@ -56,13 +57,16 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwc
1>; // CThreadTransferDstScalarPerVector 1>; // CThreadTransferDstScalarPerVector
template <ck::index_t NumDimSpatial> template <ck::index_t NumDimSpatial>
using ReferenceConvNDFwdInstance = ck::tensor_operation::host::ReferenceConvFwd<InDataType, using ReferenceConvNDFwdInstance = ck::tensor_operation::host::ReferenceConvFwd<NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp>;
NumDimSpatial>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -97,44 +101,44 @@ int main(int argc, char* argv[]) ...@@ -97,44 +101,44 @@ int main(int argc, char* argv[])
if(num_dim_spatial == 1) if(num_dim_spatial == 1)
{ {
return run_conv_fwd<1, return run_conv_fwd_nhwc<1,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType, AccDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
DeviceConvNDFwdInstance<1>, DeviceConvNDFwdInstance<1>,
ReferenceConvNDFwdInstance<1>>( ReferenceConvNDFwdInstance<1>>(
params, do_verification, init_method, time_kernel); params, do_verification, init_method, time_kernel);
} }
else if(num_dim_spatial == 2) else if(num_dim_spatial == 2)
{ {
return run_conv_fwd<2, return run_conv_fwd_nhwc<2,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType, AccDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
DeviceConvNDFwdInstance<2>, DeviceConvNDFwdInstance<2>,
ReferenceConvNDFwdInstance<2>>( ReferenceConvNDFwdInstance<2>>(
params, do_verification, init_method, time_kernel); params, do_verification, init_method, time_kernel);
} }
else if(num_dim_spatial == 3) else if(num_dim_spatial == 3)
{ {
return run_conv_fwd<3, return run_conv_fwd_nhwc<3,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType, AccDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
DeviceConvNDFwdInstance<3>, DeviceConvNDFwdInstance<3>,
ReferenceConvNDFwdInstance<3>>( ReferenceConvNDFwdInstance<3>>(
params, do_verification, init_method, time_kernel); params, do_verification, init_method, time_kernel);
} }
......
...@@ -45,29 +45,6 @@ struct DeviceConvFwd : public BaseOperator ...@@ -45,29 +45,6 @@ struct DeviceConvFwd : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
#if 0
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
using DeviceConvFwdPtr = std::unique_ptr<DeviceConvFwd<NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>>;
#endif
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -30,13 +30,16 @@ namespace host { ...@@ -30,13 +30,16 @@ namespace host {
// operation. // operation.
// @tparam NumDimSpatial Number of spatial dimensions. // @tparam NumDimSpatial Number of spatial dimensions.
// //
template <typename InDataType, template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ck::index_t NumDimSpatial = 2,
typename std::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false> typename std::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
struct ReferenceConvFwd : public device::BaseOperator struct ReferenceConvFwd : public device::BaseOperator
{ {
...@@ -86,21 +89,80 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -86,21 +89,80 @@ struct ReferenceConvFwd : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
// tensor descriptor in NCHW/KXYC/NKHW dimensional order
HostTensorDescriptor in_desc = arg.input_.mDesc;
HostTensorDescriptor wei_desc = arg.weight_.mDesc;
HostTensorDescriptor oout_desc = arg.output_.mDesc;
// input
if constexpr(is_same_v<InLayout,ck::tensor_layout::convolution::NWC>)
{
in_desc = transpose_host_tensor_descriptor_given_new2old(
input_.mDesc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(is_same_v<InLayout,ck::tensor_layout::convolution::NHWC>)
{
in_desc = transpose_host_tensor_descriptor_given_new2old(
input_.mDesc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(is_same_v<InLayout,ck::tensor_layout::convolution::NDHWC>)
{
in_desc = transpose_host_tensor_descriptor_given_new2old(
input_.mDesc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
// weight
if constexpr(is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC>)
{
wei_desc = transpose_host_tensor_descriptor_given_new2old(
weight_.mDesc, std::vector<std::size_t>{0, 2, 1});
}
if constexpr(is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC>)
{
wei_desc = transpose_host_tensor_descriptor_given_new2old(
weight_.mDesc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(NumDimSpatial == 2 &&
WeiLayout == ck::tensor_layout::convolution::KYXC)
{
wei_desc = transpose_host_tensor_descriptor_given_new2old(
weight_.mDesc, std::vector<std::size_t>{0, 3, 1, 2});
}
// output
if constexpr(NumDimSpatial == 1 && OutLayout == ck::tensor_layout::convolution::NWK)
{
out_desc = transpose_host_tensor_descriptor_given_new2old(
output_.mDesc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(NumDimSpatial == 2 &&
OutLayout == ck::tensor_layout::convolution::NHWK)
{
out_desc = transpose_host_tensor_descriptor_given_new2old(
output_.mDesc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(NumDimSpatial == 3 &&
OutLayout == ck::tensor_layout::convolution::NDHWK)
{
out_desc = transpose_host_tensor_descriptor_given_new2old(
output_.mDesc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
if constexpr(NumDimSpatial == 1) if constexpr(NumDimSpatial == 1)
{ {
auto f_ncw = [&](auto n, auto k, auto wo) { auto f_ncw = [&](auto n, auto k, auto wo) {
float v_acc = 0; float v_acc = 0;
for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < wei_desc.GetLengths()[1]; ++c)
{ {
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[2]; ++x) for(std::size_t x = 0; x < wei_desc.GetLengths()[2]; ++x)
{ {
auto wi = auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[0]) + ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[0]) - ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]); ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
if(wi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[2]) ck::type_convert<std::size_t>(wi) < in_desc.GetLengths()[2])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
...@@ -122,9 +184,9 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -122,9 +184,9 @@ struct ReferenceConvFwd : public device::BaseOperator
}; };
make_ParallelTensorFunctor(f_ncw, make_ParallelTensorFunctor(f_ncw,
arg.output_.mDesc.GetLengths()[0], out_desc.GetLengths()[0],
arg.output_.mDesc.GetLengths()[1], out_desc.GetLengths()[1],
arg.output_.mDesc.GetLengths()[2])( out_desc.GetLengths()[2])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
...@@ -134,26 +196,24 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -134,26 +196,24 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < wei_desc.GetLengths()[1]; ++c)
{ {
for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[2]; ++y) for(std::size_t y = 0; y < wei_desc.GetLengths()[2]; ++y)
{ {
auto hi = auto hi =
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) + ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) - ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]); ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[3]; ++x) for(std::size_t x = 0; x < wei_desc.GetLengths()[3]; ++x)
{ {
auto wi = auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) + ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) +
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) - ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]); ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 && if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) < ck::type_convert<std::size_t>(hi) < in_desc.GetLengths()[2] &&
arg.input_.mDesc.GetLengths()[2] &&
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < ck::type_convert<std::size_t>(wi) < in_desc.GetLengths()[3])
arg.input_.mDesc.GetLengths()[3])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
...@@ -175,10 +235,10 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -175,10 +235,10 @@ struct ReferenceConvFwd : public device::BaseOperator
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(f_nchw,
arg.output_.mDesc.GetLengths()[0], out_desc.GetLengths()[0],
arg.output_.mDesc.GetLengths()[1], out_desc.GetLengths()[1],
arg.output_.mDesc.GetLengths()[2], out_desc.GetLengths()[2],
arg.output_.mDesc.GetLengths()[3])( out_desc.GetLengths()[3])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
...@@ -188,21 +248,21 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -188,21 +248,21 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < wei_desc.GetLengths()[1]; ++c)
{ {
for(std::size_t z = 0; z < arg.weight_.mDesc.GetLengths()[2]; ++z) for(std::size_t z = 0; z < wei_desc.GetLengths()[2]; ++z)
{ {
auto di = auto di =
ck::type_convert<ck::long_index_t>(d_o * arg.conv_strides_[0]) + ck::type_convert<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[0]) - ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]); ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[3]; ++y) for(std::size_t y = 0; y < wei_desc.GetLengths()[3]; ++y)
{ {
auto hi = auto hi =
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[1]) + ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[1]) +
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[1]) - ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]); ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[4]; ++x) for(std::size_t x = 0; x < wei_desc.GetLengths()[4]; ++x)
{ {
auto wi = auto wi =
ck::type_convert<ck::long_index_t>(wo * ck::type_convert<ck::long_index_t>(wo *
...@@ -212,13 +272,12 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -212,13 +272,12 @@ struct ReferenceConvFwd : public device::BaseOperator
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[2]); ck::type_convert<ck::long_index_t>(arg.in_left_pads_[2]);
if(di >= 0 && if(di >= 0 &&
ck::type_convert<std::size_t>(di) < ck::type_convert<std::size_t>(di) <
arg.input_.mDesc.GetLengths()[2] && in_desc.GetLengths()[2] &&
hi >= 0 && hi >= 0 &&
ck::type_convert<std::size_t>(hi) < ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[3] && in_desc.GetLengths()[3] &&
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < ck::type_convert<std::size_t>(wi) < in_desc.GetLengths()[4])
arg.input_.mDesc.GetLengths()[4])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
...@@ -243,11 +302,11 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -243,11 +302,11 @@ struct ReferenceConvFwd : public device::BaseOperator
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(f_nchw,
arg.output_.mDesc.GetLengths()[0], out_desc.GetLengths()[0],
arg.output_.mDesc.GetLengths()[1], out_desc.GetLengths()[1],
arg.output_.mDesc.GetLengths()[2], out_desc.GetLengths()[2],
arg.output_.mDesc.GetLengths()[3], out_desc.GetLengths()[3],
arg.output_.mDesc.GetLengths()[4])( out_desc.GetLengths()[4])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
......
...@@ -96,18 +96,18 @@ template <ck::index_t NumDimSpatial, ...@@ -96,18 +96,18 @@ template <ck::index_t NumDimSpatial,
typename OutLayout, typename OutLayout,
typename InDataType, typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvFwd< struct DeviceOperationInstanceFactory<
NumDimSpatial, ck::tensor_operation::device::DeviceConvFwd<NumDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
OutLayout, OutLayout,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>> ck::tensor_operation::element_wise::PassThrough>>
{ {
using DeviceOp = DeviceConvFwd<NumDimSpatial, using DeviceOp = DeviceConvFwd<NumDimSpatial,
InLayout, InLayout,
......
...@@ -73,7 +73,7 @@ auto construct_f_unpack_args(F, T args) ...@@ -73,7 +73,7 @@ auto construct_f_unpack_args(F, T args)
struct HostTensorDescriptor struct HostTensorDescriptor
{ {
HostTensorDescriptor() = delete; HostTensorDescriptor() = default;
template <typename X> template <typename X>
HostTensorDescriptor(const std::vector<X>& lens); HostTensorDescriptor(const std::vector<X>& lens);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include <iostream>
#include <typeinfo>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/convolution_forward.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace ck {
namespace profiler {
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType>
int profile_conv_fwd_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
const ck::utils::conv::ConvParams& params)
{
bool pass = true;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
}
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
c_device_buf.ToDevice(c_m_n_device_result.mData.data());
using DeviceOp = ck::tensor_operation::device::DeviceGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
// Run reference GEMM
if(do_verification)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
auto ref_op = ReferenceGemmInstance{};
auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument = ref_op.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
}
std::string best_op_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device GEMM instances
for(auto& op_ptr : op_ptrs)
{
auto argument_ptr =
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init C to zero before profiling next kernel
c_device_buf.SetZero();
std::string op_name = op_ptr->GetTypeString();
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
pass =
pass & ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host : ", c_m_n_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "c_device: ", c_m_n_device_result.mData, ",")
<< std::endl;
}
}
}
else
{
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
}
}
if constexpr(is_same<CDataType, float>::value)
{
std::cout << "Best Perf for datatype = f32";
}
else if constexpr(is_same<CDataType, half_t>::value)
{
std::cout << "Best Perf for datatype = f16";
}
else if constexpr(is_same<CDataType, bhalf_t>::value)
{
std::cout << "Best Perf for datatype = bf16";
}
else if constexpr(is_same<CDataType, int8_t>::value)
{
std::cout << "Best Perf for datatype = int8";
}
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value)
{
std::cout << " ALayout = RowMajor";
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value)
{
std::cout << " ALayout = ColumnMajor";
}
if constexpr(is_same<BLayout, tensor_layout::gemm::RowMajor>::value)
{
std::cout << " BLayout = RowMajor";
}
else if constexpr(is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value)
{
std::cout << " BLayout = ColumnMajor";
}
std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA
<< " StrideB = " << StrideB << " StrideC = " << StrideC << " : " << best_ave_time
<< " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, "
<< best_op_name << std::endl;
return pass ? 0 : 1;
}
} // namespace profiler
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment