math.h 746 Bytes
Newer Older
1
2
#pragma once

3
4
#include <libtorchaudio/rnnt/macros.h>
#include <math.h>
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

namespace torchaudio {
namespace rnnt {

namespace math {

template <typename DTYPE>
FORCE_INLINE HOST_AND_DEVICE DTYPE max(DTYPE x, DTYPE y) {
  if (x > y)
    return x;
  else
    return y;
}

template <typename DTYPE>
FORCE_INLINE HOST_AND_DEVICE DTYPE min(DTYPE x, DTYPE y) {
  if (x > y)
    return y;
  else
    return x;
}

// log_sum_exp
template <typename DTYPE>
FORCE_INLINE HOST_AND_DEVICE DTYPE lse(DTYPE x, DTYPE y);

template <>
FORCE_INLINE HOST_AND_DEVICE float lse(float x, float y) {
  if (y > x) {
    return y + log1pf(expf(x - y));
  } else {
    return x + log1pf(expf(y - x));
  }
}

} // namespace math

} // namespace rnnt
} // namespace torchaudio