deep_speech.py 11.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Main entry to train and evaluate DeepSpeech model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
# pylint: disable=g-bad-import-order
from absl import app as absl_app
from absl import flags
import tensorflow as tf
# pylint: enable=g-bad-import-order

import data.dataset as dataset
28
import decoder
29
30
31
32
33
import deep_speech_model
from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper
from official.utils.logs import logger
from official.utils.misc import distribution_utils
34
from official.utils.misc import model_helpers
35
36
37
38

# Default vocabulary file
_VOCABULARY_FILE = os.path.join(
    os.path.dirname(__file__), "data/vocabulary.txt")
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Evaluation metrics
_WER_KEY = "WER"
_CER_KEY = "CER"


def evaluate_model(
    estimator, batch_size, speech_labels, targets, input_fn_eval):
  """Evaluate the model performance using WER anc CER as metrics.

  WER: Word Error Rate
  CER: Character Error Rate

  Args:
    estimator: estimator to evaluate.
    batch_size: size of a mini-batch.
    speech_labels: a string specifying all the character in the vocabulary.
    targets: a list of list of integers for the featurized transcript.
    input_fn_eval: data input function for evaluation.

  Returns:
    Evaluation result containing 'wer' and 'cer' as two metrics.
  """
  # Get predictions
  predictions = estimator.predict(
      input_fn=input_fn_eval, yield_single_examples=False)

  y_preds = []
  input_lengths = []
  for p in predictions:
    y_preds.append(p["y_pred"])
    input_lengths.append(p["ctc_input_length"])

  num_of_examples = len(targets)
  total_wer, total_cer = 0, 0
  greedy_decoder = decoder.GreedyDecoder(speech_labels)
  for i in range(len(y_preds)):
    # Compute the CER and WER for the current batch,
    # and aggregate to total_cer, total_wer.
    y_pred_tensor = tf.convert_to_tensor(y_preds[i])
    batch_targets = targets[i * batch_size : (i + 1) * batch_size]
    seq_len = tf.squeeze(input_lengths[i], axis=1)

    # Perform decoding
    _, decoded_output = greedy_decoder.decode(
        y_pred_tensor, seq_len)

    # Compute CER.
    batch_cer = greedy_decoder.batch_cer(decoded_output, batch_targets)
    total_cer += batch_cer
    # Compute WER.
    batch_wer = greedy_decoder.batch_wer(decoded_output, batch_targets)
    total_wer += batch_wer

  # Get mean value
  total_cer /= num_of_examples
  total_wer /= num_of_examples

  global_step = estimator.get_variable_value(tf.GraphKeys.GLOBAL_STEP)
  eval_results = {
      _WER_KEY: total_wer,
      _CER_KEY: total_cer,
      tf.GraphKeys.GLOBAL_STEP: global_step,
  }

  return eval_results
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205


def convert_keras_to_estimator(keras_model, num_gpus):
  """Configure and convert keras model to Estimator.

  Args:
    keras_model: A Keras model object.
    num_gpus: An integer, the number of GPUs.

  Returns:
    estimator: The converted Estimator.
  """
  # keras optimizer is not compatible with distribution strategy.
  # Use tf optimizer instead
  optimizer = tf.train.MomentumOptimizer(
      learning_rate=flags_obj.learning_rate, momentum=flags_obj.momentum,
      use_nesterov=True)

  # ctc_loss is wrapped as a Lambda layer in the model.
  keras_model.compile(
      optimizer=optimizer, loss={"ctc_loss": lambda y_true, y_pred: y_pred})

  distribution_strategy = distribution_utils.get_distribution_strategy(
      num_gpus)
  run_config = tf.estimator.RunConfig(
      train_distribute=distribution_strategy)

  estimator = tf.keras.estimator.model_to_estimator(
      keras_model=keras_model, model_dir=flags_obj.model_dir, config=run_config)

  return estimator


def generate_dataset(data_dir):
  """Generate a speech dataset."""
  audio_conf = dataset.AudioConfig(
      flags_obj.sample_rate, flags_obj.frame_length, flags_obj.frame_step)
  train_data_conf = dataset.DatasetConfig(
      audio_conf,
      data_dir,
      flags_obj.vocabulary_file,
  )
  speech_dataset = dataset.DeepSpeechDataset(train_data_conf)
  return speech_dataset


def run_deep_speech(_):
  """Run deep speech training and eval loop."""
  # Data preprocessing
  # The file name of training and test dataset
  tf.logging.info("Data preprocessing...")

  train_speech_dataset = generate_dataset(flags_obj.train_data_dir)
  eval_speech_dataset = generate_dataset(flags_obj.eval_data_dir)

  # Number of label classes. Label string is "[a-z]' -"
  num_classes = len(train_speech_dataset.speech_labels)

  # Input shape of each data example:
  # [time_steps (T), feature_bins(F), channel(C)]
  # Channel is set as 1 by default.
  input_shape = (None, train_speech_dataset.num_feature_bins, 1)

  # Create deep speech model and convert it to Estimator
  tf.logging.info("Creating Estimator from Keras model...")
  keras_model = deep_speech_model.DeepSpeech(
      input_shape, flags_obj.rnn_hidden_layers, flags_obj.rnn_type,
      flags_obj.is_bidirectional, flags_obj.rnn_hidden_size,
      flags_obj.rnn_activation, num_classes, flags_obj.use_bias)

  # Convert to estimator
  num_gpus = flags_core.get_num_gpus(flags_obj)
  estimator = convert_keras_to_estimator(keras_model, num_gpus)

  # Benchmark logging
  run_params = {
      "batch_size": flags_obj.batch_size,
      "train_epochs": flags_obj.train_epochs,
      "rnn_hidden_size": flags_obj.rnn_hidden_size,
      "rnn_hidden_layers": flags_obj.rnn_hidden_layers,
      "rnn_activation": flags_obj.rnn_activation,
      "rnn_type": flags_obj.rnn_type,
      "is_bidirectional": flags_obj.is_bidirectional,
      "use_bias": flags_obj.use_bias
  }

  dataset_name = "LibriSpeech"
  benchmark_logger = logger.get_benchmark_logger()
  benchmark_logger.log_run_info("deep_speech", dataset_name, run_params,
                                test_id=flags_obj.benchmark_test_id)

  train_hooks = hooks_helper.get_train_hooks(
      flags_obj.hooks,
      batch_size=flags_obj.batch_size)

  per_device_batch_size = distribution_utils.per_device_batch_size(
      flags_obj.batch_size, num_gpus)

  def input_fn_train():
    return dataset.input_fn(
        per_device_batch_size, train_speech_dataset)

206
  def input_fn_eval():
207
208
209
210
211
212
213
214
215
216
217
    return dataset.input_fn(
        per_device_batch_size, eval_speech_dataset)

  total_training_cycle = (flags_obj.train_epochs //
                          flags_obj.epochs_between_evals)
  for cycle_index in range(total_training_cycle):
    tf.logging.info("Starting a training cycle: %d/%d",
                    cycle_index + 1, total_training_cycle)

    estimator.train(input_fn=input_fn_train, hooks=train_hooks)

218
219
220
221
222
223
    # Evaluation
    tf.logging.info("Starting to evaluate...")

    eval_results = evaluate_model(
        estimator, flags_obj.batch_size, eval_speech_dataset.speech_labels,
        eval_speech_dataset.labels, input_fn_eval)
224

225
226
227
228
229
    # Log the WER and CER results.
    benchmark_logger.log_evaluation_result(eval_results)
    tf.logging.info(
        "Iteration {}: WER = {:.2f}, CER = {:.2f}".format(
            cycle_index + 1, eval_results[_WER_KEY], eval_results[_CER_KEY]))
230
231

    # If some evaluation threshold is met
232
233
234
    if model_helpers.past_stop_threshold(
        flags_obj.wer_threshold, eval_results[_WER_KEY]):
      break
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259

  # Clear the session explicitly to avoid session delete error
  tf.keras.backend.clear_session()


def define_deep_speech_flags():
  """Add flags for run_deep_speech."""
  # Add common flags
  flags_core.define_base(
      data_dir=False  # we use train_data_dir and eval_data_dir instead
  )
  flags_core.define_performance(
      num_parallel_calls=False,
      inter_op=False,
      intra_op=False,
      synthetic_data=False,
      max_train_steps=False,
      dtype=False
  )
  flags_core.define_benchmark()
  flags.adopt_module_key_flags(flags_core)

  flags_core.set_defaults(
      model_dir="/tmp/deep_speech_model/",
      export_dir="/tmp/deep_speech_saved_model/",
260
261
      train_epochs=2,
      batch_size=4,
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
      hooks="")

  # Deep speech flags
  flags.DEFINE_string(
      name="train_data_dir",
      default="/tmp/librispeech_data/test-clean/LibriSpeech/test-clean-20.csv",
      help=flags_core.help_wrap("The csv file path of train dataset."))

  flags.DEFINE_string(
      name="eval_data_dir",
      default="/tmp/librispeech_data/test-clean/LibriSpeech/test-clean-20.csv",
      help=flags_core.help_wrap("The csv file path of evaluation dataset."))

  flags.DEFINE_integer(
      name="sample_rate", default=16000,
      help=flags_core.help_wrap("The sample rate for audio."))

  flags.DEFINE_integer(
      name="frame_length", default=25,
      help=flags_core.help_wrap("The frame length for spectrogram."))

  flags.DEFINE_integer(
      name="frame_step", default=10,
      help=flags_core.help_wrap("The frame step."))

  flags.DEFINE_string(
      name="vocabulary_file", default=_VOCABULARY_FILE,
      help=flags_core.help_wrap("The file path of vocabulary file."))

  # RNN related flags
  flags.DEFINE_integer(
      name="rnn_hidden_size", default=256,
      help=flags_core.help_wrap("The hidden size of RNNs."))

  flags.DEFINE_integer(
      name="rnn_hidden_layers", default=3,
      help=flags_core.help_wrap("The number of RNN layers."))

  flags.DEFINE_bool(
      name="use_bias", default=True,
      help=flags_core.help_wrap("Use bias in the last fully-connected layer"))

  flags.DEFINE_bool(
      name="is_bidirectional", default=True,
      help=flags_core.help_wrap("If rnn unit is bidirectional"))

  flags.DEFINE_enum(
      name="rnn_type", default="gru",
      enum_values=deep_speech_model.SUPPORTED_RNNS.keys(),
      case_sensitive=False,
      help=flags_core.help_wrap("Type of RNN cell."))

  flags.DEFINE_enum(
      name="rnn_activation", default="tanh",
      enum_values=["tanh", "relu"], case_sensitive=False,
      help=flags_core.help_wrap("Type of the activation within RNN."))

  # Training related flags
  flags.DEFINE_float(
      name="learning_rate", default=0.0003,
      help=flags_core.help_wrap("The initial learning rate."))

  flags.DEFINE_float(
      name="momentum", default=0.9,
      help=flags_core.help_wrap("Momentum to accelerate SGD optimizer."))

  # Evaluation metrics threshold
  flags.DEFINE_float(
      name="wer_threshold", default=None,
      help=flags_core.help_wrap(
          "If passed, training will stop when the evaluation metric WER is "
          "greater than or equal to wer_threshold. For libri speech dataset "
          "the desired wer_threshold is 0.23 which is the result achieved by "
          "MLPerf implementation."))


def main(_):
  with logger.benchmark_context(flags_obj):
    run_deep_speech(flags_obj)


if __name__ == "__main__":
  tf.logging.set_verbosity(tf.logging.INFO)
  define_deep_speech_flags()
  flags_obj = flags.FLAGS
  absl_app.run(main)