Commit 12585e57 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 90acba1d
...@@ -142,22 +142,105 @@ int run_conv_fwd(bool do_verification, ...@@ -142,22 +142,105 @@ int run_conv_fwd(bool do_verification,
in_device_buf.ToDevice(in.mData.data()); in_device_buf.ToDevice(in.mData.data());
wei_device_buf.ToDevice(wei.mData.data()); wei_device_buf.ToDevice(wei.mData.data());
// tensor descriptor in NCHW/KXYC/NKHW dimensional order
HostTensorDescriptor in_n_c_wis_desc = in_desc;
HostTensorDescriptor wei_k_c_xs_desc = wei_desc;
HostTensorDescriptor out_n_k_wos_desc = out_desc;
// input
if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NWC>)
{
in_n_c_wis_desc = transpose_host_tensor_descriptor_given_new2old(
in_desc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NHWC>)
{
in_n_c_wis_desc = transpose_host_tensor_descriptor_given_new2old(
in_desc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NDHWC>)
{
in_n_c_wis_desc = transpose_host_tensor_descriptor_given_new2old(
in_desc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
// weight
if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC>)
{
wei_k_c_xs_desc = transpose_host_tensor_descriptor_given_new2old(
wei_desc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KYXC>)
{
wei_k_c_xs_desc = transpose_host_tensor_descriptor_given_new2old(
wei_desc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KZYXC>)
{
wei_k_c_xs_desc = transpose_host_tensor_descriptor_given_new2old(
wei_desc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
// output
if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NWK>)
{
out_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
out_desc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NHWK>)
{
out_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
out_desc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NDHWK>)
{
out_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
out_desc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
std::array<ck::index_t, NDimSpatial + 2> a_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 2> a_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 2> b_k_c_xs_lengths{};
std::array<ck::index_t, NDimSpatial + 2> b_k_c_xs_strides{};
std::array<ck::index_t, NDimSpatial + 2> e_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 2> e_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{};
std::array<ck::index_t, NDimSpatial> input_right_pads{};
auto copy = [](auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); };
copy(in_n_c_wis_desc.GetLengths(), a_n_c_wis_lengths);
copy(in_n_c_wis_desc.GetStrides(), a_n_c_wis_strides);
copy(wei_k_c_xs_desc.GetLengths(), b_k_c_xs_lengths);
copy(wei_k_c_xs_desc.GetStrides(), b_k_c_xs_strides);
copy(out_n_k_wos_desc.GetLengths(), e_n_k_wos_lengths);
copy(out_n_k_wos_desc.GetStrides(), e_n_k_wos_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads);
copy(conv_param.input_right_pads_, input_right_pads);
// do GEMM // do GEMM
auto conv = DeviceConvNDFwdInstance{}; auto conv = DeviceConvNDFwdInstance{};
auto invoker = conv.MakeInvoker(); auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()), wei_device_buf.GetDeviceBuffer(),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), std::array<const void*, 0>{},
conv_param.N_, out_device_buf.GetDeviceBuffer(),
conv_param.K_, a_n_c_wis_lengths,
conv_param.C_, a_n_c_wis_strides,
conv_param.input_spatial_lengths_, b_k_c_xs_lengths,
conv_param.filter_spatial_lengths_, b_k_c_xs_strides,
conv_param.GetOutputSpatialLengths(), std::array<std::array<ck::index_t, NDimSpatial + 2>, 0>{{}},
conv_param.conv_filter_strides_, std::array<std::array<ck::index_t, NDimSpatial + 2>, 0>{{}},
conv_param.conv_filter_dilations_, e_n_k_wos_lengths,
conv_param.input_left_pads_, e_n_k_wos_strides,
conv_param.input_right_pads_, conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op); out_element_op);
......
...@@ -20,7 +20,7 @@ namespace device { ...@@ -20,7 +20,7 @@ namespace device {
// E = cde_op(C, D0, D1, ...) // E = cde_op(C, D0, D1, ...)
// Assume: // Assume:
// D0, D1, ... and E have the same layout // D0, D1, ... and E have the same layout
template <ck::index_t NDimSpatial, template <index_t NDimSpatial,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename DsLayout, typename DsLayout,
...@@ -36,23 +36,26 @@ struct DeviceConvFwdMultipleD : public BaseOperator ...@@ -36,23 +36,26 @@ struct DeviceConvFwdMultipleD : public BaseOperator
{ {
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
MakeArgumentPointer(const ADataType* p_a, const void* p_a,
const BDataType* p_b, const void* p_b,
EDataType* p_e, std::array<const void*, NumDTensor> p_ds,
ck::index_t N, void* p_e,
ck::index_t K, const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
ck::index_t C, const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides,
std::vector<ck::index_t> input_spatial_lengths, const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_lengths,
std::vector<ck::index_t> filter_spatial_lengths, const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_strides,
std::vector<ck::index_t> output_spatial_lengths, const std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor>& ds_n_k_wos_lengths,
std::vector<ck::index_t> conv_filter_strides, const std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor>& ds_n_k_wos_strides,
std::vector<ck::index_t> conv_filter_dilations, const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_lengths,
std::vector<ck::index_t> input_left_pads, const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_strides,
std::vector<ck::index_t> input_right_pads, const std::array<index_t, NDimSpatial>& conv_filter_strides,
AElementwiseOperation a_element_op, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
BElementwiseOperation b_element_op, const std::array<index_t, NDimSpatial>& input_left_pads,
CDEElementwiseOperation cde_element_op) = 0; const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment