Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
ed175137
Commit
ed175137
authored
Aug 21, 2019
by
jamarshon
Committed by
cpuhrsch
Aug 21, 2019
Browse files
Increasing test coverage (ASR demo) (#248)
parent
42a705d5
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
328 additions
and
187 deletions
+328
-187
.gitignore
.gitignore
+5
-0
.travis.yml
.travis.yml
+10
-5
build_tools/travis/install.sh
build_tools/travis/install.sh
+24
-1
build_tools/travis/test_script.sh
build_tools/travis/test_script.sh
+12
-0
examples/interactive_asr/README.md
examples/interactive_asr/README.md
+24
-9
examples/interactive_asr/__init__.py
examples/interactive_asr/__init__.py
+1
-0
examples/interactive_asr/asr.py
examples/interactive_asr/asr.py
+9
-172
examples/interactive_asr/utils.py
examples/interactive_asr/utils.py
+187
-0
examples/test/__init__.py
examples/test/__init__.py
+0
-0
examples/test/test_interactive_asr.py
examples/test/test_interactive_asr.py
+56
-0
No files found.
.gitignore
View file @
ed175137
...
@@ -110,3 +110,8 @@ ENV/
...
@@ -110,3 +110,8 @@ ENV/
test/assets/sinewave.wav
test/assets/sinewave.wav
torchaudio/version.py
torchaudio/version.py
gen.yml
gen.yml
# Examples
examples/interactive_asr/data/*.txt
examples/interactive_asr/data/*.model
examples/interactive_asr/data/*.pt
.travis.yml
View file @
ed175137
...
@@ -8,15 +8,19 @@ cache:
...
@@ -8,15 +8,19 @@ cache:
directories
:
directories
:
-
/home/travis/download
-
/home/travis/download
# This matrix tests that the code works on Python 3.5, 3.6, and passes lint.
# This matrix tests that the code works on Python 2.7, 3.5, 3.6, 3.7, passes
# lint and example tests.
matrix
:
matrix
:
fast_finish
:
true
fast_finish
:
true
include
:
include
:
-
env
:
PYTHON_VERSION="3.7"
-
env
:
PYTHON_VERSION="3.6"
# TODO add this back in when there is a pytorch 1.2 for python 3.5
-
env
:
PYTHON_VERSION="3.5" RUN_FLAKE8="true" SKIP_TESTS="true"
-
env
:
PYTHON_VERSION="2.7"
-
env
:
PYTHON_VERSION="2.7"
-
env
:
PYTHON_VERSION="3.5"
-
env
:
PYTHON_VERSION="3.6"
-
env
:
PYTHON_VERSION="3.7"
-
env
:
PYTHON_VERSION="3.5" RUN_FLAKE8="true" SKIP_INSTALL="true" SKIP_TESTS="true"
-
env
:
PYTHON_VERSION="3.5" RUN_EXAMPLE_TESTS="true" SKIP_TESTS="true"
allow_failures
:
-
env
:
PYTHON_VERSION="3.5" RUN_EXAMPLE_TESTS="true" SKIP_TESTS="true"
addons
:
addons
:
apt
:
apt
:
...
@@ -24,6 +28,7 @@ addons:
...
@@ -24,6 +28,7 @@ addons:
sox
sox
libsox-dev
libsox-dev
libsox-fmt-all
libsox-fmt-all
portaudio19-dev
notifications
:
notifications
:
email
:
false
email
:
false
...
...
build_tools/travis/install.sh
View file @
ed175137
...
@@ -51,7 +51,30 @@ source activate testenv
...
@@ -51,7 +51,30 @@ source activate testenv
pip
install
-r
requirements.txt
pip
install
-r
requirements.txt
# Install the following only if running tests
# Install the following only if running tests
if
[[
"
$SKIP_
TESTS
"
!=
"true"
]]
;
then
if
[[
"
$SKIP_
INSTALL
"
!=
"true"
]]
;
then
# TorchAudio CPP Extensions
# TorchAudio CPP Extensions
python setup.py
install
python setup.py
install
fi
fi
if
[[
"
$RUN_EXAMPLE_TESTS
"
==
"true"
]]
;
then
# Install dependencies
pip
install
sentencepiece PyAudio
if
[[
!
-d
$HOME
/download/fairseq
]]
;
then
# Install fairseq from source
git clone https://github.com/pytorch/fairseq
$HOME
/download/fairseq
fi
pushd
$HOME
/download/fairseq
pip
install
--editable
.
popd
mkdir
-p
$HOME
/download/data
# Install dictionary, sentence piece model, and model
# These are cached so they are not downloaded if they already exist
wget
-nc
-O
$HOME
/download/data/dict.txt https://download.pytorch.org/models/audio/dict.txt
||
true
wget
-nc
-O
$HOME
/download/data/spm.model https://download.pytorch.org/models/audio/spm.model
||
true
wget
-nc
-O
$HOME
/download/data/model.pt https://download.pytorch.org/models/audio/checkpoint_avg_60_80.pt
||
true
fi
echo
"Finished installation"
build_tools/travis/test_script.sh
View file @
ed175137
...
@@ -32,5 +32,17 @@ if [[ "$RUN_FLAKE8" == "true" ]]; then
...
@@ -32,5 +32,17 @@ if [[ "$RUN_FLAKE8" == "true" ]]; then
fi
fi
if
[[
"
$SKIP_TESTS
"
!=
"true"
]]
;
then
if
[[
"
$SKIP_TESTS
"
!=
"true"
]]
;
then
echo
"run_tests"
run_tests
run_tests
fi
fi
if
[[
"
$RUN_EXAMPLE_TESTS
"
==
"true"
]]
;
then
echo
"run_example_tests"
pushd
examples
ASR_MODEL_PATH
=
$HOME
/download/data/model.pt
\
ASR_INPUT_FILE
=
interactive_asr/data/sample.wav
\
ASR_DATA_PATH
=
$HOME
/download/data
\
ASR_USER_DIR
=
$HOME
/download/fairseq/examples/speech_recognition
\
python
-m
unittest
test
/test_interactive_asr.py
popd
fi
examples/interactive_asr/README.md
View file @
ed175137
...
@@ -16,6 +16,9 @@ and the following models
...
@@ -16,6 +16,9 @@ and the following models
We recommend that you use
[
conda
](
https://docs.conda.io/en/latest/miniconda.html
)
to install the dependencies when available.
We recommend that you use
[
conda
](
https://docs.conda.io/en/latest/miniconda.html
)
to install the dependencies when available.
```
bash
```
bash
# Assume that all commands are from the examples folder
cd
examples
# Install dependencies
# Install dependencies
conda
install
-c
pytorch torchaudio
conda
install
-c
pytorch torchaudio
conda
install
-c
conda-forge librosa
conda
install
-c
conda-forge librosa
...
@@ -23,26 +26,38 @@ conda install pyaudio
...
@@ -23,26 +26,38 @@ conda install pyaudio
pip
install
sentencepiece
pip
install
sentencepiece
# Install fairseq from source
# Install fairseq from source
git clone https://github.com/pytorch/fairseq
git clone https://github.com/pytorch/fairseq
interactive_asr/fairseq
cd
fairseq
pushd
interactive_asr/
fairseq
export
CFLAGS
=
'-stdlib=libc++'
# For Mac only
export
CFLAGS
=
'-stdlib=libc++'
# For Mac only
pip
install
--editable
.
pip
install
--editable
.
cd
..
popd
# Install dictionary, sentence piece model, and model
# Install dictionary, sentence piece model, and model
wget
-O
.
/data/dict.txt https://download.pytorch.org/models/audio/dict.txt
wget
-O
interactive_asr
/data/dict.txt https://download.pytorch.org/models/audio/dict.txt
wget
-O
.
/data/spm.model https://download.pytorch.org/models/audio/spm.model
wget
-O
interactive_asr
/data/spm.model https://download.pytorch.org/models/audio/spm.model
wget
-O
.
/data/model.pt https://download.pytorch.org/models/audio/checkpoint_avg_60_80.pt
wget
-O
interactive_asr
/data/model.pt https://download.pytorch.org/models/audio/checkpoint_avg_60_80.pt
```
```
## Run
## Run
On a file
On a file
```
bash
```
bash
INPUT_FILE
=
./data/sample.wav
INPUT_FILE
=
interactive_asr/data/sample.wav
python asr.py ./data
--input_file
$INPUT_FILE
--max-tokens
10000000
--nbest
1
--path
./data/model.pt
--beam
40
--task
speech_recognition
--user-dir
./fairseq/examples/speech_recognition
python
-m
interactive_asr.asr interactive_asr/data
--input_file
$INPUT_FILE
--max-tokens
10000000
--nbest
1
\
--path
interactive_asr/data/model.pt
--beam
40
--task
speech_recognition
\
--user-dir
interactive_asr/fairseq/examples/speech_recognition
```
```
As a microphone
As a microphone
```
bash
```
bash
python asr.py ./data
--max-tokens
10000000
--nbest
1
--path
./data/model.pt
--beam
40
--task
speech_recognition
--user-dir
./fairseq/examples/speech_recognition
python
-m
interactive_asr.asr interactive_asr/data
--max-tokens
10000000
--nbest
1
\
--path
interactive_asr/data/model.pt
--beam
40
--task
speech_recognition
\
--user-dir
interactive_asr/fairseq/examples/speech_recognition
```
To run the testcase associated with this example
```
bash
ASR_MODEL_PATH
=
interactive_asr/data/model.pt
\
ASR_INPUT_FILE
=
interactive_asr/data/sample.wav
\
ASR_DATA_PATH
=
interactive_asr/data
\
ASR_USER_DIR
=
interactive_asr/fairseq/examples/speech_recognition
\
python
-m
unittest
test
/test_interactive_asr.py
```
```
examples/interactive_asr/__init__.py
0 → 100644
View file @
ed175137
from
.
import
utils
,
vad
examples/interactive_asr/asr.py
View file @
ed175137
...
@@ -11,187 +11,24 @@ Run inference for pre-processed data with a trained model.
...
@@ -11,187 +11,24 @@ Run inference for pre-processed data with a trained model.
import
datetime
as
dt
import
datetime
as
dt
import
logging
import
logging
import
os
import
sys
import
time
import
torch
from
fairseq
import
options
import
sentencepiece
as
spm
from
interactive_asr.utils
import
add_asr_eval_argument
,
setup_asr
,
get_microphone_transcription
,
transcribe_file
import
torchaudio
from
fairseq
import
options
,
tasks
,
utils
from
fairseq.meters
import
StopwatchMeter
,
TimeMeter
from
fairseq.utils
import
import_user_module
from
vad
import
get_microphone_chunks
logger
=
logging
.
getLogger
(
__name__
)
logger
.
setLevel
(
logging
.
INFO
)
def
add_asr_eval_argument
(
parser
):
parser
.
add_argument
(
"--input_file"
,
help
=
"input file"
)
parser
.
add_argument
(
"--ctc"
,
action
=
"store_true"
,
help
=
"decode a ctc model"
)
parser
.
add_argument
(
"--rnnt"
,
default
=
False
,
help
=
"decode a rnnt model"
)
parser
.
add_argument
(
"--kspmodel"
,
default
=
None
,
help
=
"sentence piece model"
)
parser
.
add_argument
(
"--wfstlm"
,
default
=
None
,
help
=
"wfstlm on dictonary output units"
)
parser
.
add_argument
(
"--rnnt_decoding_type"
,
default
=
"greedy"
,
help
=
"wfstlm on dictonary output units"
,
)
parser
.
add_argument
(
"--lm_weight"
,
default
=
0.2
,
help
=
"weight for wfstlm while interpolating with neural score"
,
)
parser
.
add_argument
(
"--rnnt_len_penalty"
,
default
=-
0.5
,
help
=
"rnnt length penalty on word level"
)
return
parser
def
check_args
(
args
):
assert
args
.
path
is
not
None
,
"--path required for generation!"
assert
(
not
args
.
sampling
or
args
.
nbest
==
args
.
beam
),
"--sampling requires --nbest to be equal to --beam"
assert
(
args
.
replace_unk
is
None
or
args
.
raw_text
),
"--replace-unk requires a raw text dataset (--raw-text)"
def
process_predictions
(
args
,
hypos
,
sp
,
tgt_dict
):
res
=
[]
for
hypo
in
hypos
[:
min
(
len
(
hypos
),
args
.
nbest
)]:
hyp_pieces
=
tgt_dict
.
string
(
hypo
[
"tokens"
].
int
().
cpu
())
hyp_words
=
sp
.
DecodePieces
(
hyp_pieces
.
split
())
res
.
append
(
hyp_words
)
return
res
def
optimize_models
(
args
,
use_cuda
,
models
):
"""Optimize ensemble for generation
"""
for
model
in
models
:
model
.
make_generation_fast_
(
beamable_mm_beam_size
=
None
if
args
.
no_beamable_mm
else
args
.
beam
,
need_attn
=
args
.
print_alignment
,
)
if
args
.
fp16
:
model
.
half
()
if
use_cuda
:
model
.
cuda
()
def
calc_mean_invstddev
(
feature
):
if
len
(
feature
.
shape
)
!=
2
:
raise
ValueError
(
"We expect the input feature to be 2-D tensor"
)
mean
=
torch
.
mean
(
feature
,
dim
=
0
)
var
=
torch
.
var
(
feature
,
dim
=
0
)
# avoid division by ~zero
if
(
var
<
sys
.
float_info
.
epsilon
).
any
():
return
mean
,
1.0
/
(
torch
.
sqrt
(
var
)
+
sys
.
float_info
.
epsilon
)
return
mean
,
1.0
/
torch
.
sqrt
(
var
)
def
calcMN
(
features
):
mean
,
invstddev
=
calc_mean_invstddev
(
features
)
res
=
(
features
-
mean
)
*
invstddev
return
res
def
transcribe
(
waveform
,
args
,
task
,
generator
,
models
,
sp
,
tgt_dict
):
num_features
=
80
output
=
torchaudio
.
compliance
.
kaldi
.
fbank
(
waveform
,
num_mel_bins
=
num_features
)
output_cmvn
=
calcMN
(
output
.
cpu
().
detach
())
# size (m, n)
source
=
torch
.
tensor
(
output_cmvn
)
frames_lengths
=
torch
.
LongTensor
([
source
.
size
(
0
)])
# size (1, m, n). In general, if source is (x, m, n), then hypos is (x, ...)
source
.
unsqueeze_
(
0
)
sample
=
{
"net_input"
:
{
"src_tokens"
:
source
,
"src_lengths"
:
frames_lengths
}}
hypos
=
task
.
inference_step
(
generator
,
models
,
sample
)
assert
len
(
hypos
)
==
1
transcription
=
[]
for
i
in
range
(
len
(
hypos
)):
# Process top predictions
hyp_words
=
process_predictions
(
args
,
hypos
[
i
],
sp
,
tgt_dict
)
transcription
.
append
(
hyp_words
)
return
transcription
def
main
(
args
):
def
main
(
args
):
check_args
(
args
)
logger
=
logging
.
getLogger
(
__name__
)
import_user_module
(
args
)
logger
.
setLevel
(
logging
.
INFO
)
task
,
generator
,
models
,
sp
,
tgt_dict
=
setup_asr
(
args
,
logger
)
if
args
.
max_tokens
is
None
and
args
.
max_sentences
is
None
:
args
.
max_tokens
=
30000
logger
.
info
(
args
)
use_cuda
=
torch
.
cuda
.
is_available
()
and
not
args
.
cpu
# Load dataset splits
task
=
tasks
.
setup_task
(
args
)
# Set dictionary
tgt_dict
=
task
.
target_dictionary
if
args
.
ctc
or
args
.
rnnt
:
tgt_dict
.
add_symbol
(
"<ctc_blank>"
)
if
args
.
ctc
:
logger
.
info
(
"| decoding a ctc model"
)
if
args
.
rnnt
:
logger
.
info
(
"| decoding a rnnt model"
)
# Load ensemble
logger
.
info
(
"| loading model(s) from {}"
.
format
(
args
.
path
))
models
,
_model_args
=
utils
.
load_ensemble_for_inference
(
args
.
path
.
split
(
":"
),
task
,
model_arg_overrides
=
eval
(
args
.
model_overrides
),
# noqa
)
optimize_models
(
args
,
use_cuda
,
models
)
# Initialize generator
generator
=
task
.
build_generator
(
args
)
sp
=
spm
.
SentencePieceProcessor
()
sp
.
Load
(
os
.
path
.
join
(
args
.
data
,
"spm.model"
))
print
(
"READY!"
)
if
args
.
input_file
:
if
args
.
input_file
:
path
=
args
.
input_file
transcription_time
,
transcription
=
transcribe_file
(
args
,
task
,
generator
,
models
,
sp
,
tgt_dict
)
if
not
os
.
path
.
exists
(
path
):
raise
FileNotFoundError
(
"Audio file not found: {}"
.
format
(
path
))
waveform
,
sample_rate
=
torchaudio
.
load_wav
(
path
)
waveform
=
waveform
.
mean
(
0
,
True
)
waveform
=
torchaudio
.
transforms
.
Resample
(
orig_freq
=
sample_rate
,
new_freq
=
16000
)(
waveform
)
print
(
sample_rate
,
waveform
.
shape
)
start
=
time
.
time
()
transcription
=
transcribe
(
waveform
,
args
,
task
,
generator
,
models
,
sp
,
tgt_dict
)
end
=
time
.
time
()
print
(
"transcription:"
,
transcription
)
print
(
"transcription:"
,
transcription
)
print
(
end
-
start
)
print
(
"transcription_time:"
,
transcription_time
)
else
:
else
:
print
(
"READY!"
)
for
transcription
in
get_microphone_transcription
(
args
,
task
,
generator
,
models
,
sp
,
tgt_dict
):
for
(
waveform
,
sample_rate
)
in
get_microphone_chunks
():
waveform
=
torchaudio
.
transforms
.
Resample
(
orig_freq
=
sample_rate
,
new_freq
=
16000
)(
waveform
.
reshape
(
1
,
-
1
))
transcription
=
transcribe
(
waveform
,
args
,
task
,
generator
,
models
,
sp
,
tgt_dict
)
print
(
print
(
"{}: {}"
.
format
(
"{}: {}"
.
format
(
dt
.
datetime
.
now
().
strftime
(
"%H:%M:%S"
),
transcription
[
0
][
0
]
dt
.
datetime
.
now
().
strftime
(
"%H:%M:%S"
),
transcription
[
0
][
0
]
...
...
examples/interactive_asr/utils.py
0 → 100644
View file @
ed175137
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
os
import
sys
import
time
import
torch
import
torchaudio
import
sentencepiece
as
spm
from
fairseq
import
tasks
from
fairseq.utils
import
load_ensemble_for_inference
,
import_user_module
from
interactive_asr.vad
import
get_microphone_chunks
def
add_asr_eval_argument
(
parser
):
parser
.
add_argument
(
"--input_file"
,
help
=
"input file"
)
parser
.
add_argument
(
"--ctc"
,
action
=
"store_true"
,
help
=
"decode a ctc model"
)
parser
.
add_argument
(
"--rnnt"
,
default
=
False
,
help
=
"decode a rnnt model"
)
parser
.
add_argument
(
"--kspmodel"
,
default
=
None
,
help
=
"sentence piece model"
)
parser
.
add_argument
(
"--wfstlm"
,
default
=
None
,
help
=
"wfstlm on dictonary output units"
)
parser
.
add_argument
(
"--rnnt_decoding_type"
,
default
=
"greedy"
,
help
=
"wfstlm on dictonary output units"
,
)
parser
.
add_argument
(
"--lm_weight"
,
default
=
0.2
,
help
=
"weight for wfstlm while interpolating with neural score"
,
)
parser
.
add_argument
(
"--rnnt_len_penalty"
,
default
=-
0.5
,
help
=
"rnnt length penalty on word level"
)
return
parser
def
check_args
(
args
):
assert
args
.
path
is
not
None
,
"--path required for generation!"
assert
(
not
args
.
sampling
or
args
.
nbest
==
args
.
beam
),
"--sampling requires --nbest to be equal to --beam"
assert
(
args
.
replace_unk
is
None
or
args
.
raw_text
),
"--replace-unk requires a raw text dataset (--raw-text)"
def
process_predictions
(
args
,
hypos
,
sp
,
tgt_dict
):
res
=
[]
for
hypo
in
hypos
[:
min
(
len
(
hypos
),
args
.
nbest
)]:
hyp_pieces
=
tgt_dict
.
string
(
hypo
[
"tokens"
].
int
().
cpu
())
hyp_words
=
sp
.
DecodePieces
(
hyp_pieces
.
split
())
res
.
append
(
hyp_words
)
return
res
def
optimize_models
(
args
,
use_cuda
,
models
):
"""Optimize ensemble for generation
"""
for
model
in
models
:
model
.
make_generation_fast_
(
beamable_mm_beam_size
=
None
if
args
.
no_beamable_mm
else
args
.
beam
,
need_attn
=
args
.
print_alignment
,
)
if
args
.
fp16
:
model
.
half
()
if
use_cuda
:
model
.
cuda
()
def
calc_mean_invstddev
(
feature
):
if
len
(
feature
.
shape
)
!=
2
:
raise
ValueError
(
"We expect the input feature to be 2-D tensor"
)
mean
=
torch
.
mean
(
feature
,
dim
=
0
)
var
=
torch
.
var
(
feature
,
dim
=
0
)
# avoid division by ~zero
if
(
var
<
sys
.
float_info
.
epsilon
).
any
():
return
mean
,
1.0
/
(
torch
.
sqrt
(
var
)
+
sys
.
float_info
.
epsilon
)
return
mean
,
1.0
/
torch
.
sqrt
(
var
)
def
calcMN
(
features
):
mean
,
invstddev
=
calc_mean_invstddev
(
features
)
res
=
(
features
-
mean
)
*
invstddev
return
res
def
transcribe
(
waveform
,
args
,
task
,
generator
,
models
,
sp
,
tgt_dict
):
num_features
=
80
output
=
torchaudio
.
compliance
.
kaldi
.
fbank
(
waveform
,
num_mel_bins
=
num_features
)
output_cmvn
=
calcMN
(
output
.
cpu
().
detach
())
# size (m, n)
source
=
output_cmvn
frames_lengths
=
torch
.
LongTensor
([
source
.
size
(
0
)])
# size (1, m, n). In general, if source is (x, m, n), then hypos is (x, ...)
source
.
unsqueeze_
(
0
)
sample
=
{
"net_input"
:
{
"src_tokens"
:
source
,
"src_lengths"
:
frames_lengths
}}
hypos
=
task
.
inference_step
(
generator
,
models
,
sample
)
assert
len
(
hypos
)
==
1
transcription
=
[]
for
i
in
range
(
len
(
hypos
)):
# Process top predictions
hyp_words
=
process_predictions
(
args
,
hypos
[
i
],
sp
,
tgt_dict
)
transcription
.
append
(
hyp_words
)
return
transcription
def
setup_asr
(
args
,
logger
):
check_args
(
args
)
import_user_module
(
args
)
if
args
.
max_tokens
is
None
and
args
.
max_sentences
is
None
:
args
.
max_tokens
=
30000
logger
.
info
(
args
)
use_cuda
=
torch
.
cuda
.
is_available
()
and
not
args
.
cpu
# Load dataset splits
task
=
tasks
.
setup_task
(
args
)
# Set dictionary
tgt_dict
=
task
.
target_dictionary
if
args
.
ctc
or
args
.
rnnt
:
tgt_dict
.
add_symbol
(
"<ctc_blank>"
)
if
args
.
ctc
:
logger
.
info
(
"| decoding a ctc model"
)
if
args
.
rnnt
:
logger
.
info
(
"| decoding a rnnt model"
)
# Load ensemble
logger
.
info
(
"| loading model(s) from {}"
.
format
(
args
.
path
))
models
,
_model_args
=
load_ensemble_for_inference
(
args
.
path
.
split
(
":"
),
task
,
model_arg_overrides
=
eval
(
args
.
model_overrides
),
# noqa
)
optimize_models
(
args
,
use_cuda
,
models
)
# Initialize generator
generator
=
task
.
build_generator
(
args
)
sp
=
spm
.
SentencePieceProcessor
()
sp
.
Load
(
os
.
path
.
join
(
args
.
data
,
"spm.model"
))
return
task
,
generator
,
models
,
sp
,
tgt_dict
def
transcribe_file
(
args
,
task
,
generator
,
models
,
sp
,
tgt_dict
):
path
=
args
.
input_file
if
not
os
.
path
.
exists
(
path
):
raise
FileNotFoundError
(
"Audio file not found: {}"
.
format
(
path
))
waveform
,
sample_rate
=
torchaudio
.
load_wav
(
path
)
waveform
=
waveform
.
mean
(
0
,
True
)
waveform
=
torchaudio
.
transforms
.
Resample
(
orig_freq
=
sample_rate
,
new_freq
=
16000
)(
waveform
)
start
=
time
.
time
()
transcription
=
transcribe
(
waveform
,
args
,
task
,
generator
,
models
,
sp
,
tgt_dict
)
transcription_time
=
time
.
time
()
-
start
return
transcription_time
,
transcription
def
get_microphone_transcription
(
args
,
task
,
generator
,
models
,
sp
,
tgt_dict
):
for
(
waveform
,
sample_rate
)
in
get_microphone_chunks
():
waveform
=
torchaudio
.
transforms
.
Resample
(
orig_freq
=
sample_rate
,
new_freq
=
16000
)(
waveform
.
reshape
(
1
,
-
1
))
transcription
=
transcribe
(
waveform
,
args
,
task
,
generator
,
models
,
sp
,
tgt_dict
)
yield
transcription
examples/test/__init__.py
0 → 100644
View file @
ed175137
examples/test/test_interactive_asr.py
0 → 100644
View file @
ed175137
import
argparse
import
logging
import
os
import
unittest
from
interactive_asr.utils
import
setup_asr
,
transcribe_file
class
ASRTest
(
unittest
.
TestCase
):
logger
=
logging
.
getLogger
(
__name__
)
logger
.
setLevel
(
logging
.
INFO
)
arguments_dict
=
{
'path'
:
'/scratch/jamarshon/downloads/model.pt'
,
'input_file'
:
'/scratch/jamarshon/audio/examples/interactive_asr/data/sample.wav'
,
'data'
:
'/scratch/jamarshon/downloads'
,
'user_dir'
:
'/scratch/jamarshon/fairseq-py/examples/speech_recognition'
,
'no_progress_bar'
:
False
,
'log_interval'
:
1000
,
'log_format'
:
None
,
'tensorboard_logdir'
:
''
,
'tbmf_wrapper'
:
False
,
'seed'
:
1
,
'cpu'
:
True
,
'fp16'
:
False
,
'memory_efficient_fp16'
:
False
,
'fp16_init_scale'
:
128
,
'fp16_scale_window'
:
None
,
'fp16_scale_tolerance'
:
0.0
,
'min_loss_scale'
:
0.0001
,
'threshold_loss_scale'
:
None
,
'criterion'
:
'cross_entropy'
,
'tokenizer'
:
None
,
'bpe'
:
None
,
'optimizer'
:
'nag'
,
'lr_scheduler'
:
'fixed'
,
'task'
:
'speech_recognition'
,
'num_workers'
:
0
,
'skip_invalid_size_inputs_valid_test'
:
False
,
'max_tokens'
:
10000000
,
'max_sentences'
:
None
,
'required_batch_size_multiple'
:
8
,
'dataset_impl'
:
None
,
'gen_subset'
:
'test'
,
'num_shards'
:
1
,
'shard_id'
:
0
,
'remove_bpe'
:
None
,
'quiet'
:
False
,
'model_overrides'
:
'{}'
,
'results_path'
:
None
,
'beam'
:
40
,
'nbest'
:
1
,
'max_len_a'
:
0
,
'max_len_b'
:
200
,
'min_len'
:
1
,
'match_source_len'
:
False
,
'no_early_stop'
:
False
,
'unnormalized'
:
False
,
'no_beamable_mm'
:
False
,
'lenpen'
:
1
,
'unkpen'
:
0
,
'replace_unk'
:
None
,
'sacrebleu'
:
False
,
'score_reference'
:
False
,
'prefix_size'
:
0
,
'no_repeat_ngram_size'
:
0
,
'sampling'
:
False
,
'sampling_topk'
:
-
1
,
'sampling_topp'
:
-
1.0
,
'temperature'
:
1.0
,
'diverse_beam_groups'
:
-
1
,
'diverse_beam_strength'
:
0.5
,
'print_alignment'
:
False
,
'ctc'
:
False
,
'rnnt'
:
False
,
'kspmodel'
:
None
,
'wfstlm'
:
None
,
'rnnt_decoding_type'
:
'greedy'
,
'lm_weight'
:
0.2
,
'rnnt_len_penalty'
:
-
0.5
,
'momentum'
:
0.99
,
'weight_decay'
:
0.0
,
'force_anneal'
:
None
,
'lr_shrink'
:
0.1
,
'warmup_updates'
:
0
}
arguments_dict
[
'path'
]
=
os
.
environ
.
get
(
'ASR_MODEL_PATH'
,
None
)
arguments_dict
[
'input_file'
]
=
os
.
environ
.
get
(
'ASR_INPUT_FILE'
,
None
)
arguments_dict
[
'data'
]
=
os
.
environ
.
get
(
'ASR_DATA_PATH'
,
None
)
arguments_dict
[
'user_dir'
]
=
os
.
environ
.
get
(
'ASR_USER_DIR'
,
None
)
args
=
argparse
.
Namespace
(
**
arguments_dict
)
def
test_transcribe_file
(
self
):
task
,
generator
,
models
,
sp
,
tgt_dict
=
setup_asr
(
self
.
args
,
self
.
logger
)
_
,
transcription
=
transcribe_file
(
self
.
args
,
task
,
generator
,
models
,
sp
,
tgt_dict
)
expected_transcription
=
[[
'THE QUICK BROWN FOX JUMPS OVER THE LAZY DOG'
]]
self
.
assertEqual
(
transcription
,
expected_transcription
,
msg
=
str
(
transcription
))
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment