"examples/git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "f29b3f8d3859b8249f15e4835c1f485e4c841ffc"
Commit 8c85a3e4 authored by Chao Liu's avatar Chao Liu
Browse files

update cpu verification

parent 9b363cf5
#include <boost/range/adaptor/transformed.hpp>
#include <cassert> #include <cassert>
#include "host_tensor.hpp" #include "host_tensor.hpp"
...@@ -26,8 +25,12 @@ std::size_t HostTensorDescriptor::GetElementSize() const ...@@ -26,8 +25,12 @@ std::size_t HostTensorDescriptor::GetElementSize() const
std::size_t HostTensorDescriptor::GetElementSpace() const std::size_t HostTensorDescriptor::GetElementSpace() const
{ {
auto ls = mLens | boost::adaptors::transformed([](std::size_t v) { return v - 1; }); std::size_t space = 1;
return std::inner_product(ls.begin(), ls.end(), mStrides.begin(), std::size_t{0}) + 1; for(int i = 0; i < mLens.size(); ++i)
{
space += (mLens[i] - 1) * mStrides[i];
}
return space;
} }
const std::vector<std::size_t>& HostTensorDescriptor::GetLengths() const { return mLens; } const std::vector<std::size_t>& HostTensorDescriptor::GetLengths() const { return mLens; }
......
...@@ -30,6 +30,84 @@ void add_device_conv2d_fwd_bias_relu_xdl_nhwc_kyxc_nhwk_fp16_instances( ...@@ -30,6 +30,84 @@ void add_device_conv2d_fwd_bias_relu_xdl_nhwc_kyxc_nhwk_fp16_instances(
namespace ck { namespace ck {
namespace profiler { namespace profiler {
void cpu_conv_bias_relu(ck::half_t* in_ptr,
ck::half_t* weight_ptr,
ck::half_t* output_ptr,
ck::half_t* bias_ptr,
const ck::index_t N,
const ck::index_t K,
const ck::index_t C,
const ck::index_t Y,
const ck::index_t X,
const ck::index_t Hi,
const ck::index_t Wi,
const ck::index_t Ho,
const ck::index_t Wo,
const ck::index_t Stride,
const ck::index_t Dilation,
const ck::index_t Pad)
{
const auto in_desc =
HostTensorDescriptor(std::vector<std::size_t>{static_cast<std::size_t>(N),
static_cast<std::size_t>(Hi),
static_cast<std::size_t>(Wi),
static_cast<std::size_t>(C)});
const auto wei_desc =
HostTensorDescriptor(std::vector<std::size_t>{static_cast<std::size_t>(K),
static_cast<std::size_t>(Y),
static_cast<std::size_t>(X),
static_cast<std::size_t>(C)});
const auto out_desc =
HostTensorDescriptor(std::vector<std::size_t>{static_cast<std::size_t>(N),
static_cast<std::size_t>(Ho),
static_cast<std::size_t>(Wo),
static_cast<std::size_t>(K)});
const auto bias_desc =
HostTensorDescriptor(std::vector<std::size_t>{static_cast<std::size_t>(K)});
auto f_k = [&](auto k) {
for(int n = 0; n < N; ++n)
{
for(int ho = 0; ho < Ho; ++ho)
{
for(int wo = 0; wo < Wo; ++wo)
{
double v = 0;
for(int c = 0; c < C; ++c)
{
for(int y = 0; y < Y; ++y)
{
int hi = ho * Stride + y * Dilation - Pad;
for(int x = 0; x < X; ++x)
{
int wi = wo * Stride + x * Dilation - Pad;
if(hi >= 0 && hi < Hi && wi >= 0 && wi < Wi)
{
double in =
in_ptr[in_desc.GetOffsetFromMultiIndex(n, hi, wi, c)];
double wei =
weight_ptr[wei_desc.GetOffsetFromMultiIndex(k, y, x, c)];
v += in * wei;
}
}
}
}
v += bias_ptr[bias_desc.GetOffsetFromMultiIndex(k)];
v = v > 0 ? v : 0;
output_ptr[out_desc.GetOffsetFromMultiIndex(n, ho, wo, k)] = v;
}
}
}
};
make_ParallelTensorFunctor(f_k, K)(std::thread::hardware_concurrency());
}
template <typename TIn, template <typename TIn,
typename TWei, typename TWei,
typename TOut, typename TOut,
...@@ -164,6 +242,7 @@ void profile_conv_fwd_bias_relu_impl(int do_verification, ...@@ -164,6 +242,7 @@ void profile_conv_fwd_bias_relu_impl(int do_verification,
if(do_verification) if(do_verification)
{ {
#if 0
host_reference_calculation(in_n_c_hi_wi, host_reference_calculation(in_n_c_hi_wi,
wei_k_c_y_x, wei_k_c_y_x,
out_n_k_ho_wo_host_result, out_n_k_ho_wo_host_result,
...@@ -175,6 +254,24 @@ void profile_conv_fwd_bias_relu_impl(int do_verification, ...@@ -175,6 +254,24 @@ void profile_conv_fwd_bias_relu_impl(int do_verification,
InElementOp{}, InElementOp{},
WeiElementOp{}, WeiElementOp{},
OutElementOp{}); OutElementOp{});
#else
cpu_conv_bias_relu(in_n_c_hi_wi.mData.data(),
wei_k_c_y_x.mData.data(),
out_n_k_ho_wo_host_result.mData.data(),
bias_k.mData.data(),
N,
K,
C,
Y,
X,
Hi,
Wi,
Ho,
Wo,
conv_filter_strides[0],
conv_filter_dilations[0],
input_left_pads[0]);
#endif
} }
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment