"src/diffusers/pipelines/allegro/pipeline_allegro.py" did not exist on "a5720e9e3124753c85b2260dec5f39d75ce18245"
workspace.h 5.87 KB
Newer Older
1
2
3
4
5
#pragma once

#include <cstring>
#include <vector>

6
7
8
#include <libtorchaudio/rnnt/options.h>

#include <c10/util/Logging.h>
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

namespace torchaudio {
namespace rnnt {

// Since CUDA has strict memory alignment, it's better to keep allocated memory
// blocks separate for different data types.

// DtypeWorkspace holds a "view" of  workspace for:
//     1. softmax denominators (in log form), size = B * max_T * max_U
//     2. log probibility pairs for blank and target, size = B * max_T * max_U
//     3. alphas, size = B * max_T * max_U
//     4. betas, size = B * max_T * max_U
template <typename DTYPE>
class DtypeWorkspace {
 public:
  DtypeWorkspace() : options_(), size_(0), data_(nullptr) {}
  DtypeWorkspace(const Options& options, DTYPE* data, int size)
      : DtypeWorkspace() {
    Reset(options, data, size);
  }
  ~DtypeWorkspace() {}

  static int ComputeSizeFromOptions(const Options& options) {
32
    TORCH_CHECK_NE(options.device_, UNDEFINED);
33
34
35
36
37
38
39
40
    return ComputeSizeForDenominators(options) +
        ComputeSizeForLogProbs(options) + ComputeSizeForAlphas(options) +
        ComputeSizeForBetas(options);
  }

  void Free();
  void Reset(const Options& options, DTYPE* data, int size) {
    int needed_size = ComputeSizeFromOptions(options);
41
    TORCH_CHECK_LE(needed_size, size);
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    options_ = options;
    data_ = data;
    size_ = size;
  }
  int Size() const {
    return size_;
  }

  DTYPE* GetPointerToDenominators() const {
    return data_;
  }
  DTYPE* GetPointerToLogProbs() const {
    return GetPointerToDenominators() + ComputeSizeForDenominators(options_);
  }
  DTYPE* GetPointerToAlphas() const {
    return GetPointerToLogProbs() + ComputeSizeForLogProbs(options_);
  }
  DTYPE* GetPointerToBetas() const {
    return GetPointerToAlphas() + ComputeSizeForAlphas(options_);
  }

 private:
  static int ComputeSizeForDenominators(const Options& options) { // B * T * U
    return options.BTU();
  }

  static int ComputeSizeForLogProbs(const Options& options) { // B * T * U * 2
    return options.BTU() * 2;
  }

  static int ComputeSizeForAlphas(const Options& options) { // B * T * U
    return options.BTU();
  }

  static int ComputeSizeForBetas(const Options& options) { // B * T * U
    return options.BTU();
  }

  Options options_;
  int size_; // number of elements in allocated memory.
  DTYPE* data_; // pointer to the allocated memory.
};

// IntWorkspace holds a "view" of workspace for:
//     1. alpha counters, size = B * max_U
//     2. beta counters, size = B * max_U
class IntWorkspace {
 public:
  IntWorkspace() : options_(), size_(0), data_(nullptr) {}
  IntWorkspace(const Options& options, int* data, int size) : IntWorkspace() {
    Reset(options, data, size);
  }
  ~IntWorkspace() {}

  static int ComputeSizeFromOptions(const Options& options) {
    return ComputeSizeForAlphaCounters(options) +
        ComputeSizeForBetaCounters(options);
  }

  void Reset(const Options& options, int* data, int size) {
    int needed_size = ComputeSizeFromOptions(options);
103
    TORCH_CHECK_LE(needed_size, size);
104
105
106
107
108
109
110
111
112
113
    options_ = options;
    data_ = data;
    size_ = size;
    ResetAlphaBetaCounters();
  }
  int Size() const {
    return size_;
  }

  int* GetPointerToAlphaCounters() const {
114
    TORCH_CHECK_EQ(options_.device_, GPU);
115
116
117
    return data_;
  }
  int* GetPointerToBetaCounters() const {
118
    TORCH_CHECK_EQ(options_.device_, GPU);
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    return GetPointerToAlphaCounters() + ComputeSizeForAlphaCounters(options_);
  }

 private:
  inline void ResetAlphaBetaCounters() {
#ifdef USE_CUDA
    if (data_ != nullptr && options_.device_ == GPU) {
      cudaMemset(
          GetPointerToAlphaCounters(),
          0,
          ComputeSizeForAlphaCounters(options_) * sizeof(int));
      cudaMemset(
          GetPointerToBetaCounters(),
          0,
          ComputeSizeForBetaCounters(options_) * sizeof(int));
    }
#endif // USE_CUDA
  }

  static int ComputeSizeForAlphaCounters(const Options& options) { // B * U
139
#ifdef USE_CUDA
140
141
142
143
144
145
146
147
148
149
    if (options.device_ == GPU) {
      return options.BU();
    } else {
      return 0;
    }
#else
    return 0;
#endif // USE_CUDA
  }
  static int ComputeSizeForBetaCounters(const Options& options) { // B * U
150
#ifdef USE_CUDA
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    if (options.device_ == GPU) {
      return options.BU();
    } else {
      return 0;
    }
#else
    return 0;
#endif // USE_CUDA
  }

  Options options_;
  int size_; // number of elements in allocated memory.
  int* data_; // pointer to the allocated memory.
};

// Workspace<DTYPE> holds:
//     1. DtypeWorkspace<DTYPE>
//     2. IntWorkspace
template <typename DTYPE>
class Workspace {
 public:
  Workspace() : options_(), dtype_workspace_(), int_workspace_() {}
  Workspace(
      const Options& options,
      DTYPE* dtype_data,
      int dtype_size,
      int* int_data,
      int int_size)
      : Workspace() {
    Reset(options, dtype_data, dtype_size, int_data, int_size);
  }
  ~Workspace() {}

  void Reset(
      const Options& options,
      DTYPE* dtype_data,
      int dtype_size,
      int* int_data,
      int int_size) {
    options_ = options;
    dtype_workspace_.Reset(options_, dtype_data, dtype_size);
    int_workspace_.Reset(options_, int_data, int_size);
  }

  const Options& GetOptions() const {
    return options_;
  }

  DTYPE* GetPointerToDenominators() const {
    return dtype_workspace_.GetPointerToDenominators();
  }
  DTYPE* GetPointerToLogProbs() const {
    return dtype_workspace_.GetPointerToLogProbs();
  }
  DTYPE* GetPointerToAlphas() const {
    return dtype_workspace_.GetPointerToAlphas();
  }
  DTYPE* GetPointerToBetas() const {
    return dtype_workspace_.GetPointerToBetas();
  }
  int* GetPointerToAlphaCounters() const {
    return int_workspace_.GetPointerToAlphaCounters();
  }
  int* GetPointerToBetaCounters() const {
    return int_workspace_.GetPointerToBetaCounters();
  }

 private:
  Options options_;
  DtypeWorkspace<DTYPE> dtype_workspace_;
  IntWorkspace int_workspace_;
};

} // namespace rnnt
} // namespace torchaudio