Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
1635e561
Unverified
Commit
1635e561
authored
Jul 09, 2018
by
Yanhui Liang
Committed by
GitHub
Jul 09, 2018
Browse files
Add eval and parallel dataset (#4651)
parent
c8c45fdb
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
411 additions
and
136 deletions
+411
-136
research/deep_speech/data/dataset.py
research/deep_speech/data/dataset.py
+101
-68
research/deep_speech/data/featurizer.py
research/deep_speech/data/featurizer.py
+20
-52
research/deep_speech/decoder.py
research/deep_speech/decoder.py
+206
-0
research/deep_speech/deep_speech.py
research/deep_speech/deep_speech.py
+84
-16
No files found.
research/deep_speech/data/dataset.py
View file @
1635e561
...
...
@@ -17,14 +17,15 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
functools
import
multiprocessing
import
numpy
as
np
import
scipy.io.wavfile
as
wavfile
from
six.moves
import
xrange
# pylint: disable=redefined-builtin
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
data.featurizer
import
AudioFeaturizer
from
data.featurizer
import
TextFeaturizer
import
data.featurizer
as
featurizer
# pylint: disable=g-bad-import-order
class
AudioConfig
(
object
):
...
...
@@ -44,7 +45,7 @@ class AudioConfig(object):
frame_length: an integer for the length of a spectrogram frame, in ms.
frame_step: an integer for the frame stride, in ms.
fft_length: an integer for the number of fft bins.
normalize: a boolean for whether apply normalization on the audio
tensor
.
normalize: a boolean for whether apply normalization on the audio
feature
.
spect_type: a string for the type of spectrogram to be extracted.
"""
...
...
@@ -78,90 +79,122 @@ class DatasetConfig(object):
self
.
vocab_file_path
=
vocab_file_path
def
_normalize_audio_feature
(
audio_feature
):
"""Perform mean and variance normalization on the spectrogram feature.
Args:
audio_feature: a numpy array for the spectrogram feature.
Returns:
a numpy array of the normalized spectrogram.
"""
mean
=
np
.
mean
(
audio_feature
,
axis
=
0
)
var
=
np
.
var
(
audio_feature
,
axis
=
0
)
normalized
=
(
audio_feature
-
mean
)
/
(
np
.
sqrt
(
var
)
+
1e-6
)
return
normalized
def
_preprocess_audio
(
audio_file_path
,
audio_sample_rate
,
audio_featurizer
,
normalize
):
"""Load the audio file in memory and compute spectrogram feature."""
tf
.
logging
.
info
(
"Extracting spectrogram feature for {}"
.
format
(
audio_file_path
))
sample_rate
,
data
=
wavfile
.
read
(
audio_file_path
)
assert
sample_rate
==
audio_sample_rate
if
data
.
dtype
not
in
[
np
.
float32
,
np
.
float64
]:
data
=
data
.
astype
(
np
.
float32
)
/
np
.
iinfo
(
data
.
dtype
).
max
feature
=
featurizer
.
compute_spectrogram_feature
(
data
,
audio_featurizer
.
frame_length
,
audio_featurizer
.
frame_step
,
audio_featurizer
.
fft_length
)
if
normalize
:
feature
=
_normalize_audio_feature
(
feature
)
return
feature
def
_preprocess_transcript
(
transcript
,
token_to_index
):
"""Process transcript as label features."""
return
featurizer
.
compute_label_feature
(
transcript
,
token_to_index
)
def
_preprocess_data
(
dataset_config
,
audio_featurizer
,
token_to_index
):
"""Generate a list of waveform, transcript pair.
Each dataset file contains three columns: "wav_filename", "wav_filesize",
and "transcript". This function parses the csv file and stores each example
by the increasing order of audio length (indicated by wav_filesize).
AS the waveforms are ordered in increasing length, audio samples in a
mini-batch have similar length.
Args:
dataset_config: an instance of DatasetConfig.
audio_featurizer: an instance of AudioFeaturizer.
token_to_index: the mapping from character to its index
Returns:
features and labels array processed from the audio/text input.
"""
file_path
=
dataset_config
.
data_path
sample_rate
=
dataset_config
.
audio_config
.
sample_rate
normalize
=
dataset_config
.
audio_config
.
normalize
with
tf
.
gfile
.
Open
(
file_path
,
"r"
)
as
f
:
lines
=
f
.
read
().
splitlines
()
lines
=
[
line
.
split
(
"
\t
"
)
for
line
in
lines
]
# Skip the csv header.
lines
=
lines
[
1
:]
# Sort input data by the length of waveform.
lines
.
sort
(
key
=
lambda
item
:
int
(
item
[
1
]))
# Use multiprocessing for feature/label extraction
num_cores
=
multiprocessing
.
cpu_count
()
pool
=
multiprocessing
.
Pool
(
processes
=
num_cores
)
features
=
pool
.
map
(
functools
.
partial
(
_preprocess_audio
,
audio_sample_rate
=
sample_rate
,
audio_featurizer
=
audio_featurizer
,
normalize
=
normalize
),
[
line
[
0
]
for
line
in
lines
])
labels
=
pool
.
map
(
functools
.
partial
(
_preprocess_transcript
,
token_to_index
=
token_to_index
),
[
line
[
2
]
for
line
in
lines
])
pool
.
terminate
()
return
features
,
labels
class
DeepSpeechDataset
(
object
):
"""Dataset class for training/evaluation of DeepSpeech model."""
def
__init__
(
self
,
dataset_config
):
"""Initialize the class.
Each dataset file contains three columns: "wav_filename", "wav_filesize",
and "transcript". This function parses the csv file and stores each example
by the increasing order of audio length (indicated by wav_filesize).
"""Initialize the DeepSpeechDataset class.
Args:
dataset_config: DatasetConfig object.
"""
self
.
config
=
dataset_config
# Instantiate audio feature extractor.
self
.
audio_featurizer
=
AudioFeaturizer
(
self
.
audio_featurizer
=
featurizer
.
AudioFeaturizer
(
sample_rate
=
self
.
config
.
audio_config
.
sample_rate
,
frame_length
=
self
.
config
.
audio_config
.
frame_length
,
frame_step
=
self
.
config
.
audio_config
.
frame_step
,
fft_length
=
self
.
config
.
audio_config
.
fft_length
,
spect_type
=
self
.
config
.
audio_config
.
spect_type
)
fft_length
=
self
.
config
.
audio_config
.
fft_length
)
# Instantiate text feature extractor.
self
.
text_featurizer
=
TextFeaturizer
(
self
.
text_featurizer
=
featurizer
.
TextFeaturizer
(
vocab_file
=
self
.
config
.
vocab_file_path
)
self
.
speech_labels
=
self
.
text_featurizer
.
speech_labels
self
.
features
,
self
.
labels
=
self
.
_preprocess_data
(
self
.
config
.
data_path
)
self
.
features
,
self
.
labels
=
_preprocess_data
(
self
.
config
,
self
.
audio_featurizer
,
self
.
text_featurizer
.
token_to_idx
)
self
.
num_feature_bins
=
(
self
.
features
[
0
].
shape
[
1
]
if
len
(
self
.
features
)
else
None
)
def
_preprocess_data
(
self
,
file_path
):
"""Generate a list of waveform, transcript pair.
Note that the waveforms are ordered in increasing length, so that audio
samples in a mini-batch have similar length.
Args:
file_path: a string specifying the csv file path for a data set.
Returns:
features and labels array processed from the audio/text input.
"""
with
tf
.
gfile
.
Open
(
file_path
,
"r"
)
as
f
:
lines
=
f
.
read
().
splitlines
()
lines
=
[
line
.
split
(
"
\t
"
)
for
line
in
lines
]
# Skip the csv header.
lines
=
lines
[
1
:]
# Sort input data by the length of waveform.
lines
.
sort
(
key
=
lambda
item
:
int
(
item
[
1
]))
features
=
[
self
.
_preprocess_audio
(
line
[
0
])
for
line
in
lines
]
labels
=
[
self
.
_preprocess_transcript
(
line
[
2
])
for
line
in
lines
]
return
features
,
labels
def
_normalize_audio_tensor
(
self
,
audio_tensor
):
"""Perform mean and variance normalization on the spectrogram tensor.
Args:
audio_tensor: a tensor for the spectrogram feature.
Returns:
a tensor for the normalized spectrogram.
"""
mean
,
var
=
tf
.
nn
.
moments
(
audio_tensor
,
axes
=
[
0
])
normalized
=
(
audio_tensor
-
mean
)
/
(
tf
.
sqrt
(
var
)
+
1e-6
)
return
normalized
def
_preprocess_audio
(
self
,
audio_file_path
):
"""Load the audio file in memory."""
tf
.
logging
.
info
(
"Extracting spectrogram feature for {}"
.
format
(
audio_file_path
))
sample_rate
,
data
=
wavfile
.
read
(
audio_file_path
)
assert
sample_rate
==
self
.
config
.
audio_config
.
sample_rate
if
data
.
dtype
not
in
[
np
.
float32
,
np
.
float64
]:
data
=
data
.
astype
(
np
.
float32
)
/
np
.
iinfo
(
data
.
dtype
).
max
feature
=
self
.
audio_featurizer
.
featurize
(
data
)
if
self
.
config
.
audio_config
.
normalize
:
feature
=
self
.
_normalize_audio_tensor
(
feature
)
return
tf
.
Session
().
run
(
feature
)
# return a numpy array rather than a tensor
def
_preprocess_transcript
(
self
,
transcript
):
return
self
.
text_featurizer
.
featurize
(
transcript
)
def
input_fn
(
batch_size
,
deep_speech_dataset
,
repeat
=
1
):
"""Input function for model training and evaluation.
...
...
research/deep_speech/data/featurizer.py
View file @
1635e561
...
...
@@ -18,9 +18,21 @@ from __future__ import division
from
__future__
import
print_function
import
codecs
import
functools
import
numpy
as
np
import
tensorflow
as
tf
from
scipy
import
signal
def
compute_spectrogram_feature
(
waveform
,
frame_length
,
frame_step
,
fft_length
):
"""Compute the spectrograms for the input waveform."""
_
,
_
,
stft
=
signal
.
stft
(
waveform
,
nperseg
=
frame_length
,
noverlap
=
frame_step
,
nfft
=
fft_length
)
# Perform transpose to set its shape as [time_steps, feature_num_bins]
spectrogram
=
np
.
transpose
(
np
.
absolute
(
stft
),
(
1
,
0
))
return
spectrogram
class
AudioFeaturizer
(
object
):
...
...
@@ -30,10 +42,7 @@ class AudioFeaturizer(object):
sample_rate
=
16000
,
frame_length
=
25
,
frame_step
=
10
,
fft_length
=
None
,
window_fn
=
functools
.
partial
(
tf
.
contrib
.
signal
.
hann_window
,
periodic
=
True
),
spect_type
=
"linear"
):
fft_length
=
None
):
"""Initialize the audio featurizer class according to the configs.
Args:
...
...
@@ -41,53 +50,18 @@ class AudioFeaturizer(object):
frame_length: an integer for the length of a spectrogram frame, in ms.
frame_step: an integer for the frame stride, in ms.
fft_length: an integer for the number of fft bins.
window_fn: windowing function.
spect_type: a string for the type of spectrogram to be extracted.
Currently only support 'linear', otherwise will raise a value error.
Raises:
ValueError: In case of invalid arguments for `spect_type`.
"""
if
spect_type
!=
"linear"
:
raise
ValueError
(
"Unsupported spectrogram type: %s"
%
spect_type
)
self
.
window_fn
=
window_fn
self
.
frame_length
=
int
(
sample_rate
*
frame_length
/
1e3
)
self
.
frame_step
=
int
(
sample_rate
*
frame_step
/
1e3
)
self
.
fft_length
=
fft_length
if
fft_length
else
int
(
2
**
(
np
.
ceil
(
np
.
log2
(
self
.
frame_length
))))
def
featurize
(
self
,
waveform
):
"""Extract spectrogram feature tensors from the waveform."""
return
self
.
_compute_linear_spectrogram
(
waveform
)
def
_compute_linear_spectrogram
(
self
,
waveform
):
"""Compute the linear-scale, magnitude spectrograms for the input waveform.
Args:
waveform: a float32 audio tensor.
Returns:
a float 32 tensor with shape [len, num_bins]
"""
# `stfts` is a complex64 Tensor representing the Short-time Fourier
# Transform of each signal in `signals`. Its shape is
# [?, fft_unique_bins] where fft_unique_bins = fft_length // 2 + 1.
stfts
=
tf
.
contrib
.
signal
.
stft
(
waveform
,
frame_length
=
self
.
frame_length
,
frame_step
=
self
.
frame_step
,
fft_length
=
self
.
fft_length
,
window_fn
=
self
.
window_fn
,
pad_end
=
True
)
# An energy spectrogram is the magnitude of the complex-valued STFT.
# A float32 Tensor of shape [?, 257].
magnitude_spectrograms
=
tf
.
abs
(
stfts
)
return
magnitude_spectrograms
def
_compute_mel_filterbank_features
(
self
,
waveform
):
"""Compute the mel filterbank features."""
raise
NotImplementedError
(
"MFCC feature extraction not supported yet."
)
def
compute_label_feature
(
text
,
token_to_idx
):
"""Convert string to a list of integers."""
tokens
=
list
(
text
.
strip
().
lower
())
feats
=
[
token_to_idx
[
token
]
for
token
in
tokens
]
return
feats
class
TextFeaturizer
(
object
):
...
...
@@ -114,9 +88,3 @@ class TextFeaturizer(object):
self
.
idx_to_token
[
idx
]
=
line
self
.
speech_labels
+=
line
idx
+=
1
def
featurize
(
self
,
text
):
"""Convert string to a list of integers."""
tokens
=
list
(
text
.
strip
().
lower
())
feats
=
[
self
.
token_to_idx
[
token
]
for
token
in
tokens
]
return
feats
research/deep_speech/decoder.py
0 → 100644
View file @
1635e561
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Deep speech decoder."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
nltk.metrics
import
distance
from
six.moves
import
xrange
import
tensorflow
as
tf
class
DeepSpeechDecoder
(
object
):
"""Basic decoder class from which all other decoders inherit.
Implements several helper functions. Subclasses should implement the decode()
method.
"""
def
__init__
(
self
,
labels
,
blank_index
=
28
,
space_index
=
27
):
"""Decoder initialization.
Arguments:
labels (string): mapping from integers to characters.
blank_index (int, optional): index for the blank '_' character.
Defaults to 0.
space_index (int, optional): index for the space ' ' character.
Defaults to 28.
"""
# e.g. labels = "[a-z]' _"
self
.
labels
=
labels
self
.
int_to_char
=
dict
([(
i
,
c
)
for
(
i
,
c
)
in
enumerate
(
labels
)])
self
.
blank_index
=
blank_index
self
.
space_index
=
space_index
def
convert_to_strings
(
self
,
sequences
,
sizes
=
None
):
"""Given a list of numeric sequences, returns the corresponding strings."""
strings
=
[]
for
x
in
xrange
(
len
(
sequences
)):
seq_len
=
sizes
[
x
]
if
sizes
is
not
None
else
len
(
sequences
[
x
])
string
=
self
.
_convert_to_string
(
sequences
[
x
],
seq_len
)
strings
.
append
(
string
)
return
strings
def
_convert_to_string
(
self
,
sequence
,
sizes
):
return
''
.
join
([
self
.
int_to_char
[
sequence
[
i
]]
for
i
in
range
(
sizes
)])
def
process_strings
(
self
,
sequences
,
remove_repetitions
=
False
):
"""Process strings.
Given a list of strings, removes blanks and replace space character with
space. Option to remove repetitions (e.g. 'abbca' -> 'abca').
Arguments:
sequences: list of 1-d array of integers
remove_repetitions (boolean, optional): If true, repeating characters
are removed. Defaults to False.
Returns:
The processed string.
"""
processed_strings
=
[]
for
sequence
in
sequences
:
string
=
self
.
process_string
(
remove_repetitions
,
sequence
).
strip
()
processed_strings
.
append
(
string
)
return
processed_strings
def
process_string
(
self
,
remove_repetitions
,
sequence
):
"""Process each given sequence."""
seq_string
=
''
for
i
,
char
in
enumerate
(
sequence
):
if
char
!=
self
.
int_to_char
[
self
.
blank_index
]:
# if this char is a repetition and remove_repetitions=true,
# skip.
if
remove_repetitions
and
i
!=
0
and
char
==
sequence
[
i
-
1
]:
pass
elif
char
==
self
.
labels
[
self
.
space_index
]:
seq_string
+=
' '
else
:
seq_string
+=
char
return
seq_string
def
wer
(
self
,
output
,
target
):
"""Computes the Word Error Rate (WER).
WER is defined as the edit distance between the two provided sentences after
tokenizing to words.
Args:
output: string of the decoded output.
target: a string for the true transcript.
Returns:
A float number for the WER of the current sentence pair.
"""
# Map each word to a new char.
words
=
set
(
output
.
split
()
+
target
.
split
())
word2char
=
dict
(
zip
(
words
,
range
(
len
(
words
))))
new_output
=
[
chr
(
word2char
[
w
])
for
w
in
output
.
split
()]
new_target
=
[
chr
(
word2char
[
w
])
for
w
in
target
.
split
()]
return
distance
.
edit_distance
(
''
.
join
(
new_output
),
''
.
join
(
new_target
))
def
cer
(
self
,
output
,
target
):
"""Computes the Character Error Rate (CER).
CER is defined as the edit distance between the given strings.
Args:
output: a string of the decoded output.
target: a string for the ground truth transcript.
Returns:
A float number denoting the CER for the current sentence pair.
"""
return
distance
.
edit_distance
(
output
,
target
)
def
batch_wer
(
self
,
decoded_output
,
targets
):
"""Compute the aggregate WER for each batch.
Args:
decoded_output: 2d array of integers for the decoded output of a batch.
targets: 2d array of integers for the labels of a batch.
Returns:
A float number for the aggregated WER for the current batch output.
"""
# Convert numeric representation to string.
decoded_strings
=
self
.
convert_to_strings
(
decoded_output
)
decoded_strings
=
self
.
process_strings
(
decoded_strings
,
remove_repetitions
=
True
)
target_strings
=
self
.
convert_to_strings
(
targets
)
target_strings
=
self
.
process_strings
(
target_strings
,
remove_repetitions
=
True
)
wer
=
0
for
i
in
xrange
(
len
(
decoded_strings
)):
wer
+=
self
.
wer
(
decoded_strings
[
i
],
target_strings
[
i
])
/
float
(
len
(
target_strings
[
i
].
split
()))
return
wer
def
batch_cer
(
self
,
decoded_output
,
targets
):
"""Compute the aggregate CER for each batch.
Args:
decoded_output: 2d array of integers for the decoded output of a batch.
targets: 2d array of integers for the labels of a batch.
Returns:
A float number for the aggregated CER for the current batch output.
"""
# Convert numeric representation to string.
decoded_strings
=
self
.
convert_to_strings
(
decoded_output
)
decoded_strings
=
self
.
process_strings
(
decoded_strings
,
remove_repetitions
=
True
)
target_strings
=
self
.
convert_to_strings
(
targets
)
target_strings
=
self
.
process_strings
(
target_strings
,
remove_repetitions
=
True
)
cer
=
0
for
i
in
xrange
(
len
(
decoded_strings
)):
cer
+=
self
.
cer
(
decoded_strings
[
i
],
target_strings
[
i
])
/
float
(
len
(
target_strings
[
i
]))
return
cer
def
decode
(
self
,
sequences
,
sizes
=
None
):
"""Perform sequence decoding.
Given a matrix of character probabilities, returns the decoder's best guess
of the transcription.
Arguments:
sequences: 2D array of character probabilities, where sequences[c, t] is
the probability of character c at time t.
sizes(optional): Size of each sequence in the mini-batch.
Returns:
string: sequence of the model's best guess for the transcription.
"""
strings
=
self
.
convert_to_strings
(
sequences
,
sizes
)
return
self
.
process_strings
(
strings
,
remove_repetitions
=
True
)
class
GreedyDecoder
(
DeepSpeechDecoder
):
"""Greedy decoder."""
def
decode
(
self
,
logits
,
seq_len
):
# Reshape to [max_time, batch_size, num_classes]
logits
=
tf
.
transpose
(
logits
,
(
1
,
0
,
2
))
decoded
,
_
=
tf
.
nn
.
ctc_greedy_decoder
(
logits
,
seq_len
)
decoded_dense
=
tf
.
Session
().
run
(
tf
.
sparse_to_dense
(
decoded
[
0
].
indices
,
decoded
[
0
].
dense_shape
,
decoded
[
0
].
values
))
result
=
self
.
convert_to_strings
(
decoded_dense
)
return
self
.
process_strings
(
result
,
remove_repetitions
=
True
),
decoded_dense
research/deep_speech/deep_speech.py
View file @
1635e561
...
...
@@ -25,15 +25,82 @@ import tensorflow as tf
# pylint: enable=g-bad-import-order
import
data.dataset
as
dataset
import
decoder
import
deep_speech_model
from
official.utils.flags
import
core
as
flags_core
from
official.utils.logs
import
hooks_helper
from
official.utils.logs
import
logger
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
model_helpers
# Default vocabulary file
_VOCABULARY_FILE
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"data/vocabulary.txt"
)
# Evaluation metrics
_WER_KEY
=
"WER"
_CER_KEY
=
"CER"
def
evaluate_model
(
estimator
,
batch_size
,
speech_labels
,
targets
,
input_fn_eval
):
"""Evaluate the model performance using WER anc CER as metrics.
WER: Word Error Rate
CER: Character Error Rate
Args:
estimator: estimator to evaluate.
batch_size: size of a mini-batch.
speech_labels: a string specifying all the character in the vocabulary.
targets: a list of list of integers for the featurized transcript.
input_fn_eval: data input function for evaluation.
Returns:
Evaluation result containing 'wer' and 'cer' as two metrics.
"""
# Get predictions
predictions
=
estimator
.
predict
(
input_fn
=
input_fn_eval
,
yield_single_examples
=
False
)
y_preds
=
[]
input_lengths
=
[]
for
p
in
predictions
:
y_preds
.
append
(
p
[
"y_pred"
])
input_lengths
.
append
(
p
[
"ctc_input_length"
])
num_of_examples
=
len
(
targets
)
total_wer
,
total_cer
=
0
,
0
greedy_decoder
=
decoder
.
GreedyDecoder
(
speech_labels
)
for
i
in
range
(
len
(
y_preds
)):
# Compute the CER and WER for the current batch,
# and aggregate to total_cer, total_wer.
y_pred_tensor
=
tf
.
convert_to_tensor
(
y_preds
[
i
])
batch_targets
=
targets
[
i
*
batch_size
:
(
i
+
1
)
*
batch_size
]
seq_len
=
tf
.
squeeze
(
input_lengths
[
i
],
axis
=
1
)
# Perform decoding
_
,
decoded_output
=
greedy_decoder
.
decode
(
y_pred_tensor
,
seq_len
)
# Compute CER.
batch_cer
=
greedy_decoder
.
batch_cer
(
decoded_output
,
batch_targets
)
total_cer
+=
batch_cer
# Compute WER.
batch_wer
=
greedy_decoder
.
batch_wer
(
decoded_output
,
batch_targets
)
total_wer
+=
batch_wer
# Get mean value
total_cer
/=
num_of_examples
total_wer
/=
num_of_examples
global_step
=
estimator
.
get_variable_value
(
tf
.
GraphKeys
.
GLOBAL_STEP
)
eval_results
=
{
_WER_KEY
:
total_wer
,
_CER_KEY
:
total_cer
,
tf
.
GraphKeys
.
GLOBAL_STEP
:
global_step
,
}
return
eval_results
def
convert_keras_to_estimator
(
keras_model
,
num_gpus
):
...
...
@@ -136,7 +203,7 @@ def run_deep_speech(_):
return
dataset
.
input_fn
(
per_device_batch_size
,
train_speech_dataset
)
def
input_fn_eval
():
# #pylint: disable=unused-variable
def
input_fn_eval
():
return
dataset
.
input_fn
(
per_device_batch_size
,
eval_speech_dataset
)
...
...
@@ -148,22 +215,23 @@ def run_deep_speech(_):
estimator
.
train
(
input_fn
=
input_fn_train
,
hooks
=
train_hooks
)
# Evaluate (TODO)
# tf.logging.info("Starting to evaluate.")
# Evaluation
tf
.
logging
.
info
(
"Starting to evaluate..."
)
eval_results
=
evaluate_model
(
estimator
,
flags_obj
.
batch_size
,
eval_speech_dataset
.
speech_labels
,
eval_speech_dataset
.
labels
,
input_fn_eval
)
# eval_results = evaluate_model(
# estimator, keras_model, data_set.speech_labels, [], input_fn_eval)
# Log the WER and CER results.
benchmark_logger
.
log_evaluation_result
(
eval_results
)
tf
.
logging
.
info
(
"Iteration {}: WER = {:.2f}, CER = {:.2f}"
.
format
(
cycle_index
+
1
,
eval_results
[
_WER_KEY
],
eval_results
[
_CER_KEY
]))
# benchmark_logger.log_evaluation_result(eval_results)
# If some evaluation threshold is met
# Log the HR and NDCG results.
# wer = eval_results[_WER_KEY]
# cer = eval_results[_CER_KEY]
# tf.logging.info(
# "Iteration {}: WER = {:.2f}, CER = {:.2f}".format(
# cycle_index + 1, wer, cer))
# if model_helpers.past_stop_threshold(FLAGS.wer_threshold, wer):
# break
if
model_helpers
.
past_stop_threshold
(
flags_obj
.
wer_threshold
,
eval_results
[
_WER_KEY
]):
break
# Clear the session explicitly to avoid session delete error
tf
.
keras
.
backend
.
clear_session
()
...
...
@@ -189,8 +257,8 @@ def define_deep_speech_flags():
flags_core
.
set_defaults
(
model_dir
=
"/tmp/deep_speech_model/"
,
export_dir
=
"/tmp/deep_speech_saved_model/"
,
train_epochs
=
10
,
batch_size
=
32
,
train_epochs
=
2
,
batch_size
=
4
,
hooks
=
""
)
# Deep speech flags
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment