Commit 764b3a75 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add new model

parents
../../../wenet/
\ No newline at end of file
# Performance Record
## Conformer Result
* Feature info: dither + specaug + speed perturb
* Training info: lr 0.002, warmup_steps 20000 batch size 16, 1 gpu, acc_grad 4, 120 epochs
* Decoding info: average_num 20
| decoding mode | dev93 (cer) | dev93 (wer) |
|:----------------------:|:-------------:|:-------------:|
| ctc_greedy_search | 5.25% | 13.16% |
| ctc_prefix_beam_search | 5.17% | 13.10% |
| attention_rescoring | 5.11% | 12.17% |
\ No newline at end of file
# network architecture
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: true
cnn_module_kernel: 15
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
selfattention_layer_type: 'rel_selfattn'
# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
dataset_conf:
filter_conf:
max_length: 40960
min_length: 0
token_max_length: 200
token_min_length: 1
resample_conf:
resample_rate: 16000
speed_perturb: true
fbank_conf:
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 0.1
spec_aug: true
spec_aug_conf:
num_t_mask: 2
num_f_mask: 2
max_t: 50
max_f: 10
shuffle: true
shuffle_conf:
shuffle_size: 1500
sort: true
sort_conf:
sort_size: 500 # sort_size should be less than shuffle_size
batch_conf:
batch_type: 'static' # static or dynamic
batch_size: 16
grad_clip: 5
accum_grad: 4
max_epoch: 120
log_interval: 100
optim: adam
optim_conf:
lr: 0.002
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 20000
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This program takes on its standard input a list of utterance
# id's, one for each line. (e.g. 4k0c030a is a an utterance id).
# It takes as
# Extracts from the dot files the transcripts for a given
# dataset (represented by a file list).
@ARGV == 1 || die "find_transcripts.pl dot_files_flist < utterance_ids > transcripts";
$dot_flist = shift @ARGV;
open(L, "<$dot_flist") || die "Opening file list of dot files: $dot_flist\n";
while(<L>){
chop;
m:\S+/(\w{6})00.dot: || die "Bad line in dot file list: $_";
$spk = $1;
$spk2dot{$spk} = $_;
}
while(<STDIN>){
chop;
$uttid = $_;
$uttid =~ m:(\w{6})\w\w: || die "Bad utterance id $_";
$spk = $1;
if($spk ne $curspk) {
%utt2trans = { }; # Don't keep all the transcripts in memory...
$curspk = $spk;
$dotfile = $spk2dot{$spk};
defined $dotfile || die "No dot file for speaker $spk\n";
open(F, "<$dotfile") || die "Error opening dot file $dotfile\n";
while(<F>) {
$_ =~ m:(.+)\((\w{8})\)\s*$: || die "Bad line $_ in dot file $dotfile (line $.)\n";
$trans = $1;
$utt = $2;
$utt2trans{$utt} = $trans;
}
}
if(!defined $utt2trans{$uttid}) {
print STDERR "No transcript for utterance $uttid (current dot file is $dotfile)\n";
} else {
print "$uttid $utt2trans{$uttid}\n";
}
}
\ No newline at end of file
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# takes in a file list with lines like
# /mnt/matylda2/data/WSJ1/13-16.1/wsj1/si_dt_20/4k0/4k0c030a.wv1
# and outputs an scp in kaldi format with lines like
# 4k0c030a /mnt/matylda2/data/WSJ1/13-16.1/wsj1/si_dt_20/4k0/4k0c030a.wv1
# (the first thing is the utterance-id, which is the same as the basename of the file.
while(<>){
m:^\S+/(\w+)\.[wW][vV]1$: || die "Bad line $_";
$id = $1;
$id =~ tr/A-Z/a-z/; # Necessary because of weirdness on disk 13-16.1 (uppercase filenames)
print "$id $_";
}
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This program takes as its standard input an .ndx file from the WSJ corpus that looks
# like this:
#;; File: tr_s_wv1.ndx, updated 04/26/94
#;;
#;; Index for WSJ0 SI-short Sennheiser training data
#;; Data is read WSJ sentences, Sennheiser mic.
#;; Contains 84 speakers X (~100 utts per speaker MIT/SRI and ~50 utts
#;; per speaker TI) = 7236 utts
#;;
#11_1_1:wsj0/si_tr_s/01i/01ic0201.wv1
#11_1_1:wsj0/si_tr_s/01i/01ic0202.wv1
#11_1_1:wsj0/si_tr_s/01i/01ic0203.wv1
#and as command-line arguments it takes the names of the WSJ disk locations, e.g.:
#/mnt/matylda2/data/WSJ0/11-1.1 /mnt/matylda2/data/WSJ0/11-10.1 ... etc.
# It outputs a list of absolute pathnames (it does this by replacing e.g. 11_1_1 with
# /mnt/matylda2/data/WSJ0/11-1.1.
# It also does a slight fix because one of the WSJ disks (WSJ1/13-16.1) was distributed with
# uppercase rather than lower case filenames.
foreach $fn (@ARGV) {
$fn =~ m:.+/([0-9\.\-]+)/?$: || die "Bad command-line argument $fn\n";
$disk_id=$1;
$disk_id =~ tr/-\./__/; # replace - and . with - so 11-10.1 becomes 11_10_1
$fn =~ s:/$::; # Remove final slash, just in case it is present.
$disk2fn{$disk_id} = $fn;
}
while(<STDIN>){
if(m/^;/){ next; } # Comment. Ignore it.
else {
m/^([0-9_]+):\s*(\S+)$/ || die "Could not parse line $_";
$disk=$1;
if(!defined $disk2fn{$disk}) {
die "Disk id $disk not found";
}
$filename = $2; # as a subdirectory of the distributed disk.
if($disk eq "13_16_1" && `hostname` =~ m/fit.vutbr.cz/) {
# The disk 13-16.1 has been uppercased for some reason, on the
# BUT system. This is a fix specifically for that case.
$filename =~ tr/a-z/A-Z/; # This disk contains all uppercase filenames. Why?
}
print "$disk2fn{$disk}/$filename\n";
}
}
\ No newline at end of file
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This takes data from the standard input that's unnormalized transcripts in the format
# 4k2c0308 Of course there isn\'t any guarantee the company will keep its hot hand [misc_noise]
# 4k2c030a [loud_breath] And new hardware such as the set of personal computers I\. B\. M\. introduced last week can lead to unexpected changes in the software business [door_slam]
# and outputs normalized transcripts.
# c.f. /mnt/matylda2/data/WSJ0/11-10.1/wsj0/transcrp/doc/dot_spec.doc
@ARGV == 1 || die "usage: normalize_transcript.pl noise_word < transcript > transcript2";
$noise_word = shift @ARGV;
while(<STDIN>) {
$_ =~ m:^(\S+) (.+): || die "bad line $_";
$utt = $1;
$trans = $2;
print "$utt";
foreach $w (split (" ",$trans)) {
$w =~ tr:a-z:A-Z:; # Upcase everything to match the CMU dictionary. .
$w =~ s:\\::g; # Remove backslashes. We don't need the quoting.
$w =~ s:^\%PERCENT$:PERCENT:; # Normalization for Nov'93 test transcripts.
$w =~ s:^\.POINT$:POINT:; # Normalization for Nov'93 test transcripts.
if($w =~ m:^\[\<\w+\]$: || # E.g. [<door_slam], this means a door slammed in the preceding word. Delete.
$w =~ m:^\[\w+\>\]$: || # E.g. [door_slam>], this means a door slammed in the next word. Delete.
$w =~ m:\[\w+/\]$: || # E.g. [phone_ring/], which indicates the start of this phenomenon.
$w =~ m:\[\/\w+]$: || # E.g. [/phone_ring], which indicates the end of this phenomenon.
$w eq "~" || # This is used to indicate truncation of an utterance. Not a word.
$w eq ".") { # "." is used to indicate a pause. Silence is optional anyway so not much
# point including this in the transcript.
next; # we won't print this word.
} elsif($w =~ m:\[\w+\]:) { # Other noises, e.g. [loud_breath].
print " $noise_word";
} elsif($w =~ m:^\<([\w\']+)\>$:) {
# e.g. replace <and> with and. (the <> means verbal deletion of a word).. but it's pronounced.
print " $1";
} elsif($w eq "--DASH") {
print " -DASH"; # This is a common issue; the CMU dictionary has it as -DASH.
# } elsif($w =~ m:(.+)\-DASH$:) { # E.g. INCORPORATED-DASH... seems the DASH gets combined with previous word
# print " $1 -DASH";
} else {
print " $w";
}
}
print "\n";
}
\ No newline at end of file
#!/usr/bin/env bash
# Copyright 2009-2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey)
# Apache 2.0.
# set -eu
if [ $# -le 3 ]; then
echo "Arguments should be a list of WSJ directories, see ../run.sh for example."
exit 1;
fi
dir=`pwd`/data/local/data
mkdir -p $dir
local=`pwd`/local
cd $dir
# Make directory of links to the WSJ disks such as 11-13.1. This relies on the command
# line arguments being absolute pathnames.
rm -r links/ 2>/dev/null
mkdir links/
ln -s $* links
# Do some basic checks that we have what we expected.
if [ ! -d links/11-13.1 -o ! -d links/13-34.1 -o ! -d links/11-2.1 ]; then
echo "wsj_data_prep.sh: Spot check of command line arguments failed"
echo "Command line arguments must be absolute pathnames to WSJ directories"
echo "with names like 11-13.1."
echo "Note: if you have old-style WSJ distribution,"
echo "local/cstr_wsj_data_prep.sh may work instead, see run.sh for example."
exit 1;
fi
# This version for SI-284
cat links/13-34.1/wsj1/doc/indices/si_tr_s.ndx \
links/11-13.1/wsj0/doc/indices/train/tr_s_wv1.ndx | \
$local/ndx2flist.pl $* | sort | \
grep -v -i 11-2.1/wsj0/si_tr_s/401 > train_si284.flist
nl=`cat train_si284.flist | wc -l`
[ "$nl" -eq 37416 ] || echo "Warning: expected 37416 lines in train_si284.flist, got $nl"
# Nov'92 (333 utts)
# These index files have a slightly different format;
# have to add .wv1, which is done in cstr_ndx2flist.pl
cat links/11-13.1/wsj0/doc/indices/test/nvp/si_et_20.ndx | \
$local/ndx2flist.pl $* | awk '{printf("%s.wv1\n", $1)}' | \
sort > test_eval92.flist
# Dev-set for Nov'93 (503 utts)
cat links/13-34.1/wsj1/doc/indices/h1_p0.ndx | \
$local/ndx2flist.pl $* | sort > test_dev93.flist
# Finding the transcript files:
for x in $*; do find -L $x -iname '*.dot'; done > dot_files.flist
# Convert the transcripts into our format (no normalization yet)
for x in train_si284 test_eval92 test_dev93; do
$local/flist2scp.pl $x.flist | sort > ${x}_sph.scp
cat ${x}_sph.scp | awk '{print $1}' | $local/find_transcripts.pl dot_files.flist > $x.trans1
done
# Do some basic normalization steps. At this point we don't remove OOVs--
# that will be done inside the training scripts, as we'd like to make the
# data-preparation stage independent of the specific lexicon used.
noiseword="<NOISE>";
for x in train_si284 test_eval92 test_dev93; do
cat $x.trans1 | $local/normalize_transcript.pl $noiseword | sort > $x.txt || exit 1;
done
# Create scp's with wav's. (the wv1 in the distribution is not really wav, it is sph.)
sph2pipe=/home/lsq/kaldi/tools/sph2pipe_v2.5/sph2pipe
for x in train_si284 test_eval92 test_dev93; do
awk '{printf("%s '$sph2pipe' -f wav %s \n", $1, $2);}' < ${x}_sph.scp > ${x}_wav.scp
done
echo "Data preparation succeeded"
\ No newline at end of file
#!/usr/bin/env bash
# Copyright 2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey)
# 2015 Guoguo Chen
# Apache 2.0
# This script takes data prepared in a corpus-dependent way
# in data/local/, and converts it into the "canonical" form,
# in various subdirectories of data/, e.g. data/lang, data/lang_test_ug,
# data/train_si284, data/train_si84, etc.
# Don't bother doing train_si84 separately (although we have the file lists
# in data/local/) because it's just the first 7138 utterances in train_si284.
# We'll create train_si84 after doing the feature extraction.
echo "$0 $@" # Print the command line for logging
. ./tools/parse_options.sh || exit 1;
. ./path.sh || exit 1;
echo "Preparing train and test data"
srcdir=data/local/data
for x in train_si284 test_eval92 test_dev93; do
mkdir -p data/$x
cp $srcdir/${x}_wav.scp data/$x/wav.scp || exit 1;
cp $srcdir/$x.txt data/$x/text || exit 1;
done
echo "Succeeded in formatting data."
\ No newline at end of file
#!/usr/bin/env bash
set -eu
[ $# -ne 2 ] && echo "Script format error: $0 <data-dir> <dump-dir>" && exit 0
data_dir=$1
dump_dir=$2
mkdir -p $dump_dir
num_utts=$(cat $data_dir/wav.scp | wc -l)
echo "Orginal utterances (.wav + .wv1): $num_utts"
# cat $data_dir/wav.scp | grep "sph2pipe" | \
# awk -v dir=$dump_dir '{printf("%s -f wav %s %s/%s.wav\n", $2, $5, dir, $1)}' | bash
awk '{print $1,$5}' $data_dir/wav.scp > $data_dir/raw_wav.scp
find $dump_dir -name "*.wav" | awk -F '/' '{printf("%s %s\n", $NF, $0)}' | \
sed 's:\.wav::' > $data_dir/wav.scp
num_utts=$(cat $data_dir/wav.scp | wc -l)
echo "Wave utterances (.wav): $num_utts"
echo "$0: Generate wav => $dump_dir done"
export WENET_DIR=$PWD/../../..
export BUILD_DIR=${WENET_DIR}/runtime/libtorch/build
export OPENFST_PREFIX_DIR=${BUILD_DIR}/../fc_base/openfst-subbuild/openfst-populate-prefix
export PATH=$PWD:${BUILD_DIR}/bin:${BUILD_DIR}/kaldi:${OPENFST_PREFIX_DIR}/bin:$PATH
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=../../../:$PYTHONPATH
#!/bin/bash
# Copyright 2019 Mobvoi Inc. All Rights Reserved.
. ./path.sh || exit 1;
# Use this to control how many gpu you use, It's 1-gpu training if you specify
# just 1gpu, otherwise it's is multiple gpu training based on DDP in pytorch
export CUDA_VISIBLE_DEVICES="0"
# The NCCL_SOCKET_IFNAME variable specifies which IP interface to use for nccl
# communication. More details can be found in
# https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html
# export NCCL_SOCKET_IFNAME=ens4f1
export NCCL_DEBUG=INFO
stage=0 # start from 0 if you need to start from data preparation
stop_stage=4
# The num of nodes or machines used for multi-machine training
# Default 1 for single machine/node
# NFS will be needed if you want run multi-machine training
num_nodes=1
# The rank of each node or machine, range from 0 to num_nodes -1
# The first node/machine sets node_rank 0, the second one sets node_rank 1
# the third one set node_rank 2, and so on. Default 0
node_rank=0
# data
WSJ0=/home/lsq/corpus/WSJ/wsj0
WSJ1=/home/lsq/corpus/WSJ/wsj1
nj=16
train_set=train_si284
valid_set=test_dev93
test_sets="test_dev93"
data_type=raw
# for lm training
other_text=data/local/other_text/text
# Optional train_config
# 1. conf/train_transformer.yaml: Standard transformer
# 2. conf/train_conformer.yaml: Standard conformer
# 3. conf/train_unified_conformer.yaml: Unified dynamic chunk causal conformer
# 4. conf/train_unified_transformer.yaml: Unified dynamic chunk transformer
# 5. conf/train_conformer_no_pos.yaml: Conformer without relative positional encoding
# 6. conf/train_u2++_conformer.yaml: U2++ conformer
# 7. conf/train_u2++_transformer.yaml: U2++ transformer
train_config=conf/train_conformer.yaml
cmvn=true
dir=/home/lsq/exp_dir/exp_wenet/wsj/conformer_1202
dump_wav_dir=/home/lsq/corpus/wsj_wav
checkpoint=
# use average_checkpoint will get better result
average_checkpoint=true
decode_checkpoint=$dir/final.pt
average_num=20
decode_modes="ctc_greedy_search ctc_prefix_beam_search attention attention_rescoring"
. tools/parse_options.sh || exit 1;
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "stage 0: Data preparation"
local/wsj_data_prep.sh ${WSJ0}/??-{?,??}.? ${WSJ1}/??-{?,??}.?
local/wsj_format_data.sh
for x in ${valid_set} ${train_set}; do
{
./local/wsj_gen_wav.sh data/$x $dump_wav_dir/$x
}
done
echo "Prepare text from lng_modl dir: ${WSJ1}/13-32.1/wsj1/doc/lng_modl/lm_train/np_data/{87,88,89}/*.z -> ${other_text}"
mkdir -p "$(dirname ${other_text})"
# NOTE(kamo): Give utterance id to each texts.
zcat ${WSJ1}/13-32.1/wsj1/doc/lng_modl/lm_train/np_data/{87,88,89}/*.z | \
grep -v "<" | tr "[:lower:]" "[:upper:]" | \
awk '{ printf("wsj1_lng_%07d %s\n",NR,$0) } ' > ${other_text}
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# compute cmvn
tools/compute_cmvn_stats.py --num_workers 16 --train_config $train_config \
--in_scp data/${train_set}/wav.scp \
--out_cmvn data/${train_set}/global_cmvn
fi
dict=data/dict/${train_set}_units.txt
nlsyms=data/nlsyms.txt
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# Make train dict
echo "Make a dictionary"
mkdir -p $(dirname $dict)
echo "<blank> 0" > ${dict} # 0 will be used for "blank" in CTC
echo "<unk> 1" >> ${dict} # <unk> must be 1
echo "make a non-linguistic symbol list"
cut -f 2- data/${train_set}/text | tr " " "\n" | sort | uniq | grep "<" > ${nlsyms}
cat ${nlsyms}
tools/text2token.py -s 1 -n 1 -l ${nlsyms} --space ▁ data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -v -e '^\s*$' | awk '{print $0 " " NR+1}' >> ${dict}
wc -l ${dict}
num_token=$(cat $dict | wc -l)
echo "<sos/eos> $num_token" >> $dict # <eos>
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Prepare data, prepare required format"
for x in ${valid_set} ${train_set}; do
if [ $data_type == "shard" ]; then
tools/make_shard_list.py --num_utts_per_shard $num_utts_per_shard \
--num_threads 16 data/$x/wav.scp data/$x/text \
$(realpath data/$x/shards) data/$x/data.list
else
tools/make_raw_list.py data/$x/wav.scp data/$x/text \
data/$x/data.list
fi
done
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
mkdir -p $dir
# You have to rm `INIT_FILE` manually when you resume or restart a
# multi-machine training.
INIT_FILE=$dir/ddp_init
init_method=file://$(readlink -f $INIT_FILE)
echo "$0: init method is $init_method"
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
# Use "nccl" if it works, otherwise use "gloo"
dist_backend="gloo"
world_size=`expr $num_gpus \* $num_nodes`
echo "total gpus is: $world_size"
cmvn_opts=
$cmvn && cp data/${train_set}/global_cmvn $dir
$cmvn && cmvn_opts="--cmvn ${dir}/global_cmvn"
# train.py rewrite $train_config to $dir/train.yaml with model input
# and output dimension, and $dir/train.yaml will be used for inference
# and export.
for ((i = 0; i < $num_gpus; ++i)); do
{
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
# Rank of each gpu/process used for knowing whether it is
# the master of a worker.
rank=`expr $node_rank \* $num_gpus + $i`
python wenet/bin/train.py --gpu $gpu_id \
--config $train_config \
--data_type $data_type \
--symbol_table $dict \
--train_data data/$train_set/data.list \
--cv_data data/$valid_set/data.list \
${checkpoint:+--checkpoint $checkpoint} \
--model_dir $dir \
--ddp.init_method $init_method \
--ddp.world_size $world_size \
--ddp.rank $rank \
--ddp.dist_backend $dist_backend \
--num_workers 1 \
$cmvn_opts \
--pin_memory \
--non_lang_syms ${nlsyms}
} &
done
wait
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# Test model, please specify the model you want to test by --checkpoint
if [ ${average_checkpoint} == true ]; then
decode_checkpoint=$dir/avg_${average_num}.pt
echo "do model average and final checkpoint is $decode_checkpoint"
python wenet/bin/average_model.py \
--dst_model $decode_checkpoint \
--src_path $dir \
--num ${average_num} \
--val_best
fi
# Please specify decoding_chunk_size for unified streaming and
# non-streaming model. The default value is -1, which is full chunk
# for non-streaming inference.
decoding_chunk_size=
ctc_weight=0.5
reverse_weight=0.0
for mode in ${decode_modes}; do
{
test_dir=$dir/test_${mode}
result_text=$test_dir/text
mkdir -p $(dirname $result_text)
python wenet/bin/recognize.py --gpu 3 \
--mode $mode \
--config $dir/train.yaml \
--data_type $data_type \
--test_data data/test_dev93/data.list \
--checkpoint $decode_checkpoint \
--beam_size 10 \
--batch_size 1 \
--penalty 0.0 \
--dict $dict \
--non_lang_syms $nlsyms \
--ctc_weight $ctc_weight \
--reverse_weight $reverse_weight \
--result_file $result_text \
${decoding_chunk_size:+--decoding_chunk_size $decoding_chunk_size}
python tools/compute-wer.py --char=1 --v=1 \
data/test_dev93/text $test_dir/text > $test_dir/wer
} &
done
wait
fi
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
# compute wer
for mode in ${decode_modes}; do
for test_set in $test_sets; do
test_dir=$dir/test_${mode}
sed 's:▁: :g' $test_dir/text > $test_dir/text.norm
python tools/compute-wer.py --char=1 --v=1 \
data/$test_set/text $test_dir/text.norm > $test_dir/wer
done
done
fi
../../../tools/
\ No newline at end of file
../../../wenet/
\ No newline at end of file
#!/usr/bin/env bash
PID_FILE=$2
echo $$ > $PID_FILE
CUR_DIR="$( cd "$(dirname "$0")" ; pwd )"
cd $CUR_DIR/examples/aishell/s0
# Dataset prepare
start=$(date +%s)
start_str=`date '+%Y-%m-%d %H:%M:%S' -d "@$start"`
echo "$start_str Begin dataset prepare"
bash run.sh --stage -1 --stop_stage -1
bash run.sh --stage 0 --stop_stage 0
bash run.sh --stage 1 --stop_stage 1
bash run.sh --stage 2 --stop_stage 2
bash run.sh --stage 3 --stop_stage 3
end=$(date +%s)
end_str=`date '+%Y-%m-%d %H:%M:%S' -d "@$end"`
echo "$end_str Finish dataset prepare"
data_prepare_time=$(($end-$start))
echo "Dataset prepare time: ${data_prepare_time}s"
# running training
bash run_train.sh $1 $2
# Runtime on WeNet
This is the runtime of WeNet.
We are going to support the following platforms:
1. Various deep learning inference engines, such as LibTorch, ONNX, OpenVINO, TVM, and so on.
2. Various OS, such as android, iOS, Harmony, and so on.
3. Various AI chips, such as GPU, Horzion BPU, and so on.
4. Various hardware platforms, such as Raspberry Pi.
5. Various language binding, such as python and go.
Feel free to volunteer yourself if you are interested in trying out some items(they do not have to be on the list).
## Introduction
Here is a brief summary of all platforms and OSs. please note the corresponding working `OS` and `inference engine`.
| runtime | OS | inference engine | Description |
|-----------------|---------------------|----------------------|--------------------------------------------------------------------------------------------------|
| core | / | / | common core code of all runtime |
| android | android | libtorch | android demo, [English demo](https://www.youtube.com/shorts/viEnvmZf03s ), [Chinese demo](TODO) |
| bingding/python | linux, windows, mac | libtorch | python binding of wenet, mac M1/M2 are is not supported now. |
| gpu | linux | onnxruntime/tensorrt | GPU inference with NV's Triton and TensorRT |
| horizonbpu | linux | bpu runtime | Horizon BPU runtime |
| ios | ios | libtorch | ios demo, [link](TODO) |
| kunlun | linux | xpu runtime | Kunlun XPU runtime |
| libtorch | linux, windows, mac | libtorch | c++ build with libtorch |
| onnxrutnime | linux, windows, mac | onnxruntime | c++ build with onnxruntime |
| raspberrypi | linux | onnxruntime | c++ build on raspberrypi with onnxruntime |
| web | linux, windows, mac | libtorch | web demo with gradio and python binding, [link]() |
*.iml
.gradle
/local.properties
/.idea/caches
/.idea/libraries
/.idea/modules.xml
/.idea/workspace.xml
/.idea/navEditor.xml
/.idea/assetWizardSettings.xml
.DS_Store
/build
/captures
.externalNativeBuild
.cxx
local.properties
# WeNet On-device ASR Android Demo
This Android demo shows we can run on-device streaming ASR with WeNet. You can download our prebuilt APK or build your APK from source code.
## Prebuilt APK
* [Chinese ASR Demo APK, with model trained on AIShell data](http://mobvoi-speech-public.ufile.ucloud.cn/public/wenet/aishell/20210202_app.apk)
* [English ASR Demo APK, with model trained on GigaSpeech data](http://mobvoi-speech-public.ufile.ucloud.cn/public/wenet/gigaspeech/20210823_app.apk)
## Build your APK from source code
### 1) Build model
You can use our pretrained model (click the following link to download):
[中文(WenetSpeech)](https://wenet-1256283475.cos.ap-shanghai.myqcloud.com/models/wenetspeech/wenetspeech_u2pp_conformer_libtorch_quant.tar.gz)
| [English(GigaSpeech)](https://wenet-1256283475.cos.ap-shanghai.myqcloud.com/models/gigaspeech/gigaspeech_u2pp_conformer_libtorch_quant.tar.gz)
Or you can train your own model using WeNet training pipeline on your data.
### 2) Build APK
When your model is ready, put `final.zip` and `units.txt` into Android assets (`app/src/main/assets`) folder,
then just build and run the APK. Here is a gif demo, which shows how our on-device streaming e2e ASR runs with low latency.
Please note the wifi and data has been disabled in the demo so there is no network connection ^\_^.
![Runtime android demo](../../../../docs/images/runtime_android.gif)
## Compute the RTF
Step 1, connect your Android phone, and use `adb push` command to push your model, wav scp, and waves to the sdcard.
Step 2, build the binary and the APK with Android Studio directly, or with the commands as follows:
``` sh
cd runtime/android
./gradlew build
```
Step 3, push your binary and the dynamic library to `/data/local/tmp` as follows:
``` sh
adb push app/.cxx/cmake/release/arm64-v8a/decoder_main /data/local/tmp
adb push app/build/pytorch_android-1.10.0.aar/jni/arm64-v8a/* /data/local/tmp
```
Step 4, change to the directory `/data/local/tmp` of your phone, and export the library path by:
``` sh
adb shell
cd /data/local/tmp
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:.
```
Step 5, execute the same command as the [x86 demo](../../../libtorch) to run the binary to decode and compute the RTF.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment