Commit a7785cc6 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

delete soft link

parent 9a2a05ca
#!/usr/bin/env python
import sys
print('0 1 <eps> <eps>')
print('1 1 <blank> <eps>')
print('2 2 <blank> <eps>')
print('2 0 <eps> <eps>')
with open(sys.argv[1], 'r', encoding='utf8') as fin:
node = 3
for entry in fin:
fields = entry.strip().split(' ')
phone = fields[0]
if phone == '<eps>' or phone == '<blank>':
continue
elif '#' in phone: # disambiguous phone
print('{} {} {} {}'.format(0, 0, '<eps>', phone))
else:
print('{} {} {} {}'.format(1, node, phone, phone))
print('{} {} {} {}'.format(node, node, phone, '<eps>'))
print('{} {} {} {}'.format(node, 2, '<eps>', '<eps>'))
node += 1
print('0')
#!/usr/bin/env python
import sys
print('0 0 <blank> <eps>')
with open(sys.argv[1], 'r', encoding='utf8') as fin:
node = 1
for entry in fin:
fields = entry.strip().split(' ')
phone = fields[0]
if phone == '<eps>' or phone == '<blank>':
continue
elif '#' in phone: # disambiguous phone
print('{} {} {} {}'.format(0, 0, '<eps>', phone))
else:
print('{} {} {} {}'.format(0, node, phone, phone))
print('{} {} {} {}'.format(node, node, phone, '<eps>'))
print('{} {} {} {}'.format(node, 0, '<eps>', '<eps>'))
node += 1
print('0')
#!/usr/bin/env python
import sys
def il(n):
return n + 1
def ol(n):
return n + 1
def s(n):
return n
if __name__ == "__main__":
with open(sys.argv[1]) as f:
lines = f.readlines()
phone_count = 0
disambig_count = 0
for line in lines:
sp = line.split()
phone = sp[0]
if phone == '<eps>' or phone == '<blank>':
continue
if phone.startswith('#'):
disambig_count += 1
else:
phone_count += 1
# 1. add start state
print('0 0 {} 0'.format(il(0)))
# 2. 0 -> i, i -> i, i -> 0
for i in range(1, phone_count + 1):
print('0 {} {} {}'.format(s(i), il(i), ol(i)))
print('{} {} {} 0'.format(s(i), s(i), il(i)))
print('{} 0 {} 0'.format(s(i), il(0)))
# 3. i -> other phone
for i in range(1, phone_count + 1):
for j in range(1, phone_count + 1):
if i != j:
print('{} {} {} {}'.format(s(i), s(j), il(j), ol(j)))
# 4. add disambiguous arcs on every final state
for i in range(0, phone_count + 1):
for j in range(phone_count + 2, phone_count + disambig_count + 2):
print('{} {} {} {}'.format(s(i), s(i), 0, j))
# 5. every i is final state
for i in range(0, phone_count + 1):
print(s(i))
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# 2015 Guoguo Chen
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This script replaces epsilon with #0 on the input side only, of the G.fst
# acceptor.
while(<>){
if (/\s+#0\s+/) {
print STDERR "$0: ERROR: LM has word #0, " .
"which is reserved as disambiguation symbol\n";
exit 1;
}
s:^(\d+\s+\d+\s+)\<eps\>(\s+):$1#0$2:;
print;
}
#!/usr/bin/env perl
use warnings; #sed replacement for -w perl parameter
# Copyright 2010-2011 Microsoft Corporation
# 2013 Johns Hopkins University (author: Daniel Povey)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# makes lexicon FST, in text form, from lexicon (pronunciation probabilities optional).
$pron_probs = 0;
if ((@ARGV > 0) && ($ARGV[0] eq "--pron-probs")) {
$pron_probs = 1;
shift @ARGV;
}
if (@ARGV != 1 && @ARGV != 3 && @ARGV != 4) {
print STDERR "Usage: make_lexicon_fst.pl [--pron-probs] lexicon.txt [silprob silphone [sil_disambig_sym]] >lexiconfst.txt\n\n";
print STDERR "Creates a lexicon FST that transduces phones to words, and may allow optional silence.\n\n";
print STDERR "Note: ordinarily, each line of lexicon.txt is:\n";
print STDERR " word phone1 phone2 ... phoneN;\n";
print STDERR "if the --pron-probs option is used, each line is:\n";
print STDERR " word pronunciation-probability phone1 phone2 ... phoneN.\n\n";
print STDERR "The probability 'prob' will typically be between zero and one, and note that\n";
print STDERR "it's generally helpful to normalize so the largest one for each word is 1.0, but\n";
print STDERR "this is your responsibility.\n\n";
print STDERR "The silence disambiguation symbol, e.g. something like #5, is used only\n";
print STDERR "when creating a lexicon with disambiguation symbols, e.g. L_disambig.fst,\n";
print STDERR "and was introduced to fix a particular case of non-determinism of decoding graphs.\n\n";
exit(1);
}
$lexfn = shift @ARGV;
if (@ARGV == 0) {
$silprob = 0.0;
} elsif (@ARGV == 2) {
($silprob,$silphone) = @ARGV;
} else {
($silprob,$silphone,$sildisambig) = @ARGV;
}
if ($silprob != 0.0) {
$silprob < 1.0 || die "Sil prob cannot be >= 1.0";
$silcost = -log($silprob);
$nosilcost = -log(1.0 - $silprob);
}
open(L, "<$lexfn") || die "Error opening lexicon $lexfn";
if ( $silprob == 0.0 ) { # No optional silences: just have one (loop+final) state which is numbered zero.
$loopstate = 0;
$nextstate = 1; # next unallocated state.
while (<L>) {
@A = split(" ", $_);
@A == 0 && die "Empty lexicon line.";
foreach $a (@A) {
if ($a eq "<eps>") {
die "Bad lexicon line $_ (<eps> is forbidden)";
}
}
$w = shift @A;
if (! $pron_probs) {
$pron_cost = 0.0;
} else {
$pron_prob = shift @A;
if (! defined $pron_prob || !($pron_prob > 0.0 && $pron_prob <= 1.0)) {
die "Bad pronunciation probability in line $_";
}
$pron_cost = -log($pron_prob);
}
if ($pron_cost != 0.0) { $pron_cost_string = "\t$pron_cost"; } else { $pron_cost_string = ""; }
$s = $loopstate;
$word_or_eps = $w;
while (@A > 0) {
$p = shift @A;
if (@A > 0) {
$ns = $nextstate++;
} else {
$ns = $loopstate;
}
print "$s\t$ns\t$p\t$word_or_eps$pron_cost_string\n";
$word_or_eps = "<eps>";
$pron_cost_string = ""; # so we only print it on the first arc of the word.
$s = $ns;
}
}
print "$loopstate\t0\n"; # final-cost.
} else { # have silence probs.
$startstate = 0;
$loopstate = 1;
$silstate = 2; # state from where we go to loopstate after emitting silence.
print "$startstate\t$loopstate\t<eps>\t<eps>\t$nosilcost\n"; # no silence.
if (!defined $sildisambig) {
print "$startstate\t$loopstate\t$silphone\t<eps>\t$silcost\n"; # silence.
print "$silstate\t$loopstate\t$silphone\t<eps>\n"; # no cost.
$nextstate = 3;
} else {
$disambigstate = 3;
$nextstate = 4;
print "$startstate\t$disambigstate\t$silphone\t<eps>\t$silcost\n"; # silence.
print "$silstate\t$disambigstate\t$silphone\t<eps>\n"; # no cost.
print "$disambigstate\t$loopstate\t$sildisambig\t<eps>\n"; # silence disambiguation symbol.
}
while (<L>) {
@A = split(" ", $_);
$w = shift @A;
if (! $pron_probs) {
$pron_cost = 0.0;
} else {
$pron_prob = shift @A;
if (! defined $pron_prob || !($pron_prob > 0.0 && $pron_prob <= 1.0)) {
die "Bad pronunciation probability in line $_";
}
$pron_cost = -log($pron_prob);
}
if ($pron_cost != 0.0) { $pron_cost_string = "\t$pron_cost"; } else { $pron_cost_string = ""; }
$s = $loopstate;
$word_or_eps = $w;
while (@A > 0) {
$p = shift @A;
if (@A > 0) {
$ns = $nextstate++;
print "$s\t$ns\t$p\t$word_or_eps$pron_cost_string\n";
$word_or_eps = "<eps>";
$pron_cost_string = ""; $pron_cost = 0.0; # so we only print it the 1st time.
$s = $ns;
} elsif (!defined($silphone) || $p ne $silphone) {
# This is non-deterministic but relatively compact,
# and avoids epsilons.
$local_nosilcost = $nosilcost + $pron_cost;
$local_silcost = $silcost + $pron_cost;
print "$s\t$loopstate\t$p\t$word_or_eps\t$local_nosilcost\n";
print "$s\t$silstate\t$p\t$word_or_eps\t$local_silcost\n";
} else {
# no point putting opt-sil after silence word.
print "$s\t$loopstate\t$p\t$word_or_eps$pron_cost_string\n";
}
}
}
print "$loopstate\t0\n"; # final-cost.
}
#!/bin/bash
#
if [ -f path.sh ]; then . path.sh; fi
lm_dir=$1
src_lang=$2
tgt_lang=$3
arpa_lm=${lm_dir}/lm.arpa
[ ! -f $arpa_lm ] && echo No such file $arpa_lm && exit 1;
rm -rf $tgt_lang
cp -r $src_lang $tgt_lang
# Compose the language model to FST
cat $arpa_lm | \
grep -v '<s> <s>' | \
grep -v '</s> <s>' | \
grep -v '</s> </s>' | \
grep -v -i '<unk>' | \
grep -v -i '<spoken_noise>' | \
arpa2fst --read-symbol-table=$tgt_lang/words.txt --keep-symbols=true - | fstprint | \
tools/fst/eps2disambig.pl | tools/fst/s2eps.pl | fstcompile --isymbols=$tgt_lang/words.txt \
--osymbols=$tgt_lang/words.txt --keep_isymbols=false --keep_osymbols=false | \
fstrmepsilon | fstarcsort --sort_type=ilabel > $tgt_lang/G.fst
echo "Checking how stochastic G is (the first of these numbers should be small):"
fstisstochastic $tgt_lang/G.fst
# Compose the token, lexicon and language-model FST into the final decoding graph
fsttablecompose $tgt_lang/L.fst $tgt_lang/G.fst | fstdeterminizestar --use-log=true | \
fstminimizeencoded | fstarcsort --sort_type=ilabel > $tgt_lang/LG.fst || exit 1;
fsttablecompose $tgt_lang/T.fst $tgt_lang/LG.fst > $tgt_lang/TLG.fst || exit 1;
echo "Composing decoding graph TLG.fst succeeded"
#rm -r $tgt_lang/LG.fst # We don't need to keep this intermediate FST
#!/usr/bin/env python3
# encoding: utf-8
import sys
# sys.argv[1]: e2e model unit file(lang_char.txt)
# sys.argv[2]: raw lexicon file
# sys.argv[3]: output lexicon file
# sys.argv[4]: bpemodel
unit_table = set()
with open(sys.argv[1], 'r', encoding='utf8') as fin:
for line in fin:
unit = line.split()[0]
unit_table.add(unit)
def contain_oov(units):
for unit in units:
if unit not in unit_table:
return True
return False
bpemode = len(sys.argv) > 4
if bpemode:
import sentencepiece as spm
sp = spm.SentencePieceProcessor()
sp.Load(sys.argv[4])
lexicon_table = set()
with open(sys.argv[2], 'r', encoding='utf8') as fin, \
open(sys.argv[3], 'w', encoding='utf8') as fout:
for line in fin:
word = line.split()[0]
if word == 'SIL' and not bpemode: # `sil` might be a valid piece in bpemodel
continue
elif word == '<SPOKEN_NOISE>':
continue
else:
# each word only has one pronunciation for e2e system
if word in lexicon_table:
continue
if bpemode:
# We assume that the lexicon does not contain code-switch,
# i.e. the word contains both English and Chinese.
# see PR https://github.com/wenet-e2e/wenet/pull/1693
# and Issue https://github.com/wenet-e2e/wenet/issues/1653
if word.encode('utf8').isalpha():
pieces = sp.EncodeAsPieces(word)
else:
pieces = word
if contain_oov(pieces):
print(
'Ignoring words {}, which contains oov unit'.format(
''.join(word).strip('▁'))
)
continue
chars = ' '.join(
[p if p in unit_table else '<unk>' for p in pieces])
else:
# ignore words with OOV
if contain_oov(word):
print('Ignoring words {}, which contains oov unit'.format(word))
continue
# Optional, append ▁ in front of english word
# we assume the model unit of our e2e system is char now.
if word.encode('utf8').isalpha() and '▁' in unit_table:
word = '▁' + word
chars = ' '.join(word) # word is a char list
fout.write('{} {}\n'.format(word, chars))
lexicon_table.add(word)
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This script removes lines that contain these OOVs on either the
# third or fourth fields of the line. It is intended to remove arcs
# with OOVs on, from FSTs (probably compiled from ARPAs with OOVs in).
if ( @ARGV < 1 && @ARGV > 2) {
die "Usage: remove_oovs.pl unk_list.txt [ printed-fst ]\n";
}
$unklist = shift @ARGV;
open(S, "<$unklist") || die "Failed opening unknown-symbol list $unklist\n";
while(<S>){
@A = split(" ", $_);
@A == 1 || die "Bad line in unknown-symbol list: $_";
$unk{$A[0]} = 1;
}
$num_removed = 0;
while(<>){
@A = split(" ", $_);
if(defined $unk{$A[2]} || defined $unk{$A[3]}) {
$num_removed++;
} else {
print;
}
}
print STDERR "remove_oovs.pl: removed $num_removed lines.\n";
#!/usr/bin/env python
import sys
print('0 0 <blank> <eps>')
with open(sys.argv[1], 'r', encoding='utf8') as fin:
for entry in fin:
fields = entry.strip().split(' ')
phone = fields[0]
if phone == '<eps>' or phone == '<blank>':
continue
elif '#' in phone: # disambiguous phone
print('{} {} {} {}'.format(0, 0, '<eps>', phone))
else:
print('{} {} {} {}'.format(0, 0, phone, phone))
print('0')
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This script replaces <s> and </s> with <eps> (on both input and output sides),
# for the G.fst acceptor.
while(<>){
@A = split(" ", $_);
if ( @A >= 4 ) {
if ($A[2] eq "<s>" || $A[2] eq "</s>") { $A[2] = "<eps>"; }
if ($A[3] eq "<s>" || $A[3] eq "</s>") { $A[3] = "<eps>"; }
}
print join("\t", @A) . "\n";
}
#!/bin/bash
set -e
echo "Running pre-commit flake8"
python tools/flake8_hook.py
#!/usr/bin/env bash
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey). Apache 2.0.
# 2022 Binbin Zhang(binbzha@qq.com)
current_path=`pwd`
current_dir=`basename "$current_path"`
if [ "tools" != "$current_dir" ]; then
echo "You should run this script in tools/ directory!!"
exit 1
fi
! command -v gawk > /dev/null && \
echo "GNU awk is not installed so SRILM will probably not work correctly: refusing to install" && exit 1;
srilm_url="https://github.com/BitSpeech/SRILM/archive/refs/tags/1.7.3.tar.gz"
if [ ! -f ./srilm.tar.gz ]; then
if ! wget -O ./srilm.tar.gz "$srilm_url"; then
echo 'There was a problem downloading the file.'
echo 'Check you internet connection and try again.'
exit 1
fi
fi
tar -zxvf srilm.tar.gz
mv SRILM-1.7.3 srilm
# set the SRILM variable in the top-level Makefile to this directory.
cd srilm
cp Makefile tmpf
cat tmpf | gawk -v pwd=`pwd` '/SRILM =/{printf("SRILM = %s\n", pwd); next;} {print;}' \
> Makefile || exit 1
rm tmpf
make || exit
cd ..
(
[ ! -z "${SRILM}" ] && \
echo >&2 "SRILM variable is aleady defined. Undefining..." && \
unset SRILM
[ -f ./env.sh ] && . ./env.sh
[ ! -z "${SRILM}" ] && \
echo >&2 "SRILM config is already in env.sh" && exit
wd=`pwd`
wd=`readlink -f $wd || pwd`
echo "export SRILM=$wd/srilm"
dirs="\${PATH}"
for directory in $(cd srilm && find bin -type d ) ; do
dirs="$dirs:\${SRILM}/$directory"
done
echo "export PATH=$dirs"
) >> env.sh
echo >&2 "Installation of SRILM finished successfully"
echo >&2 "Please source the tools/env.sh in your path.sh to enable it"
#!/bin/bash
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang)
# Copyright 2022 Ximalaya Speech Team (author: Xiang Lyu)
lexion_dir=$1
lm_dir=$2
tgt_dir=$3
# k2 and icefall updates very fast. Below commits are veryfied in this script.
# k2 3dc222f981b9fdbc8061b3782c3b385514a2d444, icefall 499ac24ecba64f687ff244c7d66baa5c222ecf0f
# For k2 installation, please refer to https://github.com/k2-fsa/k2/
python -c "import k2; print(k2.__file__)"
python -c "import torch; import _k2; print(_k2.__file__)"
# Prepare necessary icefall scripts
if [ ! -d tools/k2/icefall ]; then
git clone --depth 1 https://github.com/k2-fsa/icefall.git tools/k2/icefall
fi
pip install -r tools/k2/icefall/requirements.txt
export PYTHONPATH=`pwd`/tools/k2/icefall:`pwd`/tools/k2/icefall/egs/aishell/ASR/local:$PYTHONPATH
# 8.1 Prepare char based lang
mkdir -p $tgt_dir
python tools/k2/prepare_char.py $lexion_dir/units.txt $lm_dir/wordlist $tgt_dir
echo "Compile lexicon L.pt L_disambig.pt succeeded"
# 8.2 Prepare G
mkdir -p data/lm
python -m kaldilm \
--read-symbol-table="$tgt_dir/words.txt" \
--disambig-symbol='#0' \
--max-order=3 \
$lm_dir/lm.arpa > data/lm/G_3_gram.fst.txt
# 8.3 Compile HLG
python tools/k2/icefall/egs/aishell/ASR/local/compile_hlg.py --lang-dir $tgt_dir
echo "Compile decoding graph HLG.pt succeeded"
\ No newline at end of file
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang)
# Copyright 2022 Ximalaya Speech Team (author: Xiang Lyu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script generates the following files in the directory sys.argv[3]:
- lexicon.txt
- lexicon_disambig.txt
- L.pt
- L_disambig.pt
- tokens.txt
- words.txt
"""
import sys
from pathlib import Path
from typing import Dict, List
import k2
import torch
from prepare_lang import (
Lexicon,
add_disambig_symbols,
add_self_loops,
write_lexicon,
write_mapping,
)
def lexicon_to_fst_no_sil(
lexicon: Lexicon,
token2id: Dict[str, int],
word2id: Dict[str, int],
need_self_loops: bool = False,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format).
Args:
lexicon:
The input lexicon. See also :func:`read_lexicon`
token2id:
A dict mapping tokens to IDs.
word2id:
A dict mapping words to IDs.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state. The input label for this
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
loop_state = 0 # words enter and leave from here
next_state = 1 # the next un-allocated state, will be incremented as we go
arcs = []
# The blank symbol <blk> is defined in local/train_bpe_model.py
assert token2id["<blank>"] == 0
assert word2id["<eps>"] == 0
eps = 0
for word, pieces in lexicon:
assert len(pieces) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
pieces = [
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
arcs.append([cur_state, next_state, pieces[i], w, 0])
cur_state = next_state
next_state += 1
# now for the last piece of this word
i = len(pieces) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, pieces[i], w, 0])
if need_self_loops:
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
arcs.append([loop_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
"""Check if all the given tokens are in token symbol table.
Args:
token_sym_table:
Token symbol table that contains all the valid tokens.
tokens:
A list of tokens.
Returns:
Return True if there is any token not in the token_sym_table,
otherwise False.
"""
for tok in tokens:
if tok not in token_sym_table:
return True
return False
def generate_lexicon(
token_sym_table: Dict[str, int], words: List[str]
) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table.
Args:
token_sym_table:
Token symbol table that mapping token to token ids.
words:
A list of strings representing words.
Returns:
Return a dict whose keys are words and values are the corresponding
tokens.
"""
lexicon = []
for word in words:
chars = list(word.strip(" \t"))
if contain_oov(token_sym_table, chars):
continue
lexicon.append((word, chars))
# The OOV word is <UNK>
lexicon.append(("<UNK>", ["<unk>"]))
return lexicon
def generate_tokens(text_file: str) -> Dict[str, int]:
"""Generate tokens from the given text file.
Args:
text_file:
A file that contains text lines to generate tokens.
Returns:
Return a dict whose keys are tokens and values are token ids ranged
from 0 to len(keys) - 1.
"""
token2id: Dict[str, int] = dict()
with open(text_file, "r", encoding="utf-8") as f:
for line in f:
char, index = line.replace('\n', '').split()
assert char not in token2id
token2id[char] = int(index)
assert token2id['<blank>'] == 0
return token2id
def generate_words(text_file: str) -> Dict[str, int]:
"""Generate words from the given text file.
Args:
text_file:
A file that contains text lines to generate words.
Returns:
Return a dict whose keys are words and values are words ids ranged
from 0 to len(keys) - 1.
"""
words = []
with open(text_file, "r", encoding="utf-8") as f:
for line in f:
word = line.replace('\n', '')
assert word not in words
words.append(word)
words.sort()
# We put '<eps>' '<UNK>' at begining of word2id
# '#0', '<s>', '</s>' at end of word2id
words = [word for word in words
if word not in ['<eps>', '<UNK>', '#0', '<s>', '</s>']]
words.insert(0, '<eps>')
words.insert(1, '<UNK>')
words.append('#0')
words.append('<s>')
words.append('</s>')
word2id = {j: i for i, j in enumerate(words)}
return word2id
def main():
token2id = generate_tokens(sys.argv[1])
word2id = generate_words(sys.argv[2])
tgt_dir = Path(sys.argv[3])
words = [word for word in word2id.keys()
if word not in
["<eps>", "!SIL", "<SPOKEN_NOISE>", "<UNK>", "#0", "<s>", "</s>"]]
lexicon = generate_lexicon(token2id, words)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
next_token_id = max(token2id.values()) + 1
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in token2id
token2id[disambig] = next_token_id
next_token_id += 1
write_mapping(tgt_dir / "tokens.txt", token2id)
write_mapping(tgt_dir / "words.txt", word2id)
write_lexicon(tgt_dir / "lexicon.txt", lexicon)
write_lexicon(tgt_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst_no_sil(
lexicon,
token2id=token2id,
word2id=word2id,
)
L_disambig = lexicon_to_fst_no_sil(
lexicon_disambig,
token2id=token2id,
word2id=word2id,
need_self_loops=True,
)
torch.save(L.as_dict(), tgt_dir / "L.pt")
torch.save(L_disambig.as_dict(), tgt_dir / "L_disambig.pt")
if __name__ == "__main__":
main()
# Copyright (c) 2022 Horizon Inc. (author: Xingchen Song)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
import logging
import librosa
import torch
import torchaudio
import yaml
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import torchaudio.compliance.kaldi as kaldi
from wenet.utils.init_model import init_model
from wenet.utils.checkpoint import load_checkpoint
from wenet.utils.file_utils import read_symbol_table
from wenet.utils.mask import make_pad_mask
from wenet.utils.common import replace_duplicates_with_blank
def get_args():
parser = argparse.ArgumentParser(
description='Analyze latency and plot CTC-Spike.')
parser.add_argument('--config', required=True,
type=str, help='configration')
parser.add_argument('--gpu',
type=int,
default=0,
help='gpu id for this rank, -1 for cpu')
parser.add_argument('--ckpt', required=True,
type=str, help='model checkpoint')
parser.add_argument('--tag', required=True,
type=str, help='image subtitle')
parser.add_argument('--wavscp', required=True,
type=str, help='wav.scp')
parser.add_argument('--alignment', required=True,
type=str, help='force alignment, generated by Kaldi.')
parser.add_argument('--chunk_size', required=True,
type=int, help='chunk size')
parser.add_argument('--left_chunks', default=-1,
type=int, help='left chunks')
parser.add_argument('--font', required=True,
type=str, help='font file')
parser.add_argument('--dict', required=True,
type=str, help='dict file')
parser.add_argument('--result_dir', required=True,
type=str, help='saving pdf')
parser.add_argument('--model_type', default='ctc',
choices=['ctc', 'transducer'],
help='show latency metrics from ctc models or rnn-t models')
args = parser.parse_args()
return args
def main():
args = get_args()
logging.basicConfig(level=logging.INFO,
format='%(asctime)s %(levelname)s %(message)s')
torch.manual_seed(777)
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
symbol_table = read_symbol_table(args.dict)
char_dict = {v: k for k, v in symbol_table.items()}
# 1. Load model
with open(args.config, 'r') as fin:
conf = yaml.load(fin, Loader=yaml.FullLoader)
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
model = init_model(conf)
load_checkpoint(model, args.ckpt)
model = model.eval().to(device)
subsampling = model.encoder.embed.subsampling_rate
eos = model.eos_symbol()
with open(args.wavscp, 'r') as fin:
wavs = fin.readlines()
# 2. Forward model (get streaming_timestamps)
timestamps = {}
for idx, wav in enumerate(wavs):
if idx % 100 == 0:
logging.info("processed {}.".format(idx))
key, wav = wav.strip().split(' ', 1)
waveform, sr = torchaudio.load(wav)
resample_rate = conf['dataset_conf']['resample_conf']['resample_rate']
waveform = torchaudio.transforms.Resample(
orig_freq=sr, new_freq=resample_rate)(waveform)
waveform = waveform * (1 << 15)
# Only keep key, feat, label
mat = kaldi.fbank(
waveform,
num_mel_bins=conf['dataset_conf']['fbank_conf']['num_mel_bins'],
frame_length=conf['dataset_conf']['fbank_conf']['frame_length'],
frame_shift=conf['dataset_conf']['fbank_conf']['frame_shift'],
dither=0.0, energy_floor=0.0,
sample_frequency=resample_rate,
)
speech = mat.unsqueeze(0).to(device)
speech_lengths = torch.tensor([mat.size(0)]).to(device)
# Let's assume batch_size = 1
encoder_out, encoder_mask = model.encoder(
speech, speech_lengths, args.chunk_size, args.left_chunks)
maxlen = encoder_out.size(1) # (B, maxlen, encoder_dim)
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
# CTC greedy search
if args.model_type == 'ctc':
ctc_probs = model.ctc.log_softmax(
encoder_out) # (B, maxlen, vocab_size)
topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1)
topk_index = topk_index.view(1, maxlen) # (B, maxlen)
topk_prob = topk_prob.view(1, maxlen) # (B, maxlen)
mask = make_pad_mask(encoder_out_lens, maxlen) # (B, maxlen)
topk_index = topk_index.masked_fill_(mask, eos) # (B, maxlen)
topk_prob = topk_prob.masked_fill_(mask, 0.0) # (B, maxlen)
hyps = [hyp.tolist() for hyp in topk_index]
hyps = [replace_duplicates_with_blank(hyp) for hyp in hyps]
scores = [prob.tolist() for prob in topk_prob]
timestamps[key] = [hyps[0], scores[0], wav]
if args.model_type == 'transducer':
hyps = []
scores = []
# fake padding
padding = torch.zeros(1, 1).to(encoder_out.device)
# sos
pred_input_step = torch.tensor([model.blank]).reshape(1, 1)
cache = model.predictor.init_state(1, method="zero",
device=encoder_out.device)
new_cache: List[torch.Tensor] = []
t = 0
hyps = []
prev_out_nblk = True
pred_out_step = None
per_frame_max_noblk = 1
per_frame_noblk = 0
while t < encoder_out_lens:
encoder_out_step = encoder_out[:, t:t + 1, :] # [1, 1, E]
if prev_out_nblk:
step_outs = model.predictor.forward_step(pred_input_step,
padding, cache)
pred_out_step, new_cache = step_outs[0], step_outs[1]
joint_out_step = model.joint(encoder_out_step, pred_out_step) # [1,1,v]
joint_out_probs = joint_out_step.log_softmax(dim=-1)
scores.append(torch.max(joint_out_probs).item())
joint_out_max = joint_out_probs.argmax(dim=-1).squeeze() # []
if joint_out_max != model.blank:
hyps.append(joint_out_max.item())
prev_out_nblk = True
per_frame_noblk = per_frame_noblk + 1
pred_input_step = joint_out_max.reshape(1, 1)
# state_m, state_c = clstate_out_m, state_out_c
cache = new_cache
if joint_out_max == model.blank or \
per_frame_noblk >= per_frame_max_noblk:
if joint_out_max == model.blank:
prev_out_nblk = False
hyps.append(model.blank)
# TODO(Mddct): make t in chunk for streamming
# or t should't be too lang to predict none blank
t = t + 1
per_frame_noblk = 0
timestamps[key] = [hyps, scores, wav]
# 3. Analyze latency
with open(args.alignment, 'r') as fin:
aligns = fin.readlines()
not_found, len_unequal, ignored = 0, 0, 0
datas = []
for align in aligns:
key, align = align.strip().split(' ', 1)
if key not in timestamps:
not_found += 1
continue
fa, st = [], [] # force_alignment, streaming_timestamps
text_fa, text_st = "", ""
for i, token in enumerate(align.split()):
if token != '<blank>':
text_fa += token
# NOTE(xcsong): W/O subsample
fa.append(i * 10)
# ignore alignment_errors >= 70ms
frames_fa = len(align.split())
frames_st = len(timestamps[key][0]) * subsampling
if abs(frames_st - frames_fa) >= 7:
ignored += 1
continue
for i, token_id in enumerate(timestamps[key][0]):
if token_id != 0:
text_st += char_dict[token_id]
# NOTE(xcsong): W subsample
st.append(i * subsampling * 10)
if len(fa) != len(st):
len_unequal += 1
continue
# datas[i] = [key, text_fa, text_st, list_of_diff,
# FirstTokenDelay, LastTokenDelay, AvgTokenDelay,
# streaming_timestamps, force_alignment]
datas.append([key, text_fa, text_st,
[a - b for a, b in zip(st, fa)],
st[0] - fa[0], st[-1] - fa[-1],
(sum(st) - sum(fa)) / len(st),
timestamps[key], align.split()])
logging.info("not found: {}, length unequal: {}, ignored: {}, \
valid samples: {}".format(not_found, len_unequal, ignored, len(datas)))
# 4. Plot and print
num_datas = len(datas)
names = ['FirstTokenDelay', 'LastTokenDelay', 'AvgTokenDelay']
names_index = [4, 5, 6]
parts = ['max', 'P90', 'P75', 'P50', 'P25', 'min']
parts_index = [num_datas - 1, int(num_datas * 0.90), int(num_datas * 0.75),
int(num_datas * 0.50), int(num_datas * 0.25), 0]
for name, name_idx in zip(names, names_index):
def f(name_idx=name_idx):
return name_idx
datas.sort(key=lambda x: x[f()])
logging.info("==========================")
for p, i in zip(parts, parts_index):
data = datas[i]
# i.e., LastTokenDelay P90: 270.000 ms (wav_id: BAC009S0902W0144)
logging.info("{} {}: {:.3f} ms (wav_id: {})".format(
name, p, data[f()], datas[i][0]))
font = fm.FontProperties(fname=args.font)
plt.rcParams['axes.unicode_minus'] = False
# we will have 2 sub-plots (force-align + streaming timestamps)
# plus one wav-plot
fig, axes = plt.subplots(figsize=(60, 60), nrows=3, ncols=1)
for j in range(2):
if j == 0:
# subplot-0: streaming_timestamps
plt_prefix = args.tag + "_" + name + "_" + p
x = np.arange(len(data[7][0])) * subsampling
hyps, scores = data[7][0], data[7][1]
else:
# subplot-1: force_alignments
plt_prefix = "force_alignment"
x = np.arange(len(data[8]))
hyps = [symbol_table[d] for d in data[8]]
scores = [0.0] * len(data[8])
axes[j].set_title(plt_prefix, fontsize=30)
for frame, token, prob in zip(x, hyps, scores):
if char_dict[token] != '<blank>':
axes[j].bar(
frame, np.exp(prob),
label='{} {:.3f}'.format(
char_dict[token], np.exp(prob)),
)
axes[j].text(
frame, np.exp(prob),
'{} {:.3f} {}'.format(
char_dict[token], np.exp(prob), frame),
fontdict=dict(fontsize=24),
fontproperties=font,
)
else:
axes[j].bar(
frame, 0.01,
label='{} {:.3f}'.format(
char_dict[token], np.exp(prob)),
)
axes[j].tick_params(labelsize=25)
# subplot-2: wav
# wav, hardcode sample_rate to 16000
samples, sr = librosa.load(data[7][2], sr=16000)
time = np.arange(0, len(samples)) * (1.0 / sr)
axes[-1].plot(time, samples)
# i.e., RESULT_DIR/LTD_P90_120ms_BAC009S0768W0342.pdf
plt.savefig(args.result_dir + "/" + name + "_" +
p + "_" + str(data[f()]) + "ms" + "_" + data[0] + ".pdf")
if __name__ == '__main__':
main()
#!/usr/bin/env python3
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='')
parser.add_argument('--segments', default=None, help='segments file')
parser.add_argument('wav_file', help='wav file')
parser.add_argument('text_file', help='text file')
parser.add_argument('output_file', help='output list file')
args = parser.parse_args()
wav_table = {}
with open(args.wav_file, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().split()
assert len(arr) == 2
wav_table[arr[0]] = arr[1]
if args.segments is not None:
segments_table = {}
with open(args.segments, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().split()
assert len(arr) == 4
segments_table[arr[0]] = (arr[1], float(arr[2]), float(arr[3]))
with open(args.text_file, 'r', encoding='utf8') as fin, \
open(args.output_file, 'w', encoding='utf8') as fout:
for line in fin:
arr = line.strip().split(maxsplit=1)
key = arr[0]
txt = arr[1] if len(arr) > 1 else ''
if args.segments is None:
assert key in wav_table
wav = wav_table[key]
line = dict(key=key, wav=wav, txt=txt)
else:
assert key in segments_table
wav_key, start, end = segments_table[key]
wav = wav_table[wav_key]
line = dict(key=key, wav=wav, txt=txt, start=start, end=end)
json_line = json.dumps(line, ensure_ascii=False)
fout.write(json_line + '\n')
#!/usr/bin/env python3
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import io
import logging
import os
import tarfile
import time
import multiprocessing
import torch
import torchaudio
import torchaudio.backend.sox_io_backend as sox
AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
def write_tar_file(data_list,
no_segments,
tar_file,
resample=16000,
index=0,
total=1):
logging.info('Processing {} {}/{}'.format(tar_file, index, total))
read_time = 0.0
save_time = 0.0
write_time = 0.0
with tarfile.open(tar_file, "w") as tar:
prev_wav = None
for item in data_list:
if no_segments:
key, txt, wav = item
else:
key, txt, wav, start, end = item
suffix = wav.split('.')[-1]
assert suffix in AUDIO_FORMAT_SETS
if no_segments:
ts = time.time()
with open(wav, 'rb') as fin:
data = fin.read()
read_time += (time.time() - ts)
else:
if wav != prev_wav:
ts = time.time()
waveforms, sample_rate = sox.load(wav, normalize=False)
read_time += (time.time() - ts)
prev_wav = wav
start = int(start * sample_rate)
end = int(end * sample_rate)
audio = waveforms[:1, start:end]
# resample
if sample_rate != resample:
if not audio.is_floating_point():
# normalize the audio before resample
# because resample can't process int audio
audio = audio / (1 << 15)
audio = torchaudio.transforms.Resample(
sample_rate, resample)(audio)
audio = (audio * (1 << 15)).short()
else:
audio = torchaudio.transforms.Resample(
sample_rate, resample)(audio)
ts = time.time()
f = io.BytesIO()
sox.save(f, audio, resample, format="wav", bits_per_sample=16)
# Save to wav for segments file
suffix = "wav"
f.seek(0)
data = f.read()
save_time += (time.time() - ts)
assert isinstance(txt, str)
ts = time.time()
txt_file = key + '.txt'
txt = txt.encode('utf8')
txt_data = io.BytesIO(txt)
txt_info = tarfile.TarInfo(txt_file)
txt_info.size = len(txt)
tar.addfile(txt_info, txt_data)
wav_file = key + '.' + suffix
wav_data = io.BytesIO(data)
wav_info = tarfile.TarInfo(wav_file)
wav_info.size = len(data)
tar.addfile(wav_info, wav_data)
write_time += (time.time() - ts)
logging.info('read {} save {} write {}'.format(read_time, save_time,
write_time))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='')
parser.add_argument('--num_utts_per_shard',
type=int,
default=1000,
help='num utts per shard')
parser.add_argument('--num_threads',
type=int,
default=1,
help='num threads for make shards')
parser.add_argument('--prefix',
default='shards',
help='prefix of shards tar file')
parser.add_argument('--segments', default=None, help='segments file')
parser.add_argument('--resample',
type=int,
default=16000,
help='segments file')
parser.add_argument('wav_file', help='wav file')
parser.add_argument('text_file', help='text file')
parser.add_argument('shards_dir', help='output shards dir')
parser.add_argument('shards_list', help='output shards list file')
args = parser.parse_args()
logging.basicConfig(level=logging.INFO,
format='%(asctime)s %(levelname)s %(message)s')
torch.set_num_threads(1)
wav_table = {}
with open(args.wav_file, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().split()
assert len(arr) == 2
wav_table[arr[0]] = arr[1]
no_segments = True
segments_table = {}
if args.segments is not None:
no_segments = False
with open(args.segments, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().split()
assert len(arr) == 4
segments_table[arr[0]] = (arr[1], float(arr[2]), float(arr[3]))
data = []
with open(args.text_file, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().split(maxsplit=1)
key = arr[0]
txt = arr[1] if len(arr) > 1 else ''
if no_segments:
assert key in wav_table
wav = wav_table[key]
data.append((key, txt, wav))
else:
wav_key, start, end = segments_table[key]
wav = wav_table[wav_key]
data.append((key, txt, wav, start, end))
num = args.num_utts_per_shard
chunks = [data[i:i + num] for i in range(0, len(data), num)]
os.makedirs(args.shards_dir, exist_ok=True)
# Using thread pool to speedup
pool = multiprocessing.Pool(processes=args.num_threads)
shards_list = []
tasks_list = []
num_chunks = len(chunks)
for i, chunk in enumerate(chunks):
tar_file = os.path.join(args.shards_dir,
'{}_{:09d}.tar'.format(args.prefix, i))
shards_list.append(tar_file)
pool.apply_async(
write_tar_file,
(chunk, no_segments, tar_file, args.resample, i, num_chunks))
pool.close()
pool.join()
with open(args.shards_list, 'w', encoding='utf8') as fout:
for name in shards_list:
fout.write(name + '\n')
#!/usr/bin/env python3
# encoding: utf-8
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import codecs
from distutils.util import strtobool
from io import open
import logging
import sys
PY2 = sys.version_info[0] == 2
sys.stdin = codecs.getreader('utf-8')(sys.stdin if PY2 else sys.stdin.buffer)
sys.stdout = codecs.getwriter('utf-8')(
sys.stdout if PY2 else sys.stdout.buffer)
# Special types:
def shape(x):
"""Change str to List[int]
>>> shape('3,5')
[3, 5]
>>> shape(' [3, 5] ')
[3, 5]
"""
# x: ' [3, 5] ' -> '3, 5'
x = x.strip()
if x[0] == '[':
x = x[1:]
if x[-1] == ']':
x = x[:-1]
return list(map(int, x.split(',')))
def get_parser():
parser = argparse.ArgumentParser(
description='Given each file paths with such format as '
'<key>:<file>:<type>. type> can be omitted and the default '
'is "str". e.g. {} '
'--input-scps feat:data/feats.scp shape:data/utt2feat_shape:shape '
'--input-scps feat:data/feats2.scp shape:data/utt2feat2_shape:shape '
'--output-scps text:data/text shape:data/utt2text_shape:shape '
'--scps utt2spk:data/utt2spk'.format(sys.argv[0]),
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--input-scps',
type=str,
nargs='*',
action='append',
default=[],
help='files for the inputs')
parser.add_argument('--output-scps',
type=str,
nargs='*',
action='append',
default=[],
help='files for the outputs')
parser.add_argument('--scps',
type=str,
nargs='+',
default=[],
help='The files except for the input and outputs')
parser.add_argument('--verbose',
'-V',
default=1,
type=int,
help='Verbose option')
parser.add_argument('--allow-one-column',
type=strtobool,
default=False,
help='Allow one column in input scp files. '
'In this case, the value will be empty string.')
parser.add_argument('--out',
'-O',
type=str,
help='The output filename. '
'If omitted, then output to sys.stdout')
return parser
if __name__ == '__main__':
parser = get_parser()
args = parser.parse_args()
args.scps = [args.scps]
# logging info
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
if args.verbose > 0:
logging.basicConfig(level=logging.INFO, format=logfmt)
else:
logging.basicConfig(level=logging.WARN, format=logfmt)
inputs = {}
assert (len(args.input_scps) == 1)
for f in args.input_scps[0]:
arr = f.strip().split(':')
inputs[arr[0]] = arr[1]
assert ('feat' in inputs)
assert ('shape' in inputs)
outputs = {}
assert (len(args.output_scps) == 1)
for f in args.output_scps[0]:
arr = f.strip().split(':')
outputs[arr[0]] = arr[1]
assert ('shape' in outputs)
assert ('text' in outputs)
assert ('token' in outputs)
assert ('tokenid' in outputs)
files = [
inputs['feat'], inputs['shape'], outputs['text'], outputs['token'],
outputs['tokenid'], outputs['shape']
]
fields = ['feat', 'feat_shape', 'text', 'token', 'tokenid', 'token_shape']
fids = [open(f, 'r', encoding='utf-8') for f in files]
if args.out is None:
out = sys.stdout
else:
out = open(args.out, 'w', encoding='utf-8')
done = False
while not done:
for i, fid in enumerate(fids):
line = fid.readline()
if line == '':
done = True
break
arr = line.strip().split()
content = ' '.join(arr[1:])
if i == 0:
out.write('utt:{}'.format(arr[0]))
out.write('\t')
out.write('{}:{}'.format(fields[i], content))
out.write('\n')
for f in fids:
f.close()
if args.out is not None:
out.close()
This diff is collapsed.
#!/bin/bash
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
# Arnab Ghoshal, Karel Vesely
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# Parse command-line options.
# To be sourced by another script (as in ". parse_options.sh").
# Option format is: --option-name arg
# and shell variable "option_name" gets set to value "arg."
# The exception is --help, which takes no arguments, but prints the
# $help_message variable (if defined).
###
### The --config file options have lower priority to command line
### options, so we need to import them first...
###
# Now import all the configs specified by command-line, in left-to-right order
for ((argpos=1; argpos<$#; argpos++)); do
if [ "${!argpos}" == "--config" ]; then
argpos_plus1=$((argpos+1))
config=${!argpos_plus1}
[ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
. $config # source the config file.
fi
done
###
### No we process the command line options
###
while true; do
[ -z "${1:-}" ] && break; # break if there are no arguments
case "$1" in
# If the enclosing script is called with --help option, print the help
# message and exit. Scripts should put help messages in $help_message
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
else printf "$help_message\n" 1>&2 ; fi;
exit 0 ;;
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
exit 1 ;;
# If the first command-line argument begins with "--" (e.g. --foo-bar),
# then work out the variable name as $name, which will equal "foo_bar".
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
# Next we test whether the variable in question is undefned-- if so it's
# an invalid option and we die. Note: $0 evaluates to the name of the
# enclosing script.
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
# is undefined. We then have to wrap this test inside "eval" because
# foo_bar is itself inside a variable ($name).
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
oldval="`eval echo \\$$name`";
# Work out whether we seem to be expecting a Boolean argument.
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
was_bool=true;
else
was_bool=false;
fi
# Set the variable to the right value-- the escaped quotes make it work if
# the option had spaces, like --cmd "queue.pl -sync y"
eval $name=\"$2\";
# Check that Boolean-valued arguments are really Boolean.
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
exit 1;
fi
shift 2;
;;
*) break;
esac
done
# Check for an empty argument to the --cmd option, which can easily occur as a
# result of scripting errors.
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
true; # so this script returns exit code 0.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment