Commit c29dc4c5 authored by ltqin's avatar ltqin
Browse files

Merge branch 'develop' into conv_splitk_f32

parents 134af43b fd3d907a
......@@ -18,7 +18,7 @@
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 0
#define USE_DYNAMIC_MODE 1
#define USE_CONV_FWD_V4R4_NCHW 0
#define USE_CONV_FWD_V4R4R2_NHWC 0
#define USE_CONV_FWD_V6R1_NCHW 0
......
#pragma once
#include "host_tensor.hpp"
template <typename AType, typename BType, typename CType>
template <typename AType,
typename BType,
typename CType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
const Tensor<BType>& b_k_n,
Tensor<CType>& c_m_n)
Tensor<CType>& c_m_n,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op)
{
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = a_m_k.mDesc.GetLengths()[1];
......@@ -13,10 +21,11 @@ void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
for(int k = 0; k < K; ++k)
{
v += static_cast<const double>(a_m_k(m, k)) * static_cast<const double>(b_k_n(k, n));
v += static_cast<const double>(a_element_op(a_m_k(m, k))) *
static_cast<const double>(b_element_op(b_k_n(k, n)));
}
c_m_n(m, n) = v;
c_m_n(m, n) = c_element_op(v);
};
make_ParallelTensorFunctor(f_mk_kn_mn,
......
......@@ -8,12 +8,17 @@
#include "device_tensor.hpp"
#include "device_conv.hpp"
#include "device_conv_instance.hpp"
#include "element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_conv_instance {
using DeviceConvFwdNoOpPtr = DeviceConvFwdPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
template <>
void add_device_conv_fwd_instance<2,
float,
......@@ -22,7 +27,7 @@ void add_device_conv_fwd_instance<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK>(
std::vector<ck::tensor_operation::device::DeviceConvFwdPtr>&);
std::vector<DeviceConvFwdNoOpPtr>&);
template <>
void add_device_conv_fwd_instance<2,
......@@ -32,7 +37,7 @@ void add_device_conv_fwd_instance<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK>(
std::vector<ck::tensor_operation::device::DeviceConvFwdPtr>&);
std::vector<DeviceConvFwdNoOpPtr>&);
} // namespace device_conv_instance
} // namespace device
......@@ -133,8 +138,13 @@ void profile_conv(int do_verification,
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceConvFwdNoOpPtr =
ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>;
// add device Conv instances
std::vector<ck::tensor_operation::device::DeviceConvFwdPtr> conv_ptrs;
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
ck::tensor_operation::device::device_conv_instance::add_device_conv_fwd_instance<2,
InDataType,
......@@ -170,7 +180,10 @@ void profile_conv(int do_verification,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
input_right_pads,
PassThrough{},
PassThrough{},
PassThrough{});
auto invoker_ptr = conv_ptr->MakeInvokerPointer();
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment