register.h 3.22 KB
Newer Older
SWHL's avatar
SWHL committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Classes for registering derived FST for generic reading.

#ifndef FST_REGISTER_H_
#define FST_REGISTER_H_

#include <string>
#include <type_traits>


#include <fst/compat.h>
#include <fst/generic-register.h>
#include <fst/util.h>


#include <fst/types.h>
#include <fst/log.h>

namespace fst {

template <class Arc>
class Fst;

struct FstReadOptions;

// This class represents a single entry in a FstRegister
template <class Arc>
struct FstRegisterEntry {
  using Reader = Fst<Arc> *(*)(std::istream &istrm, const FstReadOptions &opts);
  using Converter = Fst<Arc> *(*)(const Fst<Arc> &fst);

  Reader reader;
  Converter converter;

  explicit FstRegisterEntry(Reader reader = nullptr,
                            Converter converter = nullptr)
      : reader(reader), converter(converter) {}
};

// This class maintains the correspondence between a string describing
// an FST type, and its reader and converter.
template <class Arc>
class FstRegister
    : public GenericRegister<string, FstRegisterEntry<Arc>, FstRegister<Arc>> {
 public:
  using Reader = typename FstRegisterEntry<Arc>::Reader;
  using Converter = typename FstRegisterEntry<Arc>::Converter;

  const Reader GetReader(const string &type) const {
    return this->GetEntry(type).reader;
  }

  const Converter GetConverter(const string &type) const {
    return this->GetEntry(type).converter;
  }

 protected:
  string ConvertKeyToSoFilename(const string &key) const override {
    string legal_type(key);
    ConvertToLegalCSymbol(&legal_type);
    return legal_type + "-fst.so";
  }
};

// This class registers an FST type for generic reading and creating.
// The type must have a default constructor and a copy constructor from
// Fst<Arc>.
template <class FST>
class FstRegisterer : public GenericRegisterer<FstRegister<typename FST::Arc>> {
 public:
  using Arc = typename FST::Arc;
  using Entry = typename FstRegister<Arc>::Entry;
  using Reader = typename FstRegister<Arc>::Reader;

  FstRegisterer()
      : GenericRegisterer<FstRegister<typename FST::Arc>>(FST().Type(),
                                                          BuildEntry()) {}

 private:
  static Fst<Arc> *ReadGeneric(
      std::istream &strm, const FstReadOptions &opts) {
    static_assert(std::is_base_of<Fst<Arc>, FST>::value,
                  "FST class does not inherit from Fst<Arc>");
    return FST::Read(strm, opts);
  }

  static Entry BuildEntry() {
    return Entry(&ReadGeneric, &FstRegisterer<FST>::Convert);
  }

  static Fst<Arc> *Convert(const Fst<Arc> &fst) { return new FST(fst); }
};

// Convenience macro to generate static FstRegisterer instance.
#define REGISTER_FST(FST, Arc) \
  static fst::FstRegisterer<FST<Arc>> FST##_##Arc##_registerer

// Converts an FST to the specified type.
template <class Arc>
Fst<Arc> *Convert(const Fst<Arc> &fst, const string &fst_type) {
  auto *reg = FstRegister<Arc>::GetRegister();
  const auto converter = reg->GetConverter(fst_type);
  if (!converter) {
    FSTERROR() << "Fst::Convert: Unknown FST type " << fst_type << " (arc type "
               << Arc::Type() << ")";
    return nullptr;
  }
  return converter(fst);
}

}  // namespace fst

#endif  // FST_REGISTER_H_