#pragma once #include #include #include #include "hostdevice.h" namespace ctc_helper { static const float threshold = 1e-1; template HOSTDEVICE T neg_inf() { return -T(INFINITY); } inline int div_up(int x, int y) { return (x + y - 1) / y; } template struct maximum { HOSTDEVICE Res operator()(const Arg &x, const Arg &y) const { return x < y ? y : x; } }; template struct minimum { HOSTDEVICE Res operator()(const Arg &x, const Arg &y) const { return x < y ? x : y; } }; template struct add { HOSTDEVICE Res operator()(const Arg &x, const Arg &y) const { return x + y; } }; template struct identity { HOSTDEVICE Res operator()(const Arg &x) const { return Res(x); } }; template struct negate { HOSTDEVICE Res operator()(const Arg &x) const { return Res(-x); } }; template struct exponential { HOSTDEVICE Res operator()(const Arg &x) const { return std::exp(x); } }; template struct log_plus { typedef Res result_type; HOSTDEVICE Res operator()(const Arg1 &p1, const Arg2 &p2) { if (p1 == neg_inf()) return p2; if (p2 == neg_inf()) return p1; Res result = log1p(exp(-fabs(p1 - p2))) + maximum()(p1, p2); return result; } }; //template //struct log_plus { // HOSTDEVICE // Res operator()(const Arg1& p1, const Arg2& p2) { // Res p12_max = maximum()(p1, p2); // Res p12_min = minimum()(p1, p2); // Res p12_diff = p12_min-p12_max; // Res NEGATIVE_CUTOFF_VAL = -(Res)100000; // // Res result = p12_diff <= NEGATIVE_CUTOFF_VAL ? maximum()(p12_max, NEGATIVE_CUTOFF_VAL) // : maximum()(p12_max + log(exp(p12_diff) + 1), NEGATIVE_CUTOFF_VAL); // // // return result; // } //}; }