kernel_utils.h 1.34 KB
Newer Older
1
2
3
4
#pragma once

#include <cassert>

5
#include <libtorchaudio/rnnt/gpu/math.cuh>
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

namespace torchaudio {
namespace rnnt {

inline HOST_AND_DEVICE bool in_range(
    int start,
    int end, // inclusive
    int val) {
  return start <= val && val <= end;
}

#define LOG_PROBS_SKIP_IDX 0
#define LOG_PROBS_EMIT_IDX 1

struct Indexer2D {
  const int& size2_;

  FORCE_INLINE HOST_AND_DEVICE Indexer2D(const int& size2) : size2_(size2) {}

  FORCE_INLINE HOST_AND_DEVICE int operator()(int index1, int index2) {
    return index1 * size2_ + index2;
  }
};

struct Indexer3D {
  const int& size2_;
  const int& size3_;

  FORCE_INLINE HOST_AND_DEVICE Indexer3D(const int& size2, const int& size3)
      : size2_(size2), size3_(size3) {}

  FORCE_INLINE HOST_AND_DEVICE int operator()(
      int index1,
      int index2,
      int index3) {
    return (index1 * size2_ + index2) * size3_ + index3;
  }
};

struct Indexer4D {
  const int& size2_;
  const int& size3_;
  const int& size4_;

  HOST_AND_DEVICE Indexer4D(
      const int& size2,
      const int& size3,
      const int& size4)
      : size2_(size2), size3_(size3), size4_(size4) {}

  HOST_AND_DEVICE int operator()(
      int index1,
      int index2,
      int index3,
      int index4) {
    return ((index1 * size2_ + index2) * size3_ + index3) * size4_ + index4;
  }
};

} // namespace rnnt
} // namespace torchaudio