Commit 0cb8ba92 authored by Chao Liu's avatar Chao Liu
Browse files

update reference conv

parent ba816e69
...@@ -26,7 +26,6 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwc ...@@ -26,7 +26,6 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwc
AccDataType, // AccDataType, //
InElementOp, // Input Elementwise Operation InElementOp, // Input Elementwise Operation
WeiElementOp, // Weights Elementwise Operation = WeiElementOp, // Weights Elementwise Operation =
// ck::tensor_layout::convolution::NKHW,
OutElementOp, // Output Elementwise Operation OutElementOp, // Output Elementwise Operation
ConvFwdDefault, // ConvForwardSpecialization ConvFwdDefault, // ConvForwardSpecialization
NumDimSpatial, // NumDimSpatial NumDimSpatial, // NumDimSpatial
...@@ -57,16 +56,17 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwc ...@@ -57,16 +56,17 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwc
1>; // CThreadTransferDstScalarPerVector 1>; // CThreadTransferDstScalarPerVector
template <ck::index_t NumDimSpatial> template <ck::index_t NumDimSpatial>
using ReferenceConvNDFwdInstance = ck::tensor_operation::host::ReferenceConvFwd<NumDimSpatial, using ReferenceConvNDFwdInstance =
InLayout, ck::tensor_operation::host::ReferenceConvFwd<NumDimSpatial,
WeiLayout, ck::tensor_layout::convolution::NHWC,
OutLayout, ck::tensor_layout::convolution::KYXC,
InDataType, ck::tensor_layout::convolution::NHWK,
WeiDataType, InDataType,
OutDataType, WeiDataType,
InElementOp, OutDataType,
WeiElementOp, InElementOp,
OutElementOp>; WeiElementOp,
OutElementOp>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
......
...@@ -260,16 +260,16 @@ int main(int argc, char* argv[]) ...@@ -260,16 +260,16 @@ int main(int argc, char* argv[])
e_ms_ns_lengths = {M0, M1, N0, N1}; e_ms_ns_lengths = {M0, M1, N0, N1};
e_ms_ns_strides = { e_ms_ns_strides = {
std::stoi(argv[22]), std::stoi(argv[23]), std::stoi(argv[24]), std::stoi(argv[25])}; std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])};
scale = std::stof(argv[26]); scale = std::stof(argv[22]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 7: M0, M1, N0, N1, K0, K1\n"); printf("arg4 to 9: M0, M1, N0, N1, K0, K1\n");
printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n");
printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n");
printf("arg18 to 21: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); printf("arg18 to 21: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n");
......
...@@ -87,65 +87,63 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -87,65 +87,63 @@ struct ReferenceConvFwd : public device::BaseOperator
{ {
using Argument = ReferenceConvFwd::Argument; using Argument = ReferenceConvFwd::Argument;
// FIXME: properly implement "TensorView" for doing transpose or refer to dimension by name
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
// tensor descriptor in NCHW/KXYC/NKHW dimensional order // tensor descriptor in NCHW/KXYC/NKHW dimensional order
HostTensorDescriptor in_desc = arg.input_.mDesc; HostTensorDescriptor in_desc = arg.input_.mDesc;
HostTensorDescriptor wei_desc = arg.weight_.mDesc; HostTensorDescriptor wei_desc = arg.weight_.mDesc;
HostTensorDescriptor oout_desc = arg.output_.mDesc; HostTensorDescriptor out_desc = arg.output_.mDesc;
// input // input
if constexpr(is_same_v<InLayout,ck::tensor_layout::convolution::NWC>) if constexpr(is_same_v<InLayout, ck::tensor_layout::convolution::NWC>)
{ {
in_desc = transpose_host_tensor_descriptor_given_new2old( in_desc = transpose_host_tensor_descriptor_given_new2old(
input_.mDesc, std::vector<std::size_t>{0, 2, 1}); arg.input_.mDesc, std::vector<std::size_t>{0, 2, 1});
} }
else if constexpr(is_same_v<InLayout,ck::tensor_layout::convolution::NHWC>) else if constexpr(is_same_v<InLayout, ck::tensor_layout::convolution::NHWC>)
{ {
in_desc = transpose_host_tensor_descriptor_given_new2old( in_desc = transpose_host_tensor_descriptor_given_new2old(
input_.mDesc, std::vector<std::size_t>{0, 3, 1, 2}); arg.input_.mDesc, std::vector<std::size_t>{0, 3, 1, 2});
} }
else if constexpr(is_same_v<InLayout,ck::tensor_layout::convolution::NDHWC>) else if constexpr(is_same_v<InLayout, ck::tensor_layout::convolution::NDHWC>)
{ {
in_desc = transpose_host_tensor_descriptor_given_new2old( in_desc = transpose_host_tensor_descriptor_given_new2old(
input_.mDesc, std::vector<std::size_t>{0, 4, 1, 2, 3}); arg.input_.mDesc, std::vector<std::size_t>{0, 4, 1, 2, 3});
} }
// weight // weight
if constexpr(is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC>) if constexpr(is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC>)
{ {
wei_desc = transpose_host_tensor_descriptor_given_new2old( wei_desc = transpose_host_tensor_descriptor_given_new2old(
weight_.mDesc, std::vector<std::size_t>{0, 2, 1}); arg.weight_.mDesc, std::vector<std::size_t>{0, 2, 1});
} }
if constexpr(is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC>) else if constexpr(is_same_v<WeiLayout, ck::tensor_layout::convolution::KYXC>)
{ {
wei_desc = transpose_host_tensor_descriptor_given_new2old( wei_desc = transpose_host_tensor_descriptor_given_new2old(
weight_.mDesc, std::vector<std::size_t>{0, 3, 1, 2}); arg.weight_.mDesc, std::vector<std::size_t>{0, 3, 1, 2});
} }
else if constexpr(NumDimSpatial == 2 && else if constexpr(is_same_v<WeiLayout, ck::tensor_layout::convolution::KZYXC>)
WeiLayout == ck::tensor_layout::convolution::KYXC)
{ {
wei_desc = transpose_host_tensor_descriptor_given_new2old( wei_desc = transpose_host_tensor_descriptor_given_new2old(
weight_.mDesc, std::vector<std::size_t>{0, 3, 1, 2}); arg.weight_.mDesc, std::vector<std::size_t>{0, 4, 1, 2, 3});
} }
// output // output
if constexpr(NumDimSpatial == 1 && OutLayout == ck::tensor_layout::convolution::NWK) if constexpr(is_same_v<OutLayout, ck::tensor_layout::convolution::NWK>)
{ {
out_desc = transpose_host_tensor_descriptor_given_new2old( out_desc = transpose_host_tensor_descriptor_given_new2old(
output_.mDesc, std::vector<std::size_t>{0, 2, 1}); arg.output_.mDesc, std::vector<std::size_t>{0, 2, 1});
} }
else if constexpr(NumDimSpatial == 2 && else if constexpr(is_same_v<OutLayout, ck::tensor_layout::convolution::NHWK>)
OutLayout == ck::tensor_layout::convolution::NHWK)
{ {
out_desc = transpose_host_tensor_descriptor_given_new2old( out_desc = transpose_host_tensor_descriptor_given_new2old(
output_.mDesc, std::vector<std::size_t>{0, 3, 1, 2}); arg.output_.mDesc, std::vector<std::size_t>{0, 3, 1, 2});
} }
else if constexpr(NumDimSpatial == 3 && else if constexpr(is_same_v<OutLayout, ck::tensor_layout::convolution::NDHWK>)
OutLayout == ck::tensor_layout::convolution::NDHWK)
{ {
out_desc = transpose_host_tensor_descriptor_given_new2old( out_desc = transpose_host_tensor_descriptor_given_new2old(
output_.mDesc, std::vector<std::size_t>{0, 4, 1, 2, 3}); arg.output_.mDesc, std::vector<std::size_t>{0, 4, 1, 2, 3});
} }
if constexpr(NumDimSpatial == 1) if constexpr(NumDimSpatial == 1)
...@@ -161,16 +159,26 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -161,16 +159,26 @@ struct ReferenceConvFwd : public device::BaseOperator
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[0]) + ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[0]) - ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]); ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
if(wi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < in_desc.GetLengths()[2]) ck::type_convert<std::size_t>(wi) < in_desc.GetLengths()[2])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
arg.in_element_op_(v_in, // FIXME hacky
ck::type_convert<float>(arg.input_(n, c, wi))); arg.in_element_op_(
arg.wei_element_op_(v_wei, v_in,
ck::type_convert<float>(arg.weight_(k, c, x))); ck::type_convert<float>(
arg.input_
.mData[in_desc.GetOffsetFromMultiIndex(n, c, wi)]));
// FIXME hacky
arg.wei_element_op_(
v_wei,
ck::type_convert<float>(
arg.weight_
.mData[wei_desc.GetOffsetFromMultiIndex(k, c, x)]));
v_acc += v_in * v_wei; v_acc += v_in * v_wei;
} }
...@@ -180,7 +188,10 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -180,7 +188,10 @@ struct ReferenceConvFwd : public device::BaseOperator
float v_out; float v_out;
arg.out_element_op_(v_out, v_acc); arg.out_element_op_(v_out, v_acc);
arg.output_(n, k, wo) = ck::type_convert<OutDataType>(v_out);
// FIXME hacky
arg.output_.mData[out_desc.GetOffsetFromMultiIndex({n, k, wo})] =
ck::type_convert<OutDataType>(v_out);
}; };
make_ParallelTensorFunctor(f_ncw, make_ParallelTensorFunctor(f_ncw,
...@@ -204,12 +215,14 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -204,12 +215,14 @@ struct ReferenceConvFwd : public device::BaseOperator
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) + ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) - ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]); ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < wei_desc.GetLengths()[3]; ++x) for(std::size_t x = 0; x < wei_desc.GetLengths()[3]; ++x)
{ {
auto wi = auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) + ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) +
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) - ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]); ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 && if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) < in_desc.GetLengths()[2] && ck::type_convert<std::size_t>(hi) < in_desc.GetLengths()[2] &&
wi >= 0 && wi >= 0 &&
...@@ -218,10 +231,20 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -218,10 +231,20 @@ struct ReferenceConvFwd : public device::BaseOperator
float v_in; float v_in;
float v_wei; float v_wei;
// FIXME hacky
arg.in_element_op_( arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(n, c, hi, wi))); v_in,
ck::type_convert<float>(
arg.input_.mData[in_desc.GetOffsetFromMultiIndex(
n, c, hi, wi)]));
// FIXME hacky
arg.wei_element_op_( arg.wei_element_op_(
v_wei, ck::type_convert<float>(arg.weight_(k, c, y, x))); v_wei,
ck::type_convert<float>(
arg.weight_.mData[wei_desc.GetOffsetFromMultiIndex(
k, c, y, x)]));
v_acc += v_in * v_wei; v_acc += v_in * v_wei;
} }
} }
...@@ -231,7 +254,10 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -231,7 +254,10 @@ struct ReferenceConvFwd : public device::BaseOperator
float v_out; float v_out;
arg.out_element_op_(v_out, v_acc); arg.out_element_op_(v_out, v_acc);
arg.output_(n, k, ho, wo) = ck::type_convert<OutDataType>(v_out);
// FIXME hacky
arg.output_.mData[out_desc.GetOffsetFromMultiIndex({n, k, ho, wo})] =
ck::type_convert<OutDataType>(v_out);
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(f_nchw,
...@@ -282,12 +308,20 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -282,12 +308,20 @@ struct ReferenceConvFwd : public device::BaseOperator
float v_in; float v_in;
float v_wei; float v_wei;
// FIXME hacky
arg.in_element_op_( arg.in_element_op_(
v_in, v_in,
ck::type_convert<float>(arg.input_(n, c, di, hi, wi))); ck::type_convert<float>(
arg.input_.mData[in_desc.GetOffsetFromMultiIndex(
n, c, di, hi, wi)]));
// FIXME hacky
arg.wei_element_op_( arg.wei_element_op_(
v_wei, v_wei,
ck::type_convert<float>(arg.weight_(k, c, z, y, x))); ck::type_convert<float>(
arg.weight_.mData[wei_desc.GetOffsetFromMultiIndex(
k, c, z, y, x)]));
v_acc += v_in * v_wei; v_acc += v_in * v_wei;
} }
} }
...@@ -298,7 +332,10 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -298,7 +332,10 @@ struct ReferenceConvFwd : public device::BaseOperator
float v_out; float v_out;
arg.out_element_op_(v_out, v_acc); arg.out_element_op_(v_out, v_acc);
arg.output_(n, k, d_o, ho, wo) = ck::type_convert<OutDataType>(v_out);
// FIXME hacky
arg.output_.mData[out_desc.GetOffsetFromMultiIndex({n, k, d_o, ho, wo})] =
ck::type_convert<OutDataType>(v_out);
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(f_nchw,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment