ctc_helper.h 2.41 KB
Newer Older
lishen's avatar
lishen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#pragma once

#include <limits>
#include <algorithm>
#include <cmath>

#include "hostdevice.h"

namespace ctc_helper {

    static const float threshold = 1e-1;

    template<typename T>
    HOSTDEVICE
    T neg_inf() { return -T(INFINITY); }

    inline int div_up(int x, int y) {
        return (x + y - 1) / y;
    }

    template<typename Arg, typename Res = Arg>
    struct maximum {
        HOSTDEVICE
        Res operator()(const Arg &x, const Arg &y) const {
            return x < y ? y : x;
        }
    };

    template<typename Arg, typename Res = Arg>
    struct minimum {
        HOSTDEVICE
        Res operator()(const Arg &x, const Arg &y) const {
            return x < y ? x : y;
        }
    };

    template<typename Arg, typename Res = Arg>
    struct add {
        HOSTDEVICE
        Res operator()(const Arg &x, const Arg &y) const {
            return x + y;
        }
    };

    template<typename Arg, typename Res = Arg>
    struct identity {
        HOSTDEVICE Res operator()(const Arg &x) const {
            return Res(x);
        }
    };

    template<typename Arg, typename Res = Arg>
    struct negate {
        HOSTDEVICE Res operator()(const Arg &x) const {
            return Res(-x);
        }
    };

    template<typename Arg, typename Res = Arg>
    struct exponential {
        HOSTDEVICE Res operator()(const Arg &x) const { return std::exp(x); }
    };

    template<typename Arg1, typename Arg2 = Arg1, typename Res=Arg1>
    struct log_plus {
        typedef Res result_type;
        HOSTDEVICE
        Res operator()(const Arg1 &p1, const Arg2 &p2) {
            if (p1 == neg_inf<Arg1>())
                return p2;
            if (p2 == neg_inf<Arg2>())
                return p1;
            Res result = log1p(exp(-fabs(p1 - p2))) + maximum<Res>()(p1, p2);
            return result;
        }
    };

//template<typename Arg1, typename Arg2 = Arg1, typename Res=Arg1>
//struct log_plus {
//    HOSTDEVICE
//    Res operator()(const Arg1& p1, const Arg2& p2) {
//        Res p12_max = maximum<Res>()(p1, p2);
//        Res p12_min = minimum<Res>()(p1, p2);
//        Res p12_diff = p12_min-p12_max;
//        Res NEGATIVE_CUTOFF_VAL = -(Res)100000;
//
//        Res result = p12_diff <= NEGATIVE_CUTOFF_VAL ? maximum<Res>()(p12_max, NEGATIVE_CUTOFF_VAL)
//                                        : maximum<Res>()(p12_max + log(exp(p12_diff) + 1), NEGATIVE_CUTOFF_VAL);
//
//
//        return result;
//    }
//};

}