Commit 12c90639 authored by “change”'s avatar “change”
Browse files

init

parent 417b607b
import re
import sys
import regex
import argparse
from tqdm import tqdm
from num2words import num2words
def writefile(filename, lines):
with open(filename, 'w', encoding='utf-8') as f:
f.writelines(lines)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input", "-i", required=True, type=str)
parser.add_argument("--output", "-o", required=True, type=str)
args = parser.parse_args()
outlines = []
with open(f"{args.input}", 'r') as f:
inputs = f.readlines()
for line in tqdm(inputs):
line = line.strip().upper()
line = re.sub(u"([^\u0041-\u005a\u0061-\u007a\u0030-\u0039\'])", " ", line)
items = []
for item in line.split():
if item.isdigit():
try:
item = num2words(item)
except Exception as e:
print(line)
raise(e)
items.append(item)
line = " ".join(items)
line = line.replace("-", " ")
line = line.upper()
line = line.replace("' S", "'S")
line = line.replace(" ", "|")
line = " ".join(line) + " |"
outlines.append(line + '\n')
# print(line)
writefile(args.output, outlines)
if __name__ == "__main__":
main()
import re
import sys
import regex
import argparse
import re,string
from tqdm import tqdm
from num2words import num2words
def writefile(filename, lines):
with open(filename, 'w', encoding='utf-8') as f:
f.writelines(lines)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input", "-i", required=True, type=str)
parser.add_argument("--output", "-o", required=True, type=str)
args = parser.parse_args()
outlines = []
with open(f"{args.input}", 'r') as f:
inputs = f.readlines()
for line in tqdm(inputs):
line = line.strip()
line = re.sub(u"([^\u0041-\u005a\u0061-\u007a\u0030-\u0039\u00d1\u00f1\'])", " ", line)
items = []
punc='~`!#$%^&*()_+-=|\';":/.,?><~.'
for item in line.split():
if item.isdigit():
try:
item = num2words(item, lang='es')
except Exception as e:
print(line)
raise(e)
items.append(item)
line = " ".join(items)
line = (re.sub(r"[%s]+" %punc, "",line))
line = line.replace("-", " ")
line = line.lower()
line = line.replace("' S", "'S")
line = line.replace(" ", "|")
line = " ".join(line) + " |"
outlines.append(line + '\n')
# print(line)
writefile(args.output, outlines)
if __name__ == "__main__":
main()
#####################################
# Hubert ED model #
#####################################
[ $# -lt 1 ] && echo "Usage: $0 <init-model> <gen-set> <src> <tgt> <max_tokens> <world_size> <rank>" && exit 0
#source /mnt/default/v-ziqzhang/.bashrc_sing
model_path=$1
gen_set=$2
tgt=$3
src="ltr"
max_tokens=$4
word_size=$5
rank=$6
outdir=$7
[ -z $tgt ] && tgt="kmu"
[ -z $gen_set ] && gen_set="dev_clean"
[ -z $word_size ] && word_size=1
[ -z $rank ] && rank=0
[ -z $max_tokens ] && max_tokens=16000
FAIRSEQ_ROOT=/mnt/output/users/v-kunwei/code/fairseq_mlstku
DATA_DIR=/home/v-kunwei/
[ $gen_set == "test" ] && DATA_DIR=/mnt/output/users/v-kunwei/code/fairseq_mlstku
[ -z $outdir ] && outdir=$DATA_DIR
results_path=$outdir/pseudo_${gen_set}_${rank}
[ ! -d $results_path ] && mkdir -p $results_path
for subset in $gen_set; do
python $FAIRSEQ_ROOT/fairseq_cli/generate_mt_label.py $DATA_DIR \
--path ${model_path} \
--task "translation_from_jst" \
--max-target-positions 18000 \
--gen-subset $subset \
-t $tgt -s "ltr" \
--dataset-impl "raw" \
--max-tokens ${max_tokens} \
--beam 2 \
--max-len-a 3 --max-len-b 100 \
--results-path $results_path \
--distributed-world-size $word_size --distributed-rank $rank \
echo "$model" > $results_path/model.record
sleep 1s
done | tee $results_path/decode.log
sleep 2s
lmweight=0
num_gpus=8
python examples/speech_recognition/new/infer.py --config-dir /mnt/output/users/v-kunwei/code/fairseq/examples/speech_recognition/new/conf \
--config-name infer task=audio_finetuning task.data=/home/v-kunwei common.user_dir=/mnt/output/users/v-kunwei/code/fairseq/examples/data2vec \
task.labels=ltr decoding.type=viterbi \
decoding.lexicon=models/es_eval/espeak_dict.txt \
decoding.unique_wer_file=True \
dataset.gen_subset=test \
common_eval.path=/mnt/output/users/v-kunwei/code/fairseq/models/es_eval/espeak_26lang_m10.pt decoding.beam=1500 distributed_training.distributed_world_size=${num_gpus} \
decoding.results_path=/home/v-kunwei
#sclite -h "/home/v-kunwei/hypo.units" -r "/home/v-kunwei/ref.units" -i rm -o all stdout > "./result.txt"
#$subset=test
python examples/speech_recognition/infer.py /home/v-kunwei --task audio_finetuning \
--nbest 1 --path /mnt/output/users/v-kunwei/code/fairseq/models/es_eval/espeak_26lang_m10.pt --gen-subset test --results-path /home/v-kunwei --criterion ctc --labels ltr --max-tokens 4000000 \
--post-process letter
# ####################################
# Hubert ED model #
# ####################################
#source /mnt/default/v-ziqzhang/.bashrc_sing
[ $# -lt 4 ] && echo "Usage: $0 <world_size> <update_freq> <w2v_path> <cpt>" && exit 0
world_size=$1
update_freq=$2
w2v_path=$3
cpt=$4
Mount=$5
[ -z $world_size ] && world_size=8
[ -z $update_freq ] && update_freq=3
[ -z $w2v_path ] && echo "you must specify a wav_path !" && exit 1
[ -z $cpt ] && cpt=030.pt
[ -z $Mount ] && Mount=/mnt/default
FAIRSEQ_ROOT=/mnt/output/users/v-kunwei/code/fairseq_mlstku
CONFIG_DIR=/mnt/output/users/v-kunwei/code/stpretrain_scripts/config
DATA_DIR="/mnt/output/users/v-kunwei/data/s2s_data/fin_enes100"
exp_name=${w2v_path%/*}
exp_name=${exp_name##*/}
MODEL_DIR="/mnt/output/users/v-kunwei/data/s2s_data/finetune/tune_ST_from_eneshu"
exp_name="tune_enes_lr5e-5_from_$cpt"
MODEL_DIR=$MODEL_DIR/$exp_name
[ -d $MODEL_DIR ] || mkdir -p $MODEL_DIR
max_tokens=490000
python $FAIRSEQ_ROOT/fairseq_cli/hydra_train.py \
--config-dir $CONFIG_DIR/finetune_asr \
--config-name base_100h \
\
+task.store_labels=true \
task.labels='["spm"]' \
task.data=$DATA_DIR \
task.label_dir=$DATA_DIR \
task.add_decoder=true \
+task.max_keep_size=490000 \
\
+model.reuse_text_emb=true \
model._name="stbert_st" \
model.w2v_path=${w2v_path} \
model.add_decoder=true \
\
criterion._name="label_smoothed_cross_entropy" \
+criterion.label_smoothing=0.2 \
+criterion.report_accuracy=true \
\
lr_scheduler._name="polynomial_decay" \
+lr_scheduler.warmup_updates=20000 \
\
optimization.lr=[0.0003] \
optimization.max_update=100000 \
checkpoint.best_checkpoint_metric="accuracy" \
checkpoint.maximize_best_checkpoint_metric=true \
checkpoint.save_interval=1 \
\
dataset.train_subset="train" \
dataset.valid_subset="valid" \
dataset.max_tokens=$max_tokens \
optimization.update_freq=[${update_freq}] \
\
distributed_training.distributed_world_size=${world_size} \
distributed_training.distributed_port=-1 \
\
common.log_interval=100 \
common.tensorboard_logdir=$MODEL_DIR \
checkpoint.save_dir=$MODEL_DIR \
hydra.run.dir=$MODEL_DIR \
hydra.job.name=${exp_name}
sleep 20s
# \
# lr_scheduler._name="polynomial_decay" \
# +lr_scheduler.warmup_updates=5000 \
# /mnt/default/v-ziqzhang/data/stbert-ed/exp/ST_enes/sc2t_base_ende_32gpu_1accum/checkpoint_204_400000.pt
# ####################################
# Hubert ED model #
# ####################################
#source /mnt/default/v-ziqzhang/.bashrc_sing
[ $# -lt 4 ] && echo "Usage: $0 <world_size> <update_freq> <w2v_path> <cpt>" && exit 0
world_size=$1
update_freq=$2
w2v_path=$3
cpt=$4
Mount=$5
[ -z $world_size ] && world_size=1
[ -z $update_freq ] && update_freq=1
[ -z $w2v_path ] && echo "you must specify a wav_path !" && exit 1
[ -z $cpt ] && cpt=030.pt
[ -z $Mount ] && Mount=/mnt/default
FAIRSEQ_ROOT=/mnt/output/users/v-kunwei/code/fairseq_mlstku
CONFIG_DIR=/mnt/output/users/v-kunwei/code/stpretrain_scripts/config
DATA_DIR="/mnt/output/users/v-kunwei/data/s2s_data/fin_esen"
exp_name=${w2v_path%/*}
exp_name=${exp_name##*/}
MODEL_DIR="/mnt/output/users/v-kunwei/data/s2s_data/finetune/tune_ST_from_esen"
exp_name="tune_esen_lr5e-5_from_$cpt"
MODEL_DIR=$MODEL_DIR/$exp_name
[ -d $MODEL_DIR ] || mkdir -p $MODEL_DIR
max_tokens=4900
python $FAIRSEQ_ROOT/fairseq_cli/hydra_train.py \
--config-dir $CONFIG_DIR/finetune_asr \
--config-name base_100h \
\
+task.store_labels=true \
task.labels='["spm"]' \
task.data=$DATA_DIR \
task.label_dir=$DATA_DIR \
task.add_decoder=true \
+task.max_keep_size=4900 \
\
+model.reuse_text_emb=true \
model._name="stbert_st" \
model.w2v_path=${w2v_path} \
model.add_decoder=true \
\
criterion._name="label_smoothed_cross_entropy" \
+criterion.label_smoothing=0.2 \
+criterion.report_accuracy=true \
\
lr_scheduler._name="polynomial_decay" \
+lr_scheduler.warmup_updates=20000 \
\
optimization.lr=[0.0002] \
optimization.max_update=100000 \
checkpoint.best_checkpoint_metric="accuracy" \
checkpoint.maximize_best_checkpoint_metric=true \
checkpoint.save_interval=1 \
\
dataset.train_subset="train" \
dataset.valid_subset="valid" \
dataset.max_tokens=$max_tokens \
optimization.update_freq=[${update_freq}] \
\
distributed_training.distributed_world_size=${world_size} \
distributed_training.distributed_port=-1 \
\
common.log_interval=100 \
common.tensorboard_logdir=$MODEL_DIR \
checkpoint.save_dir=$MODEL_DIR \
hydra.run.dir=$MODEL_DIR \
hydra.job.name=${exp_name}
sleep 20s
# \
# lr_scheduler._name="polynomial_decay" \
# +lr_scheduler.warmup_updates=5000 \
# /mnt/default/v-ziqzhang/data/stbert-ed/exp/ST_enes/sc2t_base_ende_32gpu_1accum/checkpoint_204_400000.pt
#####################################
# Hubert base model #
#####################################
[ $# -lt 1 ] && echo "Usage: $0 <init-model> <gen-set>" && exit 0
model_path=$1
src_dir=${model_path%/*}
cpt=${model_path##*/}
cpt=${cpt%.*}
#beam_size=$2
gen_set=$2
#lang=$4
[ -z $gen_set ] && gen_set="test_et"
[ -z $beam_size ] && beam_size=2
[ -z $lang ] && lang="fr"
#DATA_DIR=/mnt/output/users/v-kunwei/data/s2s_data/fin_enes
DATA_DIR=/home/v-kunwei
FAIRSEQ_ROOT=/mnt/output/users/v-kunwei/code/fairseq_mlstku
for subset in $gen_set; do
results_path=$src_dir/decode_${cpt}_beam${beam_size}/${subset}
[ ! -d $results_path ] && mkdir -p $results_path
python $FAIRSEQ_ROOT/fairseq_cli/generate.py \
$DATA_DIR --label-dir ${DATA_DIR} \
--labels '["spm"]' --gen-subset ${subset} \
--max-tokens 9000000 --task hubert_pretraining \
--add-decoder --fine-tuning --random-crop \
--path ${model_path} --results-path /home/v-kunwei --scoring sacrebleu \
--max-len-a 0 --max-len-b 900 \
--beam 10 --single-target
tail -n 1 /home/v-kunwei/generate-*.txt
sleep 1s
done
#####################################
# Hubert mt model #
#####################################
[ $# -gt 3 ] && echo "Usage: $0 <world_size> <seeds>" && exit 0
world_size=$1
update_freq=$2
w2v_path=$3
Mount=""
[ -z $world_size ] && world_size=8
[ -z $update_freq ] && update_freq=1
[ -z $w2v_path ] && w2v_path="/mnt/output/users/v-kunwei/data/s2s_data/model_wo_emb_32_1004.pt"
langs="ltr,kmu"
FAIRSEQ_ROOT=/mnt/output/users/v-kunwei/code/fairseq_mlstku
CONFIG_ROOT=/mnt/output/users/v-kunwei/code/stpretrain_scripts/config/translation
DATA_DIR=/mnt/output/users/v-kunwei/data/s2s_data/en_asr_data/
### set save-dir
MODEL_DIR="/mnt/output/users/v-kunwei/data/s2s_data/exp/text2unicode_en"
exp_name="base_pt400k_releaseiter2_${world_size}gpu_${update_freq}accum_lr1e-4_alll"
MODEL_DIR=$MODEL_DIR/$exp_name
[ -d $MODEL_DIR ] || mkdir -p $MODEL_DIR
python $FAIRSEQ_ROOT/fairseq_cli/hydra_train.py \
--config-dir $CONFIG_ROOT \
--config-name text2code \
+task.data=$DATA_DIR \
dataset.dataset_impl="raw" \
+task.source_lang="ltr" +task.target_lang="kmu" \
+task.normalize=false \
\
+criterion.label_smoothing=0.1 \
+criterion.report_accuracy=true \
optimizer.weight_decay=0.00001 \
+lr_scheduler.lr="[0.0001]" \
optimization.max_update=500000 \
\
+model.dropout=0.1 \
+model.attention_dropout=0.1 \
model.activation_dropout=0.1 \
model.decoder_layerdrop=0 \
model.layerdrop=0 \
model.w2v_path=$w2v_path \
+model.text_transformer_encoder_layers=6 \
\
dataset.train_subset="en_train" \
dataset.valid_subset="en_dev" \
optimization.update_freq=[${update_freq}] \
optimization.clip_norm=5 \
\
common.seed=222 \
common.log_interval=100 \
common.log_format="json" \
\
distributed_training.distributed_world_size=${world_size} \
distributed_training.nprocs_per_node=8 \
distributed_training.ddp_backend="legacy_ddp" \
\
common.tensorboard_logdir=$MODEL_DIR \
checkpoint.save_dir=$MODEL_DIR \
hydra.run.dir=$MODEL_DIR \
hydra.job.name=${exp_name} \
sleep 10s
# sleep infinity
#####################################
# Hubert mt model #
#####################################
[ $# -gt 3 ] && echo "Usage: $0 <world_size> <seeds>" && exit 0
world_size=$1
update_freq=$2
w2v_path=$3
Mount=""
[ -z $world_size ] && world_size=8
[ -z $update_freq ] && update_freq=1
[ -z $w2v_path ] && w2v_path="/mnt/output/users/v-kunwei/data/s2s_data/model_es_emb_90_1004.pt"
langs="ltr,kmu"
FAIRSEQ_ROOT=/mnt/output/users/v-kunwei/code/fairseq_mlstku
CONFIG_ROOT=/mnt/output/users/v-kunwei/code/stpretrain_scripts/config/translation
DATA_DIR=/mnt/output/users/v-kunwei/data/s2s_data/es_no_data/
### set save-dir
MODEL_DIR="/mnt/output/users/v-kunwei/data/s2s_data/exp/text2unicode_es"
exp_name="base_pt400k_releaseiter2_${world_size}gpu_${update_freq}accum_lr1e-4_no"
MODEL_DIR=$MODEL_DIR/$exp_name
[ -d $MODEL_DIR ] || mkdir -p $MODEL_DIR
python $FAIRSEQ_ROOT/fairseq_cli/hydra_train.py \
--config-dir $CONFIG_ROOT \
--config-name text2code \
+task.data=$DATA_DIR \
dataset.dataset_impl="raw" \
+task.source_lang="ltr" +task.target_lang="kmu" \
+task.normalize=false \
\
+criterion.label_smoothing=0.1 \
+criterion.report_accuracy=true \
optimizer.weight_decay=0.00001 \
+lr_scheduler.lr="[0.0001]" \
optimization.max_update=500000 \
\
+model.dropout=0.1 \
+model.attention_dropout=0.1 \
model.activation_dropout=0.1 \
model.decoder_layerdrop=0 \
model.layerdrop=0 \
model.w2v_path=$w2v_path \
+model.text_transformer_encoder_layers=6 \
\
dataset.train_subset="es_train" \
dataset.valid_subset="es_dev" \
optimization.update_freq=[${update_freq}] \
optimization.clip_norm=5 \
\
common.seed=222 \
common.log_interval=100 \
common.log_format="json" \
\
distributed_training.distributed_world_size=${world_size} \
distributed_training.nprocs_per_node=8 \
distributed_training.ddp_backend="legacy_ddp" \
\
common.tensorboard_logdir=$MODEL_DIR \
checkpoint.save_dir=$MODEL_DIR \
hydra.run.dir=$MODEL_DIR \
hydra.job.name=${exp_name} \
sleep 10s
# sleep infinity
#####################################
# Hubert mt model #
#####################################
[ $# -gt 3 ] && echo "Usage: $0 <world_size> <seeds>" && exit 0
world_size=$1
update_freq=$2
w2v_path=$3
Mount=""
[ -z $world_size ] && world_size=8
[ -z $update_freq ] && update_freq=1
[ -z $w2v_path ] && w2v_path="/mnt/output/users/v-kunwei/data/s2s_data/model_es_emb_81_1004.pt"
langs="ltr,kmu"
FAIRSEQ_ROOT=/mnt/output/users/v-kunwei/code/fairseq_mlstku
CONFIG_ROOT=/mnt/output/users/v-kunwei/code/stpretrain_scripts/config/translation
DATA_DIR=/mnt/output/users/v-kunwei/data/s2s_data/es_asrl_data/
### set save-dir
MODEL_DIR="/mnt/output/users/v-kunwei/data/s2s_data/exp/text2unicode_es"
exp_name="base_pt400k_releaseiter2_${world_size}gpu_${update_freq}accum_lr1e-4_ll"
MODEL_DIR=$MODEL_DIR/$exp_name
[ -d $MODEL_DIR ] || mkdir -p $MODEL_DIR
python $FAIRSEQ_ROOT/fairseq_cli/hydra_train.py \
--config-dir $CONFIG_ROOT \
--config-name text2code \
+task.data=$DATA_DIR \
dataset.dataset_impl="raw" \
+task.source_lang="ltr" +task.target_lang="kmu" \
+task.normalize=false \
\
+criterion.label_smoothing=0.1 \
+criterion.report_accuracy=true \
optimizer.weight_decay=0.00001 \
+lr_scheduler.lr="[0.0001]" \
optimization.max_update=500000 \
\
+model.dropout=0.1 \
+model.attention_dropout=0.1 \
model.activation_dropout=0.1 \
model.decoder_layerdrop=0 \
model.layerdrop=0 \
model.w2v_path=$w2v_path \
+model.text_transformer_encoder_layers=6 \
\
dataset.train_subset="es_train" \
dataset.valid_subset="es_dev" \
optimization.update_freq=[${update_freq}] \
optimization.clip_norm=5 \
\
common.seed=222 \
common.log_interval=100 \
common.log_format="json" \
\
distributed_training.distributed_world_size=${world_size} \
distributed_training.nprocs_per_node=8 \
distributed_training.ddp_backend="legacy_ddp" \
\
common.tensorboard_logdir=$MODEL_DIR \
checkpoint.save_dir=$MODEL_DIR \
hydra.run.dir=$MODEL_DIR \
hydra.job.name=${exp_name} \
sleep 10s
# sleep infinity
#####################################
# Hubert ED model #
#####################################
[ $# -lt 1 ] && echo "Usage: $0 <init-model> <gen-set> <src> <tgt> <max_tokens> <world_size> <rank>" && exit 0
#source /mnt/default/v-ziqzhang/.bashrc_sing
model_path=$1
gen_set=$2
tgt=$3
src="ltr"
max_tokens=$4
word_size=$5
rank=$6
outdir=$7
[ -z $tgt ] && tgt="kmu"
[ -z $gen_set ] && gen_set="dev_clean"
[ -z $word_size ] && word_size=1
[ -z $rank ] && rank=0
[ -z $max_tokens ] && max_tokens=2000
FAIRSEQ_ROOT=/mnt/output/users/v-kunwei/code/fairseq_mlst
DATA_DIR=${gen_set%/*}
gen_set=${gen_set##*/}
[ $gen_set == "test" ] && DATA_DIR=/mnt/output/users/v-kunwei/data/s2s_data/en_asr_data
[ -z $outdir ] && outdir=$DATA_DIR
results_path=$outdir/pseudo_${gen_set}_${rank}
[ ! -d $results_path ] && mkdir -p $results_path
for subset in $gen_set; do
python $FAIRSEQ_ROOT/fairseq_cli/generate_mt_label.py $DATA_DIR \
--path ${model_path} \
--task "translation_from_jst" \
--max-target-positions 3000 \
--gen-subset $subset \
-t $tgt -s "ltr" \
--max-tokens ${max_tokens} \
--dataset-impl "raw" \
--max-len-a 2 --max-len-b 100 \
--results-path $results_path \
--skip-invalid-size-inputs-valid-test \
--distributed-world-size $word_size --distributed-rank $rank \
echo "$model" > $results_path/model.record
sleep 1s
done | tee $results_path/decode.log
sleep 2s
#####################################
# Hubert ED model #
#####################################
[ $# -lt 1 ] && echo "Usage: $0 <init-model> <gen-set> <src> <tgt> <max_tokens> <world_size> <rank>" && exit 0
#source /mnt/default/v-ziqzhang/.bashrc_sing
model_path=$1
gen_set=$2
tgt=$3
src="ltr"
max_tokens=$4
word_size=$5
rank=$6
outdir=$7
[ -z $tgt ] && tgt="kmu"
[ -z $gen_set ] && gen_set="dev_clean"
[ -z $word_size ] && word_size=1
[ -z $rank ] && rank=0
[ -z $max_tokens ] && max_tokens=2000
FAIRSEQ_ROOT=/mnt/output/users/v-kunwei/code/fairseq_mlstku
DATA_DIR=${gen_set%/*}
gen_set=${gen_set##*/}
[ $gen_set == "test" ] && DATA_DIR=/mnt/output/users/v-kunwei/code/fairseq_mlstku
[ -z $outdir ] && outdir=$DATA_DIR
results_path=$outdir/pseudo_${gen_set}_${rank}
[ ! -d $results_path ] && mkdir -p $results_path
for subset in $gen_set; do
python $FAIRSEQ_ROOT/fairseq_cli/generate_mt_label.py $DATA_DIR \
--path ${model_path} \
--task "translation_from_jst" \
--max-target-positions 3000 \
--gen-subset $subset \
-t $tgt -s "ltr" \
--dataset-impl "raw" \
--max-tokens ${max_tokens} \
--beam 2 \
--max-len-a 2 --max-len-b 100 \
--results-path $results_path \
--skip-invalid-size-inputs-valid-test \
--distributed-world-size $word_size --distributed-rank $rank \
echo "$model" > $results_path/model.record
sleep 1s
done | tee $results_path/decode.log
sleep 2s
#####################################
# Hubert ED model #
#####################################
[ $# -lt 1 ] && echo "Usage: $0 <init-model> <gen-set>" && exit 0
model_path=$1
src_dir=${model_path%/*}
cpt=${model_path##*/}
cpt=${cpt%.*}
gen_set=$2
tgt=$3
outdir=$4
src="ltr"
[ -z $tgt ] && tgt="kmu"
[ -z $gen_set ] && gen_set="es_dev"
[ -z $outdir ] && outdir=$src_dir/decode_${cpt}
DATA_DIR=/mnt/output/users/v-kunwei/data/s2s_data/es_asr_data/
# DATA_DIR=/mnt/default/v-ziqzhang/data/stbert/data/librispeech/speech2c_joint_splitenc_400k/ltr-$tgt
# DATA_DIR=/mnt/default/v-ziqzhang/data/stbert/data/librispeech/speech2c_400k/ltr-$tgt
FAIRSEQ_ROOT=/mnt/output/users/v-kunwei/code/fairseq_mlst
langs="ltr,$tgt"
for subset in $gen_set; do
results_path=$outdir/${subset}
[ ! -d $results_path ] && mkdir -p $results_path
python $FAIRSEQ_ROOT/fairseq_cli/generate.py $DATA_DIR \
--path ${model_path} \
--task "translation_from_jst" \
--max-target-positions 3000 \
--gen-subset $subset \
-t $tgt -s "ltr" --dataset-impl "raw" \
--batch-size 16 \
--max-len-a 2 --max-len-b 400 \
--results-path $results_path \
--scoring sacrebleu $extra
echo $results_path
tail -n 1 $results_path/generate-*.txt
sleep 1s
done
# --distributed-world-size 1000 --distributed-rank 0 \
sleep 2s
# cat generate-newstest2020_enja.txt | grep "^D-" | cut -d'-' -f 2- | sort -n -k1 | cut -f3 > decode-newstest2020_enja.txt
# sacrebleu -t wmt20 -l en-ja -i decode-newstest2020_enja.txt --tokenize char
#####################################
# Hubert ED model #
#####################################
[ $# -lt 1 ] && echo "Usage: $0 <init-model> <gen-set>" && exit 0
model_path=$1
src_dir=${model_path%/*}
cpt=${model_path##*/}
cpt=${cpt%.*}
gen_set=$2
tgt=$3
outdir=$4
src="ltr"
[ -z $tgt ] && tgt="kmu"
[ -z $gen_set ] && gen_set="en_dev"
[ -z $outdir ] && outdir=$src_dir/decode_${cpt}
# DATA_DIR=/mnt/default/v-ziqzhang/data/stbert/data/librispeech/hubert_release_iter2_layer9_kmeans/ltr-$tgt
# DATA_DIR=/mnt/default/v-ziqzhang/data/stbert/data/librispeech/speech2c_joint_splitenc_400k/ltr-$tgt
#DATA_DIR=/mnt/default/v-ziqzhang/data/stbert/data/librispeech/speech2c_400k/ltr-$tgt
DATA_DIR=/mnt/output/users/v-kunwei/data/s2s_data/es_asr_data/
FAIRSEQ_ROOT=/mnt/output/users/v-kunwei/code/fairseq_mlst
langs="ltr,$tgt"
for subset in $gen_set; do
results_path=$outdir/${subset}
[ ! -d $results_path ] && mkdir -p $results_path
python $FAIRSEQ_ROOT/fairseq_cli/generate.py $DATA_DIR \
--path ${model_path} \
--task "translation_from_jst" \
--max-target-positions 3000 \
--gen-subset $subset \
-t $tgt -s "ltr" --dataset-impl "raw" \
--batch-size 16 \
--max-len-a 2 --max-len-b 400 \
--results-path $results_path \
--scoring wer
echo $results_path
tail -n 1 $results_path/generate-*.txt
sleep 1s
done
# --distributed-world-size 1000 --distributed-rank 0 \
sleep 2s
# cat generate-newstest2020_enja.txt | grep "^D-" | cut -d'-' -f 2- | sort -n -k1 | cut -f3 > decode-newstest2020_enja.txt
# sacrebleu -t wmt20 -l en-ja -i decode-newstest2020_enja.txt --tokenize char
# ----------------------------------------------------------------------------
# SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training (https://arxiv.org/abs/2210.03730)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechUT
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
import logging
import os
import sys
from typing import Dict, List, Optional, Tuple
from pathlib import Path
import numpy as np
from argparse import Namespace
from collections import OrderedDict
import torch
from dataclasses import dataclass, field
from fairseq.data import (
Dictionary,
encoders,
data_utils,
StripTokenDataset,
PrependTokenDataset,
AppendTokenDataset,
DenoisingDataset,
ConcatDataset,
FairseqDataset,
iterators,
ResamplingDataset,
MaskTokensDataset,
LanguagePairDataset,
)
from fairseq.data.audio.speech_to_text_joint_dataset import S2TJointDataConfig
from fairseq.data.shorten_dataset import maybe_shorten_dataset
# from fairseq.data.encoders.utils import get_whole_word_mask
from fairseq.dataclass.configs import FairseqDataclass
from fairseq.tasks import register_task
from fairseq.tasks.fairseq_task import FairseqTask
from fairseq.dataclass.constants import ChoiceEnum
from omegaconf import MISSING
from speechut.data.multimodal_corpus_dataset import MultiCorpusDataset
from speechut.data.load_langpair_dataset import load_langpair_dataset
from speechut.data.language_trible_dataset import LanguageTripleDataset, load_langtriple_dataset
from speechut.data.hubert_dataset import HubertDataset
logger = logging.getLogger(__name__)
TOKENIZER_CHOICES = ChoiceEnum(["sentencepiece", "hubert_letters", "none"])
def _lang_token(lang: str):
return "<lang:{}>".format(lang)
def _lang_token_index(dic: Dictionary, lang: str):
"""Return language token index."""
idx = dic.index(_lang_token(lang))
assert idx != dic.unk_index, "cannot find language token for lang {}".format(lang)
return idx
class LabelEncoder(object):
def __init__(self, dictionary: Dictionary) -> None:
self.dictionary = dictionary
def __call__(self, label: str) -> List[str]:
return self.dictionary.encode_line(
label, append_eos=False, add_if_not_exist=False,
)
### wrap the initial get_whole_word_mask which needs bpe_tokenizer,
### here we just assume words are splited by "|" or "<SIL>"
def get_whole_word_mask(args, dictionary):
def is_beginning_of_word(i):
if i < dictionary.nspecial:
# special elements are always considered beginnings
return True
tok = dictionary[i]
if tok.startswith("madeupword"):
return True
elif tok in ["<unk>", "<s>", "</s>", "<pad>", "|", "<eps>"]:
return True
else:
return False
mask_whole_words = torch.ByteTensor(
list(map(is_beginning_of_word, range(len(dictionary))))
)
return mask_whole_words
def get_repeative_start(tokens):
"""
tokens: torch.Tensor with repeative tokens
"""
length = len(tokens)
rep_start_id = tokens[:-1] != tokens[1:]
return torch.cat([torch.tensor([True]), rep_start_id])
@dataclass
class TextPretrainingConfig(FairseqDataclass):
### added for joint pretraining
text_data: Optional[str] = field(
default=None,
metadata={
"help": "if set, path to text data directory",
},
)
seed: Optional[int] = field(
default=1,
metadata={
"help": "for ordered_indices in MulticorpusDataset",
},
)
tokens_per_sample: Optional[int] = field(
default=512,
metadata={
"help": "max number of total tokens over all segments per sample for dataset",
},
)
tokens_per_sample_tgt: Optional[int] = field(
default=512,
metadata={
"help": "max number of total tokens over all segments per target sample for dataset",
},
)
sample_break_mode: Optional[str] = field(
default="eos",
metadata={
"help": "mode for breaking sentence",
},
)
mask: Optional[float] = field(
default=0.3,
metadata={
"help": "fraction of words/subwords that will be masked",
},
)
leave_unmasked_prob: float = field(
default=0.1,
metadata={"help": "probability that a masked token is unmasked"},
)
mask_random: Optional[float] = field(
default=0.1,
metadata={
"help": "instead of using [MASK], use random token this often",
},
)
freq_weighted_replacement: bool = field(
default=False,
metadata={"help": "sample random replacement words based on word frequencies"},
)
mask_whole_words: bool = field(
default=True,
metadata={"help": "mask whole words; you may also want to set --bpe"},
)
mask_repeative_tokens: bool = field(
default=True,
metadata={"help": "mask repeative_tokens; if mask_whole_words=False"},
)
mask_multiple_length: int = field(
default=1,
metadata={"help": "repeat the mask indices multiple times"},
)
mask_stdev: float = field(
default=0.0,
metadata={"help": "stdev of the mask length"},
)
shorten_method: Optional[str] = field(
default="none",
metadata={
"help": "if not none, shorten sequences that exceed tokens_per_sample",
"choices": "none/truncate/random_crop"
},
)
shorten_data_split_list: Optional[str] = field(
default="",
metadata={
"help": "comma_separated list of dataset splits to apply shortening to, e.g., train,valid (default: all dataset splits)",
},
)
### below hypra-parameters is used in bart
insert: Optional[float] = field(
default=0.0,
metadata={
"help": "insert this percentage of additional random tokens",
},
)
permute: Optional[float] = field(
default=0.0,
metadata={
"help": "take this proportion of subwords and permute them",
},
)
rotate: Optional[float] = field(
default=0.0,
metadata={
"help": "rotate this proportion of inputs",
},
)
poisson_lambda: Optional[float] = field(
default=3.5,
metadata={
"help": "randomly shuffle sentences for this proportion of inputs",
},
)
permute_sentences: Optional[float] = field(
default=0.0,
metadata={
"help": "shuffle this proportion of sentences in all inputs",
},
)
mask_length: Optional[str] = field(
default="span-poisson",
metadata={
"help": "mask length to choose",
"choice": "subword/word/span-poisson"
},
)
replace_length: Optional[int] = field(
default=1,
metadata={
"help": "when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)",
},
)
shuffle_instance: Optional[bool] = field(
default=False,
metadata={"help": "shuffle instance"},
)
max_source_positions: Optional[int] = field(
default=1024,
metadata={"help": "max number of tokens in the source sequence"},
)
max_target_positions: Optional[int] = field(
default=1024,
metadata={"help": "max number of tokens in the target sequence"},
)
bpe: Optional[str] = field(
default="",
metadata={
"help": "will wrapped by the text_data_config yaml",
},
)
data_config: Optional[str] = field(
default=None,
metadata={
"help": "a config yaml specify the bpe model of text data",
},
)
text_maxtokens_ratio: Optional[float] = field(
default=1.0,
metadata={
"help": "for text, max_tokens = max_tokens * text_maxtokens_ratio / 320 ",
},
)
prepend_tgt_lang_tag: bool = field(
default=False,
metadata={"help": "prepend tgt_lang_tag to replace <eos>"},
)
mask_text_ratio: Optional[float] = field(
default=0.0,
metadata={
"help": "mask_text_ratio, for paired data",
},
)
truncate_mono_source: bool = field(
default=True,
metadata={"help": "truncate mono source-side examples that exceed max-positions"},
)
@dataclass
class JointPretrainingConfig(FairseqDataclass):
data: str = field(
default=MISSING, metadata={"help": "path to speech data directory"}
)
fine_tuning: bool = field(
default=False, metadata={"help": "set to true if fine-tuning Hubert"}
)
labels: List[str] = field(
default_factory=lambda: ["ltr"],
metadata={
"help": (
"extension of the label files to load, frame-level labels for"
" pre-training, and sequence-level label for fine-tuning"
)
},
)
label_dir: Optional[str] = field(
default=None,
metadata={
"help": "if set, looks for labels in this directory instead",
},
)
label_rate: int = field(
default=-1,
metadata={"help": "label frame rate. -1 for sequence label"},
)
sample_rate: int = field(
default=16_000,
metadata={
"help": "target sample rate. audio files will be up/down "
"sampled to this rate"
},
)
normalize: bool = field(
default=False,
metadata={
"help": "if set, normalizes input to have 0 mean and unit variance"
},
)
enable_padding: bool = field(
default=False,
metadata={"help": "pad shorter samples instead of cropping"},
)
max_keep_size: Optional[int] = field(
default=None,
metadata={"help": "exclude sample longer than this"},
)
max_sample_size: Optional[int] = field(
default=None,
metadata={"help": "max sample size to crop to for batching"},
)
min_sample_size: Optional[int] = field(
default=None,
metadata={"help": "min sample size to crop to for batching"},
)
single_target: Optional[bool] = field(
default=False,
metadata={
"help": "if set, AddTargetDatasets outputs same keys "
"as AddTargetDataset"
},
)
random_crop: Optional[bool] = field(
default=True,
metadata={"help": "always crop from the beginning if false"},
)
pad_audio: Optional[bool] = field(
default=False,
metadata={"help": "pad audio to the longest one in the batch if true"},
)
store_labels: Optional[bool] = field(
default=True,
metadata={"help": "store spm labels in memory, should be true when fine-tune with bpe"},
)
add_decoder_target: bool = field(
default=False,
metadata={"help": "contral the model architecture, if set True, load reduced unit as target"},
)
split_modality_batch: bool = field(
default=False,
metadata={"help": "whether create all samples of different modalities in a batch"},
)
speech_tgt_lang: str = field(
default="",
metadata={"help": "prepend <tgt-id> to prev_output_tokens to replace <eos>, only used for decoder"},
)
speech_sampling_alpha: float = field(
default=0.2,
metadata={
"help": "Hyper-parameter alpha = 1/T for temperature-based speech resampling."
"(alpha = 1 for no resampling)"
},
)
text_sampling_alpha: float = field(
default=0.2,
metadata={
"help": "Hyper-parameter alpha = 1/T for temperature-based text resampling."
"(alpha = 1 for no resampling)"
},
)
hubert_tokenizer: Optional[TOKENIZER_CHOICES] = field(
default="none",
metadata={"help": "which tokenizer for processing text"},
)
sp_path: Optional[str] = field(
default=None,
metadata={"help": "sentencepiece model path if using bpe tokenizer"},
)
text_cfg: TextPretrainingConfig = TextPretrainingConfig()
# For inference
ctc_weight: float = field(
default=0.0,
metadata={"help": "ctc weight during inference"},
)
lm_dict: Optional[str] = field(
default="dict.txt",
metadata={"help": "dict used for decoding with language model, should be in cfg.data/"},
)
@register_task("joint_sc2t_pretraining", dataclass=JointPretrainingConfig)
class Jsc2tPretrainingTask(FairseqTask):
cfg: JointPretrainingConfig
def __init__(
self,
cfg: JointPretrainingConfig,
load_local_states: True,
) -> None:
super().__init__(cfg)
logger.info(f"current directory is {os.getcwd()}")
logger.info(f"JSTPretrainingTask Config {cfg}")
self.cfg = cfg
self.fine_tuning = cfg.fine_tuning
self.blank_symbol = "<s>"
if load_local_states:
self.state.add_factory("hubert_tokenizer", self.build_tokenizer)
if self.cfg.text_cfg.text_data is not None and os.path.exists(self.cfg.text_cfg.text_data):
self.state.add_factory("text_dictionary", self.load_text_dictionary)
self.state.add_factory("text_src_dictionary", self.load_text_src_dictionary)
if cfg.fine_tuning:
self.state.add_factory("target_dictionary", self.load_dictionaries)
else:
self.state.add_factory("dictionaries", self.load_dictionaries)
if cfg.text_cfg.data_config is not None:
self.text_data_cfg = S2TJointDataConfig(Path(f"{cfg.text_cfg.text_data}/{cfg.text_cfg.data_config}"))
self.cfg.text_cfg.bpe = self.text_data_cfg.bpe_tokenizer["bpe"]
else:
self.text_data_cfg = None
@property
def source_dictionary(self) -> Optional[Dictionary]:
return None
@property
def target_dictionary(self) -> Optional[Dictionary]:
return self.state.target_dictionary
@property
def dictionaries(self) -> List[Dictionary]:
return self.state.dictionaries
@property
def text_dictionary(self) -> Optional[Dictionary]:
return self.state.text_dictionary
@property
def text_src_dictionary(self) -> Optional[Dictionary]:
return self.state.text_src_dictionary
@property
def hubert_tokenizer(self):
return self.state.hubert_tokenizer
def load_dictionaries(self):
label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir
dictionaries = [Dictionary.load(f"{label_dir}/dict.{label}.txt") for label in self.cfg.labels]
if not self.cfg.fine_tuning:
for dictionary in dictionaries:
dictionary.add_symbol("<mask>")
return dictionaries[0] if self.cfg.fine_tuning else dictionaries
def load_text_dictionary(self):
tgt_dict_path = f"{self.cfg.text_cfg.text_data}/{self.text_data_cfg.vocab_filename if self.text_data_cfg is not None else 'dict.txt'}"
if not os.path.isfile(tgt_dict_path):
raise FileNotFoundError(f"Dict not found: {tgt_dict_path}")
text_dictionary = Dictionary.load(tgt_dict_path)
self.mask_idx = text_dictionary.add_symbol("<mask>")
return text_dictionary
def load_text_src_dictionary(self):
src_dict_path = f"{self.cfg.text_cfg.text_data}/{self.text_data_cfg.src_vocab_filename if self.text_data_cfg is not None else 'dict.txt'}"
if not os.path.isfile(src_dict_path):
raise FileNotFoundError(f"Dict not found: {src_dict_path}")
src_text_dictionary = Dictionary.load(src_dict_path)
self.mask_idx = src_text_dictionary.add_symbol("<mask>")
return src_text_dictionary
@classmethod
def setup_task(
cls, cfg: JointPretrainingConfig, **kwargs
) -> "Jsc2tPretrainingTask":
load_local_states = kwargs.get("load_local_states", True)
return cls(cfg, load_local_states)
def get_label_dir(self) -> str:
if self.cfg.label_dir is None:
return self.cfg.data
return self.cfg.label_dir
def load_paired_dataset(self, text_split, truncate_source=False):
text_split, lp = text_split.rsplit('.', 1) # e.g. "libritext.ltr-ltr"
if len(lp.split("-")) == 2:
src, tgt = lp.split("-")
if src == tgt:
logger.warn(f"| trying to load monolingual dataset {text_split}.{lp}, please check your task is right.")
paired_dataset = self.load_char_bart_dataset(f"{text_split}.{lp}.{tgt}")
return paired_dataset
paired_dataset = load_langpair_dataset(
self.cfg.text_cfg.text_data,
text_split,
src,
self.text_src_dictionary,
tgt,
self.text_dictionary,
combine=True,
dataset_impl=None,
upsample_primary=1,
left_pad_source=False,
left_pad_target=False,
max_source_positions=self.cfg.text_cfg.tokens_per_sample,
max_target_positions=self.cfg.text_cfg.tokens_per_sample,
truncate_source=truncate_source,
prepend_bos=False,
load_alignments=False,
append_source_id=True if self.cfg.text_cfg.prepend_tgt_lang_tag else False,
lang_format="<lang:{}>" if self.cfg.text_cfg.prepend_tgt_lang_tag else "[{}]",
input_feeding=self.cfg.add_decoder_target,
)
if self.cfg.text_cfg.mask_text_ratio > 0:
# add mask
self.mask_idx = self.text_src_dictionary.index("<mask>")
mask_whole_words = None
if self.cfg.text_cfg.mask_whole_words:
mask_whole_words = get_whole_word_mask(self.cfg.text_cfg, self.text_src_dictionary)
elif self.cfg.text_cfg.mask_repeative_tokens:
mask_whole_words = get_repeative_start
src_dataset, src_unmasked_dataset = MaskTokensDataset.apply_mask(
paired_dataset.src,
self.text_src_dictionary,
pad_idx=self.text_src_dictionary.pad(),
mask_idx=self.mask_idx,
seed=self.cfg.text_cfg.seed,
mask_prob=self.cfg.text_cfg.mask_text_ratio,
leave_unmasked_prob=self.cfg.text_cfg.leave_unmasked_prob,
random_token_prob=self.cfg.text_cfg.mask_random,
freq_weighted_replacement=self.cfg.text_cfg.freq_weighted_replacement,
mask_whole_words=mask_whole_words,
mask_multiple_length=self.cfg.text_cfg.mask_multiple_length,
mask_stdev=self.cfg.text_cfg.mask_stdev,
)
tgt_dataset = paired_dataset.tgt if paired_dataset.tgt is not None else src_unmasked_dataset
paired_dataset = LanguageTripleDataset(
src_dataset,
src_dataset.sizes,
self.text_src_dictionary,
src_unmasked_dataset,
src_unmasked_dataset.sizes,
self.text_src_dictionary,
tgt_dataset,
tgt_dataset.sizes,
self.text_dictionary,
left_pad_source=False,
left_pad_target=False,
align_dataset=None,
eos=None,
num_buckets=0,
shuffle=True,
pad_to_multiple=1,
)
else:
src, ref, tgt = lp.split("-")
paired_dataset = load_langtriple_dataset(
self.cfg.text_cfg.text_data,
text_split,
src,
self.text_src_dictionary,
ref,
self.dictionaries[-1],
tgt,
self.text_dictionary,
combine=True,
dataset_impl=None,
upsample_primary=1,
left_pad_source=False,
left_pad_target=False,
max_source_positions=self.cfg.text_cfg.tokens_per_sample,
max_target_positions=self.cfg.text_cfg.tokens_per_sample,
truncate_source=truncate_source,
prepend_bos=False,
load_alignments=False,
append_source_id=True if self.cfg.text_cfg.prepend_tgt_lang_tag else False,
lang_format="<lang:{}>" if self.cfg.text_cfg.prepend_tgt_lang_tag else "[{}]",
)
return paired_dataset
def load_dataset(self, split: str, epoch=1, **kwargs) -> None:
"""
Create Wav dataset for audio, and Index dataset for phonemized text,
then concatenate them to by fairseq.data.multi_corpus_dataset.MultiCorpusDataset.
"""
speech_splits = split.split('+')[0].split(',')
### 1st, create a speech dataset using STSpeechDataset (modified from HubertDataset)
dicts = [self.target_dictionary] if self.cfg.fine_tuning else self.dictionaries
pad_list = [dict.pad() for dict in dicts]
eos_list = [dict.eos() for dict in dicts]
procs = [LabelEncoder(dict) for dict in dicts]
if self.cfg.speech_tgt_lang != "":
tgt_lang_idx = _lang_token_index(dicts[0], self.cfg.speech_tgt_lang)
logger.info(f"Will prepend <{tgt_lang_idx}> at the beginning of prev_output_tokens to replace <eos>")
else:
tgt_lang_idx = None
# hubert v1: pad_audio=True, random_crop=False;
speech_datasets = []
for speech_split in speech_splits:
paths = [
f"{self.get_label_dir()}/{speech_split}.{l}" for l in self.cfg.labels
]
speech_datasets.append(
HubertDataset(
f"{self.cfg.data}/{speech_split}.tsv",
sample_rate=self.cfg.sample_rate,
label_paths=paths,
label_rates=self.cfg.label_rate,
pad_list=pad_list,
eos_list=eos_list,
label_processors=procs,
max_keep_sample_size=self.cfg.max_keep_size,
min_keep_sample_size=self.cfg.min_sample_size,
max_sample_size=self.cfg.max_sample_size,
pad_audio=self.cfg.pad_audio,
normalize=self.cfg.normalize,
store_labels=self.cfg.store_labels,
random_crop=self.cfg.random_crop,
single_target=self.cfg.single_target,
tgt_dict=dicts[0],
add_decoder_target=self.cfg.add_decoder_target,
fine_tuning=self.cfg.fine_tuning,
tgt_lang_idx=tgt_lang_idx,
tokenizer=self.hubert_tokenizer,
)
)
if len(speech_datasets) > 1:
speech_dataset = ConcatDataset(speech_datasets)
else:
speech_dataset = speech_datasets[0]
has_text = len(split.split('+')) > 1
if not has_text:
assert speech_dataset is not None
self.datasets[split] = speech_dataset
return
### 2nd, create paired/mono text datasets using Langpairdataset
if split.split('+')[1] != '':
paired_splits = [paired_split for paired_split in split.split('+')[1].split(',') if paired_split != '']
paired_datasets = [self.load_paired_dataset(paired_split) for paired_split in paired_splits]
else:
paired_splits, paired_datasets = [], []
if len(split.split('+')) > 2 and split.split('+')[2] != '':
mono_splits = [mono_split for mono_split in split.split('+')[2].split(',') if mono_split != '']
mono_datasets = [self.load_paired_dataset(mono_split, truncate_source=self.cfg.text_cfg.truncate_mono_source) for mono_split in mono_splits]
else:
mono_splits, mono_datasets = [], []
assert len(mono_datasets + paired_datasets) > 0, f"split {split} has no text! you should check out for that"
### 3rd, if provided, create a supervised dataset with labeled data
if len(split.split('+')) > 3 and split.split('+')[3] != '':
assert len(paired_splits) > 0, f"supervised dataset can not be loaded without text paired dataset!"
tgt = paired_splits[0].rsplit('.', 1)[1].split("-")[1]
sup_split = split.split('+')[3]
sup_dataset = HubertDataset(
f"{self.cfg.data}/{sup_split}.tsv",
sample_rate=self.cfg.sample_rate,
label_paths=[f"{self.get_label_dir()}/{sup_split}.{tgt}"],
label_rates=[-1],
pad_list=[self.text_dictionary.pad()],
eos_list=[self.text_dictionary.eos()],
label_processors=[LabelEncoder(self.text_dictionary)],
max_keep_sample_size=self.cfg.max_keep_size,
min_keep_sample_size=None,
max_sample_size=None,
pad_audio=True,
normalize=self.cfg.normalize,
store_labels=self.cfg.store_labels,
random_crop=False,
single_target=True,
tgt_dict=self.text_dictionary,
add_decoder_target=self.cfg.add_decoder_target,
fine_tuning=True,
tgt_lang_idx=None,
tokenizer=None,
)
else:
sup_dataset = None
### 4th, compose a MultiCorpusDataset
dataset_dict, max_positions_dict, distributions, max_tokens_ratios = self.resample_multi_modality_dataset(
speech_dataset, sup_dataset, mono_datasets, paired_datasets, mono_splits, paired_splits, epoch=epoch,
)
self.datasets[split] = MultiCorpusDataset(
dataset_dict,
max_positions=max_positions_dict,
distribution=distributions,
max_tokens_ratio=max_tokens_ratios,
seed=self.cfg.text_cfg.seed,
sort_indices=True,
)
def max_positions(self) -> Tuple[int, int]:
return (sys.maxsize, sys.maxsize)
def filter_indices_by_size(
self, indices: np.array, *args, **kwargs
) -> np.array:
return indices
def get_batch_iterator(
self,
dataset,
max_tokens=None,
max_sentences=None,
max_positions=None,
ignore_invalid_inputs=False,
required_batch_size_multiple=1,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
epoch=1,
data_buffer_size=0,
disable_iterator_cache=False,
skip_remainder_batch=False,
grouped_shuffling=False,
update_epoch_batch_itr=False,
):
"""
Get an iterator that yields batches of data from the given dataset.
Args:
dataset (~fairseq.data.FairseqDataset): dataset to batch
max_tokens (int, optional): max number of tokens in each batch
(default: None).
max_sentences (int, optional): max number of sentences in each
batch (default: None).
max_positions (optional): max sentence length supported by the
model (default: None).
ignore_invalid_inputs (bool, optional): don't raise Exception for
sentences that are too long (default: False).
required_batch_size_multiple (int, optional): require batch size to
be a multiple of N (default: 1).
seed (int, optional): seed for random number generator for
reproducibility (default: 1).
num_shards (int, optional): shard the data iterator into N
shards (default: 1).
shard_id (int, optional): which shard of the data iterator to
return (default: 0).
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process
(default: 0).
epoch (int, optional): the epoch to start the iterator from
(default: 1).
data_buffer_size (int, optional): number of batches to
preload (default: 0).
disable_iterator_cache (bool, optional): don't cache the
EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`)
(default: False).
skip_remainder_batch (bool, optional): if set, discard the last
batch in each training epoch, as the last batch is often smaller than
local_batch_size * distributed_word_size (default: ``True``).
grouped_shuffling (bool, optional): group batches with each groups
containing num_shards batches and shuffle groups. Reduces difference
between sequence lengths among workers for batches sorted by length.
update_epoch_batch_itr (bool optional): if true then donot use the cached
batch iterator for the epoch
Returns:
~fairseq.iterators.EpochBatchIterator: a batched iterator over the
given dataset split
"""
if self.fine_tuning or not isinstance(dataset, MultiCorpusDataset):
return super().get_batch_iterator(
dataset,
max_tokens=max_tokens,
max_sentences=max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=ignore_invalid_inputs,
required_batch_size_multiple=required_batch_size_multiple,
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
num_workers=num_workers,
epoch=epoch,
data_buffer_size=data_buffer_size,
disable_iterator_cache=disable_iterator_cache,
skip_remainder_batch=skip_remainder_batch,
grouped_shuffling=grouped_shuffling,
update_epoch_batch_itr=update_epoch_batch_itr,
)
can_reuse_epoch_itr = (
not disable_iterator_cache
and not update_epoch_batch_itr
and self.can_reuse_epoch_itr(dataset)
)
if can_reuse_epoch_itr and dataset in self.dataset_to_epoch_iter:
logger.debug("reusing EpochBatchIterator for epoch {}".format(epoch))
return self.dataset_to_epoch_iter[dataset]
assert isinstance(dataset, FairseqDataset)
# initialize the dataset with the correct starting epoch
dataset.set_epoch(epoch)
# get indices ordered by example size
with data_utils.numpy_seed(seed):
indices = dataset.ordered_indices()
# filter examples that are too large
if max_positions is not None:
indices = self.filter_indices_by_size(
indices, dataset, max_positions, ignore_invalid_inputs
)
# create mini-batches with given size constraints
batch_sampler = dataset.get_batch_sampler(
indices,
num_shards,
seed,
max_tokens=max_tokens,
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
split_modality_batch=self.cfg.split_modality_batch,
)
# return a reusable, sharded iterator
epoch_iter = iterators.EpochBatchIterator(
dataset=dataset,
collate_fn=dataset.collater,
batch_sampler=batch_sampler,
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
num_workers=num_workers,
epoch=epoch,
buffer_size=data_buffer_size,
skip_remainder_batch=skip_remainder_batch,
disable_shuffling=True,
grouped_shuffling=grouped_shuffling,
)
if can_reuse_epoch_itr:
self.dataset_to_epoch_iter[dataset] = epoch_iter
return epoch_iter
def build_generator(
self,
models,
args,
seq_gen_cls=None,
extra_gen_cls_kwargs=None,
):
"""Build ED-CTC generator for finet-tuned ASR model"""
from speechut.squence_generator import SequenceGenerator
extra_gen_cls_kwargs = {
"ctc_weight": self.cfg.ctc_weight,
"lm_dict": Dictionary.load(os.path.join(self.cfg.data, self.cfg.lm_dict)),
**extra_gen_cls_kwargs
}
return super().build_generator(
models, args, seq_gen_cls=SequenceGenerator, extra_gen_cls_kwargs=extra_gen_cls_kwargs
)
@classmethod
def _get_size_ratios(cls, ids: List[str], sizes: List[int], alpha: float = 1.0):
"""Size ratios for temperature-based sampling
(https://arxiv.org/abs/1907.05019)"""
_sizes = np.array(sizes)
prob = _sizes / _sizes.sum()
smoothed_prob = prob ** alpha
smoothed_prob = smoothed_prob / smoothed_prob.sum()
size_ratio = (smoothed_prob * _sizes.sum()) / _sizes
o_str = str({_i: f"{prob[i]:.3f}" for i, _i in enumerate(ids)})
logger.info(f"original sampling probability: {o_str}")
p_str = str({_i: f"{smoothed_prob[i]:.3f}" for i, _i in enumerate(ids)})
logger.info(f"balanced sampling probability: {p_str}")
sr_str = str({_id: f"{size_ratio[i]:.3f}" for i, _id in enumerate(ids)})
logger.info(f"balanced sampling size ratio: {sr_str}")
return size_ratio.tolist()
def resample_multi_modality_dataset(self, speech_dataset, sup_dataset, mono_datasets, paired_datasets, mono_splits, paired_splits, epoch=1, train=True):
assert len(mono_datasets+paired_datasets) > 0, f"No text data loaded!"
if len(mono_datasets) > 1 and self.cfg.text_sampling_alpha != 1.0:
size_ratios = self._get_size_ratios(
mono_splits, [len(s) for s in mono_datasets], alpha=self.cfg.text_sampling_alpha
)
mono_datasets = [
ResamplingDataset(
d, size_ratio=r, seed=0, epoch=epoch, replace=(r >= 1.0)
) for d, r in zip(mono_datasets, size_ratios)
]
if len(paired_datasets) > 1 and self.cfg.text_sampling_alpha != 1.0:
size_ratios = self._get_size_ratios(
paired_splits, [len(s) for s in paired_datasets], alpha=self.cfg.text_sampling_alpha
)
paired_datasets = [
ResamplingDataset(
d, size_ratio=r, seed=0, epoch=epoch, replace=(r >= 1.0)
) for d, r in zip(paired_datasets, size_ratios)
]
dataset_list = [speech_dataset, sup_dataset]
for datasets in [mono_datasets, paired_datasets]:
if len(datasets) > 1:
dataset_list.append(ConcatDataset(datasets))
elif len(datasets) == 1:
dataset_list.append(datasets[0])
else:
dataset_list.append(None)
### match speech/text datasets according to modality
dataset_dict = OrderedDict((name, d) for name, d in zip(["speech", "speech_sup", "text_mono", "text_paired"], dataset_list) if d is not None)
max_positions_dict = {
"speech": None,
"speech_sup": None,
"text_mono": (self.cfg.text_cfg.tokens_per_sample, self.cfg.text_cfg.tokens_per_sample),
"text_paired": (self.cfg.text_cfg.tokens_per_sample, self.cfg.text_cfg.tokens_per_sample),
}
max_positions_dict = OrderedDict((name, max_positions_dict[name]) for name in dataset_dict.keys())
max_tokens_ratios_dict = {
"speech": 1.0,
"speech_sup": 1.0,
"text_mono": 1.0 / 320 / self.cfg.text_cfg.text_maxtokens_ratio,
"text_paired": 1.0 / 320 / self.cfg.text_cfg.text_maxtokens_ratio,
}
max_tokens_ratios = [max_tokens_ratios_dict[name] for name in dataset_dict.keys()]
dataset_lens = np.array([len(dataset) for dataset in dataset_dict.values()])
dataset_avg_sample_lens = np.array([
sum([dataset.num_tokens(i) for i in np.random.randint(low=0, high=len(dataset), size=10000)]) / 10000.0
for dataset in dataset_dict.values()
])
if not "speech" in dataset_dict:
distributions = [l / sum(dataset_lens) for l in dataset_lens]
else:
## we just keep the batches of speech and non-speech the same, expand_coef is to ensure speech batches is less than others
first_ratio = dataset_lens[0] / sum(dataset_lens)
expand_coef = 1.2 if sup_dataset is None else 1.1 * sum(dataset_lens[0:2]) / dataset_lens[0]
distributions = [expand_coef * max_tokens_ratios[i] * dataset_avg_sample_lens[0] / l for (i, l) in enumerate(dataset_avg_sample_lens)]
distributions[0] = 1.0
if sup_dataset is not None:
distributions[1] = dataset_lens[1] / dataset_lens[0]
distributions = [first_ratio * d for d in distributions]
logging.info(f"Number samples of datasets is {dataset_lens}")
logging.info(f"Avg sample length of datasets is {dataset_avg_sample_lens}")
logging.info(f"Sampling distributions is {distributions}")
logging.info(f"Maxtokens ratio is {max_tokens_ratios}")
return dataset_dict, max_positions_dict, distributions, max_tokens_ratios
def build_tokenizer(self, cfg=None):
logger.info(f"tokenizer: {self.cfg.hubert_tokenizer}")
if self.cfg.hubert_tokenizer != "none":
return encoders.build_bpe(Namespace(**{"bpe": self.cfg.hubert_tokenizer, "sentencepiece_model": self.cfg.sp_path}))
else:
return None
def load_char_bart_dataset(self, split):
mono_dataset = data_utils.load_indexed_dataset(
f"{self.cfg.text_cfg.text_data}/{split}",
self.text_dictionary,
)
mono_dataset = StripTokenDataset(mono_dataset, self.text_dictionary.eos())
mono_dataset = maybe_shorten_dataset(
mono_dataset,
split,
self.cfg.text_cfg.shorten_data_split_list,
self.cfg.text_cfg.shorten_method,
self.cfg.text_cfg.tokens_per_sample - 2,
self.cfg.text_cfg.seed,
)
logger.info("loaded {} samples from: {}".format(len(mono_dataset), mono_dataset))
### prepend bos and eos to dataset
mono_dataset = PrependTokenDataset(mono_dataset, self.text_dictionary.bos())
mono_dataset = AppendTokenDataset(mono_dataset, self.text_dictionary.eos())
mask_whole_words = (
get_whole_word_mask(None, self.text_dictionary)
if self.cfg.text_cfg.mask_whole_words
else None
)
lang=self.cfg.speech_tgt_lang
mono_dataset = DenoisingDataset(
mono_dataset,
mono_dataset.sizes,
self.text_dictionary,
self.mask_idx,
mask_whole_words,
shuffle=self.cfg.text_cfg.shuffle_instance,
seed=self.cfg.text_cfg.seed,
args=self.cfg.text_cfg,
tgt_lang_idx=_lang_token_index(self.text_dictionary, lang) if self.cfg.text_cfg.prepend_tgt_lang_tag else None,
)
return mono_dataset
# SpeechLM
<!--**Pre-trained models for speech related tasks**-->
[**SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data**](https://arxiv.org/abs/2209.15329)
- June 2023: We have corrected the errors in the pre-training data for SpeechLM-P Base models, and new results are updated.
- April 2023: We discovered some errors about the data in the pre-training experiments, which will affect all the results about SpeechLM-P Base models. We are re-conducting the related experiments and will update the paper with the new results.
- (Done) Oct 2022: release the code and models
- Oct 2022: release preprint in [arXiv](https://arxiv.org/abs/2209.15329)
## Pre-Trained and Fine-tuned Models
| Model | Pre-training Dataset | Fine-tuning Dataset | Model |
| :------: | :----------------------------------------------: | :-----------------: | :-----: |
| SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | - | [Azure Storage] |
| SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [100 hrs LibriSpeech](http://www.openslr.org/12) | [Azure Storage] |
| SpeechLM-H Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | - | [Google drive](https://drive.google.com/file/d/1eblW8U8f9t-NTuCNRrNHwr-8BeLAUAmQ/view?usp=sharing) |
| SpeechLM-H Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [100 hrs LibriSpeech](http://www.openslr.org/12) | [Google drive](https://drive.google.com/file/d/1vXyO5DolbiWiTYZ6pkkKQsu2wJetaPlv/view?usp=sharing) |
| SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [En-De CoVoST-2](https://github.com/facebookresearch/covost) | [Azure Storage] |
| SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [En-Ca CoVoST-2](https://github.com/facebookresearch/covost) | [Azure Storage] |
| SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [En-Ar CoVoST-2](https://github.com/facebookresearch/covost) | [Azure Storage] |
| SpeechLM-P Base | [960 hrs LibriSpeech](http://www.openslr.org/12) + [40M Text](http://www.openslr.org/11) | [En-Tr CoVoST-2](https://github.com/facebookresearch/covost) | [Azure Storage] |
| SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | - | [Google drive](https://drive.google.com/file/d/1QjLIgTJKIylVIp5hUkfSjGPtz8Xo7Lky/view?usp=sharing) |
| SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | [960 hrs LibriSpeech](http://www.openslr.org/12) | [Google drive](https://drive.google.com/file/d/1YZQDVv096o8Opt0RBnkRiZXYPRDqKZnP/view?usp=sharing) |
| SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | [En-De CoVoST-2](https://github.com/facebookresearch/covost) | [Google drive](https://drive.google.com/file/d/1qYygNWSc11TQbBI1OzC4ChlR-dNh8t9S/view?usp=sharing) |
| SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | [En-Ca CoVoST-2](https://github.com/facebookresearch/covost) | [Google drive](https://drive.google.com/file/d/162U88mwso2aVfzzPkEM2nP_vwTpcb57T/view?usp=sharing) |
| SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | [En-Ar CoVoST-2](https://github.com/facebookresearch/covost) | [Google drive](https://drive.google.com/file/d/1lbTSRXewEeb2t45URunD6EiJcbniyjWW/view?usp=sharing) |
| SpeechLM-P Large | [60k hrs LibriLight](https://github.com/facebookresearch/libri-light) + [40M Text](http://www.openslr.org/11) | [En-Tr CoVoST-2](https://github.com/facebookresearch/covost) | [Google drive](https://drive.google.com/file/d/1Er4I_jHS175pQQph223yKtiiLQ378VvH/view?usp=sharing) |
## Extract features using pre-trained models
For easier use of our pre-trained models, we merge all inference-related code to [`SpeechLM.py`](SpeechLM.py) and make cleaned checkpoints [~~`SpeechLM-P Base`~~] [`SpeechLM-H Base`] [`SpeechLM-P Large`] by removing non-required modules. Now you can directly use the following script to extract your speech features:
```python
import torch
import torch.nn.functional as F
from SpeechLM import SpeechLMConfig, SpeechLM
checkpoint = torch.load('path/to/the/cleaned/checkpoint.pt')
cfg = SpeechLMConfig(checkpoint['cfg']['model'])
model = SpeechLM(cfg)
model.load_state_dict(checkpoint['model'])
model.eval()
wav_input_16khz = torch.randn(1,10000)
normalize = checkpoint['cfg']['task']['normalize'] # False for base model, True for large model
if normalize:
wav_input_16khz = F.layer_norm(wav_input_16khz[0], wav_input_16khz[0].shape).unsqueeze(0)
# extract the representation of last layer
rep = model.extract_features(wav_input_16khz)[0]
# extract the representation of each layer
output_layer = model.cfg.encoder_layers + model.cfg.text_transformer.encoder.layers
rep, layer_results = model.extract_features(wav_input_16khz, output_layer=output_layer, ret_layer_results=True)[0]
layer_reps = [x.transpose(0, 1) for x in layer_results]
```
## Setup
To fine-tune or pre-train more models, please follow the instructions below.
```bash
git submodule update --init SpeechLM/fairseq
cd SpeechLM/
pip install --editable fairseq/
pip install sacrebleu==1.5.1
```
## ASR on LibriSpeech
### Data preparation
Please follow the steps of wav2vec 2.0 manifest [here](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec#prepare-training-data-manifest) to prepare `train.tsv` and `train.ltr`. You should make sure the vocabulary [`dict.ltr.txt`](dataset/LibriSpeech/asr/dict.ltr.txt) is the same as that used for the pre-trained model.
Put yout prepared data into `$data_dir`, we provided eamples in [`dataset/LibriSpeech/asr`](dataset/LibriSpeech/asr/).
### Fine-tune a CTC model
- Fine-tune the base model
```bash
# Usage: speechlm/scripts/tune_speechlm_asr/finetune_base_ctc.sh <model_path> <data_dir> <cpt_tag> [mount=$PWD] [world_size=8] [update_freq=1]
model_path=path/to/your/pre-trained/model
data_dir=dataset/LibriSpeech/asr
bash speechlm/scripts/tune_speechlm_asr/finetune_base_ctc.sh $model_path $data_dir 'tag400k'
```
- Fine-tune the large model
```bash
# Usage: speechlm/scripts/tune_speechlm_asr/finetune_large_ctc.sh <model_path> <data_dir> <cpt_tag> [mount=$PWD] [world_size=8] [update_freq=4]
model_path=path/to/your/pre-trained/model
data_dir=dataset/LibriSpeech/asr
bash speechlm/scripts/tune_speechlm_asr/finetune_large_ctc.sh $model_path $data_dir 'tag400k'
```
### Decode
- Directly decode a CTC model.
```bash
# Usage: speechlm/scripts/tune_speechlm_asr/inference_ctc.sh <model_path> <data_dir> [gen-set=dev_clean,dev_other,test_clean,test_other]
model_path=path/to/your/fine-tuned/model
data_dir=dataset/LibriSpeech/asr
bash speechlm/scripts/tune_speechlm_asr/inference_ctc.sh $model_path $data_dir
# for large models
# bash speechlm/scripts/tune_speechlm_asr/inference_ctc_large.sh $model_path $data_dir
```
- Decode with 4-gram language model using [flashlight](https://github.com/flashlight/flashlight/tree/main/bindings/python) and [kenlm](https://github.com/kpu/kenlm).
> Please put [4-gram.arpa](https://www.openslr.org/resources/11/4-gram.arpa.gz) and the word-to-letter lexicon [librispeech_lexicon.lst](https://drive.google.com/file/d/1q7IbNGqtwXnctjvuvpviQ4ZmepFHQmTO/view?usp=sharing) into `$data_dir`.
```bash
# Usage: speechlm/scripts/tune_speechlm_asr/inference_ctc_kenlm.sh <model_path> <data_dir> [gen-set=dev_clean,dev_other,test_clean,test_other]
model_path=path/to/your/fine-tuned/model
data_dir=dataset/LibriSpeech/asr
bash speechlm/scripts/tune_speechlm_asr/inference_ctc_kenlm.sh $model_path $data_dir
```
- Decode large models with fairseq-lm using [flashlight](https://github.com/flashlight/flashlight/tree/main/bindings/python).
> Please put [lm_librispeech_word_transformer.pt](https://dl.fbaipublicfiles.com/wav2letter/sota/2019/lm/lm_librispeech_word_transformer.pt) and its vocabulary [`dict.txt`](https://dl.fbaipublicfiles.com/wav2letter/sota/2019/lm/lm_librispeech_word_transformer.dict) into `$data_dir/fairseq_word_lm`, and the word-to-letter lexicon [librispeech_lexicon.lst](https://drive.google.com/file/d/1q7IbNGqtwXnctjvuvpviQ4ZmepFHQmTO/view?usp=sharing) into `$data_dir`. Capitalize the `dict.txt` to amke it compatible with the word-to-letter lexicon.
```bash
# Usage: speechlm/scripts/tune_speechlm_asr/inference_ctc_large_fsqlm.sh <model_path> <data_dir> [gen-set=dev_clean,dev_other,test_clean,test_other]
model_path=path/to/your/fine-tuned/model
data_dir=dataset/LibriSpeech/asr
bash speechlm/scripts/tune_speechlm_asr/inference_ctc_large_fsqlm.sh $model_path $data_dir dev_other
```
## ST on CoVoST-2
### Data Preparation
1. Download [Common Voice audio clips](https://commonvoice.mozilla.org/en/datasets) (version 4) for English into `$cv_root/en`.
2. Get data manifest. The following script will convert mp3 files to waveform, create tsv file containing speech/translation paires, create data config files.
```bash
lang=de # ca,ar,tr
cv_root=dataset/CommonVoice/v4
bash speechlm/data_process/prepare_covost2_enxx.sh $lang $cv_root
```
We provided examples in [`dataset/CommonVoice/v4/en/en-de`](dataset/CommonVoice/v4/en/en-de).
### Fine-tune a encoder-decoder model
- Fine-tune the Base model (fine-tuned models will be stored in `$mount/exp/finetune_covost`).
```bash
model_path=path/to/your/pre-trained/model
lang=de # ca,ar,tr
data_dir=dataset/CommonVoice/v4/en/en-${lang}
# Usage (Base model): speechlm/scripts/tune_speechlm_st/ft_base_covost_enxx.sh <model_path> <data_dir> <lang> <cpt-tag> [mount=$PWD] [world_size=8] [update_freq=2]
bash speechlm/scripts/tune_speechlm_st/ft_base_covost_enxx.sh $model_path $data_dir $lang 'tag400k'
```
- Fine-tune the Large model (fine-tuned models will be stored in `$mount/exp/finetune_covost`).
```bash
# Usage (Large model): speechlm/scripts/tune_speechlm_st/ft_large_covost_enxx.sh <model_path> <data_dir> <lang> <cpt-tag> [mount=$PWD] [world_size=8] [update_freq=4]
bash speechlm/scripts/tune_speechlm_st/ft_large_covost_enxx.sh $model_path $data_dir $lang 'tag400k'
```
### Decode
- Decode the base model
```bash
# Usage: speechlm/scripts/tune_speechlm_st/inference_base.sh <model_path> <data_dir> <lang> [gen-set=dev] [beam_size=5]
model_path=path/to/your/fine-tuned/model
lang=de # ca,ar,tr
data_dir=dataset/CommonVoice/v4/en/en-${lang}
bash speechlm/scripts/tune_speechlm_st/inference_base.sh $model_path $data_dir $lang dev
```
- Decode the large model
```bash
# Usage: speechlm/scripts/tune_speechlm_st/inference_large.sh <model_path> <data_dir> <lang> [gen-set=dev] [beam_size=5]
bash speechlm/scripts/tune_speechlm_st/inference_large.sh $model_path $data_dir $lang dev
```
## Universal Representation Evaluation on SUPERB
Please refer to [**SUPERB**](https://superbbenchmark.org/) for the downstreaming tasks.
## Pre-train
Please follow the instructions of [Tokenizer](README.md#Tokenizers) to prepare the pre-training data. We provided examples in [`dataset`](dataset).
- SpeechLM-P Base model
Models will be stored in `$mount/pretrain`.
```bash
data_dir=dataset/LibriSpeech/phone_unit # should contain train_960.{tsv,phn}
text_data_dir=dataset/LibriLM/phone_unit/bin-idx # should contain train_text.phn-ltr.{phn,ltr}.{bin,idx}
# Usage: speechlm/scripts/pretrain_speechlm/base_speechlmp.sh <data_dir> <text_data_dir> [mount=$PWD] [world_size=32] [update_freq=1]
bash speechlm/scripts/pretrain_speechlm/base_speechlmp.sh $data_dir $text_data_dir
```
- SpeechLM-H Base model
```bash
data_dir=dataset/LibriSpeech/hidden_unit # should contain train_960.{tsv,phn}
text_data_dir=dataset/LibriLM/km-ltr/bin-idx # should contain train_text.km-ltr.{km,ltr}.{bin,idx}
# Usage: speechlm/scripts/pretrain_speechlm/base_speechlmh.sh <data_dir> <text_data_dir> [mount=$PWD] [world_size=32] [update_freq=1]
bash speechlm/scripts/pretrain_speechlm/base_speechlmp.sh $data_dir $text_data_dir
```
- SpeechLM-P Large model
```bash
data_dir=dataset/LibriSpeech/phone_unit # should contain train_960.{tsv,phn}
text_data_dir=dataset/LibriLM/phone_unit/bin-idx # should contain train_text.phn-ltr.{phn,ltr}.{bin,idx}
# Usage: speechlm/scripts/pretrain_speechlm/base_speechlmp.sh <data_dir> <text_data_dir> [mount=$PWD] [world_size=32] [update_freq=1]
bash speechlm/scripts/pretrain_speechlm/large_speechlmp.sh $data_dir $text_data_dir
```
## Tokenizers
### Phoneme-unit Tokenizer for Speech
This tokenizer is used to produce the frame-laigned phonemes for unlabeled speech, which is actually a hybrid HMM ASR model.
In the Base setting, we use 100h LibriSpeech labeled data to train the HMM model under Kaldi recipe, then decode the unpaired speech and get the aligned phonemes from the lattice.
Here we provided the processed phonemes of 960h speech here: [`train_960.tsv`](https://drive.google.com/file/d/1rxlikMglL2kEsF4NfqekZRoA02klY7CE/view?usp=sharing), [`train_960.phn`](), [`dev_clean.tsv`](https://drive.google.com/file/d/1NuVwe687jLBFkDLRy1EV2A2uXyV_kBo2/view?usp=sharing), [`dev_clean.phn`](https://drive.google.com/file/d/1cq_gbS-UgCALOoaE5QmhWrhkTdXuc_Uc/view?usp=sharing). Note that the label-rate is 100 (10ms).
> The phoneme inventory is 300+ word-position-dependent phones including silence phones.
### Phoneme-unit Tokenizer for Text
This tokenizer is used to phonemize the unpaired text data to (phonemes, letters) paired data, following a `words -> phonemes -> upsampled phones` pipeline.
The following script will download LibriSpeech LM corpus and produce the required data: `train_text.phn-ltr.phn.{idx,bin}` and `train_text.phn-ltr.ltr.{idx,bin}`.
> Before runing it, make sure you have our provided [`dict.phn.txt`](dataset/LibriLM/phone_unit/bin-idx/dict.phn.txt) and [`dict.ltr.txt`](dataset/LibriLM/phone_unit/bin-idx/dict.ltr.txt) in the output dir `dataset/LibriLM/phone_unit/bin-idx/`.
> The phoneme inventory is 300+ word-position-dependent phones including silence phones.
```bash
# data will be in dataset/LibriLM/phone_unit/
bash speechlm/data_process/prepare_phn2ltr_librilm.sh
```
### Hidden-unit Tokenizer for Speech
Please follow the steps of data preparation for HuBERT [here](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert#data-preparation) to prepare 1) wav recordings [`train.tsv`](dataset/LibriSpeech/hidden_unit/train_sample100.tsv) and 2) corresponding hidden-units [`train.km`](dataset/LibriSpeech/hidden_unit/train_sample100.km), and 3) unit vocabulary [`dict.km.txt`](dataset/LibriSpeech/hidden_unit/dict.km.txt).
### Hidden-unit Tokenizer for Text
This tokenizer is used to produce the speech-style hidden units from unpaired text.
We train a [FastSpeech](https://arxiv.org/abs/2006.04558)-like model (instead generating continuous spectrum in the original paper, here we generate discrete units) on a small amount of ASR data ([100 hrs LibriSpeech](http://www.openslr.org/12)) as the tokenizer.
Train:
1. Convert asr transcripts to phoneme sequence with duration information.
2. Extract hidden-units from speech, using the [Hidden-unit Tokenizer for Speech](#hidden-unit-tokenizer-for-speech).
3. Train the [model](speechlm/models/fasttext2unit.py) on the paired data:
```bash
data_dir=dataset/LibriSpeech/fast_phone2unit
bash speechlm/scripts/tokenizer_fastT2U/train_s_5e-4.sh $data_dir
```
> The phoneme inventory is 41 mono phones including silence phones.
Inference:
4. Convert text data to phoneme sequence by [`lexicon`](https://drive.google.com/file/d/1dh9NEx_cCF9_Aa0UcKyl9j00GXs6LmLQ/view?usp=sharing).
5. [Generate](speechlm/scripts/tokenizer_fastT2U/generate.sh) hidden units for a large text corpus:
```bash
gen_set=dataset/LibriSpeech/fast_phone2unit/genset_examples
bash speechlm/scripts/tokenizer_fastT2U/generate.sh $model_path $gen_set
```
We provided train/generate data examples in [`dataset/LibriSpeech/fast_phone2unit`](dataset/LibriSpeech/fast_phone2unit), and the model checkpoint [here](https://drive.google.com/file/d/1e-aYf8hPXuly8DEvNg5SISOlcUxsgED0/view?usp=sharing).
## License
This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
Portions of the source code are based on the [FAIRSEQ](https://github.com/pytorch/fairseq).
[Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
## Reference
If you find our work is useful in your research, please cite the following paper:
```bibtex
@article{zhang2022speechlm,
title = {SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data},
author = {Zhang, Ziqiang and Chen, Sanyuan and Zhou, Long and Wu, Yu and Ren, Shuo and Liu, Shujie and Yao, Zhuoyuan and Gong, Xun and Dai, Lirong and Li, Jinyu and Wei, Furu},
eprint={2209.15329},
archivePrefix={arXiv},
primaryClass={cs.CL},
year={2022}
}
```
### Contact Information
For help or issues using SpeechLM models, please submit a GitHub issue.
For other communications related to SpeechLM, please contact Long Zhou (`lozhou@microsoft.com`).
# ----------------------------------------------------------------------------
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM
# Code based on fairseq: https://github.com/facebookresearch/fairseq
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
import copy
import logging
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from modules import (
compute_mask_indices,
LayerNorm,
ConvFeatureExtractionModel,
GradMultiply,
TransformerEncoder,
TransformerEncoderBase,
)
# from fairseq.models.transformer import TransformerConfig
logger = logging.getLogger(__name__)
class DictConfig:
def __init__(self, cfg=None):
if cfg is not None:
self.update(cfg)
def update(self, cfg: dict):
self.__dict__.update(cfg)
class TransformerConfig:
def __init__(self, cfg=None):
if cfg is not None:
self.update(cfg)
def update(self, cfg: dict):
if 'encoder' in cfg:
self.encoder = DictConfig(cfg['encoder'])
del cfg['encoder']
if 'quant_noise' in cfg:
self.quant_noise = DictConfig(cfg['quant_noise'])
del cfg['quant_noise']
if 'decoder' in cfg:
del cfg['decoder']
self.__dict__.update(cfg)
class SpeechLMConfig:
def __init__(self, cfg=None):
self.label_rate: int = 50
self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
self.encoder_layers: int = 12 # num encoder layers in the transformer
self.encoder_embed_dim: int = 768 # encoder embedding dimension
self.encoder_embed_dim: int = 768 # encoder embedding dimension
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
self.encoder_attention_heads: int = 12 # num encoder attention heads
self.activation_fn: str = "gelu" # activation function to use
self.layer_type: str = "transformer" # layer type in encoder
# dropouts
self.dropout: float = 0.1 # dropout probability for the transformer
self.attention_dropout: float = 0.1 # dropout probability for attention weights
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr)
self.final_dim: int = 256 # project final representations and targets to this many dimensions
self.layer_norm_first: bool = False # apply layernorm first in the transformer
self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
self.conv_bias: bool = False # include bias in conv encoder
self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this
# masking
self.mask_length: int = 10 # mask length
self.mask_prob: float = 0.65 # probability of replacing a token with mask
self.mask_selection: str = "static" # how to choose mask length
self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
self.no_mask_overlap: bool = False # whether to allow masks to overlap
self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled)
# channel masking
self.mask_channel_length: int = 10 # length of the mask for features (channels)
self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
self.mask_channel_selection: str = "static" # how to choose mask length for channel masking
self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap
self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled)
# positional embeddings
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
# loss computation
self.skip_masked: bool = False # skip computing losses over masked frames
self.skip_nomask: bool = False # skip computing losses over unmasked frames
self.checkpoint_activations: bool = False # recompute activations and save memory for extra compute
# FP16 optimization
self.required_seq_len_multiple: int = 2 # pad the input to encoder such that the sequence length is divisible by multiple
# Custom
self.use_rel_pos_enc: bool = False # whether to use relative positional encoding
self.scaling_for_att: float = 1.0 # scaling for attention weights to prevent overflow issue (for large model)
# unit encoder-decoder
self.add_unit_encoder: bool = False # add unit encoder
# embedding mixing
self.mix_with_unit: bool = True # mix with the unit embeddings
self.use_pred_unit: bool = False # use the embeddings of predicted units
self.l2_embedding: bool = False # compute l2 loss between unit embedding and unit hidden state
if cfg is not None:
self.update(cfg)
def update(self, cfg: dict):
model_cfg = copy.deepcopy(cfg)
self.text_transformer = TransformerConfig(model_cfg['text_transformer'])
del model_cfg['text_transformer']
self.__dict__.update(model_cfg)
class SpeechLM(nn.Module):
def __init__(
self,
cfg: SpeechLMConfig,
) -> None:
super().__init__()
self.cfg = cfg
feature_enc_layers = eval(cfg.conv_feature_layers) # noqa
self.embed = feature_enc_layers[-1][0]
self.feature_extractor = ConvFeatureExtractionModel(
conv_layers=feature_enc_layers,
dropout=0.0,
mode=cfg.extractor_mode,
conv_bias=cfg.conv_bias,
)
sample_rate = 16000
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / sample_rate
self.post_extract_proj = (
nn.Linear(self.embed, cfg.encoder_embed_dim)
if self.embed != cfg.encoder_embed_dim
else None
)
self.mask_prob = cfg.mask_prob
self.mask_selection = cfg.mask_selection
self.mask_other = cfg.mask_other
self.mask_length = cfg.mask_length
self.no_mask_overlap = cfg.no_mask_overlap
self.mask_min_space = cfg.mask_min_space
self.mask_channel_prob = cfg.mask_channel_prob
self.mask_channel_selection = cfg.mask_channel_selection
self.mask_channel_other = cfg.mask_channel_other
self.mask_channel_length = cfg.mask_channel_length
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
self.mask_channel_min_space = cfg.mask_channel_min_space
self.dropout_input = nn.Dropout(cfg.dropout_input)
self.dropout_features = nn.Dropout(cfg.dropout_features)
self.feature_grad_mult = cfg.feature_grad_mult
self.logit_temp = cfg.logit_temp
self.skip_masked = cfg.skip_masked
self.skip_nomask = cfg.skip_nomask
self.final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
self.final_proj_list = nn.ModuleList([
nn.Linear(cfg.encoder_embed_dim, self.final_dim) for _ in range(2)
])
self.mask_emb = nn.Parameter(
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
)
self.encoder = TransformerEncoder(cfg)
self.layer_norm = LayerNorm(self.embed)
### build unit encoder:
self.mask_u2t = cfg.mask_u2t
self.compute_mum = cfg.compute_mum
self.add_text_ctc = cfg.add_text_ctc
self.text_ctc_conv_kernel = cfg.text_ctc_conv_kernel
self.padding_idx = 1
self.add_unit_encoder = cfg.add_unit_encoder
self.mix_with_unit = cfg.mix_with_unit
self.use_pred_unit = cfg.use_pred_unit
self.l2_embedding = cfg.l2_embedding
if self.add_unit_encoder:
self.unit_embed_tokens = None
### build unit encoder
self.unit_encoder = TransformerEncoderBase(
cfg.text_transformer,
dictionary=None,
embed_tokens=self.unit_embed_tokens,
use_rel_pos_enc=cfg.use_rel_pos_enc,
scaling_for_att=cfg.scaling_for_att,
)
### build unit2text decoder, not available for now
self.add_decoder = cfg.add_decoder
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions."""
super().upgrade_state_dict_named(state_dict, name)
return state_dict
def apply_mask(self, x, padding_mask, target_list):
B, T, C = x.shape
if self.mask_prob > 0:
mask_indices = compute_mask_indices(
(B, T),
padding_mask,
self.mask_prob,
self.mask_length,
self.mask_selection,
self.mask_other,
min_masks=2,
no_overlap=self.no_mask_overlap,
min_space=self.mask_min_space,
)
mask_indices = torch.from_numpy(mask_indices).to(x.device)
x[mask_indices] = self.mask_emb
else:
mask_indices = None
if self.mask_channel_prob > 0:
mask_channel_indices = compute_mask_indices(
(B, C),
None,
self.mask_channel_prob,
self.mask_channel_length,
self.mask_channel_selection,
self.mask_channel_other,
no_overlap=self.no_mask_channel_overlap,
min_space=self.mask_channel_min_space,
)
mask_channel_indices = (
torch.from_numpy(mask_channel_indices)
.to(x.device)
.unsqueeze(1)
.expand(-1, T, -1)
)
x[mask_channel_indices] = 0
return x, mask_indices
def forward_features(self, source: torch.Tensor) -> torch.Tensor:
if self.feature_grad_mult > 0:
features = self.feature_extractor(source)
if self.feature_grad_mult != 1.0:
features = GradMultiply.apply(features, self.feature_grad_mult)
else:
with torch.no_grad():
features = self.feature_extractor(source)
return features
def forward_targets(
self,
features: torch.Tensor,
target_list: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Trim features to ensure labels exist and then get aligned labels
feat_tsz = features.size(2)
targ_tsz = min([t.size(1) for t in target_list])
if self.feat2tar_ratio * feat_tsz > targ_tsz:
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
features = features[..., :feat_tsz]
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
target_inds += np.random.choice(int(self.feat2tar_ratio))
target_list = [t[:, target_inds.long()] for t in target_list]
return features, target_list
def forward_padding_mask(
self,
features: torch.Tensor,
padding_mask: torch.Tensor,
) -> torch.Tensor:
extra = padding_mask.size(1) % features.size(1)
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
padding_mask = padding_mask.all(-1)
return padding_mask
def get_normalized_probs(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample)
lprobs.batch_first = True
return lprobs
def downsample_ctc_padding_mask(self, padding_mask):
"""
padding_mask: (B, T)
"""
stride = self.text_ctc_conv_kernel // 2
return padding_mask[:, ::stride]
def compute_pred(self, proj_x, label_embs):
if self.target_glu:
label_embs = self.target_glu(label_embs)
x = F.normalize(proj_x.float(), dim=-1) # (S, D)
label_embs = F.normalize(label_embs.float(), dim=-1) # (C, D)
logits = torch.matmul(x, label_embs.T).type_as(proj_x) # (S, C)
logits /= self.logit_temp
return logits
def compute_hubert_logits(self, x, target, proj, label_embs, padding_mask, mask_indices):
if not self.skip_masked:
masked_indices = torch.logical_and(~padding_mask, mask_indices)
proj_x_m = proj(x[masked_indices])
logit_m_list = [(self.compute_pred(proj_x_m, label_embs), target[masked_indices])]
else:
logit_m_list = [None]
if not self.skip_nomask:
nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
proj_x_u = proj(x[nomask_indices])
logit_u_list = [(self.compute_pred(proj_x_u, label_embs), target[nomask_indices])]
else:
logit_u_list = [None]
return logit_m_list, logit_u_list
def convert_embeddings(self,
x,
padding_mask,
target=None,
mask_indices=None,
mix_with_unit=False,
use_pred_unit=False,
l2_embedding=False,
remask=False
):
"""
1. Mix with units if needed (default: True)
2. Prepare for unit_encoder inputs
Inputs:
x, (B, T, D)
Return:
src_tokens, (B, T)
soft_embeddings, (B, T, D)
l2_loss, a loss
"""
soft_embeddings = self.final_proj_list[0](x) if x.size(-1) == self.final_dim else x
if padding_mask is None:
padding_mask = soft_embeddings.new_zeros(soft_embeddings.size(0), soft_embeddings.size(1), dtype=torch.long)
if use_pred_unit:
src_tokens = self.compute_pred(self.final_proj_list[0](x), self.label_embs_list[0]).argmax(dim=-1)
src_tokens[padding_mask] = self.padding_idx
elif target is not None:
src_tokens = target
else:
src_tokens = padding_mask.long()
if l2_embedding | mix_with_unit:
unit_embeddings = self.unit_embed_tokens(src_tokens) # (B, T, D)
l2_loss = 0
if l2_embedding:
if mask_indices is not None:
l2_loss = (soft_embeddings - unit_embeddings)[mask_indices].float().pow(2).mean(dim=-1)
scale = unit_embeddings[mask_indices].float().pow(2).sum(dim=-1)
else:
l2_loss = (soft_embeddings - unit_embeddings).float().pow(2).mean(dim=-1)
scale = unit_embeddings.float().pow(2).sum(dim=-1)
l2_loss = (l2_loss / scale).mean()
if mix_with_unit:
B, T, D = x.shape
selected_indices = compute_mask_indices(
(B, T),
padding_mask,
self.mask_prob / 2,
self.mask_length // 2,
self.mask_selection,
self.mask_other,
min_masks=2,
no_overlap=self.no_mask_overlap,
min_space=self.mask_min_space,
)
selected_indices = torch.from_numpy(selected_indices).to(x.device)
if mask_indices is not None:
if remask:
remask_indices = torch.logical_and(selected_indices, mask_indices)
soft_embeddings[remask_indices] = self.mask_emb
swap_indices = torch.logical_and(selected_indices, ~mask_indices)
else:
swap_indices = selected_indices
soft_embeddings[swap_indices] = unit_embeddings[swap_indices]
soft_embeddings = soft_embeddings * (1 - padding_mask.unsqueeze(-1).type_as(x))
return src_tokens, soft_embeddings, l2_loss
def forward(
self,
source: torch.Tensor = None,
src_tokens: torch.Tensor = None,
src_lengths: torch.Tensor = None,
target_list: Optional[List[torch.Tensor]] = None,
padding_mask: Optional[torch.Tensor] = None,
mask: bool = True,
features_only: bool = False,
output_layer: Optional[int] = None,
) -> Dict[str, torch.Tensor]:
assert source is not None or src_tokens is not None
if source is not None:
return self.forward_speech(
source=source,
target_list=target_list,
padding_mask=padding_mask,
mask=mask,
features_only=features_only,
output_layer=output_layer,
)
else:
return self.forward_text(
src_tokens=src_tokens,
src_lengths=src_lengths,
mask=self.mask_u2t,
output_layer=output_layer,
)
def forward_speech(
self,
source: torch.Tensor = None,
target_list: Optional[List[torch.Tensor]] = None,
padding_mask: Optional[torch.Tensor] = None,
mask: bool = True,
features_only: bool = False,
output_layer: Optional[int] = None,
) -> Dict[str, torch.Tensor]:
"""output layer is 1-based"""
features = self.forward_features(source)
if target_list is not None:
features, target_list = self.forward_targets(features, target_list)
features_pen = features.float().pow(2).mean()
features = features.transpose(1, 2)
features = self.layer_norm(features)
unmasked_features = features.clone()
if padding_mask is not None:
padding_mask = self.forward_padding_mask(features, padding_mask)
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
features = self.dropout_input(features)
unmasked_features = self.dropout_features(unmasked_features)
if mask:
x, mask_indices = self.apply_mask(features, padding_mask, target_list)
else:
x = features
mask_indices = None
# feature: (B, T, D), float
# target: (B, T), long
# x: (B, T, D), float
# padding_mask: (B, T), bool
# mask_indices: (B, T), bool
x, layer_results = self.encoder(
x,
padding_mask=padding_mask,
layer=None if output_layer is None else output_layer - 1,
)
if features_only:
return {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
logit_m_list, logit_u_list = self.compute_hubert_logits(
x,
target_list[0],
self.final_proj_list[0],
self.label_embs_list[0],
padding_mask,
mask_indices,
)
result = {
"logit_m_list": logit_m_list,
"logit_u_list": logit_u_list,
"padding_mask": padding_mask,
"features_pen": features_pen,
}
if self.add_unit_encoder:
src_tokens, x_emb, l2_loss = self.convert_embeddings(
x,
padding_mask, target_list[0],
mask_indices=mask_indices,
mix_with_unit=self.mix_with_unit,
use_pred_unit=self.use_pred_unit,
l2_embedding=self.l2_embedding,
)
encoder_out = self.unit_encoder(src_tokens, token_embeddings=x_emb)
result['encoder_out'] = encoder_out['encoder_out'] # [(T, B, D)]
result['encoder_padding_mask'] = encoder_out['encoder_padding_mask'] # [(B, T)]
if self.l2_embedding:
result['embedding_l2_loss'] = l2_loss
code_logit_m_list, code_logit_u_list = self.compute_hubert_logits(
encoder_out['encoder_out'][0].transpose(0, 1),
target_list[-1],
self.final_proj_list[-1],
self.label_embs_list[-1],
padding_mask,
mask_indices,
)
result['logit_m_list'] += code_logit_m_list
result['logit_u_list'] += code_logit_u_list
return result
def forward_text(
self,
src_tokens: torch.Tensor = None,
src_lengths: torch.Tensor = None,
target_list: Optional[List[torch.Tensor]] = None,
mask: bool = True,
output_layer: Optional[int] = None,
) -> Dict[str, torch.Tensor]:
assert self.add_unit_encoder, f"Can not forward unit-text branch without unit_encoder!"
padding_mask = src_tokens == self.padding_idx
unit_embeddings = self.unit_embed_tokens(src_tokens)
if mask:
unit_embeddings, mask_indices = self.apply_mask(unit_embeddings, padding_mask, [src_tokens])
else:
### If already applied mask on src_tokens, then the target_list should contains many padding_idx
mask_indices = target_list[-1] != self.padding_idx
unit_embeddings[mask_indices] = self.mask_emb
encoder_out = self.unit_encoder(
src_tokens,
token_embeddings=unit_embeddings,
return_all_hiddens=output_layer is not None,
)
result = {}
result["encoder_out"] = encoder_out["encoder_out"]
result["encoder_states"] = encoder_out["encoder_states"]
result["padding_mask"] = padding_mask
if self.compute_mum:
code_logit_m_list, code_logit_u_list = self.compute_hubert_logits(
encoder_out["encoder_out"].transpose(0, 1),
target_list[-1],
self.final_proj_list[-1],
self.label_embs_list[-1],
padding_mask,
mask_indices,
)
result["logit_m_list"] = code_logit_m_list
result["logit_u_list"] = code_logit_u_list
if self.add_text_ctc:
result["encoder_out_ctc"] = [self.unit_encoder_ctc_head(x) for x in encoder_out['encoder_out']]
result["encoder_padding_mask"] = [
self.downsample_ctc_padding_mask(padding_mask) for padding_mask in encoder_out['encoder_padding_mask']
]
return result
def extract_features(
self,
source: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
mask: bool = False,
ret_conv: bool = False,
output_layer: Optional[int] = None,
ret_layer_results: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Extract features for only speech input"""
with torch.no_grad():
res = self.forward(
source,
padding_mask=padding_mask,
mask=mask,
features_only=True,
output_layer=output_layer,
)
# {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
x = res["x"] # B x T x D
padding_mask = res["padding_mask"]
if self.add_unit_encoder and (output_layer is None or output_layer > self.cfg.encoder_layers):
src_tokens, x, _ = self.convert_embeddings(
x,
padding_mask,
mix_with_unit=False,
use_pred_unit=False,
)
return_all_hiddens=output_layer is not None and output_layer > self.cfg.encoder_layers
encoder_out = self.unit_encoder(
src_tokens,
token_embeddings=x,
return_all_hiddens=return_all_hiddens,
)
res["x"] = encoder_out['encoder_out'][0].transpose(0, 1) # (B, T, D)
if return_all_hiddens:
res["layer_results"] += encoder_out['encoder_states'][1:1+output_layer-len(res["layer_results"])]
feature = res["features"] if ret_conv else res["x"]
if ret_layer_results:
feature = (feature, res["layer_results"])
return feature, padding_mask
def get_logits(self, net_output, is_masked=True):
if is_masked:
logits_list = net_output["logit_m_list"]
else:
logits_list = net_output["logit_u_list"]
logits_list = [x[0].float() for x in logits_list if x is not None]
return logits_list
def get_targets(self, net_output, is_masked=True):
if is_masked:
logits_list = net_output["logit_m_list"]
else:
logits_list = net_output["logit_u_list"]
targets_list = [x[1].long() for x in logits_list if x is not None]
return targets_list
def get_extra_losses(self, net_output):
extra_losses = []
names = []
if "features_pen" in net_output:
extra_losses.append(net_output["features_pen"])
names.append("features_pen")
if "embedding_l2_loss" in net_output:
extra_losses.append(net_output["embedding_l2_loss"])
names.append("embedding_l2_loss")
return extra_losses, names
def remove_pretraining_modules(self, step2=False):
self.target_glu = None
bpe_tokenizer:
bpe: sentencepiece
sentencepiece_model: spm_char_st_en_de.model
shuffle: false
use_audio_input: true
use_sample_rate: 16000
standardize_audio: false
vocab_filename: spm_char_st_en_de.txt
# required by speech_to_text task but never used
input_channels: 1
input_feat_per_channel: 1
bpe_tokenizer:
bpe: sentencepiece
sentencepiece_model: spm_char_st_en_de.model
shuffle: false
use_audio_input: true
use_sample_rate: 16000
standardize_audio: true
vocab_filename: spm_char_st_en_de.txt
# required by speech_to_text task but never used
input_channels: 1
input_feat_per_channel: 1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment