params.h 7.05 KB
Newer Older
yangql's avatar
yangql committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
#ifndef DECODER_PARAMS_H_
#define DECODER_PARAMS_H_

#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "decoder/asr_decoder.h"
#include "decoder/onnx_asr_model.h"
#include "frontend/feature_pipeline.h"
#include "post_processor/post_processor.h"
#include "utils/flags.h"
#include "utils/string.h"

DEFINE_int32(device_id, 0, "set XPU DeviceID for ASR model");

// TorchAsrModel flags
DEFINE_string(model_path, "", "pytorch exported model path");
// OnnxAsrModel flags
DEFINE_string(onnx_dir, "", "directory where the onnx model is saved");
// XPUAsrModel flags
DEFINE_string(xpu_model_dir, "",
              "directory where the XPU model and weights is saved");
// BPUAsrModel flags
DEFINE_string(bpu_model_dir, "",
              "directory where the HORIZON BPU model is saved");
// OVAsrModel flags
DEFINE_string(openvino_dir, "", "directory where the OV model is saved");
DEFINE_int32(core_number, 1, "Core number of process");

// FeaturePipelineConfig flags
DEFINE_int32(num_bins, 80, "num mel bins for fbank feature");
DEFINE_int32(sample_rate, 16000, "sample rate for audio");

// TLG fst
DEFINE_string(fst_path, "", "TLG fst path");

// DecodeOptions flags
DEFINE_int32(chunk_size, 16, "decoding chunk size");
DEFINE_int32(num_left_chunks, -1, "left chunks in decoding");
DEFINE_double(ctc_weight, 0.5,
              "ctc weight when combining ctc score and rescoring score");
DEFINE_double(rescoring_weight, 1.0,
              "rescoring weight when combining ctc score and rescoring score");
DEFINE_double(reverse_weight, 0.0,
              "used for bitransformer rescoring. it must be 0.0 if decoder is"
              "conventional transformer decoder, and only reverse_weight > 0.0"
              "dose the right to left decoder will be calculated and used");
DEFINE_int32(max_active, 7000, "max active states in ctc wfst search");
DEFINE_int32(min_active, 200, "min active states in ctc wfst search");
DEFINE_double(beam, 16.0, "beam in ctc wfst search");
DEFINE_double(lattice_beam, 10.0, "lattice beam in ctc wfst search");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale for ctc wfst search");
DEFINE_double(blank_skip_thresh, 1.0,
              "blank skip thresh for ctc wfst search, 1.0 means no skip");
DEFINE_double(blank_scale, 1.0, "blank scale for ctc wfst search");
DEFINE_double(length_penalty, 0.0,
              "length penalty ctc wfst search, will not"
              "apply on self-loop arc, for balancing the del/ins ratio, "
              "suggest set to -3.0");
DEFINE_int32(nbest, 10, "nbest for ctc wfst or prefix search");

// SymbolTable flags
DEFINE_string(dict_path, "",
              "dict symbol table path, required when LM is enabled");
DEFINE_string(unit_path, "./units.txt",
              "e2e model unit symbol table, it is used in both "
              "with/without LM scenarios for context/timestamp");

// Context flags
DEFINE_string(context_path, "", "context path, is used to build context graph");
DEFINE_double(context_score, 3.0, "is used to rescore the decoded result");

// PostProcessOptions flags
DEFINE_int32(language_type, 0,
             "remove spaces according to language type"
             "0x00 = kMandarinEnglish, "
             "0x01 = kIndoEuropean");
DEFINE_bool(lowercase, true, "lowercase final result if needed");

namespace wenet {
std::shared_ptr<FeaturePipelineConfig> InitFeaturePipelineConfigFromFlags() {
  auto feature_config = std::make_shared<FeaturePipelineConfig>(
      FLAGS_num_bins, FLAGS_sample_rate);
  return feature_config;
}

std::shared_ptr<DecodeOptions> InitDecodeOptionsFromFlags() {
  auto decode_config = std::make_shared<DecodeOptions>();
  decode_config->chunk_size = FLAGS_chunk_size;
  decode_config->num_left_chunks = FLAGS_num_left_chunks;
  decode_config->ctc_weight = FLAGS_ctc_weight;
  decode_config->reverse_weight = FLAGS_reverse_weight;
  decode_config->rescoring_weight = FLAGS_rescoring_weight;
  decode_config->ctc_wfst_search_opts.max_active = FLAGS_max_active;
  decode_config->ctc_wfst_search_opts.min_active = FLAGS_min_active;
  decode_config->ctc_wfst_search_opts.beam = FLAGS_beam;
  decode_config->ctc_wfst_search_opts.lattice_beam = FLAGS_lattice_beam;
  decode_config->ctc_wfst_search_opts.acoustic_scale = FLAGS_acoustic_scale;
  decode_config->ctc_wfst_search_opts.blank_skip_thresh =
      FLAGS_blank_skip_thresh;
  decode_config->ctc_wfst_search_opts.blank_scale = FLAGS_blank_scale;
  decode_config->ctc_wfst_search_opts.length_penalty = FLAGS_length_penalty;
  decode_config->ctc_wfst_search_opts.nbest = FLAGS_nbest;
  decode_config->ctc_prefix_search_opts.first_beam_size = FLAGS_nbest;
  decode_config->ctc_prefix_search_opts.second_beam_size = FLAGS_nbest;
  return decode_config;
}

std::shared_ptr<DecodeResource> InitDecodeResourceFromFlags() {
  auto resource = std::make_shared<DecodeResource>();
  const int kNumGemmThreads = 1;
  if (!FLAGS_onnx_dir.empty()) {
    LOG(INFO) << "Reading onnx model ";
    OnnxAsrModel::InitEngineThreads(kNumGemmThreads);
    auto model = std::make_shared<OnnxAsrModel>();
    model->Read(FLAGS_onnx_dir);
    resource->model = model;

  } else{
    LOG(FATAL) << "Please set ONNX model path!!!";
  }

  LOG(INFO) << "Reading unit table " << FLAGS_unit_path;
  auto unit_table = std::shared_ptr<fst::SymbolTable>(
      fst::SymbolTable::ReadText(FLAGS_unit_path));
  CHECK(unit_table != nullptr);
  resource->unit_table = unit_table;

  if (!FLAGS_fst_path.empty()) {  // With LM
    CHECK(!FLAGS_dict_path.empty());
    LOG(INFO) << "Reading fst " << FLAGS_fst_path;
    auto fst = std::shared_ptr<fst::Fst<fst::StdArc>>(
        fst::Fst<fst::StdArc>::Read(FLAGS_fst_path));
    CHECK(fst != nullptr);
    resource->fst = fst;

    LOG(INFO) << "Reading symbol table " << FLAGS_dict_path;
    auto symbol_table = std::shared_ptr<fst::SymbolTable>(
        fst::SymbolTable::ReadText(FLAGS_dict_path));
    CHECK(symbol_table != nullptr);
    resource->symbol_table = symbol_table;
  } else {  // Without LM, symbol_table is the same as unit_table
    resource->symbol_table = unit_table;
  }

  if (!FLAGS_context_path.empty()) {
    LOG(INFO) << "Reading context " << FLAGS_context_path;
    std::vector<std::string> contexts;
    std::ifstream infile(FLAGS_context_path);
    std::string context;
    while (getline(infile, context)) {
      contexts.emplace_back(Trim(context));
    }
    ContextConfig config;
    config.context_score = FLAGS_context_score;
    resource->context_graph = std::make_shared<ContextGraph>(config);
    resource->context_graph->BuildContextGraph(contexts,
                                               resource->symbol_table);
  }

  PostProcessOptions post_process_opts;
  post_process_opts.language_type =
      FLAGS_language_type == 0 ? kMandarinEnglish : kIndoEuropean;
  post_process_opts.lowercase = FLAGS_lowercase;
  resource->post_processor =
      std::make_shared<PostProcessor>(std::move(post_process_opts));
  return resource;
}

}  // namespace wenet

#endif  // DECODER_PARAMS_H_