Commit 11edd0f0 authored by Chao Liu's avatar Chao Liu
Browse files

update conv fwd profiler

parent 0cb8ba92
...@@ -12,63 +12,73 @@ ...@@ -12,63 +12,73 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
ck::tensor_operation::device::ConvParams ck::tensor_operation::device::ConvParams
parse_conv_params(int num_dim_spatial, int arg_idx, char* const argv[]) parse_conv_params(int num_dim_spatial, int arg_idx, char* const argv[])
{ {
ck::tensor_operation::device::ConvParams params; const ck::index_t N = std::stoi(argv[arg_idx++]);
const ck::index_t K = std::stoi(argv[arg_idx++]);
const ck::index_t C = std::stoi(argv[arg_idx++]);
params.num_dim_spatial_ = num_dim_spatial; std::vector<ck::index_t> filter_spatial_lengths(num_dim_spatial);
params.N_ = std::stoi(argv[arg_idx++]); std::vector<ck::index_t> input_spatial_lengths(num_dim_spatial);
params.K_ = std::stoi(argv[arg_idx++]); std::vector<ck::index_t> conv_filter_strides(num_dim_spatial);
params.C_ = std::stoi(argv[arg_idx++]); std::vector<ck::index_t> conv_filter_dilations(num_dim_spatial);
std::vector<ck::index_t> input_left_pads(num_dim_spatial);
std::vector<ck::index_t> input_right_pads(num_dim_spatial);
params.filter_spatial_lengths_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
} }
params.input_spatial_lengths_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); input_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
} }
params.conv_filter_strides_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); conv_filter_strides[i] = std::stoi(argv[arg_idx++]);
} }
params.conv_filter_dilations_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); conv_filter_dilations[i] = std::stoi(argv[arg_idx++]);
} }
params.input_left_pads_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); input_left_pads[i] = std::stoi(argv[arg_idx++]);
} }
params.input_right_pads_.resize(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); input_right_pads[i] = std::stoi(argv[arg_idx++]);
} }
return params; return ck::tensor_operation::device::ConvParams{num_dim_spatial,
N,
K,
C,
filter_spatial_lengths,
input_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
} }
void print_helper_msg() void print_helper_msg()
{ {
std::cout << "arg1: verification (0=no, 1=yes)\n" std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=n0, 1=yes)\n" << "arg3: time kernel (0=no, 1=yes)\n"
<< "arg4: N spatial dimensions (default 2)\n" << "arg4: N spatial dimensions (default 2)\n"
<< "Following arguments (depending on number of spatial dims):\n" << "Following arguments (depending on number of spatial dims):\n"
<< " N, K, C, \n" << " N, K, C, \n"
...@@ -108,7 +118,7 @@ int run_conv_fwd_nhwc(const ck::tensor_operation::device::ConvParams& params, ...@@ -108,7 +118,7 @@ int run_conv_fwd_nhwc(const ck::tensor_operation::device::ConvParams& params,
Tensor<InDataType> input( Tensor<InDataType> input(
f_nhwc_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_)); f_nhwc_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_));
Tensor<WeiDataType> weights( Tensor<WeiDataType> weight(
f_nhwc_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_)); f_nhwc_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_));
Tensor<OutDataType> host_output( Tensor<OutDataType> host_output(
f_nhwc_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths())); f_nhwc_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths()));
...@@ -116,7 +126,7 @@ int run_conv_fwd_nhwc(const ck::tensor_operation::device::ConvParams& params, ...@@ -116,7 +126,7 @@ int run_conv_fwd_nhwc(const ck::tensor_operation::device::ConvParams& params,
f_nhwc_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths())); f_nhwc_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths()));
std::cout << "input: " << input.mDesc << std::endl; std::cout << "input: " << input.mDesc << std::endl;
std::cout << "weights: " << weights.mDesc << std::endl; std::cout << "weight: " << weight.mDesc << std::endl;
std::cout << "output: " << host_output.mDesc << std::endl; std::cout << "output: " << host_output.mDesc << std::endl;
switch(init_method) switch(init_method)
...@@ -124,19 +134,19 @@ int run_conv_fwd_nhwc(const ck::tensor_operation::device::ConvParams& params, ...@@ -124,19 +134,19 @@ int run_conv_fwd_nhwc(const ck::tensor_operation::device::ConvParams& params,
case 0: break; case 0: break;
case 1: case 1:
input.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); input.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
weights.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5}); weight.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
break; break;
default: default:
input.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0}); input.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
weights.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5}); weight.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
} }
DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpace()); DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpace());
in_device_buf.ToDevice(input.mData.data()); in_device_buf.ToDevice(input.mData.data());
wei_device_buf.ToDevice(weights.mData.data()); wei_device_buf.ToDevice(weight.mData.data());
// do GEMM // do GEMM
auto conv = DeviceConvNDFwdInstance{}; auto conv = DeviceConvNDFwdInstance{};
...@@ -181,7 +191,7 @@ int run_conv_fwd_nhwc(const ck::tensor_operation::device::ConvParams& params, ...@@ -181,7 +191,7 @@ int run_conv_fwd_nhwc(const ck::tensor_operation::device::ConvParams& params,
auto ref_invoker = ref_conv.MakeInvoker(); auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(input, auto ref_argument = ref_conv.MakeArgument(input,
weights, weight,
host_output, host_output,
params.conv_filter_strides_, params.conv_filter_strides_,
params.conv_filter_dilations_, params.conv_filter_dilations_,
......
...@@ -20,6 +20,7 @@ static constexpr auto ConvFwdDefault = ...@@ -20,6 +20,7 @@ static constexpr auto ConvFwdDefault =
template <ck::index_t NumDimSpatial> template <ck::index_t NumDimSpatial>
using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwcKxcNwk_Xdl< using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwcKxcNwk_Xdl<
NumDimSpatial, // NumDimSpatial
InDataType, // InDataType, //
WeiDataType, // WeiDataType, //
OutDataType, // OutDataType, //
...@@ -28,7 +29,6 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwc ...@@ -28,7 +29,6 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwc
WeiElementOp, // Weights Elementwise Operation = WeiElementOp, // Weights Elementwise Operation =
OutElementOp, // Output Elementwise Operation OutElementOp, // Output Elementwise Operation
ConvFwdDefault, // ConvForwardSpecialization ConvFwdDefault, // ConvForwardSpecialization
NumDimSpatial, // NumDimSpatial
256, // BlockSize 256, // BlockSize
128, // MPerBlock 128, // MPerBlock
256, // NPerBlock 256, // NPerBlock
...@@ -77,7 +77,8 @@ int main(int argc, char* argv[]) ...@@ -77,7 +77,8 @@ int main(int argc, char* argv[])
bool time_kernel = false; bool time_kernel = false;
int num_dim_spatial = 2; int num_dim_spatial = 2;
ck::tensor_operation::device::ConvParams params; ck::tensor_operation::device::ConvParams params{
2, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}};
if(argc == 1) if(argc == 1)
{ {
...@@ -96,7 +97,7 @@ int main(int argc, char* argv[]) ...@@ -96,7 +97,7 @@ int main(int argc, char* argv[])
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
num_dim_spatial = std::stoi(argv[4]); num_dim_spatial = std::stoi(argv[4]);
params = parse_conv_params(num_dim_spatial, argc, argv); params = parse_conv_params(num_dim_spatial, 5, argv);
} }
if(num_dim_spatial == 1) if(num_dim_spatial == 1)
......
...@@ -39,7 +39,8 @@ namespace device { ...@@ -39,7 +39,8 @@ namespace device {
// 3D: // 3D:
// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] // out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
// //
template <typename InDataType, template <ck::index_t NumDimSpatial,
typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
typename AccDataType, typename AccDataType,
...@@ -47,7 +48,6 @@ template <typename InDataType, ...@@ -47,7 +48,6 @@ template <typename InDataType,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization, ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::index_t NumDimSpatial,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -83,17 +83,38 @@ struct DeviceConvNdFwdNwcKxcNwk_Xdl ...@@ -83,17 +83,38 @@ struct DeviceConvNdFwdNwcKxcNwk_Xdl
ck::Tuple<ck::tensor_layout::convolution::KXC, ck::Tuple<ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::KYXC, ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::KZYXC>>, ck::tensor_layout::convolution::KZYXC>>,
ck::tuple_element_t<NumDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWK,
ck::tensor_layout::convolution::NHWK,
ck::tensor_layout::convolution::NDHWK>>,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
{
using Base =
DeviceConvFwd<NumDimSpatial,
ck::tuple_element_t<NumDimSpatial - 1, ck::tuple_element_t<NumDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWC, ck::Tuple<ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::NHWC, ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::NDHWC>>, ck::tensor_layout::convolution::NDHWC>>,
ck::tuple_element_t<NumDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::KZYXC>>,
ck::tuple_element_t<NumDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWK,
ck::tensor_layout::convolution::NHWK,
ck::tensor_layout::convolution::NDHWK>>,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation> OutElementwiseOperation>;
{
using DeviceOp = DeviceConvNdFwdNwcKxcNwk_Xdl; using DeviceOp = DeviceConvNdFwdNwcKxcNwk_Xdl;
using ADataType = InDataType; using ADataType = InDataType;
......
...@@ -21,6 +21,8 @@ struct TupleElementKey ...@@ -21,6 +21,8 @@ struct TupleElementKey
template <typename Key, typename Data> template <typename Key, typename Data>
struct TupleElementKeyData struct TupleElementKeyData
{ {
using DataType = Data;
#if 0 // workaround compiler complaint about implicitly-deleted default constructor #if 0 // workaround compiler complaint about implicitly-deleted default constructor
__host__ __device__ constexpr TupleElementKeyData() = default; __host__ __device__ constexpr TupleElementKeyData() = default;
#else #else
...@@ -34,29 +36,40 @@ struct TupleElementKeyData ...@@ -34,29 +36,40 @@ struct TupleElementKeyData
{ {
} }
Data mData; DataType mData;
}; };
// for read access of tuple element
template <typename Key, typename Data> template <typename Key, typename Data>
__host__ __device__ constexpr const Data& __host__ __device__ constexpr const Data&
get_tuple_element_data(const TupleElementKeyData<Key, Data>& x) get_tuple_element_data_reference(const TupleElementKeyData<Key, Data>& x)
{ {
return static_cast<const Data&>(x.mData); return static_cast<const Data&>(x.mData);
} }
// for write access of tuple element
template <typename Key, typename Data> template <typename Key, typename Data>
__host__ __device__ constexpr Data& get_tuple_element_data(TupleElementKeyData<Key, Data>& x) __host__ __device__ constexpr Data&
get_tuple_element_data_reference(TupleElementKeyData<Key, Data>& x)
{ {
return x.mData; return x.mData;
} }
// TODO: not sure the use of reference is correct // TODO: not sure the use of reference is correct
template <typename Key, typename Data> template <typename Key, typename Data>
__host__ __device__ constexpr Data&& get_tuple_element_data(TupleElementKeyData<Key, Data>&& x) __host__ __device__ constexpr Data&&
get_tuple_element_data_reference(TupleElementKeyData<Key, Data>&& x)
{ {
return static_cast<Data&&>(x.mData); return static_cast<Data&&>(x.mData);
} }
// for infering type of tuple element
template <typename Key, typename Data>
__host__ __device__ constexpr Data get_tuple_element_data(const TupleElementKeyData<Key, Data>& x)
{
return std::forward(x.mData);
}
template <typename Indices, typename... Xs> template <typename Indices, typename... Xs>
struct TupleImpl; struct TupleImpl;
...@@ -87,13 +100,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<I ...@@ -87,13 +100,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<I
template <index_t I> template <index_t I>
__host__ __device__ constexpr const auto& GetElementDataByKey(TupleElementKey<I>) const __host__ __device__ constexpr const auto& GetElementDataByKey(TupleElementKey<I>) const
{ {
return get_tuple_element_data<TupleElementKey<I>>(*this); return get_tuple_element_data_reference<TupleElementKey<I>>(*this);
} }
template <index_t I> template <index_t I>
__host__ __device__ constexpr auto& GetElementDataByKey(TupleElementKey<I>) __host__ __device__ constexpr auto& GetElementDataByKey(TupleElementKey<I>)
{ {
return get_tuple_element_data<TupleElementKey<I>>(*this); return get_tuple_element_data_reference<TupleElementKey<I>>(*this);
} }
}; };
...@@ -185,7 +198,8 @@ struct Tuple<> ...@@ -185,7 +198,8 @@ struct Tuple<>
template <index_t I, typename TTuple> template <index_t I, typename TTuple>
struct tuple_element struct tuple_element
{ {
using type = decltype(TTuple{}.At(Number<I>{})); // type should keep the cv/ref qualifier of original tuple element
using type = decltype(detail::get_tuple_element_data<detail::TupleElementKey<I>>(TTuple{}));
}; };
template <index_t I, typename TTuple> template <index_t I, typename TTuple>
......
...@@ -30,6 +30,8 @@ namespace host { ...@@ -30,6 +30,8 @@ namespace host {
// operation. // operation.
// @tparam NumDimSpatial Number of spatial dimensions. // @tparam NumDimSpatial Number of spatial dimensions.
// //
// FIXME: only support NDimSpatial = 1 to 3; only support NCHW and NHWC layout.
// Need to be more general
template <ck::index_t NumDimSpatial, template <ck::index_t NumDimSpatial,
typename InLayout, typename InLayout,
typename WeiLayout, typename WeiLayout,
......
...@@ -10,7 +10,7 @@ namespace tensor_operation { ...@@ -10,7 +10,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// aliasing, for commonly used type // aliasing, for commonly used data type
using F64 = double; using F64 = double;
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
...@@ -23,9 +23,24 @@ using F16_F16_TUPLE = ck::Tuple<F16, F16>; ...@@ -23,9 +23,24 @@ using F16_F16_TUPLE = ck::Tuple<F16, F16>;
using F32_TUPLE = ck::Tuple<F32>; using F32_TUPLE = ck::Tuple<F32>;
// GEMM layout
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
// Conv layout
using NWC = ck::tensor_layout::convolution::NWC;
using NHWC = ck::tensor_layout::convolution::NHWC;
using NDHWC = ck::tensor_layout::convolution::NDHWC;
using KXC = ck::tensor_layout::convolution::KXC;
using KYXC = ck::tensor_layout::convolution::KYXC;
using KZYXC = ck::tensor_layout::convolution::KZYXC;
using NWK = ck::tensor_layout::convolution::NWK;
using NHWK = ck::tensor_layout::convolution::NHWK;
using NDHWK = ck::tensor_layout::convolution::NDHWK;
// pointwise functor
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
using Bilinear = ck::tensor_operation::element_wise::Bilinear; using Bilinear = ck::tensor_operation::element_wise::Bilinear;
......
...@@ -75,7 +75,7 @@ void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances( ...@@ -75,7 +75,7 @@ void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(
void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances( void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceConvFwd<2, NHWC, KYXC, NHWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceConvFwd<2, NHWC, KYXC, NHWK, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances( void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(
......
...@@ -256,6 +256,14 @@ struct Tensor ...@@ -256,6 +256,14 @@ struct Tensor
return *this; return *this;
} }
void SetZero()
{
for(auto& v : mData)
{
v = T{0};
}
}
template <typename F> template <typename F>
void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank) void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank)
{ {
......
...@@ -5,8 +5,6 @@ set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE ...@@ -5,8 +5,6 @@ set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp; device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp;
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp; device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp;
device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp; device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp;
)
set(DEVICE_CONVND_2D_FWD_INSTANCE_SOURCE
device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp; device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp;
device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp; device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp;
device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp; device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp;
...@@ -14,10 +12,5 @@ set(DEVICE_CONVND_2D_FWD_INSTANCE_SOURCE ...@@ -14,10 +12,5 @@ set(DEVICE_CONVND_2D_FWD_INSTANCE_SOURCE
) )
add_library(device_conv2d_fwd_instance OBJECT ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE}) add_library(device_conv2d_fwd_instance OBJECT ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE})
add_library(device_convnd_2d_fwd_instance OBJECT ${DEVICE_CONVND_2D_FWD_INSTANCE_SOURCE})
set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
set_target_properties(device_convnd_2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(device_conv2d_fwd_instance) clang_tidy_check(device_conv2d_fwd_instance)
clang_tidy_check(device_convnd_2d_fwd_instance)
...@@ -45,17 +45,17 @@ ConvParams::ConvParams(ck::index_t n_dim, ...@@ -45,17 +45,17 @@ ConvParams::ConvParams(ck::index_t n_dim,
{ {
// XEff = (X - 1) * conv_dilation_w + 1; // XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck::index_t idx_eff = const ck::index_t x_eff = (filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
(filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
output_spatial_lengths_[i] = output_spatial_lengths_[i] =
(input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - idx_eff) / (input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) /
conv_filter_strides_[i] + conv_filter_strides_[i] +
1; 1;
} }
} }
ConvParams::ConvParams() ConvParams::ConvParams()
: ConvParams::ConvParams(2, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {2, 2}, {2, 2}) : ConvParams::ConvParams(2, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1})
{ {
} }
......
...@@ -5,45 +5,45 @@ include_directories(BEFORE ...@@ -5,45 +5,45 @@ include_directories(BEFORE
# ck_profiler # ck_profiler
set(PROFILER_SOURCE set(PROFILER_SOURCE
src/profiler.cpp src/profiler.cpp
src/profile_gemm.cpp # src/profile_gemm.cpp
src/profile_gemm_splitk.cpp # src/profile_gemm_splitk.cpp
src/profile_gemm_bilinear.cpp # src/profile_gemm_bilinear.cpp
src/profile_gemm_bias_add_reduce.cpp # src/profile_gemm_bias_add_reduce.cpp
src/profile_gemm_add_add_fastgelu.cpp # src/profile_gemm_add_add_fastgelu.cpp
src/profile_gemm_reduce.cpp # src/profile_gemm_reduce.cpp
src/profile_batched_gemm.cpp # src/profile_batched_gemm.cpp
src/profile_batched_gemm_reduce.cpp # src/profile_batched_gemm_reduce.cpp
src/profile_grouped_gemm.cpp # src/profile_grouped_gemm.cpp
src/profile_conv_fwd_bias_relu.cpp src/profile_conv_fwd.cpp
src/profile_conv_fwd_bias_relu_add.cpp # src/profile_conv_fwd_bias_relu.cpp
src/profile_convnd_fwd.cpp # src/profile_conv_fwd_bias_relu_add.cpp
src/profile_convnd_bwd_data.cpp # src/profile_convnd_fwd.cpp
src/profile_conv_bwd_weight.cpp # src/profile_convnd_bwd_data.cpp
src/profile_convnd_bwd_weight.cpp # src/profile_conv_bwd_weight.cpp
src/profile_reduce.cpp # src/profile_convnd_bwd_weight.cpp
src/profile_normalization.cpp # src/profile_reduce.cpp
# src/profile_normalization.cpp
) )
add_executable(ckProfiler ${PROFILER_SOURCE}) add_executable(ckProfiler ${PROFILER_SOURCE})
target_link_libraries(ckProfiler PRIVATE host_tensor) target_link_libraries(ckProfiler PRIVATE utility)
target_link_libraries(ckProfiler PRIVATE conv_util) #target_link_libraries(ckProfiler PRIVATE device_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_instance) #target_link_libraries(ckProfiler PRIVATE device_gemm_splitk_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_splitk_instance) #target_link_libraries(ckProfiler PRIVATE device_gemm_bilinear_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_bilinear_instance) #target_link_libraries(ckProfiler PRIVATE device_gemm_add_add_fastgelu_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_add_add_fastgelu_instance) #target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance) #target_link_libraries(ckProfiler PRIVATE device_gemm_bias_add_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_add_reduce_instance) #target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance) #target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance) #target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance) #target_link_libraries(ckProfiler PRIVATE device_conv1d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_conv1d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_conv3d_fwd_instance) #target_link_libraries(ckProfiler PRIVATE device_conv3d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance) #target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) #target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance)
target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance) #target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance) #target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance)
target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_weight_instance) #target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_weight_instance)
target_link_libraries(ckProfiler PRIVATE device_normalization_instance) #target_link_libraries(ckProfiler PRIVATE device_normalization_instance)
target_link_libraries(ckProfiler PRIVATE device_reduce_instance) #target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
...@@ -15,14 +15,16 @@ ...@@ -15,14 +15,16 @@
#include "ck/library/tensor_operation_instance/gpu/convolution_forward.hpp" #include "ck/library/tensor_operation_instance/gpu/convolution_forward.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
namespace ck { namespace ck {
namespace profiler { namespace profiler {
// FIXME: only support NCHW and NHWC layout, need to be more general
template <ck::index_t NumDimSpatial, template <ck::index_t NumDimSpatial,
typename InLayout, typename InLayout,
typename WeiLayout, typename WeiLayout,
...@@ -34,70 +36,126 @@ int profile_conv_fwd_impl(int do_verification, ...@@ -34,70 +36,126 @@ int profile_conv_fwd_impl(int do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
bool time_kernel, bool time_kernel,
const ck::utils::conv::ConvParams& params) const ck::tensor_operation::device::ConvParams& params)
{ {
bool pass = true; bool pass = true;
auto f_host_tensor_descriptor = // make host tensor descritpor
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { auto f_nhwc_host_tensor_descriptor =
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value) [](ck::index_t n, ck::index_t c, std::vector<ck::index_t> spatial_lengths) {
std::vector<std::size_t> nhwc_lengths{static_cast<std::size_t>(n),
static_cast<std::size_t>(c)};
nhwc_lengths.insert(
nhwc_lengths.begin() + 1, spatial_lengths.begin(), spatial_lengths.end());
return HostTensorDescriptor(nhwc_lengths);
};
auto f_nchw_host_tensor_descriptor =
[](ck::index_t n, ck::index_t c, std::vector<ck::index_t> spatial_lengths) {
std::vector<std::size_t> nchw_lengths{static_cast<std::size_t>(n),
static_cast<std::size_t>(c)};
nchw_lengths.insert(nchw_lengths.end(), spatial_lengths.begin(), spatial_lengths.end());
return HostTensorDescriptor(nchw_lengths);
};
HostTensorDescriptor in_desc, wei_desc, out_desc;
// FIXME: properly implement "make host descriptor" for different layout
if constexpr(is_same_v<InLayout, ck::tensor_layout::convolution::NWC> ||
is_same_v<InLayout, ck::tensor_layout::convolution::NHWC> ||
is_same_v<InLayout, ck::tensor_layout::convolution::NDHWC>)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({row, col}), in_desc =
std::vector<std::size_t>({stride, 1})); f_nhwc_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_);
} }
else else if constexpr(is_same_v<InLayout, ck::tensor_layout::convolution::NCW> ||
is_same_v<InLayout, ck::tensor_layout::convolution::NCHW> ||
is_same_v<InLayout, ck::tensor_layout::convolution::NCDHW>)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({row, col}), in_desc =
std::vector<std::size_t>({1, stride})); f_nchw_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_);
} }
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); // FIXME: properly implement "make host descriptor" for different layout
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); if constexpr(is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC> ||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); is_same_v<WeiLayout, ck::tensor_layout::convolution::KYXC> ||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); is_same_v<WeiLayout, ck::tensor_layout::convolution::KZYXC>)
{
wei_desc =
f_nhwc_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_);
}
else if constexpr(is_same_v<WeiLayout, ck::tensor_layout::convolution::KCX> ||
is_same_v<WeiLayout, ck::tensor_layout::convolution::KCYX> ||
is_same_v<WeiLayout, ck::tensor_layout::convolution::KCZYX>)
{
wei_desc =
f_nchw_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_);
}
// FIXME: properly implement "make host descriptor" for different layout
if constexpr(is_same_v<OutLayout, ck::tensor_layout::convolution::NWK> ||
is_same_v<OutLayout, ck::tensor_layout::convolution::NHWK> ||
is_same_v<OutLayout, ck::tensor_layout::convolution::NDHWK>)
{
out_desc =
f_nhwc_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths());
}
else if constexpr(is_same_v<OutLayout, ck::tensor_layout::convolution::NKW> ||
is_same_v<OutLayout, ck::tensor_layout::convolution::NKHW> ||
is_same_v<OutLayout, ck::tensor_layout::convolution::NKDHW>)
{
out_desc =
f_nchw_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths());
}
Tensor<InDataType> input(in_desc);
Tensor<WeiDataType> weight(wei_desc);
Tensor<OutDataType> host_output(out_desc);
Tensor<OutDataType> device_output(out_desc);
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "input: " << input.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "weight: " << weight.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; std::cout << "output: " << host_output.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
case 1: case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); input.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); weight.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
break; break;
default: default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); input.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); weight.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
} }
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
const auto a_element_op = AElementOp{}; const auto in_element_op = InElementOp{};
const auto b_element_op = BElementOp{}; const auto wei_element_op = WeiElementOp{};
const auto c_element_op = CElementOp{}; const auto out_element_op = OutElementOp{};
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_m_k.mData.data()); in_device_buf.ToDevice(input.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); wei_device_buf.ToDevice(weight.mData.data());
c_device_buf.ToDevice(c_m_n_device_result.mData.data());
using DeviceOp = ck::tensor_operation::device::DeviceGemm<ALayout, using DeviceOp = ck::tensor_operation::device::DeviceConvFwd<NumDimSpatial,
BLayout, InLayout,
CLayout, WeiLayout,
ADataType, OutLayout,
BDataType, InDataType,
CDataType, WeiDataType,
AElementOp, OutDataType,
BElementOp, InElementOp,
CElementOp>; WeiElementOp,
OutElementOp>;
// get device op instances // get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
...@@ -105,94 +163,107 @@ int profile_conv_fwd_impl(int do_verification, ...@@ -105,94 +163,107 @@ int profile_conv_fwd_impl(int do_verification,
std::cout << "found " << op_ptrs.size() << " instances" << std::endl; std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
// Run reference GEMM // run reference op
if(do_verification) if(do_verification)
{ {
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NumDimSpatial,
BDataType, InLayout,
CDataType, WeiLayout,
AccDataType, OutLayout,
AElementOp, InDataType,
BElementOp, WeiDataType,
CElementOp>; OutDataType,
InElementOp,
auto ref_op = ReferenceGemmInstance{}; WeiElementOp,
auto ref_invoker = ref_op.MakeInvoker(); OutElementOp>{};
auto ref_argument = ref_op.MakeArgument( auto ref_invoker = ref_conv.MakeInvoker();
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); auto ref_argument = ref_conv.MakeArgument(input,
weight,
host_output,
params.conv_filter_strides_,
params.conv_filter_dilations_,
params.input_left_pads_,
params.input_right_pads_,
in_element_op,
wei_element_op,
out_element_op);
// init host output to zero
host_output.SetZero();
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
} }
std::string best_op_name; std::string best_op_name;
float best_ave_time = 0; float best_avg_time = 0;
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
// profile device GEMM instances // profile device op instances
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
auto argument_ptr = auto argument_ptr =
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), op_ptr->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
M, params.N_,
N, params.K_,
K, params.C_,
StrideA, params.input_spatial_lengths_,
StrideB, params.filter_spatial_lengths_,
StrideC, params.GetOutputSpatialLengths(),
a_element_op, params.conv_filter_strides_,
b_element_op, params.conv_filter_dilations_,
c_element_op); params.input_left_pads_,
params.input_right_pads_,
in_element_op,
wei_element_op,
out_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get())) if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
// re-init C to zero before profiling next kernel // re-init output to zero before profiling next kernel
c_device_buf.SetZero(); out_device_buf.SetZero();
std::string op_name = op_ptr->GetTypeString(); std::string op_name = op_ptr->GetTypeString();
float ave_time = float avg_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = params.GetFlops();
std::size_t num_btype = params.GetByte<InDataType, WeiDataType, OutDataType>();
std::size_t num_btype = float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / avg_time;
float gb_per_sec = num_btype / 1.E6 / ave_time; std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl; << gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops) if(tflops > best_tflops)
{ {
best_op_name = op_name; best_op_name = op_name;
best_tflops = tflops; best_tflops = tflops;
best_ave_time = ave_time; best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec; best_gb_per_sec = gb_per_sec;
} }
if(do_verification) if(do_verification)
{ {
c_device_buf.FromDevice(c_m_n_device_result.mData.data()); out_device_buf.FromDevice(device_output.mData.data());
pass = pass = pass & ck::utils::check_err(device_output.mData, host_output.mData);
pass & ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
if(do_log) if(do_log)
{ {
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "input : ", input.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "weight: ", weight.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host : ", c_m_n_host_result.mData, ",") LogRangeAsType<float>(std::cout << "host_output : ", host_output.mData, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>(std::cout << "c_device: ", c_m_n_device_result.mData, ",") LogRangeAsType<float>(std::cout << "device_output: ", device_output.mData, ",")
<< std::endl; << std::endl;
} }
} }
...@@ -203,47 +274,11 @@ int profile_conv_fwd_impl(int do_verification, ...@@ -203,47 +274,11 @@ int profile_conv_fwd_impl(int do_verification,
} }
} }
if constexpr(is_same<CDataType, float>::value) std::cout << "Best configuration parameters:"
{ << "\nname: " << best_op_name << "\navg_time: " << best_avg_time
std::cout << "Best Perf for datatype = f32"; << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl;
}
else if constexpr(is_same<CDataType, half_t>::value)
{
std::cout << "Best Perf for datatype = f16";
}
else if constexpr(is_same<CDataType, bhalf_t>::value)
{
std::cout << "Best Perf for datatype = bf16";
}
else if constexpr(is_same<CDataType, int8_t>::value)
{
std::cout << "Best Perf for datatype = int8";
}
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value)
{
std::cout << " ALayout = RowMajor";
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value)
{
std::cout << " ALayout = ColumnMajor";
}
if constexpr(is_same<BLayout, tensor_layout::gemm::RowMajor>::value)
{
std::cout << " BLayout = RowMajor";
}
else if constexpr(is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value)
{
std::cout << " BLayout = ColumnMajor";
}
std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA
<< " StrideB = " << StrideB << " StrideC = " << StrideC << " : " << best_ave_time
<< " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, "
<< best_op_name << std::endl;
return pass ? 0 : 1; return 0;
} }
} // namespace profiler } // namespace profiler
......
...@@ -23,13 +23,13 @@ ...@@ -23,13 +23,13 @@
namespace ck { namespace ck {
namespace profiler { namespace profiler {
template <typename ADataType, template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType, typename BDataType,
typename AccDataType, typename AccDataType,
typename CDataType, typename CDataType>
typename ALayout,
typename BLayout,
typename CLayout>
int profile_gemm_impl(int do_verification, int profile_gemm_impl(int do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
...@@ -92,7 +92,6 @@ int profile_gemm_impl(int do_verification, ...@@ -92,7 +92,6 @@ int profile_gemm_impl(int do_verification,
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
c_device_buf.ToDevice(c_m_n_device_result.mData.data());
using DeviceOp = ck::tensor_operation::device::DeviceGemm<ALayout, using DeviceOp = ck::tensor_operation::device::DeviceGemm<ALayout,
BLayout, BLayout,
...@@ -110,7 +109,7 @@ int profile_gemm_impl(int do_verification, ...@@ -110,7 +109,7 @@ int profile_gemm_impl(int do_verification,
std::cout << "found " << op_ptrs.size() << " instances" << std::endl; std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
// Run reference GEMM // Run reference op
if(do_verification) if(do_verification)
{ {
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
...@@ -131,11 +130,11 @@ int profile_gemm_impl(int do_verification, ...@@ -131,11 +130,11 @@ int profile_gemm_impl(int do_verification,
} }
std::string best_op_name; std::string best_op_name;
float best_ave_time = 0; float best_avg_time = 0;
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
// profile device GEMM instances // profile device op instances
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
auto argument_ptr = auto argument_ptr =
...@@ -161,7 +160,7 @@ int profile_gemm_impl(int do_verification, ...@@ -161,7 +160,7 @@ int profile_gemm_impl(int do_verification,
std::string op_name = op_ptr->GetTypeString(); std::string op_name = op_ptr->GetTypeString();
float ave_time = float avg_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
...@@ -169,18 +168,18 @@ int profile_gemm_impl(int do_verification, ...@@ -169,18 +168,18 @@ int profile_gemm_impl(int do_verification,
std::size_t num_btype = std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl; << gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops) if(tflops > best_tflops)
{ {
best_op_name = op_name; best_op_name = op_name;
best_tflops = tflops; best_tflops = tflops;
best_ave_time = ave_time; best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec; best_gb_per_sec = gb_per_sec;
} }
...@@ -244,7 +243,7 @@ int profile_gemm_impl(int do_verification, ...@@ -244,7 +243,7 @@ int profile_gemm_impl(int do_verification,
} }
std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA
<< " StrideB = " << StrideB << " StrideC = " << StrideC << " : " << best_ave_time << " StrideB = " << StrideB << " StrideC = " << StrideC << " : " << best_avg_time
<< " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, "
<< best_op_name << std::endl; << best_op_name << std::endl;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/include/profile_conv_fwd_impl.hpp"
enum struct ConvLayout
{
NCHW_KYXC_NKHW, // 0
NHWC_KYXC_NHWK, // 1
};
enum struct ConvDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3
};
static void print_helper_msg()
{
// clang-format-off
std::cout << "arg1: tensor operation (conv_fwd: ForwardConvolution)\n"
<< "arg2: data type (0: fp32; 1: fp16, 2: bf16, 3: int8)\n"
<< "arg3: tensor layout (0: Input[N, C, Hi, Wi] * Weight[K, C, Y, X] = Output[N, K, "
"Ho, Wo]\n"
<< " 1: Input[N, Hi, Wi, C] * Weight[K, Y, X, C] = Output[N, Ho, "
"Wo, K])\n"
<< "arg4: verification (0: no, 1: yes)\n"
<< "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n"
<< "arg7: time kernel (0: no, 1: yes)\n"
<< "arg8: N spatial dimensions\n"
<< "Following arguments (depending on number of spatial dims):\n"
<< " N, K, C, \n"
<< " <filter spatial dimensions>, (ie Y, X for 2D)\n"
<< " <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
<< " <strides>, (ie Sy, Sx for 2D)\n"
<< " <dilations>, (ie Dy, Dx for 2D)\n"
<< " <left padding>, (ie LeftPy, LeftPx for 2D)\n"
<< " <right padding>, (ie RightPy, RightPx for 2D)\n"
<< std::endl;
// clang-format-on
}
ck::tensor_operation::device::ConvParams
parse_conv_params(int num_dim_spatial, int arg_idx, char* const argv[])
{
const ck::index_t N = std::stoi(argv[arg_idx++]);
const ck::index_t K = std::stoi(argv[arg_idx++]);
const ck::index_t C = std::stoi(argv[arg_idx++]);
std::vector<ck::index_t> filter_spatial_lengths(num_dim_spatial);
std::vector<ck::index_t> input_spatial_lengths(num_dim_spatial);
std::vector<ck::index_t> conv_filter_strides(num_dim_spatial);
std::vector<ck::index_t> conv_filter_dilations(num_dim_spatial);
std::vector<ck::index_t> input_left_pads(num_dim_spatial);
std::vector<ck::index_t> input_right_pads(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
conv_filter_strides[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
conv_filter_dilations[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_left_pads[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_right_pads[i] = std::stoi(argv[arg_idx++]);
}
return ck::tensor_operation::device::ConvParams{num_dim_spatial,
N,
K,
C,
filter_spatial_lengths,
input_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
}
int profile_conv_fwd(int argc, char* argv[])
{
// 8 for control, 1 for num_dim_spatial
if(argc < 9)
{
print_helper_msg();
exit(1);
}
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const auto layout = static_cast<ConvLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);
const bool time_kernel = std::stoi(argv[7]);
const int num_dim_spatial = std::stoi(argv[8]);
// 8 for control, 1 for num_dim_spatial, 3 for N/K/C, and 6 * num_dim_spatial
if(argc != 8 + 4 + 6 * num_dim_spatial)
{
print_helper_msg();
exit(1);
}
const auto params = parse_conv_params(num_dim_spatial, 9, argv);
using F32 = float;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using INT8 = int8_t;
using NWC = ck::tensor_layout::convolution::NWC;
using NHWC = ck::tensor_layout::convolution::NHWC;
using NDHWC = ck::tensor_layout::convolution::NDHWC;
using KXC = ck::tensor_layout::convolution::KXC;
using KYXC = ck::tensor_layout::convolution::KYXC;
using KZYXC = ck::tensor_layout::convolution::KZYXC;
using NWK = ck::tensor_layout::convolution::NWK;
using NHWK = ck::tensor_layout::convolution::NHWK;
using NDHWK = ck::tensor_layout::convolution::NDHWK;
constexpr auto I1 = ck::Number<1>{};
constexpr auto I2 = ck::Number<2>{};
constexpr auto I3 = ck::Number<3>{};
auto profile = [&](auto num_dim_spatial_tmp,
auto in_type,
auto wei_type,
auto out_type,
auto in_layout,
auto wei_layout,
auto out_layout) {
constexpr ck::index_t NumDimSpatial = num_dim_spatial_tmp.value;
using InDataType = decltype(in_type);
using WeiDataType = decltype(wei_type);
using OutDataType = decltype(out_type);
using InLayout = decltype(in_layout);
using WeiLayout = decltype(wei_layout);
using OutLayout = decltype(out_layout);
bool pass = ck::profiler::profile_conv_fwd_impl<NumDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(
do_verification, init_method, do_log, time_kernel, params);
return pass ? 0 : 1;
};
if(num_dim_spatial == 1 && layout == ConvLayout::NHWC_KYXC_NHWK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I1, NWC{}, KXC{}, NWK{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I1, NWC{}, KXC{}, NWK{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I1, NWC{}, KXC{}, NWK{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::INT8_INT8_INT8)
{
return profile(I1, NWC{}, KXC{}, NWK{}, INT8{}, INT8{}, INT8{});
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWC_KYXC_NHWK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I2, NHWC{}, KYXC{}, NHWK{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I2, NHWC{}, KYXC{}, NHWK{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I2, NHWC{}, KYXC{}, NHWK{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::INT8_INT8_INT8)
{
return profile(I2, NHWC{}, KYXC{}, NHWK{}, INT8{}, INT8{}, INT8{});
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWC_KYXC_NHWK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::INT8_INT8_INT8)
{
return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, INT8{}, INT8{}, INT8{});
}
}
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
}
...@@ -24,21 +24,27 @@ enum struct GemmDataType ...@@ -24,21 +24,27 @@ enum struct GemmDataType
INT8_INT8_INT8, // 3 INT8_INT8_INT8, // 3
}; };
static void print_helper_msg()
{
std::cout << "arg1: tensor operation (gemm: GEMM)\n"
<< "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"
<< "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"
<< " 1: A[m, k] * B[n, k] = C[m, n];\n"
<< " 2: A[k, m] * B[k, n] = C[m, n];\n"
<< " 3: A[k, m] * B[n, k] = C[m, n])\n"
<< "arg4: verification (0: no; 1: yes)\n"
<< "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n"
<< "arg7: time kernel (0: no, 1: yes)\n"
<< "arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"
<< std::endl;
}
int profile_gemm(int argc, char* argv[]) int profile_gemm(int argc, char* argv[])
{ {
if(argc != 14) if(argc != 14)
{ {
printf("arg1: tensor operation (gemm: GEMM)\n"); print_helper_msg();
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
printf("arg4: verification (0: no; 1: yes)\n");
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg6: print tensor value (0: no; 1: yes)\n");
printf("arg7: time kernel (0=no, 1=yes)\n");
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
exit(1); exit(1);
} }
...@@ -109,67 +115,67 @@ int profile_gemm(int argc, char* argv[]) ...@@ -109,67 +115,67 @@ int profile_gemm(int argc, char* argv[])
if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{}); return profile(Row{}, Row{}, Row{}, F32{}, F32{}, F32{}, F32{});
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{}); return profile(Row{}, Col{}, Row{}, F32{}, F32{}, F32{}, F32{});
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{}); return profile(Col{}, Row{}, Row{}, F32{}, F32{}, F32{}, F32{});
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{}); return profile(Col{}, Col{}, Row{}, F32{}, F32{}, F32{}, F32{});
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}); return profile(Row{}, Row{}, Row{}, F16{}, F16{}, F32{}, F16{});
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}); return profile(Row{}, Col{}, Row{}, F16{}, F16{}, F32{}, F16{});
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}); return profile(Col{}, Row{}, Row{}, F16{}, F16{}, F32{}, F16{});
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}); return profile(Col{}, Col{}, Row{}, F16{}, F16{}, F32{}, F16{});
} }
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
return profile(BF16{}, BF16{}, F32{}, BF16{}, Row{}, Row{}, Row{}); return profile(Row{}, Row{}, Row{}, BF16{}, BF16{}, F32{}, BF16{});
} }
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
return profile(BF16{}, BF16{}, F32{}, BF16{}, Row{}, Col{}, Row{}); return profile(Row{}, Col{}, Row{}, BF16{}, BF16{}, F32{}, BF16{});
} }
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
return profile(BF16{}, BF16{}, F32{}, BF16{}, Col{}, Row{}, Row{}); return profile(Col{}, Row{}, Row{}, BF16{}, BF16{}, F32{}, BF16{});
} }
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
return profile(BF16{}, BF16{}, F32{}, BF16{}, Col{}, Col{}, Row{}); return profile(Col{}, Col{}, Row{}, BF16{}, BF16{}, F32{}, BF16{});
} }
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
return profile(INT8{}, INT8{}, INT32{}, INT8{}, Row{}, Row{}, Row{}); return profile(Row{}, Row{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{});
} }
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
return profile(INT8{}, INT8{}, INT32{}, INT8{}, Row{}, Col{}, Row{}); return profile(Row{}, Col{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{});
} }
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
return profile(INT8{}, INT8{}, INT32{}, INT8{}, Col{}, Row{}, Row{}); return profile(Col{}, Row{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{});
} }
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
return profile(INT8{}, INT8{}, INT32{}, INT8{}, Col{}, Col{}, Row{}); return profile(Col{}, Col{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{});
} }
else else
{ {
......
...@@ -3,24 +3,23 @@ ...@@ -3,24 +3,23 @@
#include <cstring> #include <cstring>
int profile_gemm(int, char*[]); // int profile_gemm(int, char*[]);
int profile_gemm_splitk(int, char*[]); // int profile_gemm_splitk(int, char*[]);
int profile_gemm_bilinear(int, char*[]); // int profile_gemm_bilinear(int, char*[]);
int profile_gemm_add_add_fastgelu(int, char*[]); // int profile_gemm_add_add_fastgelu(int, char*[]);
int profile_gemm_reduce(int, char*[]); // int profile_gemm_reduce(int, char*[]);
int profile_gemm_bias_add_reduce(int, char*[]); // int profile_gemm_bias_add_reduce(int, char*[]);
int profile_batched_gemm(int, char*[]); // int profile_batched_gemm(int, char*[]);
int profile_batched_gemm_reduce(int, char*[]); // int profile_batched_gemm_reduce(int, char*[]);
int profile_grouped_gemm(int, char*[]); // int profile_grouped_gemm(int, char*[]);
int profile_conv_fwd(int, char*[]); int profile_conv_fwd(int, char*[]);
int profile_conv_fwd_bias_relu(int, char*[]); // int profile_conv_fwd_bias_relu(int, char*[]);
int profile_conv_fwd_bias_relu_add(int, char*[]); // int profile_conv_fwd_bias_relu_add(int, char*[]);
int profile_convnd_fwd(int argc, char* argv[]); // int profile_convnd_bwd_data(int, char*[], int);
int profile_convnd_bwd_data(int, char*[], int); // int profile_conv_bwd_weight(int, char*[]);
int profile_conv_bwd_weight(int, char*[]); // int profile_normalization(int, char*[]);
int profile_normalization(int, char*[]); // int profile_reduce(int, char*[]);
int profile_reduce(int, char*[]); // int profile_convnd_bwd_weight(int, char*[], int);
int profile_convnd_bwd_weight(int, char*[], int);
static void print_helper_message() static void print_helper_message()
{ {
...@@ -54,6 +53,7 @@ int main(int argc, char* argv[]) ...@@ -54,6 +53,7 @@ int main(int argc, char* argv[])
return 0; return 0;
} }
#if 0
if(strcmp(argv[1], "gemm") == 0) if(strcmp(argv[1], "gemm") == 0)
{ {
return profile_gemm(argc, argv); return profile_gemm(argc, argv);
...@@ -90,10 +90,12 @@ int main(int argc, char* argv[]) ...@@ -90,10 +90,12 @@ int main(int argc, char* argv[])
{ {
return profile_grouped_gemm(argc, argv); return profile_grouped_gemm(argc, argv);
} }
else if(strcmp(argv[1], "conv_fwd") == 0) #endif
if(strcmp(argv[1], "conv_fwd") == 0)
{ {
return profile_convnd_fwd(argc, argv); return profile_conv_fwd(argc, argv);
} }
#if 0
else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0) else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0)
{ {
return profile_conv_fwd_bias_relu(argc, argv); return profile_conv_fwd_bias_relu(argc, argv);
...@@ -139,6 +141,7 @@ int main(int argc, char* argv[]) ...@@ -139,6 +141,7 @@ int main(int argc, char* argv[])
{ {
return profile_normalization(argc, argv); return profile_normalization(argc, argv);
} }
#endif
else else
{ {
print_helper_message(); print_helper_message();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment