options.h 1.92 KB
Newer Older
1
2
3
4
5
6
7
8
#pragma once

//#include <iostream>

#ifdef USE_CUDA
#include <cuda_runtime.h>
#endif // USE_CUDA

9
10
#include <libtorchaudio/rnnt/macros.h>
#include <libtorchaudio/rnnt/types.h>
11
12
13
14
15
16
17

namespace torchaudio {
namespace rnnt {

typedef struct Options {
  // the device to compute transducer loss.
  device_t device_;
18
#ifdef USE_CUDA
19
  // the stream to launch kernels in when using GPU.
20
  cudaStream_t stream_;
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#endif
  // The maximum number of threads that can be used.
  int numThreads_;

  // the index for "blank".
  int blank_;
  // whether to backtrack the best path.
  bool backtrack_;
  // gradient clamp value.
  float clamp_;

  // batch size = B.
  int batchSize_;

  // Number of hypos per sample = H
  int nHypos_;

  // the maximum length of src encodings = max_T.
  int maxSrcLen_;
  // the maximum length of tgt encodings = max_U.
  int maxTgtLen_;
  // num_targets = D.
  int numTargets_;

45
46
47
48
49
50
  // if set to true, inputs are logits and gradients are
  // fused with logsoftmax gradients.
  // if set to false, log_softmax is computed outside of loss
  // True by default
  bool fusedLogSmax_;

51
52
53
54
55
56
57
58
59
60
  Options()
      : device_(UNDEFINED),
        numThreads_(0),
        blank_(-1),
        backtrack_(false),
        clamp_(-1), // negative for disabling clamping by default.
        batchSize_(0),
        nHypos_(1),
        maxSrcLen_(0),
        maxTgtLen_(0),
61
62
        numTargets_(0),
        fusedLogSmax_(true) {}
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84

  int BU() const {
    return batchSize_ * maxTgtLen_ * nHypos_;
  }

  int BTU() const {
    return batchSize_ * maxSrcLen_ * maxTgtLen_ * nHypos_;
  }

  friend std::ostream& operator<<(std::ostream& os, const Options& options) {
    os << "Options("
       << "batchSize_=" << options.batchSize_ << ", "
       << "maxSrcLen_=" << options.maxSrcLen_ << ", "
       << "maxTgtLen_=" << options.maxTgtLen_ << ", "
       << "numTargets_=" << options.numTargets_ << ")";

    return os;
  }
} Options;

} // namespace rnnt
} // namespace torchaudio