#!/bin/bash

# 初始化默认值
TRAIN_SET="train"
VALID_SET="valid"

# 定义帮助函数
usage() {
  cat <<EOF
Usage:
  $(basename "$0") [options]

Options:
  -h                                  - This help
  --dcu <number>                      - The number of dcu
  --log <log directory>               - The log file of train
  --td <train dir>                    - The directory of (train.tsv, valid.tsv)
  --res <result dir>                  - The directory of (xxx.pt)
  --lab <label dir>                   - The directory of (train.txt, valid.txt)
  --token <tokenizer path>            - The path of BPE_TOKENIZER
  --speecht5 <speecht5 path>          - The path of speecht5
  --checkpoint <checkpoint path>      - The path of speecht5_base.pt  
  --epoch                             - epoch of train         
EOF
}

# 主处理逻辑
process_long_option() {
  local arg="$1"
  local key="${arg%%=*}"
  local value="${arg#*=}"

  if [[ "$key" == "$arg" ]]; then
    # 如果没有等号，则假定下一个参数是值
    value="$2"
    if [[ -z "$value" || "$value" == -* ]]; then
      echo "Option $key requires a value."
      usage
      exit 1
    fi
    # 移除已处理的值
    shift
  fi

  case "$key" in
    --dcu)
      dcu="$value"
      ;;
    --log)
      logdir="$value"
      ;;
    --td)
      DATA_ROOT="$value"
      ;;
    --res)
      SAVE_DIR="$value"
      ;;
    --lab)
      LABEL_DIR="$value"
      ;;
    --token)
      BPE_TOKENIZER="$value"
      ;;
    --speecht5)
      USER_DIR="$value"
      ;;
    --checkpoint)
      PT_CHECKPOINT_PATH="$value"
      ;;
    --epoch)
      epoch="$value"
      ;;
    -h)
      usage
      exit 0
      ;;
    *)
      echo "Unknown option: $key"
      usage
      exit 1
      ;;
  esac
}

# 解析命令行参数
while [[ $# -gt 0 ]]; do
  case "$1" in
    --*)
       # 处理长选项及其值
      process_long_option "$1" "$2"
      # 移除已处理的选项
      shift 2
      ;;
    -h)
      usage
      exit 0
      ;;
    -*)
      echo "Invalid option: $1"
      usage
      exit 1
      ;;
    *)
      break
      ;;
  esac
done

# 创建日志目录
mkdir -p "$logdir"

# 创建保存目录
mkdir -p "$SAVE_DIR"

# 日志文件
all_log="$logdir/all-${dcu}-log.log"
benchmark_log="$logdir/train-log.log"

# 输出解析结果（仅为演示）
echo "dcu: $dcu"
echo "log: $logdir"
echo "DATA_ROOT: $DATA_ROOT"
echo "SAVE_DIR: $SAVE_DIR"
echo "LABEL_DIR: $LABEL_DIR"
echo "BPE_TOKENIZER: $BPE_TOKENIZER"
echo "USER_DIR: $USER_DIR"
echo "PT_CHECKPOINT_PATH: $PT_CHECKPOINT_PATH"
echo "epoch: $epoch"

# 调用 fairseq-train 并传递参数
fairseq-train "$DATA_ROOT" \
  --save-dir "$SAVE_DIR" \
  --tensorboard-logdir "$SAVE_DIR" \
  --train-subset "$TRAIN_SET" \
  --valid-subset "$VALID_SET" \
  --hubert-label-dir "$LABEL_DIR" \
  --distributed-world-size "$dcu" \
  --distributed-port 0 \
  --ddp-backend legacy_ddp \
  --user-dir "$USER_DIR" \
  --log-format json \
  --seed 1 \
  \
  --task speecht5 \
  --t5-task s2t \
  --sample-rate 16000 \
  --num-workers 0 \
  --max-tokens 1600000 \
  --update-freq 2 \
  --bpe-tokenizer "$BPE_TOKENIZER" \
  \
  --criterion speecht5 \
  --report-accuracy \
  --zero-infinity \
  --ce-weight 0.5 \
  --ctc-weight 0.5 \
  --sentence-avg \
  \
  --optimizer adam \
  --adam-betas "(0.9, 0.98)" \
  --adam-eps 1e-08 \
  --weight-decay 0.1 \
  --clip-norm 25.0 \
  --lr 0.00006 \
  --lr-scheduler tri_stage \
  --phase-ratio "[0.1, 0.4, 0.5]" \
  --final-lr-scale 0.05 \
  \
  --max-epoch "$epoch" \
  --max-update 80000 \
  --max-text-positions 600 \
  --required-batch-size-multiple 1 \
  --save-interval-updates 3000 \
  --skip-invalid-size-inputs-valid-test \
  \
  --arch t5_transformer_base_asr \
  --share-input-output-embed \
  --find-unused-parameters \
  --bert-init \
  --relative-position-embedding \
  --freeze-encoder-updates 13000 \
  \
  --keep-last-epochs 10 \
  --feature-grad-mult 1.0 \
  --best-checkpoint-metric s2t_accuracy \
  --maximize-best-checkpoint-metric \
  --finetune-from-model "$PT_CHECKPOINT_PATH" 2>&1 | tee "$all_log"