kaldi-matrix.h 7.04 KB
Newer Older
moto's avatar
moto committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h

#ifndef KALDI_MATRIX_KALDI_MATRIX_H_
#define KALDI_MATRIX_KALDI_MATRIX_H_

#include <torch/torch.h>
#include "matrix/kaldi-vector.h"
#include "matrix/matrix-common.h"

using namespace torch::indexing;

namespace kaldi {

// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L44-L48
template <typename Real>
class MatrixBase {
 public:
  ////////////////////////////////////////////////////////////////////////////////
  // PyTorch-specific items
  ////////////////////////////////////////////////////////////////////////////////
  torch::Tensor tensor_;
  /// Construct VectorBase which is an interface to an existing torch::Tensor
  /// object.
  MatrixBase(torch::Tensor tensor);

  ////////////////////////////////////////////////////////////////////////////////
  // Kaldi-compatible items
  ////////////////////////////////////////////////////////////////////////////////
  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L62-L63
  inline MatrixIndexT NumRows() const {
    return tensor_.size(0);
  };

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L65-L66
  inline MatrixIndexT NumCols() const {
    return tensor_.size(1);
  };

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L177-L178
  void CopyColFromVec(const VectorBase<Real>& v, const MatrixIndexT col) {
    tensor_.index_put_({Slice(), col}, v.tensor_);
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L99-L107
  inline Real& operator()(MatrixIndexT r, MatrixIndexT c) {
    // CPU only
    return tensor_.accessor<Real, 2>()[r][c];
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L112-L120
  inline const Real operator()(MatrixIndexT r, MatrixIndexT c) const {
    return tensor_.index({Slice(r), Slice(c)}).item().template to<Real>();
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L138-L141
  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.cc#L859-L898
  template <typename OtherReal>
  void CopyFromMat(
      const MatrixBase<OtherReal>& M,
      MatrixTransposeType trans = kNoTrans) {
    auto src = M.tensor_;
    if (trans == kTrans)
      src = src.transpose(1, 0);
    tensor_.index_put_({Slice(), Slice()}, src);
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L186-L191
  inline const SubVector<Real> Row(MatrixIndexT i) const {
    return SubVector<Real>(*this, i);
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L208-L211
  inline SubMatrix<Real> RowRange(
      const MatrixIndexT row_offset,
      const MatrixIndexT num_rows) const {
    return SubMatrix<Real>(*this, row_offset, num_rows, 0, NumCols());
  }

 protected:
  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L749-L753
  explicit MatrixBase() : tensor_(torch::empty({0, 0})) {
    KALDI_ASSERT_IS_FLOATING_TYPE(Real);
  }
};

// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L781-L784
template <typename Real>
class Matrix : public MatrixBase<Real> {
 public:
  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L786-L787
  Matrix() : MatrixBase<Real>() {}

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L789-L793
  Matrix(
      const MatrixIndexT r,
      const MatrixIndexT c,
      MatrixResizeType resize_type = kSetZero,
      MatrixStrideType stride_type = kDefaultStride)
      : MatrixBase<Real>() {
    Resize(r, c, resize_type, stride_type);
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L808-L811
  explicit Matrix(
      const MatrixBase<Real>& M,
      MatrixTransposeType trans = kNoTrans)
      : MatrixBase<Real>(
            trans == kNoTrans ? M.tensor_ : M.tensor_.transpose(1, 0)) {}

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L816-L819
  template <typename OtherReal>
  explicit Matrix(
      const MatrixBase<OtherReal>& M,
      MatrixTransposeType trans = kNoTrans)
      : MatrixBase<Real>(
            trans == kNoTrans ? M.tensor_ : M.tensor_.transpose(1, 0)) {}

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L859-L874
  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.cc#L817-L857
  void Resize(
      const MatrixIndexT r,
      const MatrixIndexT c,
      MatrixResizeType resize_type = kSetZero,
      MatrixStrideType stride_type = kDefaultStride) {
    auto& tensor_ = MatrixBase<Real>::tensor_;
    switch (resize_type) {
      case kSetZero:
        tensor_.resize_({r, c}).zero_();
        break;
      case kUndefined:
        tensor_.resize_({r, c});
        break;
      case kCopyData:
        auto tmp = tensor_;
        auto tmp_rows = tmp.size(0);
        auto tmp_cols = tmp.size(1);
        tensor_.resize_({r, c}).zero_();
        auto rows = Slice(None, r < tmp_rows ? r : tmp_rows);
        auto cols = Slice(None, c < tmp_cols ? c : tmp_cols);
        tensor_.index_put_({rows, cols}, tmp.index({rows, cols}));
        break;
    }
  }

  // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L876-L883
  Matrix<Real>& operator=(const MatrixBase<Real>& other) {
    if (MatrixBase<Real>::NumRows() != other.NumRows() ||
        MatrixBase<Real>::NumCols() != other.NumCols())
      Resize(other.NumRows(), other.NumCols(), kUndefined);
    MatrixBase<Real>::CopyFromMat(other);
    return *this;
  }
};

// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L940-L948
template <typename Real>
class SubMatrix : public MatrixBase<Real> {
 public:
  SubMatrix(
      const MatrixBase<Real>& T,
      const MatrixIndexT ro, // row offset, 0 < ro < NumRows()
      const MatrixIndexT r, // number of rows, r > 0
      const MatrixIndexT co, // column offset, 0 < co < NumCols()
      const MatrixIndexT c) // number of columns, c > 0
      : MatrixBase<Real>(
            T.tensor_.index({Slice(ro, ro + r), Slice(co, co + c)})) {}
};

// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L1059-L1060
template <typename Real>
std::ostream& operator<<(std::ostream& Out, const MatrixBase<Real>& M) {
  Out << M.tensor_;
  return Out;
}

} // namespace kaldi

#endif