Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3311242a
Unverified
Commit
3311242a
authored
Aug 15, 2020
by
moneypi
Committed by
GitHub
Aug 14, 2020
Browse files
1. update to tf2.x for deep_speech (#8696)
Update to TF 2 for deep_speech
parent
ccf7da9d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
54 additions
and
78 deletions
+54
-78
research/deep_speech/data/dataset.py
research/deep_speech/data/dataset.py
+4
-4
research/deep_speech/data/download.py
research/deep_speech/data/download.py
+16
-16
research/deep_speech/deep_speech.py
research/deep_speech/deep_speech.py
+19
-35
research/deep_speech/deep_speech_model.py
research/deep_speech/deep_speech_model.py
+15
-23
No files found.
research/deep_speech/data/dataset.py
View file @
3311242a
...
@@ -71,8 +71,8 @@ class DatasetConfig(object):
...
@@ -71,8 +71,8 @@ class DatasetConfig(object):
"""
"""
self
.
audio_config
=
audio_config
self
.
audio_config
=
audio_config
assert
tf
.
gfile
.
E
xists
(
data_path
)
assert
tf
.
io
.
gfile
.
e
xists
(
data_path
)
assert
tf
.
gfile
.
E
xists
(
vocab_file_path
)
assert
tf
.
io
.
gfile
.
e
xists
(
vocab_file_path
)
self
.
data_path
=
data_path
self
.
data_path
=
data_path
self
.
vocab_file_path
=
vocab_file_path
self
.
vocab_file_path
=
vocab_file_path
self
.
sortagrad
=
sortagrad
self
.
sortagrad
=
sortagrad
...
@@ -125,8 +125,8 @@ def _preprocess_data(file_path):
...
@@ -125,8 +125,8 @@ def _preprocess_data(file_path):
A list of tuples (wav_filename, wav_filesize, transcript) sorted by
A list of tuples (wav_filename, wav_filesize, transcript) sorted by
file_size.
file_size.
"""
"""
tf
.
logging
.
info
(
"Loading data set {}"
.
format
(
file_path
))
tf
.
compat
.
v1
.
logging
.
info
(
"Loading data set {}"
.
format
(
file_path
))
with
tf
.
gfile
.
Open
(
file_path
,
"r"
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
file_path
,
"r"
)
as
f
:
lines
=
f
.
read
().
splitlines
()
lines
=
f
.
read
().
splitlines
()
# Skip the csv header in lines[0].
# Skip the csv header in lines[0].
lines
=
lines
[
1
:]
lines
=
lines
[
1
:]
...
...
research/deep_speech/data/download.py
View file @
3311242a
...
@@ -59,13 +59,13 @@ def download_and_extract(directory, url):
...
@@ -59,13 +59,13 @@ def download_and_extract(directory, url):
url: the url to download the data file.
url: the url to download the data file.
"""
"""
if
not
tf
.
gfile
.
E
xists
(
directory
):
if
not
tf
.
io
.
gfile
.
e
xists
(
directory
):
tf
.
gfile
.
M
ake
D
irs
(
directory
)
tf
.
io
.
gfile
.
m
ake
d
irs
(
directory
)
_
,
tar_filepath
=
tempfile
.
mkstemp
(
suffix
=
".tar.gz"
)
_
,
tar_filepath
=
tempfile
.
mkstemp
(
suffix
=
".tar.gz"
)
try
:
try
:
tf
.
logging
.
info
(
"Downloading %s to %s"
%
(
url
,
tar_filepath
))
tf
.
compat
.
v1
.
logging
.
info
(
"Downloading %s to %s"
%
(
url
,
tar_filepath
))
def
_progress
(
count
,
block_size
,
total_size
):
def
_progress
(
count
,
block_size
,
total_size
):
sys
.
stdout
.
write
(
"
\r
>> Downloading {} {:.1f}%"
.
format
(
sys
.
stdout
.
write
(
"
\r
>> Downloading {} {:.1f}%"
.
format
(
...
@@ -75,12 +75,12 @@ def download_and_extract(directory, url):
...
@@ -75,12 +75,12 @@ def download_and_extract(directory, url):
urllib
.
request
.
urlretrieve
(
url
,
tar_filepath
,
_progress
)
urllib
.
request
.
urlretrieve
(
url
,
tar_filepath
,
_progress
)
print
()
print
()
statinfo
=
os
.
stat
(
tar_filepath
)
statinfo
=
os
.
stat
(
tar_filepath
)
tf
.
logging
.
info
(
tf
.
compat
.
v1
.
logging
.
info
(
"Successfully downloaded %s, size(bytes): %d"
%
(
url
,
statinfo
.
st_size
))
"Successfully downloaded %s, size(bytes): %d"
%
(
url
,
statinfo
.
st_size
))
with
tarfile
.
open
(
tar_filepath
,
"r"
)
as
tar
:
with
tarfile
.
open
(
tar_filepath
,
"r"
)
as
tar
:
tar
.
extractall
(
directory
)
tar
.
extractall
(
directory
)
finally
:
finally
:
tf
.
gfile
.
R
emove
(
tar_filepath
)
tf
.
io
.
gfile
.
r
emove
(
tar_filepath
)
def
convert_audio_and_split_transcript
(
input_dir
,
source_name
,
target_name
,
def
convert_audio_and_split_transcript
(
input_dir
,
source_name
,
target_name
,
...
@@ -112,18 +112,18 @@ def convert_audio_and_split_transcript(input_dir, source_name, target_name,
...
@@ -112,18 +112,18 @@ def convert_audio_and_split_transcript(input_dir, source_name, target_name,
output_file: the name of the newly generated csv file. e.g. test-clean.csv
output_file: the name of the newly generated csv file. e.g. test-clean.csv
"""
"""
tf
.
logging
.
info
(
"Preprocessing audio and transcript for %s"
%
source_name
)
tf
.
compat
.
v1
.
logging
.
info
(
"Preprocessing audio and transcript for %s"
%
source_name
)
source_dir
=
os
.
path
.
join
(
input_dir
,
source_name
)
source_dir
=
os
.
path
.
join
(
input_dir
,
source_name
)
target_dir
=
os
.
path
.
join
(
input_dir
,
target_name
)
target_dir
=
os
.
path
.
join
(
input_dir
,
target_name
)
if
not
tf
.
gfile
.
E
xists
(
target_dir
):
if
not
tf
.
io
.
gfile
.
e
xists
(
target_dir
):
tf
.
gfile
.
M
ake
D
irs
(
target_dir
)
tf
.
io
.
gfile
.
m
ake
d
irs
(
target_dir
)
files
=
[]
files
=
[]
tfm
=
Transformer
()
tfm
=
Transformer
()
# Convert all FLAC file into WAV format. At the same time, generate the csv
# Convert all FLAC file into WAV format. At the same time, generate the csv
# file.
# file.
for
root
,
_
,
filenames
in
tf
.
gfile
.
W
alk
(
source_dir
):
for
root
,
_
,
filenames
in
tf
.
io
.
gfile
.
w
alk
(
source_dir
):
for
filename
in
fnmatch
.
filter
(
filenames
,
"*.trans.txt"
):
for
filename
in
fnmatch
.
filter
(
filenames
,
"*.trans.txt"
):
trans_file
=
os
.
path
.
join
(
root
,
filename
)
trans_file
=
os
.
path
.
join
(
root
,
filename
)
with
codecs
.
open
(
trans_file
,
"r"
,
"utf-8"
)
as
fin
:
with
codecs
.
open
(
trans_file
,
"r"
,
"utf-8"
)
as
fin
:
...
@@ -137,7 +137,7 @@ def convert_audio_and_split_transcript(input_dir, source_name, target_name,
...
@@ -137,7 +137,7 @@ def convert_audio_and_split_transcript(input_dir, source_name, target_name,
# Convert FLAC to WAV.
# Convert FLAC to WAV.
flac_file
=
os
.
path
.
join
(
root
,
seqid
+
".flac"
)
flac_file
=
os
.
path
.
join
(
root
,
seqid
+
".flac"
)
wav_file
=
os
.
path
.
join
(
target_dir
,
seqid
+
".wav"
)
wav_file
=
os
.
path
.
join
(
target_dir
,
seqid
+
".wav"
)
if
not
tf
.
gfile
.
E
xists
(
wav_file
):
if
not
tf
.
io
.
gfile
.
e
xists
(
wav_file
):
tfm
.
build
(
flac_file
,
wav_file
)
tfm
.
build
(
flac_file
,
wav_file
)
wav_filesize
=
os
.
path
.
getsize
(
wav_file
)
wav_filesize
=
os
.
path
.
getsize
(
wav_file
)
...
@@ -149,7 +149,7 @@ def convert_audio_and_split_transcript(input_dir, source_name, target_name,
...
@@ -149,7 +149,7 @@ def convert_audio_and_split_transcript(input_dir, source_name, target_name,
df
=
pandas
.
DataFrame
(
df
=
pandas
.
DataFrame
(
data
=
files
,
columns
=
[
"wav_filename"
,
"wav_filesize"
,
"transcript"
])
data
=
files
,
columns
=
[
"wav_filename"
,
"wav_filesize"
,
"transcript"
])
df
.
to_csv
(
csv_file_path
,
index
=
False
,
sep
=
"
\t
"
)
df
.
to_csv
(
csv_file_path
,
index
=
False
,
sep
=
"
\t
"
)
tf
.
logging
.
info
(
"Successfully generated csv file {}"
.
format
(
csv_file_path
))
tf
.
compat
.
v1
.
logging
.
info
(
"Successfully generated csv file {}"
.
format
(
csv_file_path
))
def
download_and_process_datasets
(
directory
,
datasets
):
def
download_and_process_datasets
(
directory
,
datasets
):
...
@@ -160,10 +160,10 @@ def download_and_process_datasets(directory, datasets):
...
@@ -160,10 +160,10 @@ def download_and_process_datasets(directory, datasets):
datasets: list of dataset names that will be downloaded and processed.
datasets: list of dataset names that will be downloaded and processed.
"""
"""
tf
.
logging
.
info
(
"Preparing LibriSpeech dataset: {}"
.
format
(
tf
.
compat
.
v1
.
logging
.
info
(
"Preparing LibriSpeech dataset: {}"
.
format
(
","
.
join
(
datasets
)))
","
.
join
(
datasets
)))
for
dataset
in
datasets
:
for
dataset
in
datasets
:
tf
.
logging
.
info
(
"Preparing dataset %s"
,
dataset
)
tf
.
compat
.
v1
.
logging
.
info
(
"Preparing dataset %s"
,
dataset
)
dataset_dir
=
os
.
path
.
join
(
directory
,
dataset
)
dataset_dir
=
os
.
path
.
join
(
directory
,
dataset
)
download_and_extract
(
dataset_dir
,
LIBRI_SPEECH_URLS
[
dataset
])
download_and_extract
(
dataset_dir
,
LIBRI_SPEECH_URLS
[
dataset
])
convert_audio_and_split_transcript
(
convert_audio_and_split_transcript
(
...
@@ -185,8 +185,8 @@ def define_data_download_flags():
...
@@ -185,8 +185,8 @@ def define_data_download_flags():
def
main
(
_
):
def
main
(
_
):
if
not
tf
.
gfile
.
E
xists
(
FLAGS
.
data_dir
):
if
not
tf
.
io
.
gfile
.
e
xists
(
FLAGS
.
data_dir
):
tf
.
gfile
.
M
ake
D
irs
(
FLAGS
.
data_dir
)
tf
.
io
.
gfile
.
m
ake
d
irs
(
FLAGS
.
data_dir
)
if
FLAGS
.
train_only
:
if
FLAGS
.
train_only
:
download_and_process_datasets
(
download_and_process_datasets
(
...
@@ -202,7 +202,7 @@ def main(_):
...
@@ -202,7 +202,7 @@ def main(_):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
INFO
)
define_data_download_flags
()
define_data_download_flags
()
FLAGS
=
absl_flags
.
FLAGS
FLAGS
=
absl_flags
.
FLAGS
absl_app
.
run
(
main
)
absl_app
.
run
(
main
)
research/deep_speech/deep_speech.py
View file @
3311242a
...
@@ -61,25 +61,10 @@ def compute_length_after_conv(max_time_steps, ctc_time_steps, input_length):
...
@@ -61,25 +61,10 @@ def compute_length_after_conv(max_time_steps, ctc_time_steps, input_length):
Returns:
Returns:
the ctc_input_length after convolution layer.
the ctc_input_length after convolution layer.
"""
"""
ctc_input_length
=
tf
.
to_float
(
tf
.
multiply
(
ctc_input_length
=
tf
.
cast
(
tf
.
multiply
(
input_length
,
ctc_time_steps
))
input_length
,
ctc_time_steps
),
dtype
=
tf
.
float32
)
return
tf
.
to_int32
(
tf
.
floordiv
(
return
tf
.
cast
(
tf
.
math
.
floordiv
(
ctc_input_length
,
tf
.
to_float
(
max_time_steps
)))
ctc_input_length
,
tf
.
cast
(
max_time_steps
,
dtype
=
tf
.
float32
)),
dtype
=
tf
.
int32
)
def
ctc_loss
(
label_length
,
ctc_input_length
,
labels
,
logits
):
"""Computes the ctc loss for the current batch of predictions."""
label_length
=
tf
.
to_int32
(
tf
.
squeeze
(
label_length
))
ctc_input_length
=
tf
.
to_int32
(
tf
.
squeeze
(
ctc_input_length
))
sparse_labels
=
tf
.
to_int32
(
tf
.
keras
.
backend
.
ctc_label_dense_to_sparse
(
labels
,
label_length
))
y_pred
=
tf
.
log
(
tf
.
transpose
(
logits
,
perm
=
[
1
,
0
,
2
])
+
tf
.
keras
.
backend
.
epsilon
())
return
tf
.
expand_dims
(
tf
.
nn
.
ctc_loss
(
labels
=
sparse_labels
,
inputs
=
y_pred
,
sequence_length
=
ctc_input_length
),
axis
=
1
)
def
evaluate_model
(
estimator
,
speech_labels
,
entries
,
input_fn_eval
):
def
evaluate_model
(
estimator
,
speech_labels
,
entries
,
input_fn_eval
):
...
@@ -123,11 +108,11 @@ def evaluate_model(estimator, speech_labels, entries, input_fn_eval):
...
@@ -123,11 +108,11 @@ def evaluate_model(estimator, speech_labels, entries, input_fn_eval):
total_cer
/=
num_of_examples
total_cer
/=
num_of_examples
total_wer
/=
num_of_examples
total_wer
/=
num_of_examples
global_step
=
estimator
.
get_variable_value
(
tf
.
GraphKeys
.
GLOBAL_STEP
)
global_step
=
estimator
.
get_variable_value
(
tf
.
compat
.
v1
.
GraphKeys
.
GLOBAL_STEP
)
eval_results
=
{
eval_results
=
{
_WER_KEY
:
total_wer
,
_WER_KEY
:
total_wer
,
_CER_KEY
:
total_cer
,
_CER_KEY
:
total_cer
,
tf
.
GraphKeys
.
GLOBAL_STEP
:
global_step
,
tf
.
compat
.
v1
.
GraphKeys
.
GLOBAL_STEP
:
global_step
,
}
}
return
eval_results
return
eval_results
...
@@ -163,7 +148,7 @@ def model_fn(features, labels, mode, params):
...
@@ -163,7 +148,7 @@ def model_fn(features, labels, mode, params):
logits
=
model
(
features
,
training
=
False
)
logits
=
model
(
features
,
training
=
False
)
predictions
=
{
predictions
=
{
"classes"
:
tf
.
argmax
(
logits
,
axis
=
2
),
"classes"
:
tf
.
argmax
(
logits
,
axis
=
2
),
"probabilities"
:
tf
.
nn
.
softmax
(
logits
)
,
"probabilities"
:
logits
,
"logits"
:
logits
"logits"
:
logits
}
}
return
tf
.
estimator
.
EstimatorSpec
(
return
tf
.
estimator
.
EstimatorSpec
(
...
@@ -172,17 +157,16 @@ def model_fn(features, labels, mode, params):
...
@@ -172,17 +157,16 @@ def model_fn(features, labels, mode, params):
# In training mode.
# In training mode.
logits
=
model
(
features
,
training
=
True
)
logits
=
model
(
features
,
training
=
True
)
probs
=
tf
.
nn
.
softmax
(
logits
)
ctc_input_length
=
compute_length_after_conv
(
ctc_input_length
=
compute_length_after_conv
(
tf
.
shape
(
features
)[
1
],
tf
.
shape
(
prob
s
)[
1
],
input_length
)
tf
.
shape
(
features
)[
1
],
tf
.
shape
(
logit
s
)[
1
],
input_length
)
# Compute CTC loss
# Compute CTC loss
loss
=
tf
.
reduce_mean
(
ctc_l
os
s
(
loss
=
tf
.
reduce_mean
(
tf
.
keras
.
backend
.
ctc_batch_c
os
t
(
label
_length
,
ctc_input_length
,
label
s
,
probs
))
label
s
,
logits
,
ctc_input_length
,
label
_length
))
optimizer
=
tf
.
train
.
AdamOptimizer
(
learning_rate
=
flags_obj
.
learning_rate
)
optimizer
=
tf
.
compat
.
v1
.
train
.
AdamOptimizer
(
learning_rate
=
flags_obj
.
learning_rate
)
global_step
=
tf
.
train
.
get_or_create_global_step
()
global_step
=
tf
.
compat
.
v1
.
train
.
get_or_create_global_step
()
minimize_op
=
optimizer
.
minimize
(
loss
,
global_step
=
global_step
)
minimize_op
=
optimizer
.
minimize
(
loss
,
global_step
=
global_step
)
update_ops
=
tf
.
get_collection
(
tf
.
GraphKeys
.
UPDATE_OPS
)
update_ops
=
tf
.
compat
.
v1
.
get_collection
(
tf
.
compat
.
v1
.
GraphKeys
.
UPDATE_OPS
)
# Create the train_op that groups both minimize_ops and update_ops
# Create the train_op that groups both minimize_ops and update_ops
train_op
=
tf
.
group
(
minimize_op
,
update_ops
)
train_op
=
tf
.
group
(
minimize_op
,
update_ops
)
...
@@ -239,9 +223,9 @@ def per_device_batch_size(batch_size, num_gpus):
...
@@ -239,9 +223,9 @@ def per_device_batch_size(batch_size, num_gpus):
def
run_deep_speech
(
_
):
def
run_deep_speech
(
_
):
"""Run deep speech training and eval loop."""
"""Run deep speech training and eval loop."""
tf
.
set_random_seed
(
flags_obj
.
seed
)
tf
.
compat
.
v1
.
set_random_seed
(
flags_obj
.
seed
)
# Data preprocessing
# Data preprocessing
tf
.
logging
.
info
(
"Data preprocessing..."
)
tf
.
compat
.
v1
.
logging
.
info
(
"Data preprocessing..."
)
train_speech_dataset
=
generate_dataset
(
flags_obj
.
train_data_dir
)
train_speech_dataset
=
generate_dataset
(
flags_obj
.
train_data_dir
)
eval_speech_dataset
=
generate_dataset
(
flags_obj
.
eval_data_dir
)
eval_speech_dataset
=
generate_dataset
(
flags_obj
.
eval_data_dir
)
...
@@ -287,7 +271,7 @@ def run_deep_speech(_):
...
@@ -287,7 +271,7 @@ def run_deep_speech(_):
total_training_cycle
=
(
flags_obj
.
train_epochs
//
total_training_cycle
=
(
flags_obj
.
train_epochs
//
flags_obj
.
epochs_between_evals
)
flags_obj
.
epochs_between_evals
)
for
cycle_index
in
range
(
total_training_cycle
):
for
cycle_index
in
range
(
total_training_cycle
):
tf
.
logging
.
info
(
"Starting a training cycle: %d/%d"
,
tf
.
compat
.
v1
.
logging
.
info
(
"Starting a training cycle: %d/%d"
,
cycle_index
+
1
,
total_training_cycle
)
cycle_index
+
1
,
total_training_cycle
)
# Perform batch_wise dataset shuffling
# Perform batch_wise dataset shuffling
...
@@ -298,7 +282,7 @@ def run_deep_speech(_):
...
@@ -298,7 +282,7 @@ def run_deep_speech(_):
estimator
.
train
(
input_fn
=
input_fn_train
)
estimator
.
train
(
input_fn
=
input_fn_train
)
# Evaluation
# Evaluation
tf
.
logging
.
info
(
"Starting to evaluate..."
)
tf
.
compat
.
v1
.
logging
.
info
(
"Starting to evaluate..."
)
eval_results
=
evaluate_model
(
eval_results
=
evaluate_model
(
estimator
,
eval_speech_dataset
.
speech_labels
,
estimator
,
eval_speech_dataset
.
speech_labels
,
...
@@ -306,7 +290,7 @@ def run_deep_speech(_):
...
@@ -306,7 +290,7 @@ def run_deep_speech(_):
# Log the WER and CER results.
# Log the WER and CER results.
benchmark_logger
.
log_evaluation_result
(
eval_results
)
benchmark_logger
.
log_evaluation_result
(
eval_results
)
tf
.
logging
.
info
(
tf
.
compat
.
v1
.
logging
.
info
(
"Iteration {}: WER = {:.2f}, CER = {:.2f}"
.
format
(
"Iteration {}: WER = {:.2f}, CER = {:.2f}"
.
format
(
cycle_index
+
1
,
eval_results
[
_WER_KEY
],
eval_results
[
_CER_KEY
]))
cycle_index
+
1
,
eval_results
[
_WER_KEY
],
eval_results
[
_CER_KEY
]))
...
@@ -425,7 +409,7 @@ def main(_):
...
@@ -425,7 +409,7 @@ def main(_):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
INFO
)
define_deep_speech_flags
()
define_deep_speech_flags
()
flags_obj
=
flags
.
FLAGS
flags_obj
=
flags
.
FLAGS
absl_app
.
run
(
main
)
absl_app
.
run
(
main
)
...
...
research/deep_speech/deep_speech_model.py
View file @
3311242a
...
@@ -22,9 +22,9 @@ import tensorflow as tf
...
@@ -22,9 +22,9 @@ import tensorflow as tf
# Supported rnn cells.
# Supported rnn cells.
SUPPORTED_RNNS
=
{
SUPPORTED_RNNS
=
{
"lstm"
:
tf
.
contrib
.
rnn
.
Basic
LSTMCell
,
"lstm"
:
tf
.
keras
.
layers
.
LSTMCell
,
"rnn"
:
tf
.
contrib
.
rnn
.
RNNCell
,
"rnn"
:
tf
.
keras
.
layers
.
Simple
RNNCell
,
"gru"
:
tf
.
contrib
.
rnn
.
GRUCell
,
"gru"
:
tf
.
keras
.
layers
.
GRUCell
,
}
}
# Parameters for batch normalization.
# Parameters for batch normalization.
...
@@ -53,9 +53,8 @@ def batch_norm(inputs, training):
...
@@ -53,9 +53,8 @@ def batch_norm(inputs, training):
Returns:
Returns:
tensor output from batch norm layer.
tensor output from batch norm layer.
"""
"""
return
tf
.
layers
.
batch_normalization
(
return
tf
.
keras
.
layers
.
BatchNormalization
(
inputs
=
inputs
,
momentum
=
_BATCH_NORM_DECAY
,
epsilon
=
_BATCH_NORM_EPSILON
,
momentum
=
_BATCH_NORM_DECAY
,
epsilon
=
_BATCH_NORM_EPSILON
)(
inputs
,
training
=
training
)
fused
=
True
,
training
=
training
)
def
_conv_bn_layer
(
inputs
,
padding
,
filters
,
kernel_size
,
strides
,
layer_id
,
def
_conv_bn_layer
(
inputs
,
padding
,
filters
,
kernel_size
,
strides
,
layer_id
,
...
@@ -81,10 +80,10 @@ def _conv_bn_layer(inputs, padding, filters, kernel_size, strides, layer_id,
...
@@ -81,10 +80,10 @@ def _conv_bn_layer(inputs, padding, filters, kernel_size, strides, layer_id,
inputs
=
tf
.
pad
(
inputs
=
tf
.
pad
(
inputs
,
inputs
,
[[
0
,
0
],
[
padding
[
0
],
padding
[
0
]],
[
padding
[
1
],
padding
[
1
]],
[
0
,
0
]])
[[
0
,
0
],
[
padding
[
0
],
padding
[
0
]],
[
padding
[
1
],
padding
[
1
]],
[
0
,
0
]])
inputs
=
tf
.
layers
.
c
onv2
d
(
inputs
=
tf
.
keras
.
layers
.
C
onv2
D
(
inputs
=
inputs
,
filters
=
filters
,
kernel_size
=
kernel_size
,
strides
=
strides
,
filters
=
filters
,
kernel_size
=
kernel_size
,
strides
=
strides
,
padding
=
"valid"
,
use_bias
=
False
,
activation
=
tf
.
nn
.
relu6
,
padding
=
"valid"
,
use_bias
=
False
,
activation
=
tf
.
nn
.
relu6
,
name
=
"cnn_{}"
.
format
(
layer_id
))
name
=
"cnn_{}"
.
format
(
layer_id
))
(
inputs
)
return
batch_norm
(
inputs
,
training
)
return
batch_norm
(
inputs
,
training
)
...
@@ -109,24 +108,16 @@ def _rnn_layer(inputs, rnn_cell, rnn_hidden_size, layer_id, is_batch_norm,
...
@@ -109,24 +108,16 @@ def _rnn_layer(inputs, rnn_cell, rnn_hidden_size, layer_id, is_batch_norm,
if
is_batch_norm
:
if
is_batch_norm
:
inputs
=
batch_norm
(
inputs
,
training
)
inputs
=
batch_norm
(
inputs
,
training
)
# Construct forward/backward RNN cells.
fw_cell
=
rnn_cell
(
num_units
=
rnn_hidden_size
,
name
=
"rnn_fw_{}"
.
format
(
layer_id
))
bw_cell
=
rnn_cell
(
num_units
=
rnn_hidden_size
,
name
=
"rnn_bw_{}"
.
format
(
layer_id
))
if
is_bidirectional
:
if
is_bidirectional
:
outputs
,
_
=
tf
.
nn
.
bidirectional_dynamic_rnn
(
rnn_outputs
=
tf
.
keras
.
layers
.
Bidirectional
(
cell_fw
=
fw_cell
,
cell_bw
=
bw_cell
,
inputs
=
inputs
,
dtype
=
tf
.
float32
,
tf
.
keras
.
layers
.
RNN
(
rnn_cell
(
rnn_hidden_size
),
swap_memory
=
True
)
return_sequences
=
True
))(
inputs
)
rnn_outputs
=
tf
.
concat
(
outputs
,
-
1
)
else
:
else
:
rnn_outputs
=
tf
.
nn
.
dynamic_rnn
(
rnn_outputs
=
tf
.
keras
.
layers
.
RNN
(
fw
_cell
,
inputs
,
dtype
=
tf
.
float32
,
swap_memory
=
True
)
rnn
_cell
(
rnn_hidden_size
),
return_sequences
=
True
)(
inputs
)
return
rnn_outputs
return
rnn_outputs
class
DeepSpeech2
(
object
):
class
DeepSpeech2
(
object
):
"""Define DeepSpeech2 model."""
"""Define DeepSpeech2 model."""
...
@@ -179,7 +170,8 @@ class DeepSpeech2(object):
...
@@ -179,7 +170,8 @@ class DeepSpeech2(object):
# FC layer with batch norm.
# FC layer with batch norm.
inputs
=
batch_norm
(
inputs
,
training
)
inputs
=
batch_norm
(
inputs
,
training
)
logits
=
tf
.
layers
.
dense
(
inputs
,
self
.
num_classes
,
use_bias
=
self
.
use_bias
)
logits
=
tf
.
keras
.
layers
.
Dense
(
self
.
num_classes
,
use_bias
=
self
.
use_bias
,
activation
=
"softmax"
)(
inputs
)
return
logits
return
logits
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment