kaldi-vector.h 12.8 KB
Newer Older
moto's avatar
moto committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h

#ifndef KALDI_MATRIX_KALDI_VECTOR_H_
#define KALDI_MATRIX_KALDI_VECTOR_H_

#include <torch/torch.h>
#include "matrix/matrix-common.h"

using namespace torch::indexing;

namespace kaldi {

// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L36-L40
template <typename Real>
class VectorBase {
 public:
  ////////////////////////////////////////////////////////////////////////////////
  // PyTorch-specific things
  ////////////////////////////////////////////////////////////////////////////////
  torch::Tensor tensor_;

  /// Construct VectorBase which is an interface to an existing torch::Tensor
  /// object.
  VectorBase(torch::Tensor tensor);

  ////////////////////////////////////////////////////////////////////////////////
  // Kaldi-compatible methods
  ////////////////////////////////////////////////////////////////////////////////
  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L42-L43
  void SetZero() {
    Set(0);
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L48-L49
  void Set(Real f) {
    tensor_.index_put_({"..."}, f);
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L62-L63
  inline MatrixIndexT Dim() const {
    return tensor_.numel();
  };

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L68-L69
  inline Real* Data() {
    return data_;
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L71-L72
  inline const Real* Data() const {
    return data_;
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L74-L79
  inline Real operator()(MatrixIndexT i) const {
    return data_[i];
  };

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L81-L86
  inline Real& operator()(MatrixIndexT i) {
    return tensor_.accessor<Real, 1>()[i];
  };

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L88-L95
  SubVector<Real> Range(const MatrixIndexT o, const MatrixIndexT l) {
    return SubVector<Real>(*this, o, l);
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L97-L105
  const SubVector<Real> Range(const MatrixIndexT o, const MatrixIndexT l)
      const {
    return SubVector<Real>(*this, o, l);
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L107-L108
  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L226-L233
  void CopyFromVec(const VectorBase<Real>& v) {
    TORCH_INTERNAL_ASSERT(tensor_.sizes() == v.tensor_.sizes());
    tensor_.copy_(v.tensor_);
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L137-L139
  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L816-L832
  void ApplyFloor(Real floor_val, MatrixIndexT* floored_count = nullptr) {
    auto index = tensor_ < floor_val;
    auto tmp = tensor_.index_put_({index}, floor_val);
    if (floored_count) {
      *floored_count = index.sum().item().template to<MatrixIndexT>();
    }
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L164-L165
  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L449-L479
  void ApplyPow(Real power) {
    tensor_.pow_(power);
    TORCH_INTERNAL_ASSERT(!tensor_.isnan().sum().item().template to<int32_t>());
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L181-L184
  template <typename OtherReal>
  void AddVec(const Real alpha, const VectorBase<OtherReal>& v) {
    tensor_ += alpha * v.tensor_;
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L186-L187
  void AddVec2(const Real alpha, const VectorBase<Real>& v) {
    tensor_ += alpha * (v.tensor_.square());
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L196-L198
  void AddMatVec(
      const Real alpha,
      const MatrixBase<Real>& M,
      const MatrixTransposeType trans,
      const VectorBase<Real>& v,
      const Real beta) { // **beta previously defaulted to 0.0**
    auto mat = M.tensor_;
    if (trans == kTrans) {
      mat = mat.transpose(1, 0);
    }
    tensor_.addmv_(mat, v.tensor_, beta, alpha);
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L221-L222
  void MulElements(const VectorBase<Real>& v) {
    tensor_ *= v.tensor_;
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L233-L234
  void Add(Real c) {
    tensor_ += c;
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L236-L239
  void AddVecVec(
      Real alpha,
      const VectorBase<Real>& v,
      const VectorBase<Real>& r,
      Real beta) {
    tensor_ = beta * tensor_ + alpha * v.tensor_ * r.tensor_;
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L246-L247
  void Scale(Real alpha) {
    tensor_ *= alpha;
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L305-L306
  Real Min() const {
    if (tensor_.numel()) {
      return tensor_.min().item().template to<Real>();
    }
    return std::numeric_limits<Real>::infinity();
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L308-L310
  Real Min(MatrixIndexT* index) const {
    TORCH_INTERNAL_ASSERT(tensor_.numel());
    torch::Tensor value, ind;
    std::tie(value, ind) = tensor_.min(0);
    *index = ind.item().to<MatrixIndexT>();
    return value.item().to<Real>();
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L312-L313
  Real Sum() const {
    return tensor_.sum().item().template to<Real>();
  };

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L320-L321
  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L718-L736
  void AddRowSumMat(Real alpha, const MatrixBase<Real>& M, Real beta = 1.0) {
    Vector<Real> ones(M.NumRows());
    ones.Set(1.0);
    this->AddMatVec(alpha, M, kTrans, ones, beta);
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L323-L324
  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L738-L757
  void AddColSumMat(Real alpha, const MatrixBase<Real>& M, Real beta = 1.0) {
    Vector<Real> ones(M.NumCols());
    ones.Set(1.0);
    this->AddMatVec(alpha, M, kNoTrans, ones, beta);
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L326-L330
  void AddDiagMat2(
      Real alpha,
      const MatrixBase<Real>& M,
      MatrixTransposeType trans = kNoTrans,
      Real beta = 1.0) {
    auto mat = M.tensor_;
    if (trans == kNoTrans) {
      tensor_ =
          beta * tensor_ + torch::diag(torch::mm(mat, mat.transpose(1, 0)));
    } else {
      tensor_ =
          beta * tensor_ + torch::diag(torch::mm(mat.transpose(1, 0), mat));
    }
  }

 protected:
  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L362-L365
  explicit VectorBase();

  //  https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L378-L379
  Real* data_;
  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L382
  KALDI_DISALLOW_COPY_AND_ASSIGN(VectorBase);
};

// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L385-L390
template <typename Real>
class Vector : public VectorBase<Real> {
 public:
  ////////////////////////////////////////////////////////////////////////////////
  // PyTorch-compatibility things
  ////////////////////////////////////////////////////////////////////////////////
  /// Construct VectorBase which is an interface to an existing torch::Tensor
  /// object.
  Vector(torch::Tensor tensor) : VectorBase<Real>(tensor){};

  ////////////////////////////////////////////////////////////////////////////////
  // Kaldi-compatible methods
  ////////////////////////////////////////////////////////////////////////////////
  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L392-L393
  Vector() : VectorBase<Real>(){};

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L395-L399
  explicit Vector(const MatrixIndexT s, MatrixResizeType resize_type = kSetZero)
      : VectorBase<Real>() {
    Resize(s, resize_type);
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L406-L410
  // Note: unlike the original implementation, this is "explicit".
  explicit Vector(const Vector<Real>& v)
      : VectorBase<Real>(v.tensor_.clone()) {}

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L412-L416
  explicit Vector(const VectorBase<Real>& v)
      : VectorBase<Real>(v.tensor_.clone()) {}

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L434-L435
  void Swap(Vector<Real>* other) {
    auto tmp = VectorBase<Real>::tensor_;
    this->tensor_ = other->tensor_;
    other->tensor_ = tmp;
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L444-L451
  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L189-L223
  void Resize(MatrixIndexT length, MatrixResizeType resize_type = kSetZero) {
    auto& tensor_ = this->tensor_;
    switch (resize_type) {
      case kSetZero:
        tensor_.resize_({length}).zero_();
        break;
      case kUndefined:
        tensor_.resize_({length});
        break;
      case kCopyData:
        auto tmp = tensor_;
        auto tmp_numel = tensor_.numel();
        tensor_.resize_({length}).zero_();
        auto numel = Slice(length < tmp_numel ? length : tmp_numel);
        tensor_.index_put_({numel}, tmp.index({numel}));
        break;
    }
    // data_ptr<Real>() causes compiler error
    this->data_ = static_cast<Real*>(tensor_.data_ptr());
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L463-L468
  Vector<Real>& operator=(const VectorBase<Real>& other) {
    Resize(other.Dim(), kUndefined);
    this->CopyFromVec(other);
    return *this;
  }
};

// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L482-L485
template <typename Real>
class SubVector : public VectorBase<Real> {
 public:
  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L487-L499
  SubVector(
      const VectorBase<Real>& t,
      const MatrixIndexT origin,
      const MatrixIndexT length)
      : VectorBase<Real>(t.tensor_.index({Slice(origin, origin + length)})) {}

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L524-L528
  SubVector(const MatrixBase<Real>& matrix, MatrixIndexT row)
      : VectorBase<Real>(matrix.tensor_.index({row})) {}
};

// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L540-L543
template <typename Real>
std::ostream& operator<<(std::ostream& out, const VectorBase<Real>& v) {
  out << v.tensor_;
  return out;
}

// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L573-L575
template <typename Real>
Real VecVec(const VectorBase<Real>& v1, const VectorBase<Real>& v2) {
  return torch::dot(v1.tensor_, v2.tensor_).item().template to<Real>();
}

} // namespace kaldi

#endif