Commit 4b306e5b authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into rework_ector_type

parents 5a1b0857 970fa3e9
...@@ -239,14 +239,10 @@ int main(int argc, char* argv[]) ...@@ -239,14 +239,10 @@ int main(int argc, char* argv[])
using ab_data_t = float; using ab_data_t = float;
using acc_data_t = float; using acc_data_t = float;
using c_data_t = float; using c_data_t = float;
#elif 0 #elif 1
using ab_data_t = half_t; using ab_data_t = half_t;
using acc_data_t = float; using acc_data_t = float;
using c_data_t = half_t; using c_data_t = half_t;
#elif 1
using ab_data_t = ushort;
using acc_data_t = float;
using c_data_t = ushort;
#elif 1 #elif 1
using ab_data_t = int8_t; using ab_data_t = int8_t;
using acc_data_t = int32_t; using acc_data_t = int32_t;
......
...@@ -74,4 +74,17 @@ calculate_convolution_flops(const InDesc&, const WeiDesc& wei_desc, const OutDes ...@@ -74,4 +74,17 @@ calculate_convolution_flops(const InDesc&, const WeiDesc& wei_desc, const OutDes
return std::size_t(2) * N * K * Ho * Wo * C * Y * X; return std::size_t(2) * N * K * Ho * Wo * C * Y * X;
} }
template <typename T>
inline auto activ(T v, const ck::ActivTypeEnum_t activ_type)
{
const T alpha = 0.3;
switch(activ_type)
{
case ck::ActivTypeEnum_t::None: return v;
case ck::ActivTypeEnum_t::LeakyRelu: return (v >= 0 ? v : alpha * v);
case ck::ActivTypeEnum_t::Sigmoid: return (1 / (1 + exp(-v)));
default: throw std::runtime_error("unsupported activ type"); break;
}
}
#endif #endif
...@@ -257,6 +257,18 @@ struct Tensor ...@@ -257,6 +257,18 @@ struct Tensor
mDesc.GetLengths()[3])(num_thread); mDesc.GetLengths()[3])(num_thread);
break; break;
} }
case 5: {
auto f = [&](auto i0, auto i1, auto i2, auto i3, auto i4) {
(*this)(i0, i1, i2, i3, i4) = g(i0, i1, i2, i3, i4);
};
make_ParallelTensorFunctor(f,
mDesc.GetLengths()[0],
mDesc.GetLengths()[1],
mDesc.GetLengths()[2],
mDesc.GetLengths()[3],
mDesc.GetLengths()[4])(num_thread);
break;
}
default: throw std::runtime_error("unspported dimension"); default: throw std::runtime_error("unspported dimension");
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment