Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
c394d7d1
Commit
c394d7d1
authored
Sep 28, 2024
by
“change”
Browse files
init
parents
Changes
347
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3217 additions
and
0 deletions
+3217
-0
examples/speech_recognition/data/replabels.py
examples/speech_recognition/data/replabels.py
+70
-0
examples/speech_recognition/datasets/asr_prep_json.py
examples/speech_recognition/datasets/asr_prep_json.py
+125
-0
examples/speech_recognition/datasets/prepare-librispeech.sh
examples/speech_recognition/datasets/prepare-librispeech.sh
+88
-0
examples/speech_recognition/infer.py
examples/speech_recognition/infer.py
+427
-0
examples/speech_recognition/kaldi/__init__.py
examples/speech_recognition/kaldi/__init__.py
+0
-0
examples/speech_recognition/kaldi/add-self-loop-simple.cc
examples/speech_recognition/kaldi/add-self-loop-simple.cc
+95
-0
examples/speech_recognition/kaldi/config/kaldi_initializer.yaml
...es/speech_recognition/kaldi/config/kaldi_initializer.yaml
+8
-0
examples/speech_recognition/kaldi/kaldi_decoder.py
examples/speech_recognition/kaldi/kaldi_decoder.py
+244
-0
examples/speech_recognition/kaldi/kaldi_initializer.py
examples/speech_recognition/kaldi/kaldi_initializer.py
+698
-0
examples/speech_recognition/models/__init__.py
examples/speech_recognition/models/__init__.py
+8
-0
examples/speech_recognition/models/vggtransformer.py
examples/speech_recognition/models/vggtransformer.py
+1019
-0
examples/speech_recognition/models/w2l_conv_glu_enc.py
examples/speech_recognition/models/w2l_conv_glu_enc.py
+177
-0
examples/speech_recognition/new/README.md
examples/speech_recognition/new/README.md
+43
-0
examples/speech_recognition/new/__init__.py
examples/speech_recognition/new/__init__.py
+0
-0
examples/speech_recognition/new/conf/hydra/sweeper/ax.yaml
examples/speech_recognition/new/conf/hydra/sweeper/ax.yaml
+26
-0
examples/speech_recognition/new/conf/infer.yaml
examples/speech_recognition/new/conf/infer.yaml
+25
-0
examples/speech_recognition/new/decoders/__init__.py
examples/speech_recognition/new/decoders/__init__.py
+0
-0
examples/speech_recognition/new/decoders/base_decoder.py
examples/speech_recognition/new/decoders/base_decoder.py
+62
-0
examples/speech_recognition/new/decoders/decoder.py
examples/speech_recognition/new/decoders/decoder.py
+32
-0
examples/speech_recognition/new/decoders/decoder_config.py
examples/speech_recognition/new/decoders/decoder_config.py
+70
-0
No files found.
Too many changes to show.
To preserve performance only
347 of 347+
files are displayed.
Plain diff
Email patch
examples/speech_recognition/data/replabels.py
0 → 100644
View file @
c394d7d1
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Replabel transforms for use with flashlight's ASG criterion.
"""
def
replabel_symbol
(
i
):
"""
Replabel symbols used in flashlight, currently just "1", "2", ...
This prevents training with numeral tokens, so this might change in the future
"""
return
str
(
i
)
def
pack_replabels
(
tokens
,
dictionary
,
max_reps
):
"""
Pack a token sequence so that repeated symbols are replaced by replabels
"""
if
len
(
tokens
)
==
0
or
max_reps
<=
0
:
return
tokens
replabel_value_to_idx
=
[
0
]
*
(
max_reps
+
1
)
for
i
in
range
(
1
,
max_reps
+
1
):
replabel_value_to_idx
[
i
]
=
dictionary
.
index
(
replabel_symbol
(
i
))
result
=
[]
prev_token
=
-
1
num_reps
=
0
for
token
in
tokens
:
if
token
==
prev_token
and
num_reps
<
max_reps
:
num_reps
+=
1
else
:
if
num_reps
>
0
:
result
.
append
(
replabel_value_to_idx
[
num_reps
])
num_reps
=
0
result
.
append
(
token
)
prev_token
=
token
if
num_reps
>
0
:
result
.
append
(
replabel_value_to_idx
[
num_reps
])
return
result
def
unpack_replabels
(
tokens
,
dictionary
,
max_reps
):
"""
Unpack a token sequence so that replabels are replaced by repeated symbols
"""
if
len
(
tokens
)
==
0
or
max_reps
<=
0
:
return
tokens
replabel_idx_to_value
=
{}
for
i
in
range
(
1
,
max_reps
+
1
):
replabel_idx_to_value
[
dictionary
.
index
(
replabel_symbol
(
i
))]
=
i
result
=
[]
prev_token
=
-
1
for
token
in
tokens
:
try
:
for
_
in
range
(
replabel_idx_to_value
[
token
]):
result
.
append
(
prev_token
)
prev_token
=
-
1
except
KeyError
:
result
.
append
(
token
)
prev_token
=
token
return
result
examples/speech_recognition/datasets/asr_prep_json.py
0 → 100644
View file @
c394d7d1
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
argparse
import
concurrent.futures
import
json
import
multiprocessing
import
os
from
collections
import
namedtuple
from
itertools
import
chain
import
sentencepiece
as
spm
from
fairseq.data
import
Dictionary
MILLISECONDS_TO_SECONDS
=
0.001
def
process_sample
(
aud_path
,
lable
,
utt_id
,
sp
,
tgt_dict
):
import
torchaudio
input
=
{}
output
=
{}
si
,
ei
=
torchaudio
.
info
(
aud_path
)
input
[
"length_ms"
]
=
int
(
si
.
length
/
si
.
channels
/
si
.
rate
/
MILLISECONDS_TO_SECONDS
)
input
[
"path"
]
=
aud_path
token
=
" "
.
join
(
sp
.
EncodeAsPieces
(
lable
))
ids
=
tgt_dict
.
encode_line
(
token
,
append_eos
=
False
)
output
[
"text"
]
=
lable
output
[
"token"
]
=
token
output
[
"tokenid"
]
=
", "
.
join
(
map
(
str
,
[
t
.
tolist
()
for
t
in
ids
]))
return
{
utt_id
:
{
"input"
:
input
,
"output"
:
output
}}
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--audio-dirs"
,
nargs
=
"+"
,
default
=
[
"-"
],
required
=
True
,
help
=
"input directories with audio files"
,
)
parser
.
add_argument
(
"--labels"
,
required
=
True
,
help
=
"aggregated input labels with format <ID LABEL> per line"
,
type
=
argparse
.
FileType
(
"r"
,
encoding
=
"UTF-8"
),
)
parser
.
add_argument
(
"--spm-model"
,
required
=
True
,
help
=
"sentencepiece model to use for encoding"
,
type
=
argparse
.
FileType
(
"r"
,
encoding
=
"UTF-8"
),
)
parser
.
add_argument
(
"--dictionary"
,
required
=
True
,
help
=
"file to load fairseq dictionary from"
,
type
=
argparse
.
FileType
(
"r"
,
encoding
=
"UTF-8"
),
)
parser
.
add_argument
(
"--audio-format"
,
choices
=
[
"flac"
,
"wav"
],
default
=
"wav"
)
parser
.
add_argument
(
"--output"
,
required
=
True
,
type
=
argparse
.
FileType
(
"w"
),
help
=
"path to save json output"
,
)
args
=
parser
.
parse_args
()
sp
=
spm
.
SentencePieceProcessor
()
sp
.
Load
(
args
.
spm_model
.
name
)
tgt_dict
=
Dictionary
.
load
(
args
.
dictionary
)
labels
=
{}
for
line
in
args
.
labels
:
(
utt_id
,
label
)
=
line
.
split
(
" "
,
1
)
labels
[
utt_id
]
=
label
if
len
(
labels
)
==
0
:
raise
Exception
(
"No labels found in "
,
args
.
labels_path
)
Sample
=
namedtuple
(
"Sample"
,
"aud_path utt_id"
)
samples
=
[]
for
path
,
_
,
files
in
chain
.
from_iterable
(
os
.
walk
(
path
)
for
path
in
args
.
audio_dirs
):
for
f
in
files
:
if
f
.
endswith
(
args
.
audio_format
):
if
len
(
os
.
path
.
splitext
(
f
))
!=
2
:
raise
Exception
(
"Expect <utt_id.extension> file name. Got: "
,
f
)
utt_id
=
os
.
path
.
splitext
(
f
)[
0
]
if
utt_id
not
in
labels
:
continue
samples
.
append
(
Sample
(
os
.
path
.
join
(
path
,
f
),
utt_id
))
utts
=
{}
num_cpu
=
multiprocessing
.
cpu_count
()
with
concurrent
.
futures
.
ThreadPoolExecutor
(
max_workers
=
num_cpu
)
as
executor
:
future_to_sample
=
{
executor
.
submit
(
process_sample
,
s
.
aud_path
,
labels
[
s
.
utt_id
],
s
.
utt_id
,
sp
,
tgt_dict
):
s
for
s
in
samples
}
for
future
in
concurrent
.
futures
.
as_completed
(
future_to_sample
):
try
:
data
=
future
.
result
()
except
Exception
as
exc
:
print
(
"generated an exception: "
,
exc
)
else
:
utts
.
update
(
data
)
json
.
dump
({
"utts"
:
utts
},
args
.
output
,
indent
=
4
)
if
__name__
==
"__main__"
:
main
()
examples/speech_recognition/datasets/prepare-librispeech.sh
0 → 100644
View file @
c394d7d1
#!/usr/bin/env bash
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Prepare librispeech dataset
base_url
=
www.openslr.org/resources/12
train_dir
=
train_960
if
[
"$#"
-ne
2
]
;
then
echo
"Usage:
$0
<download_dir> <out_dir>"
echo
"e.g.:
$0
/tmp/librispeech_raw/ ~/data/librispeech_final"
exit
1
fi
download_dir
=
${
1
%/
}
out_dir
=
${
2
%/
}
fairseq_root
=
~/fairseq-py/
mkdir
-p
${
out_dir
}
cd
${
out_dir
}
||
exit
nbpe
=
5000
bpemode
=
unigram
if
[
!
-d
"
$fairseq_root
"
]
;
then
echo
"
$0
: Please set correct fairseq_root"
exit
1
fi
echo
"Data Download"
for
part
in
dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500
;
do
url
=
$base_url
/
$part
.tar.gz
if
!
wget
-P
$download_dir
$url
;
then
echo
"
$0
: wget failed for
$url
"
exit
1
fi
if
!
tar
-C
$download_dir
-xvzf
$download_dir
/
$part
.tar.gz
;
then
echo
"
$0
: error un-tarring archive
$download_dir
/
$part
.tar.gz"
exit
1
fi
done
echo
"Merge all train packs into one"
mkdir
-p
${
download_dir
}
/LibriSpeech/
${
train_dir
}
/
for
part
in
train-clean-100 train-clean-360 train-other-500
;
do
mv
${
download_dir
}
/LibriSpeech/
${
part
}
/
*
$download_dir
/LibriSpeech/
${
train_dir
}
/
done
echo
"Merge train text"
find
${
download_dir
}
/LibriSpeech/
${
train_dir
}
/
-name
'*.txt'
-exec
cat
{}
\;
>>
${
download_dir
}
/LibriSpeech/
${
train_dir
}
/text
# Use combined dev-clean and dev-other as validation set
find
${
download_dir
}
/LibriSpeech/dev-clean/
${
download_dir
}
/LibriSpeech/dev-other/
-name
'*.txt'
-exec
cat
{}
\;
>>
${
download_dir
}
/LibriSpeech/valid_text
find
${
download_dir
}
/LibriSpeech/test-clean/
-name
'*.txt'
-exec
cat
{}
\;
>>
${
download_dir
}
/LibriSpeech/test-clean/text
find
${
download_dir
}
/LibriSpeech/test-other/
-name
'*.txt'
-exec
cat
{}
\;
>>
${
download_dir
}
/LibriSpeech/test-other/text
dict
=
data/lang_char/
${
train_dir
}
_
${
bpemode
}${
nbpe
}
_units.txt
encoded
=
data/lang_char/
${
train_dir
}
_
${
bpemode
}${
nbpe
}
_encoded.txt
fairseq_dict
=
data/lang_char/
${
train_dir
}
_
${
bpemode
}${
nbpe
}
_fairseq_dict.txt
bpemodel
=
data/lang_char/
${
train_dir
}
_
${
bpemode
}${
nbpe
}
echo
"dictionary:
${
dict
}
"
echo
"Dictionary preparation"
mkdir
-p
data/lang_char/
echo
"<unk> 3"
>
${
dict
}
echo
"</s> 2"
>>
${
dict
}
echo
"<pad> 1"
>>
${
dict
}
cut
-f
2-
-d
" "
${
download_dir
}
/LibriSpeech/
${
train_dir
}
/text
>
data/lang_char/input.txt
spm_train
--input
=
data/lang_char/input.txt
--vocab_size
=
${
nbpe
}
--model_type
=
${
bpemode
}
--model_prefix
=
${
bpemodel
}
--input_sentence_size
=
100000000
--unk_id
=
3
--eos_id
=
2
--pad_id
=
1
--bos_id
=
-1
--character_coverage
=
1
spm_encode
--model
=
${
bpemodel
}
.model
--output_format
=
piece < data/lang_char/input.txt
>
${
encoded
}
cat
${
encoded
}
|
tr
' '
'\n'
|
sort
|
uniq
|
awk
'{print $0 " " NR+3}'
>>
${
dict
}
cat
${
encoded
}
|
tr
' '
'\n'
|
sort
|
uniq
-c
|
awk
'{print $2 " " $1}'
>
${
fairseq_dict
}
wc
-l
${
dict
}
echo
"Prepare train and test jsons"
for
part
in
train_960 test-other test-clean
;
do
python
${
fairseq_root
}
/examples/speech_recognition/datasets/asr_prep_json.py
--audio-dirs
${
download_dir
}
/LibriSpeech/
${
part
}
--labels
${
download_dir
}
/LibriSpeech/
${
part
}
/text
--spm-model
${
bpemodel
}
.model
--audio-format
flac
--dictionary
${
fairseq_dict
}
--output
${
part
}
.json
done
# fairseq expects to find train.json and valid.json during training
mv
train_960.json train.json
echo
"Prepare valid json"
python
${
fairseq_root
}
/examples/speech_recognition/datasets/asr_prep_json.py
--audio-dirs
${
download_dir
}
/LibriSpeech/dev-clean
${
download_dir
}
/LibriSpeech/dev-other
--labels
${
download_dir
}
/LibriSpeech/valid_text
--spm-model
${
bpemodel
}
.model
--audio-format
flac
--dictionary
${
fairseq_dict
}
--output
valid.json
cp
${
fairseq_dict
}
./dict.txt
cp
${
bpemodel
}
.model ./spm.model
examples/speech_recognition/infer.py
0 → 100644
View file @
c394d7d1
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Run inference for pre-processed data with a trained model.
"""
import
ast
import
logging
import
math
import
os
import
sys
import
editdistance
import
numpy
as
np
import
torch
from
fairseq
import
checkpoint_utils
,
options
,
progress_bar
,
tasks
,
utils
from
fairseq.data.data_utils
import
post_process
from
fairseq.logging.meters
import
StopwatchMeter
,
TimeMeter
logging
.
basicConfig
()
logging
.
root
.
setLevel
(
logging
.
INFO
)
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logger
=
logging
.
getLogger
(
__name__
)
def
add_asr_eval_argument
(
parser
):
parser
.
add_argument
(
"--kspmodel"
,
default
=
None
,
help
=
"sentence piece model"
)
parser
.
add_argument
(
"--wfstlm"
,
default
=
None
,
help
=
"wfstlm on dictonary output units"
)
parser
.
add_argument
(
"--rnnt_decoding_type"
,
default
=
"greedy"
,
help
=
"wfstlm on dictonary
\
output units"
,
)
try
:
parser
.
add_argument
(
"--lm-weight"
,
"--lm_weight"
,
type
=
float
,
default
=
0.2
,
help
=
"weight for lm while interpolating with neural score"
,
)
except
:
pass
parser
.
add_argument
(
"--rnnt_len_penalty"
,
default
=-
0.5
,
help
=
"rnnt length penalty on word level"
)
parser
.
add_argument
(
"--w2l-decoder"
,
choices
=
[
"viterbi"
,
"kenlm"
,
"fairseqlm"
],
help
=
"use a w2l decoder"
,
)
parser
.
add_argument
(
"--lexicon"
,
help
=
"lexicon for w2l decoder"
)
parser
.
add_argument
(
"--unit-lm"
,
action
=
"store_true"
,
help
=
"if using a unit lm"
)
parser
.
add_argument
(
"--kenlm-model"
,
"--lm-model"
,
help
=
"lm model for w2l decoder"
)
parser
.
add_argument
(
"--beam-threshold"
,
type
=
float
,
default
=
25.0
)
parser
.
add_argument
(
"--beam-size-token"
,
type
=
float
,
default
=
100
)
parser
.
add_argument
(
"--word-score"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--unk-weight"
,
type
=
float
,
default
=-
math
.
inf
)
parser
.
add_argument
(
"--sil-weight"
,
type
=
float
,
default
=
0.0
)
parser
.
add_argument
(
"--dump-emissions"
,
type
=
str
,
default
=
None
,
help
=
"if present, dumps emissions into this file and exits"
,
)
parser
.
add_argument
(
"--dump-features"
,
type
=
str
,
default
=
None
,
help
=
"if present, dumps features into this file and exits"
,
)
parser
.
add_argument
(
"--load-emissions"
,
type
=
str
,
default
=
None
,
help
=
"if present, loads emissions from this file"
,
)
return
parser
def
check_args
(
args
):
# assert args.path is not None, "--path required for generation!"
# assert args.results_path is not None, "--results_path required for generation!"
assert
(
not
args
.
sampling
or
args
.
nbest
==
args
.
beam
),
"--sampling requires --nbest to be equal to --beam"
assert
(
args
.
replace_unk
is
None
or
args
.
raw_text
),
"--replace-unk requires a raw text dataset (--raw-text)"
def
get_dataset_itr
(
args
,
task
,
models
):
return
task
.
get_batch_iterator
(
dataset
=
task
.
dataset
(
args
.
gen_subset
),
max_tokens
=
args
.
max_tokens
,
max_sentences
=
args
.
batch_size
,
max_positions
=
(
sys
.
maxsize
,
sys
.
maxsize
),
ignore_invalid_inputs
=
args
.
skip_invalid_size_inputs_valid_test
,
required_batch_size_multiple
=
args
.
required_batch_size_multiple
,
num_shards
=
args
.
num_shards
,
shard_id
=
args
.
shard_id
,
num_workers
=
args
.
num_workers
,
data_buffer_size
=
args
.
data_buffer_size
,
).
next_epoch_itr
(
shuffle
=
False
)
def
process_predictions
(
args
,
hypos
,
sp
,
tgt_dict
,
target_tokens
,
res_files
,
speaker
,
id
):
for
hypo
in
hypos
[:
min
(
len
(
hypos
),
args
.
nbest
)]:
hyp_pieces
=
tgt_dict
.
string
(
hypo
[
"tokens"
].
int
().
cpu
())
if
"words"
in
hypo
:
hyp_words
=
" "
.
join
(
hypo
[
"words"
])
else
:
hyp_words
=
post_process
(
hyp_pieces
,
args
.
post_process
)
if
res_files
is
not
None
:
print
(
"{} ({}-{})"
.
format
(
hyp_pieces
,
speaker
,
id
),
file
=
res_files
[
"hypo.units"
],
)
print
(
"{} ({}-{})"
.
format
(
hyp_words
,
speaker
,
id
),
file
=
res_files
[
"hypo.words"
],
)
tgt_pieces
=
tgt_dict
.
string
(
target_tokens
)
tgt_words
=
post_process
(
tgt_pieces
,
args
.
post_process
)
if
res_files
is
not
None
:
print
(
"{} ({}-{})"
.
format
(
tgt_pieces
,
speaker
,
id
),
file
=
res_files
[
"ref.units"
],
)
print
(
"{} ({}-{})"
.
format
(
tgt_words
,
speaker
,
id
),
file
=
res_files
[
"ref.words"
]
)
if
not
args
.
quiet
:
logger
.
info
(
"HYPO:"
+
hyp_words
)
logger
.
info
(
"TARGET:"
+
tgt_words
)
logger
.
info
(
"___________________"
)
hyp_words
=
hyp_words
.
split
()
tgt_words
=
tgt_words
.
split
()
return
editdistance
.
eval
(
hyp_words
,
tgt_words
),
len
(
tgt_words
)
def
prepare_result_files
(
args
):
def
get_res_file
(
file_prefix
):
if
args
.
num_shards
>
1
:
file_prefix
=
f
"
{
args
.
shard_id
}
_
{
file_prefix
}
"
path
=
os
.
path
.
join
(
args
.
results_path
,
"{}-{}-{}.txt"
.
format
(
file_prefix
,
os
.
path
.
basename
(
args
.
path
),
args
.
gen_subset
),
)
return
open
(
path
,
"w"
,
buffering
=
1
)
if
not
args
.
results_path
:
return
None
return
{
"hypo.words"
:
get_res_file
(
"hypo.word"
),
"hypo.units"
:
get_res_file
(
"hypo.units"
),
"ref.words"
:
get_res_file
(
"ref.word"
),
"ref.units"
:
get_res_file
(
"ref.units"
),
}
def
optimize_models
(
args
,
use_cuda
,
models
):
"""Optimize ensemble for generation"""
for
model
in
models
:
model
.
make_generation_fast_
(
beamable_mm_beam_size
=
None
if
args
.
no_beamable_mm
else
args
.
beam
,
need_attn
=
args
.
print_alignment
,
)
if
args
.
fp16
:
model
.
half
()
if
use_cuda
:
model
.
cuda
()
class
ExistingEmissionsDecoder
(
object
):
def
__init__
(
self
,
decoder
,
emissions
):
self
.
decoder
=
decoder
self
.
emissions
=
emissions
def
generate
(
self
,
models
,
sample
,
**
unused
):
ids
=
sample
[
"id"
].
cpu
().
numpy
()
try
:
emissions
=
np
.
stack
(
self
.
emissions
[
ids
])
except
:
print
([
x
.
shape
for
x
in
self
.
emissions
[
ids
]])
raise
Exception
(
"invalid sizes"
)
emissions
=
torch
.
from_numpy
(
emissions
)
return
self
.
decoder
.
decode
(
emissions
)
def
main
(
args
,
task
=
None
,
model_state
=
None
):
check_args
(
args
)
if
args
.
max_tokens
is
None
and
args
.
batch_size
is
None
:
args
.
max_tokens
=
4000000
logger
.
info
(
args
)
use_cuda
=
torch
.
cuda
.
is_available
()
and
not
args
.
cpu
logger
.
info
(
"| decoding with criterion {}"
.
format
(
args
.
criterion
))
task
=
tasks
.
setup_task
(
args
)
# Load ensemble
if
args
.
load_emissions
:
models
,
criterions
=
[],
[]
task
.
load_dataset
(
args
.
gen_subset
)
else
:
logger
.
info
(
"| loading model(s) from {}"
.
format
(
args
.
path
))
models
,
saved_cfg
,
task
=
checkpoint_utils
.
load_model_ensemble_and_task
(
utils
.
split_paths
(
args
.
path
,
separator
=
"
\\
"
),
arg_overrides
=
ast
.
literal_eval
(
args
.
model_overrides
),
task
=
task
,
suffix
=
args
.
checkpoint_suffix
,
strict
=
(
args
.
checkpoint_shard_count
==
1
),
num_shards
=
args
.
checkpoint_shard_count
,
state
=
model_state
,
)
optimize_models
(
args
,
use_cuda
,
models
)
task
.
load_dataset
(
args
.
gen_subset
,
task_cfg
=
saved_cfg
.
task
)
# Set dictionary
tgt_dict
=
task
.
target_dictionary
logger
.
info
(
"| {} {} {} examples"
.
format
(
args
.
data
,
args
.
gen_subset
,
len
(
task
.
dataset
(
args
.
gen_subset
))
)
)
# hack to pass transitions to W2lDecoder
if
args
.
criterion
==
"asg_loss"
:
raise
NotImplementedError
(
"asg_loss is currently not supported"
)
# trans = criterions[0].asg.trans.data
# args.asg_transitions = torch.flatten(trans).tolist()
# Load dataset (possibly sharded)
itr
=
get_dataset_itr
(
args
,
task
,
models
)
# Initialize generator
gen_timer
=
StopwatchMeter
()
def
build_generator
(
args
):
w2l_decoder
=
getattr
(
args
,
"w2l_decoder"
,
None
)
if
w2l_decoder
==
"viterbi"
:
from
examples.speech_recognition.w2l_decoder
import
W2lViterbiDecoder
return
W2lViterbiDecoder
(
args
,
task
.
target_dictionary
)
elif
w2l_decoder
==
"kenlm"
:
from
examples.speech_recognition.w2l_decoder
import
W2lKenLMDecoder
return
W2lKenLMDecoder
(
args
,
task
.
target_dictionary
)
elif
w2l_decoder
==
"fairseqlm"
:
from
examples.speech_recognition.w2l_decoder
import
W2lFairseqLMDecoder
return
W2lFairseqLMDecoder
(
args
,
task
.
target_dictionary
)
else
:
print
(
"only flashlight decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment"
)
# please do not touch this unless you test both generate.py and infer.py with audio_pretraining task
generator
=
build_generator
(
args
)
if
args
.
load_emissions
:
generator
=
ExistingEmissionsDecoder
(
generator
,
np
.
load
(
args
.
load_emissions
,
allow_pickle
=
True
)
)
logger
.
info
(
"loaded emissions from "
+
args
.
load_emissions
)
num_sentences
=
0
if
args
.
results_path
is
not
None
and
not
os
.
path
.
exists
(
args
.
results_path
):
os
.
makedirs
(
args
.
results_path
)
max_source_pos
=
(
utils
.
resolve_max_positions
(
task
.
max_positions
(),
*
[
model
.
max_positions
()
for
model
in
models
]
),
)
if
max_source_pos
is
not
None
:
max_source_pos
=
max_source_pos
[
0
]
if
max_source_pos
is
not
None
:
max_source_pos
=
max_source_pos
[
0
]
-
1
if
args
.
dump_emissions
:
emissions
=
{}
if
args
.
dump_features
:
features
=
{}
models
[
0
].
bert
.
proj
=
None
else
:
res_files
=
prepare_result_files
(
args
)
errs_t
=
0
lengths_t
=
0
with
progress_bar
.
build_progress_bar
(
args
,
itr
)
as
t
:
wps_meter
=
TimeMeter
()
for
sample
in
t
:
sample
=
utils
.
move_to_cuda
(
sample
)
if
use_cuda
else
sample
if
"net_input"
not
in
sample
:
continue
prefix_tokens
=
None
if
args
.
prefix_size
>
0
:
prefix_tokens
=
sample
[
"target"
][:,
:
args
.
prefix_size
]
gen_timer
.
start
()
if
args
.
dump_emissions
:
with
torch
.
no_grad
():
encoder_out
=
models
[
0
](
**
sample
[
"net_input"
])
emm
=
models
[
0
].
get_normalized_probs
(
encoder_out
,
log_probs
=
True
)
emm
=
emm
.
transpose
(
0
,
1
).
cpu
().
numpy
()
for
i
,
id
in
enumerate
(
sample
[
"id"
]):
emissions
[
id
.
item
()]
=
emm
[
i
]
continue
elif
args
.
dump_features
:
with
torch
.
no_grad
():
encoder_out
=
models
[
0
](
**
sample
[
"net_input"
])
feat
=
encoder_out
[
"encoder_out"
].
transpose
(
0
,
1
).
cpu
().
numpy
()
for
i
,
id
in
enumerate
(
sample
[
"id"
]):
padding
=
(
encoder_out
[
"encoder_padding_mask"
][
i
].
cpu
().
numpy
()
if
encoder_out
[
"encoder_padding_mask"
]
is
not
None
else
None
)
features
[
id
.
item
()]
=
(
feat
[
i
],
padding
)
continue
hypos
=
task
.
inference_step
(
generator
,
models
,
sample
,
prefix_tokens
)
num_generated_tokens
=
sum
(
len
(
h
[
0
][
"tokens"
])
for
h
in
hypos
)
gen_timer
.
stop
(
num_generated_tokens
)
for
i
,
sample_id
in
enumerate
(
sample
[
"id"
].
tolist
()):
speaker
=
None
# id = task.dataset(args.gen_subset).ids[int(sample_id)]
id
=
sample_id
toks
=
(
sample
[
"target"
][
i
,
:]
if
"target_label"
not
in
sample
else
sample
[
"target_label"
][
i
,
:]
)
target_tokens
=
utils
.
strip_pad
(
toks
,
tgt_dict
.
pad
()).
int
().
cpu
()
# Process top predictions
errs
,
length
=
process_predictions
(
args
,
hypos
[
i
],
None
,
tgt_dict
,
target_tokens
,
res_files
,
speaker
,
id
,
)
errs_t
+=
errs
lengths_t
+=
length
wps_meter
.
update
(
num_generated_tokens
)
t
.
log
({
"wps"
:
round
(
wps_meter
.
avg
)})
num_sentences
+=
(
sample
[
"nsentences"
]
if
"nsentences"
in
sample
else
sample
[
"id"
].
numel
()
)
wer
=
None
if
args
.
dump_emissions
:
emm_arr
=
[]
for
i
in
range
(
len
(
emissions
)):
emm_arr
.
append
(
emissions
[
i
])
np
.
save
(
args
.
dump_emissions
,
emm_arr
)
logger
.
info
(
f
"saved
{
len
(
emissions
)
}
emissions to
{
args
.
dump_emissions
}
"
)
elif
args
.
dump_features
:
feat_arr
=
[]
for
i
in
range
(
len
(
features
)):
feat_arr
.
append
(
features
[
i
])
np
.
save
(
args
.
dump_features
,
feat_arr
)
logger
.
info
(
f
"saved
{
len
(
features
)
}
emissions to
{
args
.
dump_features
}
"
)
else
:
if
lengths_t
>
0
:
wer
=
errs_t
*
100.0
/
lengths_t
logger
.
info
(
f
"WER:
{
wer
}
"
)
logger
.
info
(
"| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
"sentences/s, {:.2f} tokens/s)"
.
format
(
num_sentences
,
gen_timer
.
n
,
gen_timer
.
sum
,
num_sentences
/
gen_timer
.
sum
,
1.0
/
gen_timer
.
avg
,
)
)
logger
.
info
(
"| Generate {} with beam={}"
.
format
(
args
.
gen_subset
,
args
.
beam
))
return
task
,
wer
def
make_parser
():
parser
=
options
.
get_generation_parser
()
parser
=
add_asr_eval_argument
(
parser
)
return
parser
def
cli_main
():
parser
=
make_parser
()
args
=
options
.
parse_args_and_arch
(
parser
)
main
(
args
)
if
__name__
==
"__main__"
:
cli_main
()
examples/speech_recognition/kaldi/__init__.py
0 → 100644
View file @
c394d7d1
examples/speech_recognition/kaldi/add-self-loop-simple.cc
0 → 100644
View file @
c394d7d1
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <iostream>
#include "fstext/fstext-lib.h" // @manual
#include "util/common-utils.h" // @manual
/*
* This program is to modify a FST without self-loop by:
* for each incoming arc with non-eps input symbol, add a self-loop arc
* with that non-eps symbol as input and eps as output.
*
* This is to make sure the resultant FST can do deduplication for repeated
* symbols, which is very common in acoustic model
*
*/
namespace
{
int32
AddSelfLoopsSimple
(
fst
::
StdVectorFst
*
fst
)
{
typedef
fst
::
MutableArcIterator
<
fst
::
StdVectorFst
>
IterType
;
int32
num_states_before
=
fst
->
NumStates
();
fst
::
MakePrecedingInputSymbolsSame
(
false
,
fst
);
int32
num_states_after
=
fst
->
NumStates
();
KALDI_LOG
<<
"There are "
<<
num_states_before
<<
" states in the original FST; "
<<
" after MakePrecedingInputSymbolsSame, there are "
<<
num_states_after
<<
" states "
<<
std
::
endl
;
auto
weight_one
=
fst
::
StdArc
::
Weight
::
One
();
int32
num_arc_added
=
0
;
fst
::
StdArc
self_loop_arc
;
self_loop_arc
.
weight
=
weight_one
;
int32
num_states
=
fst
->
NumStates
();
std
::
vector
<
std
::
set
<
int32
>>
incoming_non_eps_label_per_state
(
num_states
);
for
(
int32
state
=
0
;
state
<
num_states
;
state
++
)
{
for
(
IterType
aiter
(
fst
,
state
);
!
aiter
.
Done
();
aiter
.
Next
())
{
fst
::
StdArc
arc
(
aiter
.
Value
());
if
(
arc
.
ilabel
!=
0
)
{
incoming_non_eps_label_per_state
[
arc
.
nextstate
].
insert
(
arc
.
ilabel
);
}
}
}
for
(
int32
state
=
0
;
state
<
num_states
;
state
++
)
{
if
(
!
incoming_non_eps_label_per_state
[
state
].
empty
())
{
auto
&
ilabel_set
=
incoming_non_eps_label_per_state
[
state
];
for
(
auto
it
=
ilabel_set
.
begin
();
it
!=
ilabel_set
.
end
();
it
++
)
{
self_loop_arc
.
ilabel
=
*
it
;
self_loop_arc
.
olabel
=
0
;
self_loop_arc
.
nextstate
=
state
;
fst
->
AddArc
(
state
,
self_loop_arc
);
num_arc_added
++
;
}
}
}
return
num_arc_added
;
}
void
print_usage
()
{
std
::
cout
<<
"add-self-loop-simple usage:
\n
"
"
\t
add-self-loop-simple <in-fst> <out-fst>
\n
"
;
}
}
// namespace
int
main
(
int
argc
,
char
**
argv
)
{
if
(
argc
!=
3
)
{
print_usage
();
exit
(
1
);
}
auto
input
=
argv
[
1
];
auto
output
=
argv
[
2
];
auto
fst
=
fst
::
ReadFstKaldi
(
input
);
auto
num_states
=
fst
->
NumStates
();
KALDI_LOG
<<
"Loading FST from "
<<
input
<<
" with "
<<
num_states
<<
" states."
<<
std
::
endl
;
int32
num_arc_added
=
AddSelfLoopsSimple
(
fst
);
KALDI_LOG
<<
"Adding "
<<
num_arc_added
<<
" self-loop arcs "
<<
std
::
endl
;
fst
::
WriteFstKaldi
(
*
fst
,
std
::
string
(
output
));
KALDI_LOG
<<
"Writing FST to "
<<
output
<<
std
::
endl
;
delete
fst
;
}
\ No newline at end of file
examples/speech_recognition/kaldi/config/kaldi_initializer.yaml
0 → 100644
View file @
c394d7d1
# @package _group_
data_dir
:
???
fst_dir
:
???
in_labels
:
???
kaldi_root
:
???
lm_arpa
:
???
blank_symbol
:
<s>
examples/speech_recognition/kaldi/kaldi_decoder.py
0 → 100644
View file @
c394d7d1
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
concurrent.futures
import
ThreadPoolExecutor
import
logging
from
omegaconf
import
MISSING
import
os
import
torch
from
typing
import
Optional
import
warnings
from
dataclasses
import
dataclass
from
fairseq.dataclass
import
FairseqDataclass
from
.kaldi_initializer
import
KaldiInitializerConfig
,
initalize_kaldi
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
KaldiDecoderConfig
(
FairseqDataclass
):
hlg_graph_path
:
Optional
[
str
]
=
None
output_dict
:
str
=
MISSING
kaldi_initializer_config
:
Optional
[
KaldiInitializerConfig
]
=
None
acoustic_scale
:
float
=
0.5
max_active
:
int
=
10000
beam_delta
:
float
=
0.5
hash_ratio
:
float
=
2.0
is_lattice
:
bool
=
False
lattice_beam
:
float
=
10.0
prune_interval
:
int
=
25
determinize_lattice
:
bool
=
True
prune_scale
:
float
=
0.1
max_mem
:
int
=
0
phone_determinize
:
bool
=
True
word_determinize
:
bool
=
True
minimize
:
bool
=
True
num_threads
:
int
=
1
class
KaldiDecoder
(
object
):
def
__init__
(
self
,
cfg
:
KaldiDecoderConfig
,
beam
:
int
,
nbest
:
int
=
1
,
):
try
:
from
kaldi.asr
import
FasterRecognizer
,
LatticeFasterRecognizer
from
kaldi.base
import
set_verbose_level
from
kaldi.decoder
import
(
FasterDecoder
,
FasterDecoderOptions
,
LatticeFasterDecoder
,
LatticeFasterDecoderOptions
,
)
from
kaldi.lat.functions
import
DeterminizeLatticePhonePrunedOptions
from
kaldi.fstext
import
read_fst_kaldi
,
SymbolTable
except
:
warnings
.
warn
(
"pykaldi is required for this functionality. Please install from https://github.com/pykaldi/pykaldi"
)
# set_verbose_level(2)
self
.
acoustic_scale
=
cfg
.
acoustic_scale
self
.
nbest
=
nbest
if
cfg
.
hlg_graph_path
is
None
:
assert
(
cfg
.
kaldi_initializer_config
is
not
None
),
"Must provide hlg graph path or kaldi initializer config"
cfg
.
hlg_graph_path
=
initalize_kaldi
(
cfg
.
kaldi_initializer_config
)
assert
os
.
path
.
exists
(
cfg
.
hlg_graph_path
),
cfg
.
hlg_graph_path
if
cfg
.
is_lattice
:
self
.
dec_cls
=
LatticeFasterDecoder
opt_cls
=
LatticeFasterDecoderOptions
self
.
rec_cls
=
LatticeFasterRecognizer
else
:
assert
self
.
nbest
==
1
,
"nbest > 1 requires lattice decoder"
self
.
dec_cls
=
FasterDecoder
opt_cls
=
FasterDecoderOptions
self
.
rec_cls
=
FasterRecognizer
self
.
decoder_options
=
opt_cls
()
self
.
decoder_options
.
beam
=
beam
self
.
decoder_options
.
max_active
=
cfg
.
max_active
self
.
decoder_options
.
beam_delta
=
cfg
.
beam_delta
self
.
decoder_options
.
hash_ratio
=
cfg
.
hash_ratio
if
cfg
.
is_lattice
:
self
.
decoder_options
.
lattice_beam
=
cfg
.
lattice_beam
self
.
decoder_options
.
prune_interval
=
cfg
.
prune_interval
self
.
decoder_options
.
determinize_lattice
=
cfg
.
determinize_lattice
self
.
decoder_options
.
prune_scale
=
cfg
.
prune_scale
det_opts
=
DeterminizeLatticePhonePrunedOptions
()
det_opts
.
max_mem
=
cfg
.
max_mem
det_opts
.
phone_determinize
=
cfg
.
phone_determinize
det_opts
.
word_determinize
=
cfg
.
word_determinize
det_opts
.
minimize
=
cfg
.
minimize
self
.
decoder_options
.
det_opts
=
det_opts
self
.
output_symbols
=
{}
with
open
(
cfg
.
output_dict
,
"r"
)
as
f
:
for
line
in
f
:
items
=
line
.
rstrip
().
split
()
assert
len
(
items
)
==
2
self
.
output_symbols
[
int
(
items
[
1
])]
=
items
[
0
]
logger
.
info
(
f
"Loading FST from
{
cfg
.
hlg_graph_path
}
"
)
self
.
fst
=
read_fst_kaldi
(
cfg
.
hlg_graph_path
)
self
.
symbol_table
=
SymbolTable
.
read_text
(
cfg
.
output_dict
)
self
.
executor
=
ThreadPoolExecutor
(
max_workers
=
cfg
.
num_threads
)
def
generate
(
self
,
models
,
sample
,
**
unused
):
"""Generate a batch of inferences."""
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input
=
{
k
:
v
for
k
,
v
in
sample
[
"net_input"
].
items
()
if
k
!=
"prev_output_tokens"
}
emissions
,
padding
=
self
.
get_emissions
(
models
,
encoder_input
)
return
self
.
decode
(
emissions
,
padding
)
def
get_emissions
(
self
,
models
,
encoder_input
):
"""Run encoder and normalize emissions"""
model
=
models
[
0
]
all_encoder_out
=
[
m
(
**
encoder_input
)
for
m
in
models
]
if
len
(
all_encoder_out
)
>
1
:
if
"encoder_out"
in
all_encoder_out
[
0
]:
encoder_out
=
{
"encoder_out"
:
sum
(
e
[
"encoder_out"
]
for
e
in
all_encoder_out
)
/
len
(
all_encoder_out
),
"encoder_padding_mask"
:
all_encoder_out
[
0
][
"encoder_padding_mask"
],
}
padding
=
encoder_out
[
"encoder_padding_mask"
]
else
:
encoder_out
=
{
"logits"
:
sum
(
e
[
"logits"
]
for
e
in
all_encoder_out
)
/
len
(
all_encoder_out
),
"padding_mask"
:
all_encoder_out
[
0
][
"padding_mask"
],
}
padding
=
encoder_out
[
"padding_mask"
]
else
:
encoder_out
=
all_encoder_out
[
0
]
padding
=
(
encoder_out
[
"padding_mask"
]
if
"padding_mask"
in
encoder_out
else
encoder_out
[
"encoder_padding_mask"
]
)
if
hasattr
(
model
,
"get_logits"
):
emissions
=
model
.
get_logits
(
encoder_out
,
normalize
=
True
)
else
:
emissions
=
model
.
get_normalized_probs
(
encoder_out
,
log_probs
=
True
)
return
(
emissions
.
cpu
().
float
().
transpose
(
0
,
1
),
padding
.
cpu
()
if
padding
is
not
None
and
padding
.
any
()
else
None
,
)
def
decode_one
(
self
,
logits
,
padding
):
from
kaldi.matrix
import
Matrix
decoder
=
self
.
dec_cls
(
self
.
fst
,
self
.
decoder_options
)
asr
=
self
.
rec_cls
(
decoder
,
self
.
symbol_table
,
acoustic_scale
=
self
.
acoustic_scale
)
if
padding
is
not
None
:
logits
=
logits
[
~
padding
]
mat
=
Matrix
(
logits
.
numpy
())
out
=
asr
.
decode
(
mat
)
if
self
.
nbest
>
1
:
from
kaldi.fstext
import
shortestpath
from
kaldi.fstext.utils
import
(
convert_compact_lattice_to_lattice
,
convert_lattice_to_std
,
convert_nbest_to_list
,
get_linear_symbol_sequence
,
)
lat
=
out
[
"lattice"
]
sp
=
shortestpath
(
lat
,
nshortest
=
self
.
nbest
)
sp
=
convert_compact_lattice_to_lattice
(
sp
)
sp
=
convert_lattice_to_std
(
sp
)
seq
=
convert_nbest_to_list
(
sp
)
results
=
[]
for
s
in
seq
:
_
,
o
,
w
=
get_linear_symbol_sequence
(
s
)
words
=
list
(
self
.
output_symbols
[
z
]
for
z
in
o
)
results
.
append
(
{
"tokens"
:
words
,
"words"
:
words
,
"score"
:
w
.
value
,
"emissions"
:
logits
,
}
)
return
results
else
:
words
=
out
[
"text"
].
split
()
return
[
{
"tokens"
:
words
,
"words"
:
words
,
"score"
:
out
[
"likelihood"
],
"emissions"
:
logits
,
}
]
def
decode
(
self
,
emissions
,
padding
):
if
padding
is
None
:
padding
=
[
None
]
*
len
(
emissions
)
ret
=
list
(
map
(
lambda
e
,
p
:
self
.
executor
.
submit
(
self
.
decode_one
,
e
,
p
),
emissions
,
padding
,
)
)
return
ret
examples/speech_recognition/kaldi/kaldi_initializer.py
0 → 100644
View file @
c394d7d1
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
dataclasses
import
dataclass
import
hydra
from
hydra.core.config_store
import
ConfigStore
import
logging
from
omegaconf
import
MISSING
,
OmegaConf
import
os
import
os.path
as
osp
from
pathlib
import
Path
import
subprocess
from
typing
import
Optional
from
fairseq.data.dictionary
import
Dictionary
from
fairseq.dataclass
import
FairseqDataclass
script_dir
=
Path
(
__file__
).
resolve
().
parent
config_path
=
script_dir
/
"config"
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
KaldiInitializerConfig
(
FairseqDataclass
):
data_dir
:
str
=
MISSING
fst_dir
:
Optional
[
str
]
=
None
in_labels
:
str
=
MISSING
out_labels
:
Optional
[
str
]
=
None
wav2letter_lexicon
:
Optional
[
str
]
=
None
lm_arpa
:
str
=
MISSING
kaldi_root
:
str
=
MISSING
blank_symbol
:
str
=
"<s>"
silence_symbol
:
Optional
[
str
]
=
None
def
create_units
(
fst_dir
:
Path
,
in_labels
:
str
,
vocab
:
Dictionary
)
->
Path
:
in_units_file
=
fst_dir
/
f
"kaldi_dict.
{
in_labels
}
.txt"
if
not
in_units_file
.
exists
():
logger
.
info
(
f
"Creating
{
in_units_file
}
"
)
with
open
(
in_units_file
,
"w"
)
as
f
:
print
(
"<eps> 0"
,
file
=
f
)
i
=
1
for
symb
in
vocab
.
symbols
[
vocab
.
nspecial
:]:
if
not
symb
.
startswith
(
"madeupword"
):
print
(
f
"
{
symb
}
{
i
}
"
,
file
=
f
)
i
+=
1
return
in_units_file
def
create_lexicon
(
cfg
:
KaldiInitializerConfig
,
fst_dir
:
Path
,
unique_label
:
str
,
in_units_file
:
Path
,
out_words_file
:
Path
,
)
->
(
Path
,
Path
):
disambig_in_units_file
=
fst_dir
/
f
"kaldi_dict.
{
cfg
.
in_labels
}
_disambig.txt"
lexicon_file
=
fst_dir
/
f
"kaldi_lexicon.
{
unique_label
}
.txt"
disambig_lexicon_file
=
fst_dir
/
f
"kaldi_lexicon.
{
unique_label
}
_disambig.txt"
if
(
not
lexicon_file
.
exists
()
or
not
disambig_lexicon_file
.
exists
()
or
not
disambig_in_units_file
.
exists
()
):
logger
.
info
(
f
"Creating
{
lexicon_file
}
(in units file:
{
in_units_file
}
)"
)
assert
cfg
.
wav2letter_lexicon
is
not
None
or
cfg
.
in_labels
==
cfg
.
out_labels
if
cfg
.
wav2letter_lexicon
is
not
None
:
lm_words
=
set
()
with
open
(
out_words_file
,
"r"
)
as
lm_dict_f
:
for
line
in
lm_dict_f
:
lm_words
.
add
(
line
.
split
()[
0
])
num_skipped
=
0
total
=
0
with
open
(
cfg
.
wav2letter_lexicon
,
"r"
)
as
w2l_lex_f
,
open
(
lexicon_file
,
"w"
)
as
out_f
:
for
line
in
w2l_lex_f
:
items
=
line
.
rstrip
().
split
(
"
\t
"
)
assert
len
(
items
)
==
2
,
items
if
items
[
0
]
in
lm_words
:
print
(
items
[
0
],
items
[
1
],
file
=
out_f
)
else
:
num_skipped
+=
1
logger
.
debug
(
f
"Skipping word
{
items
[
0
]
}
as it was not found in LM"
)
total
+=
1
if
num_skipped
>
0
:
logger
.
warning
(
f
"Skipped
{
num_skipped
}
out of
{
total
}
words as they were not found in LM"
)
else
:
with
open
(
in_units_file
,
"r"
)
as
in_f
,
open
(
lexicon_file
,
"w"
)
as
out_f
:
for
line
in
in_f
:
symb
=
line
.
split
()[
0
]
if
symb
!=
"<eps>"
and
symb
!=
"<ctc_blank>"
and
symb
!=
"<SIL>"
:
print
(
symb
,
symb
,
file
=
out_f
)
lex_disambig_path
=
(
Path
(
cfg
.
kaldi_root
)
/
"egs/wsj/s5/utils/add_lex_disambig.pl"
)
res
=
subprocess
.
run
(
[
lex_disambig_path
,
lexicon_file
,
disambig_lexicon_file
],
check
=
True
,
capture_output
=
True
,
)
ndisambig
=
int
(
res
.
stdout
)
disamib_path
=
Path
(
cfg
.
kaldi_root
)
/
"egs/wsj/s5/utils/add_disambig.pl"
res
=
subprocess
.
run
(
[
disamib_path
,
"--include-zero"
,
in_units_file
,
str
(
ndisambig
)],
check
=
True
,
capture_output
=
True
,
)
with
open
(
disambig_in_units_file
,
"wb"
)
as
f
:
f
.
write
(
res
.
stdout
)
return
disambig_lexicon_file
,
disambig_in_units_file
def
create_G
(
kaldi_root
:
Path
,
fst_dir
:
Path
,
lm_arpa
:
Path
,
arpa_base
:
str
)
->
(
Path
,
Path
):
out_words_file
=
fst_dir
/
f
"kaldi_dict.
{
arpa_base
}
.txt"
grammar_graph
=
fst_dir
/
f
"G_
{
arpa_base
}
.fst"
if
not
grammar_graph
.
exists
()
or
not
out_words_file
.
exists
():
logger
.
info
(
f
"Creating
{
grammar_graph
}
"
)
arpa2fst
=
kaldi_root
/
"src/lmbin/arpa2fst"
subprocess
.
run
(
[
arpa2fst
,
"--disambig-symbol=#0"
,
f
"--write-symbol-table=
{
out_words_file
}
"
,
lm_arpa
,
grammar_graph
,
],
check
=
True
,
)
return
grammar_graph
,
out_words_file
def
create_L
(
kaldi_root
:
Path
,
fst_dir
:
Path
,
unique_label
:
str
,
lexicon_file
:
Path
,
in_units_file
:
Path
,
out_words_file
:
Path
,
)
->
Path
:
lexicon_graph
=
fst_dir
/
f
"L.
{
unique_label
}
.fst"
if
not
lexicon_graph
.
exists
():
logger
.
info
(
f
"Creating
{
lexicon_graph
}
(in units:
{
in_units_file
}
)"
)
make_lex
=
kaldi_root
/
"egs/wsj/s5/utils/make_lexicon_fst.pl"
fstcompile
=
kaldi_root
/
"tools/openfst-1.6.7/bin/fstcompile"
fstaddselfloops
=
kaldi_root
/
"src/fstbin/fstaddselfloops"
fstarcsort
=
kaldi_root
/
"tools/openfst-1.6.7/bin/fstarcsort"
def
write_disambig_symbol
(
file
):
with
open
(
file
,
"r"
)
as
f
:
for
line
in
f
:
items
=
line
.
rstrip
().
split
()
if
items
[
0
]
==
"#0"
:
out_path
=
str
(
file
)
+
"_disamig"
with
open
(
out_path
,
"w"
)
as
out_f
:
print
(
items
[
1
],
file
=
out_f
)
return
out_path
return
None
in_disambig_sym
=
write_disambig_symbol
(
in_units_file
)
assert
in_disambig_sym
is
not
None
out_disambig_sym
=
write_disambig_symbol
(
out_words_file
)
assert
out_disambig_sym
is
not
None
try
:
with
open
(
lexicon_graph
,
"wb"
)
as
out_f
:
res
=
subprocess
.
run
(
[
make_lex
,
lexicon_file
],
capture_output
=
True
,
check
=
True
)
assert
len
(
res
.
stderr
)
==
0
,
res
.
stderr
.
decode
(
"utf-8"
)
res
=
subprocess
.
run
(
[
fstcompile
,
f
"--isymbols=
{
in_units_file
}
"
,
f
"--osymbols=
{
out_words_file
}
"
,
"--keep_isymbols=false"
,
"--keep_osymbols=false"
,
],
input
=
res
.
stdout
,
capture_output
=
True
,
)
assert
len
(
res
.
stderr
)
==
0
,
res
.
stderr
.
decode
(
"utf-8"
)
res
=
subprocess
.
run
(
[
fstaddselfloops
,
in_disambig_sym
,
out_disambig_sym
],
input
=
res
.
stdout
,
capture_output
=
True
,
check
=
True
,
)
res
=
subprocess
.
run
(
[
fstarcsort
,
"--sort_type=olabel"
],
input
=
res
.
stdout
,
capture_output
=
True
,
check
=
True
,
)
out_f
.
write
(
res
.
stdout
)
except
subprocess
.
CalledProcessError
as
e
:
logger
.
error
(
f
"cmd:
{
e
.
cmd
}
, err:
{
e
.
stderr
.
decode
(
'utf-8'
)
}
"
)
os
.
remove
(
lexicon_graph
)
raise
except
AssertionError
:
os
.
remove
(
lexicon_graph
)
raise
return
lexicon_graph
def
create_LG
(
kaldi_root
:
Path
,
fst_dir
:
Path
,
unique_label
:
str
,
lexicon_graph
:
Path
,
grammar_graph
:
Path
,
)
->
Path
:
lg_graph
=
fst_dir
/
f
"LG.
{
unique_label
}
.fst"
if
not
lg_graph
.
exists
():
logger
.
info
(
f
"Creating
{
lg_graph
}
"
)
fsttablecompose
=
kaldi_root
/
"src/fstbin/fsttablecompose"
fstdeterminizestar
=
kaldi_root
/
"src/fstbin/fstdeterminizestar"
fstminimizeencoded
=
kaldi_root
/
"src/fstbin/fstminimizeencoded"
fstpushspecial
=
kaldi_root
/
"src/fstbin/fstpushspecial"
fstarcsort
=
kaldi_root
/
"tools/openfst-1.6.7/bin/fstarcsort"
try
:
with
open
(
lg_graph
,
"wb"
)
as
out_f
:
res
=
subprocess
.
run
(
[
fsttablecompose
,
lexicon_graph
,
grammar_graph
],
capture_output
=
True
,
check
=
True
,
)
res
=
subprocess
.
run
(
[
fstdeterminizestar
,
"--use-log=true"
,
],
input
=
res
.
stdout
,
capture_output
=
True
,
)
res
=
subprocess
.
run
(
[
fstminimizeencoded
],
input
=
res
.
stdout
,
capture_output
=
True
,
check
=
True
,
)
res
=
subprocess
.
run
(
[
fstpushspecial
],
input
=
res
.
stdout
,
capture_output
=
True
,
check
=
True
,
)
res
=
subprocess
.
run
(
[
fstarcsort
,
"--sort_type=ilabel"
],
input
=
res
.
stdout
,
capture_output
=
True
,
check
=
True
,
)
out_f
.
write
(
res
.
stdout
)
except
subprocess
.
CalledProcessError
as
e
:
logger
.
error
(
f
"cmd:
{
e
.
cmd
}
, err:
{
e
.
stderr
.
decode
(
'utf-8'
)
}
"
)
os
.
remove
(
lg_graph
)
raise
return
lg_graph
def
create_H
(
kaldi_root
:
Path
,
fst_dir
:
Path
,
disambig_out_units_file
:
Path
,
in_labels
:
str
,
vocab
:
Dictionary
,
blk_sym
:
str
,
silence_symbol
:
Optional
[
str
],
)
->
(
Path
,
Path
,
Path
):
h_graph
=
(
fst_dir
/
f
"H.
{
in_labels
}{
'_'
+
silence_symbol
if
silence_symbol
else
''
}
.fst"
)
h_out_units_file
=
fst_dir
/
f
"kaldi_dict.h_out.
{
in_labels
}
.txt"
disambig_in_units_file_int
=
Path
(
str
(
h_graph
)
+
"isym_disambig.int"
)
disambig_out_units_file_int
=
Path
(
str
(
disambig_out_units_file
)
+
".int"
)
if
(
not
h_graph
.
exists
()
or
not
h_out_units_file
.
exists
()
or
not
disambig_in_units_file_int
.
exists
()
):
logger
.
info
(
f
"Creating
{
h_graph
}
"
)
eps_sym
=
"<eps>"
num_disambig
=
0
osymbols
=
[]
with
open
(
disambig_out_units_file
,
"r"
)
as
f
,
open
(
disambig_out_units_file_int
,
"w"
)
as
out_f
:
for
line
in
f
:
symb
,
id
=
line
.
rstrip
().
split
()
if
line
.
startswith
(
"#"
):
num_disambig
+=
1
print
(
id
,
file
=
out_f
)
else
:
if
len
(
osymbols
)
==
0
:
assert
symb
==
eps_sym
,
symb
osymbols
.
append
((
symb
,
id
))
i_idx
=
0
isymbols
=
[(
eps_sym
,
0
)]
imap
=
{}
for
i
,
s
in
enumerate
(
vocab
.
symbols
):
i_idx
+=
1
isymbols
.
append
((
s
,
i_idx
))
imap
[
s
]
=
i_idx
fst_str
=
[]
node_idx
=
0
root_node
=
node_idx
special_symbols
=
[
blk_sym
]
if
silence_symbol
is
not
None
:
special_symbols
.
append
(
silence_symbol
)
for
ss
in
special_symbols
:
fst_str
.
append
(
"{} {} {} {}"
.
format
(
root_node
,
root_node
,
ss
,
eps_sym
))
for
symbol
,
_
in
osymbols
:
if
symbol
==
eps_sym
or
symbol
.
startswith
(
"#"
):
continue
node_idx
+=
1
# 1. from root to emitting state
fst_str
.
append
(
"{} {} {} {}"
.
format
(
root_node
,
node_idx
,
symbol
,
symbol
))
# 2. from emitting state back to root
fst_str
.
append
(
"{} {} {} {}"
.
format
(
node_idx
,
root_node
,
eps_sym
,
eps_sym
))
# 3. from emitting state to optional blank state
pre_node
=
node_idx
node_idx
+=
1
for
ss
in
special_symbols
:
fst_str
.
append
(
"{} {} {} {}"
.
format
(
pre_node
,
node_idx
,
ss
,
eps_sym
))
# 4. from blank state back to root
fst_str
.
append
(
"{} {} {} {}"
.
format
(
node_idx
,
root_node
,
eps_sym
,
eps_sym
))
fst_str
.
append
(
"{}"
.
format
(
root_node
))
fst_str
=
"
\n
"
.
join
(
fst_str
)
h_str
=
str
(
h_graph
)
isym_file
=
h_str
+
".isym"
with
open
(
isym_file
,
"w"
)
as
f
:
for
sym
,
id
in
isymbols
:
f
.
write
(
"{} {}
\n
"
.
format
(
sym
,
id
))
with
open
(
h_out_units_file
,
"w"
)
as
f
:
for
sym
,
id
in
osymbols
:
f
.
write
(
"{} {}
\n
"
.
format
(
sym
,
id
))
with
open
(
disambig_in_units_file_int
,
"w"
)
as
f
:
disam_sym_id
=
len
(
isymbols
)
for
_
in
range
(
num_disambig
):
f
.
write
(
"{}
\n
"
.
format
(
disam_sym_id
))
disam_sym_id
+=
1
fstcompile
=
kaldi_root
/
"tools/openfst-1.6.7/bin/fstcompile"
fstaddselfloops
=
kaldi_root
/
"src/fstbin/fstaddselfloops"
fstarcsort
=
kaldi_root
/
"tools/openfst-1.6.7/bin/fstarcsort"
try
:
with
open
(
h_graph
,
"wb"
)
as
out_f
:
res
=
subprocess
.
run
(
[
fstcompile
,
f
"--isymbols=
{
isym_file
}
"
,
f
"--osymbols=
{
h_out_units_file
}
"
,
"--keep_isymbols=false"
,
"--keep_osymbols=false"
,
],
input
=
str
.
encode
(
fst_str
),
capture_output
=
True
,
check
=
True
,
)
res
=
subprocess
.
run
(
[
fstaddselfloops
,
disambig_in_units_file_int
,
disambig_out_units_file_int
,
],
input
=
res
.
stdout
,
capture_output
=
True
,
check
=
True
,
)
res
=
subprocess
.
run
(
[
fstarcsort
,
"--sort_type=olabel"
],
input
=
res
.
stdout
,
capture_output
=
True
,
check
=
True
,
)
out_f
.
write
(
res
.
stdout
)
except
subprocess
.
CalledProcessError
as
e
:
logger
.
error
(
f
"cmd:
{
e
.
cmd
}
, err:
{
e
.
stderr
.
decode
(
'utf-8'
)
}
"
)
os
.
remove
(
h_graph
)
raise
return
h_graph
,
h_out_units_file
,
disambig_in_units_file_int
def
create_HLGa
(
kaldi_root
:
Path
,
fst_dir
:
Path
,
unique_label
:
str
,
h_graph
:
Path
,
lg_graph
:
Path
,
disambig_in_words_file_int
:
Path
,
)
->
Path
:
hlga_graph
=
fst_dir
/
f
"HLGa.
{
unique_label
}
.fst"
if
not
hlga_graph
.
exists
():
logger
.
info
(
f
"Creating
{
hlga_graph
}
"
)
fsttablecompose
=
kaldi_root
/
"src/fstbin/fsttablecompose"
fstdeterminizestar
=
kaldi_root
/
"src/fstbin/fstdeterminizestar"
fstrmsymbols
=
kaldi_root
/
"src/fstbin/fstrmsymbols"
fstrmepslocal
=
kaldi_root
/
"src/fstbin/fstrmepslocal"
fstminimizeencoded
=
kaldi_root
/
"src/fstbin/fstminimizeencoded"
try
:
with
open
(
hlga_graph
,
"wb"
)
as
out_f
:
res
=
subprocess
.
run
(
[
fsttablecompose
,
h_graph
,
lg_graph
,
],
capture_output
=
True
,
check
=
True
,
)
res
=
subprocess
.
run
(
[
fstdeterminizestar
,
"--use-log=true"
],
input
=
res
.
stdout
,
capture_output
=
True
,
check
=
True
,
)
res
=
subprocess
.
run
(
[
fstrmsymbols
,
disambig_in_words_file_int
],
input
=
res
.
stdout
,
capture_output
=
True
,
check
=
True
,
)
res
=
subprocess
.
run
(
[
fstrmepslocal
],
input
=
res
.
stdout
,
capture_output
=
True
,
check
=
True
,
)
res
=
subprocess
.
run
(
[
fstminimizeencoded
],
input
=
res
.
stdout
,
capture_output
=
True
,
check
=
True
,
)
out_f
.
write
(
res
.
stdout
)
except
subprocess
.
CalledProcessError
as
e
:
logger
.
error
(
f
"cmd:
{
e
.
cmd
}
, err:
{
e
.
stderr
.
decode
(
'utf-8'
)
}
"
)
os
.
remove
(
hlga_graph
)
raise
return
hlga_graph
def
create_HLa
(
kaldi_root
:
Path
,
fst_dir
:
Path
,
unique_label
:
str
,
h_graph
:
Path
,
l_graph
:
Path
,
disambig_in_words_file_int
:
Path
,
)
->
Path
:
hla_graph
=
fst_dir
/
f
"HLa.
{
unique_label
}
.fst"
if
not
hla_graph
.
exists
():
logger
.
info
(
f
"Creating
{
hla_graph
}
"
)
fsttablecompose
=
kaldi_root
/
"src/fstbin/fsttablecompose"
fstdeterminizestar
=
kaldi_root
/
"src/fstbin/fstdeterminizestar"
fstrmsymbols
=
kaldi_root
/
"src/fstbin/fstrmsymbols"
fstrmepslocal
=
kaldi_root
/
"src/fstbin/fstrmepslocal"
fstminimizeencoded
=
kaldi_root
/
"src/fstbin/fstminimizeencoded"
try
:
with
open
(
hla_graph
,
"wb"
)
as
out_f
:
res
=
subprocess
.
run
(
[
fsttablecompose
,
h_graph
,
l_graph
,
],
capture_output
=
True
,
check
=
True
,
)
res
=
subprocess
.
run
(
[
fstdeterminizestar
,
"--use-log=true"
],
input
=
res
.
stdout
,
capture_output
=
True
,
check
=
True
,
)
res
=
subprocess
.
run
(
[
fstrmsymbols
,
disambig_in_words_file_int
],
input
=
res
.
stdout
,
capture_output
=
True
,
check
=
True
,
)
res
=
subprocess
.
run
(
[
fstrmepslocal
],
input
=
res
.
stdout
,
capture_output
=
True
,
check
=
True
,
)
res
=
subprocess
.
run
(
[
fstminimizeencoded
],
input
=
res
.
stdout
,
capture_output
=
True
,
check
=
True
,
)
out_f
.
write
(
res
.
stdout
)
except
subprocess
.
CalledProcessError
as
e
:
logger
.
error
(
f
"cmd:
{
e
.
cmd
}
, err:
{
e
.
stderr
.
decode
(
'utf-8'
)
}
"
)
os
.
remove
(
hla_graph
)
raise
return
hla_graph
def
create_HLG
(
kaldi_root
:
Path
,
fst_dir
:
Path
,
unique_label
:
str
,
hlga_graph
:
Path
,
prefix
:
str
=
"HLG"
,
)
->
Path
:
hlg_graph
=
fst_dir
/
f
"
{
prefix
}
.
{
unique_label
}
.fst"
if
not
hlg_graph
.
exists
():
logger
.
info
(
f
"Creating
{
hlg_graph
}
"
)
add_self_loop
=
script_dir
/
"add-self-loop-simple"
kaldi_src
=
kaldi_root
/
"src"
kaldi_lib
=
kaldi_src
/
"lib"
try
:
if
not
add_self_loop
.
exists
():
fst_include
=
kaldi_root
/
"tools/openfst-1.6.7/include"
add_self_loop_src
=
script_dir
/
"add-self-loop-simple.cc"
subprocess
.
run
(
[
"c++"
,
f
"-I
{
kaldi_src
}
"
,
f
"-I
{
fst_include
}
"
,
f
"-L
{
kaldi_lib
}
"
,
add_self_loop_src
,
"-lkaldi-base"
,
"-lkaldi-fstext"
,
"-o"
,
add_self_loop
,
],
check
=
True
,
)
my_env
=
os
.
environ
.
copy
()
my_env
[
"LD_LIBRARY_PATH"
]
=
f
"
{
kaldi_lib
}
:
{
my_env
[
'LD_LIBRARY_PATH'
]
}
"
subprocess
.
run
(
[
add_self_loop
,
hlga_graph
,
hlg_graph
,
],
check
=
True
,
capture_output
=
True
,
env
=
my_env
,
)
except
subprocess
.
CalledProcessError
as
e
:
logger
.
error
(
f
"cmd:
{
e
.
cmd
}
, err:
{
e
.
stderr
.
decode
(
'utf-8'
)
}
"
)
raise
return
hlg_graph
def
initalize_kaldi
(
cfg
:
KaldiInitializerConfig
)
->
Path
:
if
cfg
.
fst_dir
is
None
:
cfg
.
fst_dir
=
osp
.
join
(
cfg
.
data_dir
,
"kaldi"
)
if
cfg
.
out_labels
is
None
:
cfg
.
out_labels
=
cfg
.
in_labels
kaldi_root
=
Path
(
cfg
.
kaldi_root
)
data_dir
=
Path
(
cfg
.
data_dir
)
fst_dir
=
Path
(
cfg
.
fst_dir
)
fst_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
arpa_base
=
osp
.
splitext
(
osp
.
basename
(
cfg
.
lm_arpa
))[
0
]
unique_label
=
f
"
{
cfg
.
in_labels
}
.
{
arpa_base
}
"
with
open
(
data_dir
/
f
"dict.
{
cfg
.
in_labels
}
.txt"
,
"r"
)
as
f
:
vocab
=
Dictionary
.
load
(
f
)
in_units_file
=
create_units
(
fst_dir
,
cfg
.
in_labels
,
vocab
)
grammar_graph
,
out_words_file
=
create_G
(
kaldi_root
,
fst_dir
,
Path
(
cfg
.
lm_arpa
),
arpa_base
)
disambig_lexicon_file
,
disambig_L_in_units_file
=
create_lexicon
(
cfg
,
fst_dir
,
unique_label
,
in_units_file
,
out_words_file
)
h_graph
,
h_out_units_file
,
disambig_in_units_file_int
=
create_H
(
kaldi_root
,
fst_dir
,
disambig_L_in_units_file
,
cfg
.
in_labels
,
vocab
,
cfg
.
blank_symbol
,
cfg
.
silence_symbol
,
)
lexicon_graph
=
create_L
(
kaldi_root
,
fst_dir
,
unique_label
,
disambig_lexicon_file
,
disambig_L_in_units_file
,
out_words_file
,
)
lg_graph
=
create_LG
(
kaldi_root
,
fst_dir
,
unique_label
,
lexicon_graph
,
grammar_graph
)
hlga_graph
=
create_HLGa
(
kaldi_root
,
fst_dir
,
unique_label
,
h_graph
,
lg_graph
,
disambig_in_units_file_int
)
hlg_graph
=
create_HLG
(
kaldi_root
,
fst_dir
,
unique_label
,
hlga_graph
)
# for debugging
# hla_graph = create_HLa(kaldi_root, fst_dir, unique_label, h_graph, lexicon_graph, disambig_in_units_file_int)
# hl_graph = create_HLG(kaldi_root, fst_dir, unique_label, hla_graph, prefix="HL_looped")
# create_HLG(kaldi_root, fst_dir, "phnc", h_graph, prefix="H_looped")
return
hlg_graph
@
hydra
.
main
(
config_path
=
config_path
,
config_name
=
"kaldi_initializer"
)
def
cli_main
(
cfg
:
KaldiInitializerConfig
)
->
None
:
container
=
OmegaConf
.
to_container
(
cfg
,
resolve
=
True
,
enum_to_str
=
True
)
cfg
=
OmegaConf
.
create
(
container
)
OmegaConf
.
set_struct
(
cfg
,
True
)
initalize_kaldi
(
cfg
)
if
__name__
==
"__main__"
:
logging
.
root
.
setLevel
(
logging
.
INFO
)
logging
.
basicConfig
(
level
=
logging
.
INFO
)
try
:
from
hydra._internal.utils
import
(
get_args
,
)
# pylint: disable=import-outside-toplevel
cfg_name
=
get_args
().
config_name
or
"kaldi_initializer"
except
ImportError
:
logger
.
warning
(
"Failed to get config name from hydra args"
)
cfg_name
=
"kaldi_initializer"
cs
=
ConfigStore
.
instance
()
cs
.
store
(
name
=
cfg_name
,
node
=
KaldiInitializerConfig
)
cli_main
()
examples/speech_recognition/models/__init__.py
0 → 100644
View file @
c394d7d1
import
importlib
import
os
for
file
in
sorted
(
os
.
listdir
(
os
.
path
.
dirname
(
__file__
))):
if
file
.
endswith
(
".py"
)
and
not
file
.
startswith
(
"_"
):
model_name
=
file
[:
file
.
find
(
".py"
)]
importlib
.
import_module
(
"examples.speech_recognition.models."
+
model_name
)
examples/speech_recognition/models/vggtransformer.py
0 → 100644
View file @
c394d7d1
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
argparse
import
math
from
collections.abc
import
Iterable
import
torch
import
torch.nn
as
nn
from
examples.speech_recognition.data.data_utils
import
lengths_to_encoder_padding_mask
from
fairseq
import
utils
from
fairseq.models
import
(
FairseqEncoder
,
FairseqEncoderDecoderModel
,
FairseqEncoderModel
,
FairseqIncrementalDecoder
,
register_model
,
register_model_architecture
,
)
from
fairseq.modules
import
(
LinearizedConvolution
,
TransformerDecoderLayer
,
TransformerEncoderLayer
,
VGGBlock
,
)
@
register_model
(
"asr_vggtransformer"
)
class
VGGTransformerModel
(
FairseqEncoderDecoderModel
):
"""
Transformers with convolutional context for ASR
https://arxiv.org/abs/1904.11660
"""
def
__init__
(
self
,
encoder
,
decoder
):
super
().
__init__
(
encoder
,
decoder
)
@
staticmethod
def
add_args
(
parser
):
"""Add model-specific arguments to the parser."""
parser
.
add_argument
(
"--input-feat-per-channel"
,
type
=
int
,
metavar
=
"N"
,
help
=
"encoder input dimension per input channel"
,
)
parser
.
add_argument
(
"--vggblock-enc-config"
,
type
=
str
,
metavar
=
"EXPR"
,
help
=
"""
an array of tuples each containing the configuration of one vggblock:
[(out_channels,
conv_kernel_size,
pooling_kernel_size,
num_conv_layers,
use_layer_norm), ...])
"""
,
)
parser
.
add_argument
(
"--transformer-enc-config"
,
type
=
str
,
metavar
=
"EXPR"
,
help
=
""""
a tuple containing the configuration of the encoder transformer layers
configurations:
[(input_dim,
num_heads,
ffn_dim,
normalize_before,
dropout,
attention_dropout,
relu_dropout), ...]')
"""
,
)
parser
.
add_argument
(
"--enc-output-dim"
,
type
=
int
,
metavar
=
"N"
,
help
=
"""
encoder output dimension, can be None. If specified, projecting the
transformer output to the specified dimension"""
,
)
parser
.
add_argument
(
"--in-channels"
,
type
=
int
,
metavar
=
"N"
,
help
=
"number of encoder input channels"
,
)
parser
.
add_argument
(
"--tgt-embed-dim"
,
type
=
int
,
metavar
=
"N"
,
help
=
"embedding dimension of the decoder target tokens"
,
)
parser
.
add_argument
(
"--transformer-dec-config"
,
type
=
str
,
metavar
=
"EXPR"
,
help
=
"""
a tuple containing the configuration of the decoder transformer layers
configurations:
[(input_dim,
num_heads,
ffn_dim,
normalize_before,
dropout,
attention_dropout,
relu_dropout), ...]
"""
,
)
parser
.
add_argument
(
"--conv-dec-config"
,
type
=
str
,
metavar
=
"EXPR"
,
help
=
"""
an array of tuples for the decoder 1-D convolution config
[(out_channels, conv_kernel_size, use_layer_norm), ...]"""
,
)
@
classmethod
def
build_encoder
(
cls
,
args
,
task
):
return
VGGTransformerEncoder
(
input_feat_per_channel
=
args
.
input_feat_per_channel
,
vggblock_config
=
eval
(
args
.
vggblock_enc_config
),
transformer_config
=
eval
(
args
.
transformer_enc_config
),
encoder_output_dim
=
args
.
enc_output_dim
,
in_channels
=
args
.
in_channels
,
)
@
classmethod
def
build_decoder
(
cls
,
args
,
task
):
return
TransformerDecoder
(
dictionary
=
task
.
target_dictionary
,
embed_dim
=
args
.
tgt_embed_dim
,
transformer_config
=
eval
(
args
.
transformer_dec_config
),
conv_config
=
eval
(
args
.
conv_dec_config
),
encoder_output_dim
=
args
.
enc_output_dim
,
)
@
classmethod
def
build_model
(
cls
,
args
,
task
):
"""Build a new model instance."""
# make sure that all args are properly defaulted
# (in case there are any new ones)
base_architecture
(
args
)
encoder
=
cls
.
build_encoder
(
args
,
task
)
decoder
=
cls
.
build_decoder
(
args
,
task
)
return
cls
(
encoder
,
decoder
)
def
get_normalized_probs
(
self
,
net_output
,
log_probs
,
sample
=
None
):
# net_output['encoder_out'] is a (B, T, D) tensor
lprobs
=
super
().
get_normalized_probs
(
net_output
,
log_probs
,
sample
)
lprobs
.
batch_first
=
True
return
lprobs
DEFAULT_ENC_VGGBLOCK_CONFIG
=
((
32
,
3
,
2
,
2
,
False
),)
*
2
DEFAULT_ENC_TRANSFORMER_CONFIG
=
((
256
,
4
,
1024
,
True
,
0.2
,
0.2
,
0.2
),)
*
2
# 256: embedding dimension
# 4: number of heads
# 1024: FFN
# True: apply layerNorm before (dropout + resiaul) instead of after
# 0.2 (dropout): dropout after MultiheadAttention and second FC
# 0.2 (attention_dropout): dropout in MultiheadAttention
# 0.2 (relu_dropout): dropout after ReLu
DEFAULT_DEC_TRANSFORMER_CONFIG
=
((
256
,
2
,
1024
,
True
,
0.2
,
0.2
,
0.2
),)
*
2
DEFAULT_DEC_CONV_CONFIG
=
((
256
,
3
,
True
),)
*
2
# TODO: repace transformer encoder config from one liner
# to explicit args to get rid of this transformation
def
prepare_transformer_encoder_params
(
input_dim
,
num_heads
,
ffn_dim
,
normalize_before
,
dropout
,
attention_dropout
,
relu_dropout
,
):
args
=
argparse
.
Namespace
()
args
.
encoder_embed_dim
=
input_dim
args
.
encoder_attention_heads
=
num_heads
args
.
attention_dropout
=
attention_dropout
args
.
dropout
=
dropout
args
.
activation_dropout
=
relu_dropout
args
.
encoder_normalize_before
=
normalize_before
args
.
encoder_ffn_embed_dim
=
ffn_dim
return
args
def
prepare_transformer_decoder_params
(
input_dim
,
num_heads
,
ffn_dim
,
normalize_before
,
dropout
,
attention_dropout
,
relu_dropout
,
):
args
=
argparse
.
Namespace
()
args
.
decoder_embed_dim
=
input_dim
args
.
decoder_attention_heads
=
num_heads
args
.
attention_dropout
=
attention_dropout
args
.
dropout
=
dropout
args
.
activation_dropout
=
relu_dropout
args
.
decoder_normalize_before
=
normalize_before
args
.
decoder_ffn_embed_dim
=
ffn_dim
return
args
class
VGGTransformerEncoder
(
FairseqEncoder
):
"""VGG + Transformer encoder"""
def
__init__
(
self
,
input_feat_per_channel
,
vggblock_config
=
DEFAULT_ENC_VGGBLOCK_CONFIG
,
transformer_config
=
DEFAULT_ENC_TRANSFORMER_CONFIG
,
encoder_output_dim
=
512
,
in_channels
=
1
,
transformer_context
=
None
,
transformer_sampling
=
None
,
):
"""constructor for VGGTransformerEncoder
Args:
- input_feat_per_channel: feature dim (not including stacked,
just base feature)
- in_channel: # input channels (e.g., if stack 8 feature vector
together, this is 8)
- vggblock_config: configuration of vggblock, see comments on
DEFAULT_ENC_VGGBLOCK_CONFIG
- transformer_config: configuration of transformer layer, see comments
on DEFAULT_ENC_TRANSFORMER_CONFIG
- encoder_output_dim: final transformer output embedding dimension
- transformer_context: (left, right) if set, self-attention will be focused
on (t-left, t+right)
- transformer_sampling: an iterable of int, must match with
len(transformer_config), transformer_sampling[i] indicates sampling
factor for i-th transformer layer, after multihead att and feedfoward
part
"""
super
().
__init__
(
None
)
self
.
num_vggblocks
=
0
if
vggblock_config
is
not
None
:
if
not
isinstance
(
vggblock_config
,
Iterable
):
raise
ValueError
(
"vggblock_config is not iterable"
)
self
.
num_vggblocks
=
len
(
vggblock_config
)
self
.
conv_layers
=
nn
.
ModuleList
()
self
.
in_channels
=
in_channels
self
.
input_dim
=
input_feat_per_channel
self
.
pooling_kernel_sizes
=
[]
if
vggblock_config
is
not
None
:
for
_
,
config
in
enumerate
(
vggblock_config
):
(
out_channels
,
conv_kernel_size
,
pooling_kernel_size
,
num_conv_layers
,
layer_norm
,
)
=
config
self
.
conv_layers
.
append
(
VGGBlock
(
in_channels
,
out_channels
,
conv_kernel_size
,
pooling_kernel_size
,
num_conv_layers
,
input_dim
=
input_feat_per_channel
,
layer_norm
=
layer_norm
,
)
)
self
.
pooling_kernel_sizes
.
append
(
pooling_kernel_size
)
in_channels
=
out_channels
input_feat_per_channel
=
self
.
conv_layers
[
-
1
].
output_dim
transformer_input_dim
=
self
.
infer_conv_output_dim
(
self
.
in_channels
,
self
.
input_dim
)
# transformer_input_dim is the output dimension of VGG part
self
.
validate_transformer_config
(
transformer_config
)
self
.
transformer_context
=
self
.
parse_transformer_context
(
transformer_context
)
self
.
transformer_sampling
=
self
.
parse_transformer_sampling
(
transformer_sampling
,
len
(
transformer_config
)
)
self
.
transformer_layers
=
nn
.
ModuleList
()
if
transformer_input_dim
!=
transformer_config
[
0
][
0
]:
self
.
transformer_layers
.
append
(
Linear
(
transformer_input_dim
,
transformer_config
[
0
][
0
])
)
self
.
transformer_layers
.
append
(
TransformerEncoderLayer
(
prepare_transformer_encoder_params
(
*
transformer_config
[
0
])
)
)
for
i
in
range
(
1
,
len
(
transformer_config
)):
if
transformer_config
[
i
-
1
][
0
]
!=
transformer_config
[
i
][
0
]:
self
.
transformer_layers
.
append
(
Linear
(
transformer_config
[
i
-
1
][
0
],
transformer_config
[
i
][
0
])
)
self
.
transformer_layers
.
append
(
TransformerEncoderLayer
(
prepare_transformer_encoder_params
(
*
transformer_config
[
i
])
)
)
self
.
encoder_output_dim
=
encoder_output_dim
self
.
transformer_layers
.
extend
(
[
Linear
(
transformer_config
[
-
1
][
0
],
encoder_output_dim
),
LayerNorm
(
encoder_output_dim
),
]
)
def
forward
(
self
,
src_tokens
,
src_lengths
,
**
kwargs
):
"""
src_tokens: padded tensor (B, T, C * feat)
src_lengths: tensor of original lengths of input utterances (B,)
"""
bsz
,
max_seq_len
,
_
=
src_tokens
.
size
()
x
=
src_tokens
.
view
(
bsz
,
max_seq_len
,
self
.
in_channels
,
self
.
input_dim
)
x
=
x
.
transpose
(
1
,
2
).
contiguous
()
# (B, C, T, feat)
for
layer_idx
in
range
(
len
(
self
.
conv_layers
)):
x
=
self
.
conv_layers
[
layer_idx
](
x
)
bsz
,
_
,
output_seq_len
,
_
=
x
.
size
()
# (B, C, T, feat) -> (B, T, C, feat) -> (T, B, C, feat) -> (T, B, C * feat)
x
=
x
.
transpose
(
1
,
2
).
transpose
(
0
,
1
)
x
=
x
.
contiguous
().
view
(
output_seq_len
,
bsz
,
-
1
)
input_lengths
=
src_lengths
.
clone
()
for
s
in
self
.
pooling_kernel_sizes
:
input_lengths
=
(
input_lengths
.
float
()
/
s
).
ceil
().
long
()
encoder_padding_mask
,
_
=
lengths_to_encoder_padding_mask
(
input_lengths
,
batch_first
=
True
)
if
not
encoder_padding_mask
.
any
():
encoder_padding_mask
=
None
subsampling_factor
=
int
(
max_seq_len
*
1.0
/
output_seq_len
+
0.5
)
attn_mask
=
self
.
lengths_to_attn_mask
(
input_lengths
,
subsampling_factor
)
transformer_layer_idx
=
0
for
layer_idx
in
range
(
len
(
self
.
transformer_layers
)):
if
isinstance
(
self
.
transformer_layers
[
layer_idx
],
TransformerEncoderLayer
):
x
=
self
.
transformer_layers
[
layer_idx
](
x
,
encoder_padding_mask
,
attn_mask
)
if
self
.
transformer_sampling
[
transformer_layer_idx
]
!=
1
:
sampling_factor
=
self
.
transformer_sampling
[
transformer_layer_idx
]
x
,
encoder_padding_mask
,
attn_mask
=
self
.
slice
(
x
,
encoder_padding_mask
,
attn_mask
,
sampling_factor
)
transformer_layer_idx
+=
1
else
:
x
=
self
.
transformer_layers
[
layer_idx
](
x
)
# encoder_padding_maks is a (T x B) tensor, its [t, b] elements indicate
# whether encoder_output[t, b] is valid or not (valid=0, invalid=1)
return
{
"encoder_out"
:
x
,
# (T, B, C)
"encoder_padding_mask"
:
encoder_padding_mask
.
t
()
if
encoder_padding_mask
is
not
None
else
None
,
# (B, T) --> (T, B)
}
def
infer_conv_output_dim
(
self
,
in_channels
,
input_dim
):
sample_seq_len
=
200
sample_bsz
=
10
x
=
torch
.
randn
(
sample_bsz
,
in_channels
,
sample_seq_len
,
input_dim
)
for
i
,
_
in
enumerate
(
self
.
conv_layers
):
x
=
self
.
conv_layers
[
i
](
x
)
x
=
x
.
transpose
(
1
,
2
)
mb
,
seq
=
x
.
size
()[:
2
]
return
x
.
contiguous
().
view
(
mb
,
seq
,
-
1
).
size
(
-
1
)
def
validate_transformer_config
(
self
,
transformer_config
):
for
config
in
transformer_config
:
input_dim
,
num_heads
=
config
[:
2
]
if
input_dim
%
num_heads
!=
0
:
msg
=
(
"ERROR in transformer config {}: "
.
format
(
config
)
+
"input dimension {} "
.
format
(
input_dim
)
+
"not dividable by number of heads {}"
.
format
(
num_heads
)
)
raise
ValueError
(
msg
)
def
parse_transformer_context
(
self
,
transformer_context
):
"""
transformer_context can be the following:
- None; indicates no context is used, i.e.,
transformer can access full context
- a tuple/list of two int; indicates left and right context,
any number <0 indicates infinite context
* e.g., (5, 6) indicates that for query at x_t, transformer can
access [t-5, t+6] (inclusive)
* e.g., (-1, 6) indicates that for query at x_t, transformer can
access [0, t+6] (inclusive)
"""
if
transformer_context
is
None
:
return
None
if
not
isinstance
(
transformer_context
,
Iterable
):
raise
ValueError
(
"transformer context must be Iterable if it is not None"
)
if
len
(
transformer_context
)
!=
2
:
raise
ValueError
(
"transformer context must have length 2"
)
left_context
=
transformer_context
[
0
]
if
left_context
<
0
:
left_context
=
None
right_context
=
transformer_context
[
1
]
if
right_context
<
0
:
right_context
=
None
if
left_context
is
None
and
right_context
is
None
:
return
None
return
(
left_context
,
right_context
)
def
parse_transformer_sampling
(
self
,
transformer_sampling
,
num_layers
):
"""
parsing transformer sampling configuration
Args:
- transformer_sampling, accepted input:
* None, indicating no sampling
* an Iterable with int (>0) as element
- num_layers, expected number of transformer layers, must match with
the length of transformer_sampling if it is not None
Returns:
- A tuple with length num_layers
"""
if
transformer_sampling
is
None
:
return
(
1
,)
*
num_layers
if
not
isinstance
(
transformer_sampling
,
Iterable
):
raise
ValueError
(
"transformer_sampling must be an iterable if it is not None"
)
if
len
(
transformer_sampling
)
!=
num_layers
:
raise
ValueError
(
"transformer_sampling {} does not match with the number "
"of layers {}"
.
format
(
transformer_sampling
,
num_layers
)
)
for
layer
,
value
in
enumerate
(
transformer_sampling
):
if
not
isinstance
(
value
,
int
):
raise
ValueError
(
"Invalid value in transformer_sampling: "
)
if
value
<
1
:
raise
ValueError
(
"{} layer's subsampling is {}."
.
format
(
layer
,
value
)
+
" This is not allowed! "
)
return
transformer_sampling
def
slice
(
self
,
embedding
,
padding_mask
,
attn_mask
,
sampling_factor
):
"""
embedding is a (T, B, D) tensor
padding_mask is a (B, T) tensor or None
attn_mask is a (T, T) tensor or None
"""
embedding
=
embedding
[::
sampling_factor
,
:,
:]
if
padding_mask
is
not
None
:
padding_mask
=
padding_mask
[:,
::
sampling_factor
]
if
attn_mask
is
not
None
:
attn_mask
=
attn_mask
[::
sampling_factor
,
::
sampling_factor
]
return
embedding
,
padding_mask
,
attn_mask
def
lengths_to_attn_mask
(
self
,
input_lengths
,
subsampling_factor
=
1
):
"""
create attention mask according to sequence lengths and transformer
context
Args:
- input_lengths: (B, )-shape Int/Long tensor; input_lengths[b] is
the length of b-th sequence
- subsampling_factor: int
* Note that the left_context and right_context is specified in
the input frame-level while input to transformer may already
go through subsampling (e.g., the use of striding in vggblock)
we use subsampling_factor to scale the left/right context
Return:
- a (T, T) binary tensor or None, where T is max(input_lengths)
* if self.transformer_context is None, None
* if left_context is None,
* attn_mask[t, t + right_context + 1:] = 1
* others = 0
* if right_context is None,
* attn_mask[t, 0:t - left_context] = 1
* others = 0
* elsif
* attn_mask[t, t - left_context: t + right_context + 1] = 0
* others = 1
"""
if
self
.
transformer_context
is
None
:
return
None
maxT
=
torch
.
max
(
input_lengths
).
item
()
attn_mask
=
torch
.
zeros
(
maxT
,
maxT
)
left_context
=
self
.
transformer_context
[
0
]
right_context
=
self
.
transformer_context
[
1
]
if
left_context
is
not
None
:
left_context
=
math
.
ceil
(
self
.
transformer_context
[
0
]
/
subsampling_factor
)
if
right_context
is
not
None
:
right_context
=
math
.
ceil
(
self
.
transformer_context
[
1
]
/
subsampling_factor
)
for
t
in
range
(
maxT
):
if
left_context
is
not
None
:
st
=
0
en
=
max
(
st
,
t
-
left_context
)
attn_mask
[
t
,
st
:
en
]
=
1
if
right_context
is
not
None
:
st
=
t
+
right_context
+
1
st
=
min
(
st
,
maxT
-
1
)
attn_mask
[
t
,
st
:]
=
1
return
attn_mask
.
to
(
input_lengths
.
device
)
def
reorder_encoder_out
(
self
,
encoder_out
,
new_order
):
encoder_out
[
"encoder_out"
]
=
encoder_out
[
"encoder_out"
].
index_select
(
1
,
new_order
)
if
encoder_out
[
"encoder_padding_mask"
]
is
not
None
:
encoder_out
[
"encoder_padding_mask"
]
=
encoder_out
[
"encoder_padding_mask"
].
index_select
(
1
,
new_order
)
return
encoder_out
class
TransformerDecoder
(
FairseqIncrementalDecoder
):
"""
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs.
Default: ``False``
left_pad (bool, optional): whether the input is left-padded. Default:
``False``
"""
def
__init__
(
self
,
dictionary
,
embed_dim
=
512
,
transformer_config
=
DEFAULT_ENC_TRANSFORMER_CONFIG
,
conv_config
=
DEFAULT_DEC_CONV_CONFIG
,
encoder_output_dim
=
512
,
):
super
().
__init__
(
dictionary
)
vocab_size
=
len
(
dictionary
)
self
.
padding_idx
=
dictionary
.
pad
()
self
.
embed_tokens
=
Embedding
(
vocab_size
,
embed_dim
,
self
.
padding_idx
)
self
.
conv_layers
=
nn
.
ModuleList
()
for
i
in
range
(
len
(
conv_config
)):
out_channels
,
kernel_size
,
layer_norm
=
conv_config
[
i
]
if
i
==
0
:
conv_layer
=
LinearizedConv1d
(
embed_dim
,
out_channels
,
kernel_size
,
padding
=
kernel_size
-
1
)
else
:
conv_layer
=
LinearizedConv1d
(
conv_config
[
i
-
1
][
0
],
out_channels
,
kernel_size
,
padding
=
kernel_size
-
1
,
)
self
.
conv_layers
.
append
(
conv_layer
)
if
layer_norm
:
self
.
conv_layers
.
append
(
nn
.
LayerNorm
(
out_channels
))
self
.
conv_layers
.
append
(
nn
.
ReLU
())
self
.
layers
=
nn
.
ModuleList
()
if
conv_config
[
-
1
][
0
]
!=
transformer_config
[
0
][
0
]:
self
.
layers
.
append
(
Linear
(
conv_config
[
-
1
][
0
],
transformer_config
[
0
][
0
]))
self
.
layers
.
append
(
TransformerDecoderLayer
(
prepare_transformer_decoder_params
(
*
transformer_config
[
0
])
)
)
for
i
in
range
(
1
,
len
(
transformer_config
)):
if
transformer_config
[
i
-
1
][
0
]
!=
transformer_config
[
i
][
0
]:
self
.
layers
.
append
(
Linear
(
transformer_config
[
i
-
1
][
0
],
transformer_config
[
i
][
0
])
)
self
.
layers
.
append
(
TransformerDecoderLayer
(
prepare_transformer_decoder_params
(
*
transformer_config
[
i
])
)
)
self
.
fc_out
=
Linear
(
transformer_config
[
-
1
][
0
],
vocab_size
)
def
forward
(
self
,
prev_output_tokens
,
encoder_out
=
None
,
incremental_state
=
None
):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for input feeding/teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
Returns:
tuple:
- the last decoder layer's output of shape `(batch, tgt_len,
vocab)`
- the last decoder layer's attention weights of shape `(batch,
tgt_len, src_len)`
"""
target_padding_mask
=
(
(
prev_output_tokens
==
self
.
padding_idx
).
to
(
prev_output_tokens
.
device
)
if
incremental_state
is
None
else
None
)
if
incremental_state
is
not
None
:
prev_output_tokens
=
prev_output_tokens
[:,
-
1
:]
# embed tokens
x
=
self
.
embed_tokens
(
prev_output_tokens
)
# B x T x C -> T x B x C
x
=
self
.
_transpose_if_training
(
x
,
incremental_state
)
for
layer
in
self
.
conv_layers
:
if
isinstance
(
layer
,
LinearizedConvolution
):
x
=
layer
(
x
,
incremental_state
)
else
:
x
=
layer
(
x
)
# B x T x C -> T x B x C
x
=
self
.
_transpose_if_inference
(
x
,
incremental_state
)
# decoder layers
for
layer
in
self
.
layers
:
if
isinstance
(
layer
,
TransformerDecoderLayer
):
x
,
*
_
=
layer
(
x
,
(
encoder_out
[
"encoder_out"
]
if
encoder_out
is
not
None
else
None
),
(
encoder_out
[
"encoder_padding_mask"
].
t
()
if
encoder_out
[
"encoder_padding_mask"
]
is
not
None
else
None
),
incremental_state
,
self_attn_mask
=
(
self
.
buffered_future_mask
(
x
)
if
incremental_state
is
None
else
None
),
self_attn_padding_mask
=
(
target_padding_mask
if
incremental_state
is
None
else
None
),
)
else
:
x
=
layer
(
x
)
# T x B x C -> B x T x C
x
=
x
.
transpose
(
0
,
1
)
x
=
self
.
fc_out
(
x
)
return
x
,
None
def
buffered_future_mask
(
self
,
tensor
):
dim
=
tensor
.
size
(
0
)
if
(
not
hasattr
(
self
,
"_future_mask"
)
or
self
.
_future_mask
is
None
or
self
.
_future_mask
.
device
!=
tensor
.
device
):
self
.
_future_mask
=
torch
.
triu
(
utils
.
fill_with_neg_inf
(
tensor
.
new
(
dim
,
dim
)),
1
)
if
self
.
_future_mask
.
size
(
0
)
<
dim
:
self
.
_future_mask
=
torch
.
triu
(
utils
.
fill_with_neg_inf
(
self
.
_future_mask
.
resize_
(
dim
,
dim
)),
1
)
return
self
.
_future_mask
[:
dim
,
:
dim
]
def
_transpose_if_training
(
self
,
x
,
incremental_state
):
if
incremental_state
is
None
:
x
=
x
.
transpose
(
0
,
1
)
return
x
def
_transpose_if_inference
(
self
,
x
,
incremental_state
):
if
incremental_state
:
x
=
x
.
transpose
(
0
,
1
)
return
x
@
register_model
(
"asr_vggtransformer_encoder"
)
class
VGGTransformerEncoderModel
(
FairseqEncoderModel
):
def
__init__
(
self
,
encoder
):
super
().
__init__
(
encoder
)
@
staticmethod
def
add_args
(
parser
):
"""Add model-specific arguments to the parser."""
parser
.
add_argument
(
"--input-feat-per-channel"
,
type
=
int
,
metavar
=
"N"
,
help
=
"encoder input dimension per input channel"
,
)
parser
.
add_argument
(
"--vggblock-enc-config"
,
type
=
str
,
metavar
=
"EXPR"
,
help
=
"""
an array of tuples each containing the configuration of one vggblock
[(out_channels, conv_kernel_size, pooling_kernel_size,num_conv_layers), ...]
"""
,
)
parser
.
add_argument
(
"--transformer-enc-config"
,
type
=
str
,
metavar
=
"EXPR"
,
help
=
"""
a tuple containing the configuration of the Transformer layers
configurations:
[(input_dim,
num_heads,
ffn_dim,
normalize_before,
dropout,
attention_dropout,
relu_dropout), ]"""
,
)
parser
.
add_argument
(
"--enc-output-dim"
,
type
=
int
,
metavar
=
"N"
,
help
=
"encoder output dimension, projecting the LSTM output"
,
)
parser
.
add_argument
(
"--in-channels"
,
type
=
int
,
metavar
=
"N"
,
help
=
"number of encoder input channels"
,
)
parser
.
add_argument
(
"--transformer-context"
,
type
=
str
,
metavar
=
"EXPR"
,
help
=
"""
either None or a tuple of two ints, indicating left/right context a
transformer can have access to"""
,
)
parser
.
add_argument
(
"--transformer-sampling"
,
type
=
str
,
metavar
=
"EXPR"
,
help
=
"""
either None or a tuple of ints, indicating sampling factor in each layer"""
,
)
@
classmethod
def
build_model
(
cls
,
args
,
task
):
"""Build a new model instance."""
base_architecture_enconly
(
args
)
encoder
=
VGGTransformerEncoderOnly
(
vocab_size
=
len
(
task
.
target_dictionary
),
input_feat_per_channel
=
args
.
input_feat_per_channel
,
vggblock_config
=
eval
(
args
.
vggblock_enc_config
),
transformer_config
=
eval
(
args
.
transformer_enc_config
),
encoder_output_dim
=
args
.
enc_output_dim
,
in_channels
=
args
.
in_channels
,
transformer_context
=
eval
(
args
.
transformer_context
),
transformer_sampling
=
eval
(
args
.
transformer_sampling
),
)
return
cls
(
encoder
)
def
get_normalized_probs
(
self
,
net_output
,
log_probs
,
sample
=
None
):
# net_output['encoder_out'] is a (T, B, D) tensor
lprobs
=
super
().
get_normalized_probs
(
net_output
,
log_probs
,
sample
)
# lprobs is a (T, B, D) tensor
# we need to transoose to get (B, T, D) tensor
lprobs
=
lprobs
.
transpose
(
0
,
1
).
contiguous
()
lprobs
.
batch_first
=
True
return
lprobs
class
VGGTransformerEncoderOnly
(
VGGTransformerEncoder
):
def
__init__
(
self
,
vocab_size
,
input_feat_per_channel
,
vggblock_config
=
DEFAULT_ENC_VGGBLOCK_CONFIG
,
transformer_config
=
DEFAULT_ENC_TRANSFORMER_CONFIG
,
encoder_output_dim
=
512
,
in_channels
=
1
,
transformer_context
=
None
,
transformer_sampling
=
None
,
):
super
().
__init__
(
input_feat_per_channel
=
input_feat_per_channel
,
vggblock_config
=
vggblock_config
,
transformer_config
=
transformer_config
,
encoder_output_dim
=
encoder_output_dim
,
in_channels
=
in_channels
,
transformer_context
=
transformer_context
,
transformer_sampling
=
transformer_sampling
,
)
self
.
fc_out
=
Linear
(
self
.
encoder_output_dim
,
vocab_size
)
def
forward
(
self
,
src_tokens
,
src_lengths
,
**
kwargs
):
"""
src_tokens: padded tensor (B, T, C * feat)
src_lengths: tensor of original lengths of input utterances (B,)
"""
enc_out
=
super
().
forward
(
src_tokens
,
src_lengths
)
x
=
self
.
fc_out
(
enc_out
[
"encoder_out"
])
# x = F.log_softmax(x, dim=-1)
# Note: no need this line, because model.get_normalized_prob will call
# log_softmax
return
{
"encoder_out"
:
x
,
# (T, B, C)
"encoder_padding_mask"
:
enc_out
[
"encoder_padding_mask"
],
# (T, B)
}
def
max_positions
(
self
):
"""Maximum input length supported by the encoder."""
return
(
1e6
,
1e6
)
# an arbitrary large number
def
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
):
m
=
nn
.
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
padding_idx
)
# nn.init.uniform_(m.weight, -0.1, 0.1)
# nn.init.constant_(m.weight[padding_idx], 0)
return
m
def
Linear
(
in_features
,
out_features
,
bias
=
True
,
dropout
=
0
):
"""Linear layer (input: N x T x C)"""
m
=
nn
.
Linear
(
in_features
,
out_features
,
bias
=
bias
)
# m.weight.data.uniform_(-0.1, 0.1)
# if bias:
# m.bias.data.uniform_(-0.1, 0.1)
return
m
def
LinearizedConv1d
(
in_channels
,
out_channels
,
kernel_size
,
dropout
=
0
,
**
kwargs
):
"""Weight-normalized Conv1d layer optimized for decoding"""
m
=
LinearizedConvolution
(
in_channels
,
out_channels
,
kernel_size
,
**
kwargs
)
std
=
math
.
sqrt
((
4
*
(
1.0
-
dropout
))
/
(
m
.
kernel_size
[
0
]
*
in_channels
))
nn
.
init
.
normal_
(
m
.
weight
,
mean
=
0
,
std
=
std
)
nn
.
init
.
constant_
(
m
.
bias
,
0
)
return
nn
.
utils
.
weight_norm
(
m
,
dim
=
2
)
def
LayerNorm
(
embedding_dim
):
m
=
nn
.
LayerNorm
(
embedding_dim
)
return
m
# seq2seq models
def
base_architecture
(
args
):
args
.
input_feat_per_channel
=
getattr
(
args
,
"input_feat_per_channel"
,
40
)
args
.
vggblock_enc_config
=
getattr
(
args
,
"vggblock_enc_config"
,
DEFAULT_ENC_VGGBLOCK_CONFIG
)
args
.
transformer_enc_config
=
getattr
(
args
,
"transformer_enc_config"
,
DEFAULT_ENC_TRANSFORMER_CONFIG
)
args
.
enc_output_dim
=
getattr
(
args
,
"enc_output_dim"
,
512
)
args
.
in_channels
=
getattr
(
args
,
"in_channels"
,
1
)
args
.
tgt_embed_dim
=
getattr
(
args
,
"tgt_embed_dim"
,
128
)
args
.
transformer_dec_config
=
getattr
(
args
,
"transformer_dec_config"
,
DEFAULT_ENC_TRANSFORMER_CONFIG
)
args
.
conv_dec_config
=
getattr
(
args
,
"conv_dec_config"
,
DEFAULT_DEC_CONV_CONFIG
)
args
.
transformer_context
=
getattr
(
args
,
"transformer_context"
,
"None"
)
@
register_model_architecture
(
"asr_vggtransformer"
,
"vggtransformer_1"
)
def
vggtransformer_1
(
args
):
args
.
input_feat_per_channel
=
getattr
(
args
,
"input_feat_per_channel"
,
80
)
args
.
vggblock_enc_config
=
getattr
(
args
,
"vggblock_enc_config"
,
"[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
)
args
.
transformer_enc_config
=
getattr
(
args
,
"transformer_enc_config"
,
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 14"
,
)
args
.
enc_output_dim
=
getattr
(
args
,
"enc_output_dim"
,
1024
)
args
.
tgt_embed_dim
=
getattr
(
args
,
"tgt_embed_dim"
,
128
)
args
.
conv_dec_config
=
getattr
(
args
,
"conv_dec_config"
,
"((256, 3, True),) * 4"
)
args
.
transformer_dec_config
=
getattr
(
args
,
"transformer_dec_config"
,
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 4"
,
)
@
register_model_architecture
(
"asr_vggtransformer"
,
"vggtransformer_2"
)
def
vggtransformer_2
(
args
):
args
.
input_feat_per_channel
=
getattr
(
args
,
"input_feat_per_channel"
,
80
)
args
.
vggblock_enc_config
=
getattr
(
args
,
"vggblock_enc_config"
,
"[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
)
args
.
transformer_enc_config
=
getattr
(
args
,
"transformer_enc_config"
,
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 16"
,
)
args
.
enc_output_dim
=
getattr
(
args
,
"enc_output_dim"
,
1024
)
args
.
tgt_embed_dim
=
getattr
(
args
,
"tgt_embed_dim"
,
512
)
args
.
conv_dec_config
=
getattr
(
args
,
"conv_dec_config"
,
"((256, 3, True),) * 4"
)
args
.
transformer_dec_config
=
getattr
(
args
,
"transformer_dec_config"
,
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 6"
,
)
@
register_model_architecture
(
"asr_vggtransformer"
,
"vggtransformer_base"
)
def
vggtransformer_base
(
args
):
args
.
input_feat_per_channel
=
getattr
(
args
,
"input_feat_per_channel"
,
80
)
args
.
vggblock_enc_config
=
getattr
(
args
,
"vggblock_enc_config"
,
"[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
)
args
.
transformer_enc_config
=
getattr
(
args
,
"transformer_enc_config"
,
"((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 12"
)
args
.
enc_output_dim
=
getattr
(
args
,
"enc_output_dim"
,
512
)
args
.
tgt_embed_dim
=
getattr
(
args
,
"tgt_embed_dim"
,
512
)
args
.
conv_dec_config
=
getattr
(
args
,
"conv_dec_config"
,
"((256, 3, True),) * 4"
)
args
.
transformer_dec_config
=
getattr
(
args
,
"transformer_dec_config"
,
"((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 6"
)
# Size estimations:
# Encoder:
# - vggblock param: 64*1*3*3 + 64*64*3*3 + 128*64*3*3 + 128*128*3 = 258K
# Transformer:
# - input dimension adapter: 2560 x 512 -> 1.31M
# - transformer_layers (x12) --> 37.74M
# * MultiheadAttention: 512*512*3 (in_proj) + 512*512 (out_proj) = 1.048M
# * FFN weight: 512*2048*2 = 2.097M
# - output dimension adapter: 512 x 512 -> 0.26 M
# Decoder:
# - LinearizedConv1d: 512 * 256 * 3 + 256 * 256 * 3 * 3
# - transformer_layer: (x6) --> 25.16M
# * MultiheadAttention (self-attention): 512*512*3 + 512*512 = 1.048M
# * MultiheadAttention (encoder-attention): 512*512*3 + 512*512 = 1.048M
# * FFN: 512*2048*2 = 2.097M
# Final FC:
# - FC: 512*5000 = 256K (assuming vocab size 5K)
# In total:
# ~65 M
# CTC models
def
base_architecture_enconly
(
args
):
args
.
input_feat_per_channel
=
getattr
(
args
,
"input_feat_per_channel"
,
40
)
args
.
vggblock_enc_config
=
getattr
(
args
,
"vggblock_enc_config"
,
"[(32, 3, 2, 2, True)] * 2"
)
args
.
transformer_enc_config
=
getattr
(
args
,
"transformer_enc_config"
,
"((256, 4, 1024, True, 0.2, 0.2, 0.2),) * 2"
)
args
.
enc_output_dim
=
getattr
(
args
,
"enc_output_dim"
,
512
)
args
.
in_channels
=
getattr
(
args
,
"in_channels"
,
1
)
args
.
transformer_context
=
getattr
(
args
,
"transformer_context"
,
"None"
)
args
.
transformer_sampling
=
getattr
(
args
,
"transformer_sampling"
,
"None"
)
@
register_model_architecture
(
"asr_vggtransformer_encoder"
,
"vggtransformer_enc_1"
)
def
vggtransformer_enc_1
(
args
):
# vggtransformer_1 is the same as vggtransformer_enc_big, except the number
# of layers is increased to 16
# keep it here for backward compatiablity purpose
args
.
input_feat_per_channel
=
getattr
(
args
,
"input_feat_per_channel"
,
80
)
args
.
vggblock_enc_config
=
getattr
(
args
,
"vggblock_enc_config"
,
"[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
)
args
.
transformer_enc_config
=
getattr
(
args
,
"transformer_enc_config"
,
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 16"
,
)
args
.
enc_output_dim
=
getattr
(
args
,
"enc_output_dim"
,
1024
)
examples/speech_recognition/models/w2l_conv_glu_enc.py
0 → 100644
View file @
c394d7d1
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
fairseq.models
import
(
FairseqEncoder
,
FairseqEncoderModel
,
register_model
,
register_model_architecture
,
)
from
fairseq.modules.fairseq_dropout
import
FairseqDropout
default_conv_enc_config
=
"""[
(400, 13, 170, 0.2),
(440, 14, 0, 0.214),
(484, 15, 0, 0.22898),
(532, 16, 0, 0.2450086),
(584, 17, 0, 0.262159202),
(642, 18, 0, 0.28051034614),
(706, 19, 0, 0.30014607037),
(776, 20, 0, 0.321156295296),
(852, 21, 0, 0.343637235966),
(936, 22, 0, 0.367691842484),
(1028, 23, 0, 0.393430271458),
(1130, 24, 0, 0.42097039046),
(1242, 25, 0, 0.450438317792),
(1366, 26, 0, 0.481969000038),
(1502, 27, 0, 0.51570683004),
(1652, 28, 0, 0.551806308143),
(1816, 29, 0, 0.590432749713),
]"""
@
register_model
(
"asr_w2l_conv_glu_encoder"
)
class
W2lConvGluEncoderModel
(
FairseqEncoderModel
):
def
__init__
(
self
,
encoder
):
super
().
__init__
(
encoder
)
@
staticmethod
def
add_args
(
parser
):
"""Add model-specific arguments to the parser."""
parser
.
add_argument
(
"--input-feat-per-channel"
,
type
=
int
,
metavar
=
"N"
,
help
=
"encoder input dimension per input channel"
,
)
parser
.
add_argument
(
"--in-channels"
,
type
=
int
,
metavar
=
"N"
,
help
=
"number of encoder input channels"
,
)
parser
.
add_argument
(
"--conv-enc-config"
,
type
=
str
,
metavar
=
"EXPR"
,
help
=
"""
an array of tuples each containing the configuration of one conv layer
[(out_channels, kernel_size, padding, dropout), ...]
"""
,
)
@
classmethod
def
build_model
(
cls
,
args
,
task
):
"""Build a new model instance."""
conv_enc_config
=
getattr
(
args
,
"conv_enc_config"
,
default_conv_enc_config
)
encoder
=
W2lConvGluEncoder
(
vocab_size
=
len
(
task
.
target_dictionary
),
input_feat_per_channel
=
args
.
input_feat_per_channel
,
in_channels
=
args
.
in_channels
,
conv_enc_config
=
eval
(
conv_enc_config
),
)
return
cls
(
encoder
)
def
get_normalized_probs
(
self
,
net_output
,
log_probs
,
sample
=
None
):
lprobs
=
super
().
get_normalized_probs
(
net_output
,
log_probs
,
sample
)
lprobs
.
batch_first
=
False
return
lprobs
class
W2lConvGluEncoder
(
FairseqEncoder
):
def
__init__
(
self
,
vocab_size
,
input_feat_per_channel
,
in_channels
,
conv_enc_config
):
super
().
__init__
(
None
)
self
.
input_dim
=
input_feat_per_channel
if
in_channels
!=
1
:
raise
ValueError
(
"only 1 input channel is currently supported"
)
self
.
conv_layers
=
nn
.
ModuleList
()
self
.
linear_layers
=
nn
.
ModuleList
()
self
.
dropouts
=
[]
cur_channels
=
input_feat_per_channel
for
out_channels
,
kernel_size
,
padding
,
dropout
in
conv_enc_config
:
layer
=
nn
.
Conv1d
(
cur_channels
,
out_channels
,
kernel_size
,
padding
=
padding
)
layer
.
weight
.
data
.
mul_
(
math
.
sqrt
(
3
))
# match wav2letter init
self
.
conv_layers
.
append
(
nn
.
utils
.
weight_norm
(
layer
))
self
.
dropouts
.
append
(
FairseqDropout
(
dropout
,
module_name
=
self
.
__class__
.
__name__
)
)
if
out_channels
%
2
!=
0
:
raise
ValueError
(
"odd # of out_channels is incompatible with GLU"
)
cur_channels
=
out_channels
//
2
# halved by GLU
for
out_channels
in
[
2
*
cur_channels
,
vocab_size
]:
layer
=
nn
.
Linear
(
cur_channels
,
out_channels
)
layer
.
weight
.
data
.
mul_
(
math
.
sqrt
(
3
))
self
.
linear_layers
.
append
(
nn
.
utils
.
weight_norm
(
layer
))
cur_channels
=
out_channels
//
2
def
forward
(
self
,
src_tokens
,
src_lengths
,
**
kwargs
):
"""
src_tokens: padded tensor (B, T, C * feat)
src_lengths: tensor of original lengths of input utterances (B,)
"""
B
,
T
,
_
=
src_tokens
.
size
()
x
=
src_tokens
.
transpose
(
1
,
2
).
contiguous
()
# (B, feat, T) assuming C == 1
for
layer_idx
in
range
(
len
(
self
.
conv_layers
)):
x
=
self
.
conv_layers
[
layer_idx
](
x
)
x
=
F
.
glu
(
x
,
dim
=
1
)
x
=
self
.
dropouts
[
layer_idx
](
x
)
x
=
x
.
transpose
(
1
,
2
).
contiguous
()
# (B, T, 908)
x
=
self
.
linear_layers
[
0
](
x
)
x
=
F
.
glu
(
x
,
dim
=
2
)
x
=
self
.
dropouts
[
-
1
](
x
)
x
=
self
.
linear_layers
[
1
](
x
)
assert
x
.
size
(
0
)
==
B
assert
x
.
size
(
1
)
==
T
encoder_out
=
x
.
transpose
(
0
,
1
)
# (T, B, vocab_size)
# need to debug this -- find a simpler/elegant way in pytorch APIs
encoder_padding_mask
=
(
torch
.
arange
(
T
).
view
(
1
,
T
).
expand
(
B
,
-
1
).
to
(
x
.
device
)
>=
src_lengths
.
view
(
B
,
1
).
expand
(
-
1
,
T
)
).
t
()
# (B x T) -> (T x B)
return
{
"encoder_out"
:
encoder_out
,
# (T, B, vocab_size)
"encoder_padding_mask"
:
encoder_padding_mask
,
# (T, B)
}
def
reorder_encoder_out
(
self
,
encoder_out
,
new_order
):
encoder_out
[
"encoder_out"
]
=
encoder_out
[
"encoder_out"
].
index_select
(
1
,
new_order
)
encoder_out
[
"encoder_padding_mask"
]
=
encoder_out
[
"encoder_padding_mask"
].
index_select
(
1
,
new_order
)
return
encoder_out
def
max_positions
(
self
):
"""Maximum input length supported by the encoder."""
return
(
1e6
,
1e6
)
# an arbitrary large number
@
register_model_architecture
(
"asr_w2l_conv_glu_encoder"
,
"w2l_conv_glu_enc"
)
def
w2l_conv_glu_enc
(
args
):
args
.
input_feat_per_channel
=
getattr
(
args
,
"input_feat_per_channel"
,
80
)
args
.
in_channels
=
getattr
(
args
,
"in_channels"
,
1
)
args
.
conv_enc_config
=
getattr
(
args
,
"conv_enc_config"
,
default_conv_enc_config
)
examples/speech_recognition/new/README.md
0 → 100644
View file @
c394d7d1
# Flashlight Decoder
This script runs decoding for pre-trained speech recognition models.
## Usage
Assuming a few variables:
```
bash
checkpoint
=
<path-to-checkpoint>
data
=
<path-to-data-directory>
lm_model
=
<path-to-language-model>
lexicon
=
<path-to-lexicon>
```
Example usage for decoding a fine-tuned Wav2Vec model:
```
bash
python
$FAIRSEQ_ROOT
/examples/speech_recognition/new/infer.py
--multirun
\
task
=
audio_pretraining
\
task.data
=
$data
\
task.labels
=
ltr
\
common_eval.path
=
$checkpoint
\
decoding.type
=
kenlm
\
decoding.lexicon
=
$lexicon
\
decoding.lmpath
=
$lm_model
\
dataset.gen_subset
=
dev_clean,dev_other,test_clean,test_other
```
Example usage for using Ax to sweep WER parameters (requires
`pip install hydra-ax-sweeper`
):
```
bash
python
$FAIRSEQ_ROOT
/examples/speech_recognition/new/infer.py
--multirun
\
hydra/sweeper
=
ax
\
task
=
audio_pretraining
\
task.data
=
$data
\
task.labels
=
ltr
\
common_eval.path
=
$checkpoint
\
decoding.type
=
kenlm
\
decoding.lexicon
=
$lexicon
\
decoding.lmpath
=
$lm_model
\
dataset.gen_subset
=
dev_other
```
examples/speech_recognition/new/__init__.py
0 → 100644
View file @
c394d7d1
examples/speech_recognition/new/conf/hydra/sweeper/ax.yaml
0 → 100644
View file @
c394d7d1
# @package hydra.sweeper
_target_
:
hydra_plugins.hydra_ax_sweeper.ax_sweeper.AxSweeper
max_batch_size
:
null
ax_config
:
max_trials
:
128
early_stop
:
minimize
:
true
max_epochs_without_improvement
:
32
epsilon
:
1.0e-05
experiment
:
name
:
${dataset.gen_subset}
objective_name
:
wer
minimize
:
true
parameter_constraints
:
null
outcome_constraints
:
null
status_quo
:
null
client
:
verbose_logging
:
false
random_seed
:
null
params
:
decoding.lmweight
:
type
:
range
bounds
:
[
0.0
,
5.0
]
decoding.wordscore
:
type
:
range
bounds
:
[
-5.0
,
5.0
]
examples/speech_recognition/new/conf/infer.yaml
0 → 100644
View file @
c394d7d1
# @package _group_
defaults
:
-
task
:
null
-
model
:
null
hydra
:
run
:
dir
:
${common_eval.results_path}/${dataset.gen_subset}
sweep
:
dir
:
${common_eval.results_path}
subdir
:
${dataset.gen_subset}
common_eval
:
results_path
:
null
path
:
null
post_process
:
letter
quiet
:
true
dataset
:
max_tokens
:
1000000
gen_subset
:
test
distributed_training
:
distributed_world_size
:
1
decoding
:
beam
:
5
type
:
viterbi
examples/speech_recognition/new/decoders/__init__.py
0 → 100644
View file @
c394d7d1
examples/speech_recognition/new/decoders/base_decoder.py
0 → 100644
View file @
c394d7d1
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
itertools
as
it
from
typing
import
Any
,
Dict
,
List
import
torch
from
fairseq.data.dictionary
import
Dictionary
from
fairseq.models.fairseq_model
import
FairseqModel
class
BaseDecoder
:
def
__init__
(
self
,
tgt_dict
:
Dictionary
)
->
None
:
self
.
tgt_dict
=
tgt_dict
self
.
vocab_size
=
len
(
tgt_dict
)
self
.
blank
=
(
tgt_dict
.
index
(
"<ctc_blank>"
)
if
"<ctc_blank>"
in
tgt_dict
.
indices
else
tgt_dict
.
bos
()
)
if
"<sep>"
in
tgt_dict
.
indices
:
self
.
silence
=
tgt_dict
.
index
(
"<sep>"
)
elif
"|"
in
tgt_dict
.
indices
:
self
.
silence
=
tgt_dict
.
index
(
"|"
)
else
:
self
.
silence
=
tgt_dict
.
eos
()
def
generate
(
self
,
models
:
List
[
FairseqModel
],
sample
:
Dict
[
str
,
Any
],
**
unused
)
->
List
[
List
[
Dict
[
str
,
torch
.
LongTensor
]]]:
encoder_input
=
{
k
:
v
for
k
,
v
in
sample
[
"net_input"
].
items
()
if
k
!=
"prev_output_tokens"
}
emissions
=
self
.
get_emissions
(
models
,
encoder_input
)
return
self
.
decode
(
emissions
)
def
get_emissions
(
self
,
models
:
List
[
FairseqModel
],
encoder_input
:
Dict
[
str
,
Any
],
)
->
torch
.
FloatTensor
:
model
=
models
[
0
]
encoder_out
=
model
(
**
encoder_input
)
if
hasattr
(
model
,
"get_logits"
):
emissions
=
model
.
get_logits
(
encoder_out
)
else
:
emissions
=
model
.
get_normalized_probs
(
encoder_out
,
log_probs
=
True
)
return
emissions
.
transpose
(
0
,
1
).
float
().
cpu
().
contiguous
()
def
get_tokens
(
self
,
idxs
:
torch
.
IntTensor
)
->
torch
.
LongTensor
:
idxs
=
(
g
[
0
]
for
g
in
it
.
groupby
(
idxs
))
idxs
=
filter
(
lambda
x
:
x
!=
self
.
blank
,
idxs
)
return
torch
.
LongTensor
(
list
(
idxs
))
def
decode
(
self
,
emissions
:
torch
.
FloatTensor
,
)
->
List
[
List
[
Dict
[
str
,
torch
.
LongTensor
]]]:
raise
NotImplementedError
examples/speech_recognition/new/decoders/decoder.py
0 → 100644
View file @
c394d7d1
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
Union
from
fairseq.data.dictionary
import
Dictionary
from
.decoder_config
import
DecoderConfig
,
FlashlightDecoderConfig
from
.base_decoder
import
BaseDecoder
def
Decoder
(
cfg
:
Union
[
DecoderConfig
,
FlashlightDecoderConfig
],
tgt_dict
:
Dictionary
)
->
BaseDecoder
:
if
cfg
.
type
==
"viterbi"
:
from
.viterbi_decoder
import
ViterbiDecoder
return
ViterbiDecoder
(
tgt_dict
)
if
cfg
.
type
==
"kenlm"
:
from
.flashlight_decoder
import
KenLMDecoder
return
KenLMDecoder
(
cfg
,
tgt_dict
)
if
cfg
.
type
==
"fairseqlm"
:
from
.flashlight_decoder
import
FairseqLMDecoder
return
FairseqLMDecoder
(
cfg
,
tgt_dict
)
raise
NotImplementedError
(
f
"Invalid decoder name:
{
cfg
.
name
}
"
)
examples/speech_recognition/new/decoders/decoder_config.py
0 → 100644
View file @
c394d7d1
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
from
fairseq.dataclass.configs
import
FairseqDataclass
from
fairseq.dataclass.constants
import
ChoiceEnum
from
omegaconf
import
MISSING
DECODER_CHOICES
=
ChoiceEnum
([
"viterbi"
,
"kenlm"
,
"fairseqlm"
])
@
dataclass
class
DecoderConfig
(
FairseqDataclass
):
type
:
DECODER_CHOICES
=
field
(
default
=
"viterbi"
,
metadata
=
{
"help"
:
"The type of decoder to use"
},
)
@
dataclass
class
FlashlightDecoderConfig
(
FairseqDataclass
):
nbest
:
int
=
field
(
default
=
1
,
metadata
=
{
"help"
:
"Number of decodings to return"
},
)
unitlm
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"If set, use unit language model"
},
)
lmpath
:
str
=
field
(
default
=
MISSING
,
metadata
=
{
"help"
:
"Language model for KenLM decoder"
},
)
lexicon
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Lexicon for Flashlight decoder"
},
)
beam
:
int
=
field
(
default
=
50
,
metadata
=
{
"help"
:
"Number of beams to use for decoding"
},
)
beamthreshold
:
float
=
field
(
default
=
50.0
,
metadata
=
{
"help"
:
"Threshold for beam search decoding"
},
)
beamsizetoken
:
Optional
[
int
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Beam size to use"
}
)
wordscore
:
float
=
field
(
default
=-
1
,
metadata
=
{
"help"
:
"Word score for KenLM decoder"
},
)
unkweight
:
float
=
field
(
default
=-
math
.
inf
,
metadata
=
{
"help"
:
"Unknown weight for KenLM decoder"
},
)
silweight
:
float
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"Silence weight for KenLM decoder"
},
)
lmweight
:
float
=
field
(
default
=
2
,
metadata
=
{
"help"
:
"Weight for LM while interpolating score"
},
)
Prev
1
…
10
11
12
13
14
15
16
17
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment