Commit c29dc4c5 authored by ltqin's avatar ltqin
Browse files

Merge branch 'develop' into conv_splitk_f32

parents 134af43b fd3d907a
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 0 #define USE_DYNAMIC_MODE 1
#define USE_CONV_FWD_V4R4_NCHW 0 #define USE_CONV_FWD_V4R4_NCHW 0
#define USE_CONV_FWD_V4R4R2_NHWC 0 #define USE_CONV_FWD_V4R4R2_NHWC 0
#define USE_CONV_FWD_V6R1_NCHW 0 #define USE_CONV_FWD_V6R1_NCHW 0
......
#pragma once #pragma once
#include "host_tensor.hpp" #include "host_tensor.hpp"
template <typename AType, typename BType, typename CType> template <typename AType,
typename BType,
typename CType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k, void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
const Tensor<BType>& b_k_n, const Tensor<BType>& b_k_n,
Tensor<CType>& c_m_n) Tensor<CType>& c_m_n,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op)
{ {
auto f_mk_kn_mn = [&](auto m, auto n) { auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = a_m_k.mDesc.GetLengths()[1]; const int K = a_m_k.mDesc.GetLengths()[1];
...@@ -13,10 +21,11 @@ void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k, ...@@ -13,10 +21,11 @@ void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
v += static_cast<const double>(a_m_k(m, k)) * static_cast<const double>(b_k_n(k, n)); v += static_cast<const double>(a_element_op(a_m_k(m, k))) *
static_cast<const double>(b_element_op(b_k_n(k, n)));
} }
c_m_n(m, n) = v; c_m_n(m, n) = c_element_op(v);
}; };
make_ParallelTensorFunctor(f_mk_kn_mn, make_ParallelTensorFunctor(f_mk_kn_mn,
......
...@@ -8,12 +8,17 @@ ...@@ -8,12 +8,17 @@
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_conv.hpp" #include "device_conv.hpp"
#include "device_conv_instance.hpp" #include "device_conv_instance.hpp"
#include "element_wise_operation.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_conv_instance { namespace device_conv_instance {
using DeviceConvFwdNoOpPtr = DeviceConvFwdPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
template <> template <>
void add_device_conv_fwd_instance<2, void add_device_conv_fwd_instance<2,
float, float,
...@@ -22,7 +27,7 @@ void add_device_conv_fwd_instance<2, ...@@ -22,7 +27,7 @@ void add_device_conv_fwd_instance<2,
ck::tensor_layout::convolution::NHWC, ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC, ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK>( ck::tensor_layout::convolution::NHWK>(
std::vector<ck::tensor_operation::device::DeviceConvFwdPtr>&); std::vector<DeviceConvFwdNoOpPtr>&);
template <> template <>
void add_device_conv_fwd_instance<2, void add_device_conv_fwd_instance<2,
...@@ -32,7 +37,7 @@ void add_device_conv_fwd_instance<2, ...@@ -32,7 +37,7 @@ void add_device_conv_fwd_instance<2,
ck::tensor_layout::convolution::NHWC, ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC, ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK>( ck::tensor_layout::convolution::NHWK>(
std::vector<ck::tensor_operation::device::DeviceConvFwdPtr>&); std::vector<DeviceConvFwdNoOpPtr>&);
} // namespace device_conv_instance } // namespace device_conv_instance
} // namespace device } // namespace device
...@@ -133,8 +138,13 @@ void profile_conv(int do_verification, ...@@ -133,8 +138,13 @@ void profile_conv(int do_verification,
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceConvFwdNoOpPtr =
ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>;
// add device Conv instances // add device Conv instances
std::vector<ck::tensor_operation::device::DeviceConvFwdPtr> conv_ptrs; std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
ck::tensor_operation::device::device_conv_instance::add_device_conv_fwd_instance<2, ck::tensor_operation::device::device_conv_instance::add_device_conv_fwd_instance<2,
InDataType, InDataType,
...@@ -170,7 +180,10 @@ void profile_conv(int do_verification, ...@@ -170,7 +180,10 @@ void profile_conv(int do_verification,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads,
PassThrough{},
PassThrough{},
PassThrough{});
auto invoker_ptr = conv_ptr->MakeInvokerPointer(); auto invoker_ptr = conv_ptr->MakeInvokerPointer();
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment