from fastapi import FastAPI, File, UploadFile, Path
import uvicorn
from fastapi.responses import FileResponse

import os
from typing import List


from wav_to_label import wav_to_label
from kantts.bin.train_sambert import train as train_sambert
from text_to_wav_trans import text_to_wav as text_to_wav_trans
from text_to_wav_onnx import text_to_wav_onnx


wav_dir = "./Data/ptts_spk0_wav"
txt_dir = "./Data"
output_dir = "./res/ptts_syn"


# 创建FastAPI实例
app = FastAPI()

@app.get("/")
def get_root():
    """
    注册一个根路径
    """
    return {"message": "Welcome to try: Personal Text To Speech !"}


@app.post("/uploadwavs")
async def uploadwavs(files: List[UploadFile] = File(...)):
    # 将文件保存到指定目录；文件路径=目录+文件名
    if not os.path.exists(wav_dir):
        os.makedirs(wav_dir, exist_ok=True)
    for file in files:
        with open(os.path.join(wav_dir, file.filename), "wb") as f:  
            f.write(await file.read())
    return {"msg": "File upload success in directory 'Data/ptts_spk0_wav/'"}

@app.post("/uploadtxt")
async def uploadtxt(file: UploadFile = File(...)):
    # 将文件保存到指定目录；文件路径=目录+文件名
    if not os.path.exists(txt_dir):
        os.makedirs(txt_dir, exist_ok=True)
    with open(os.path.join(txt_dir, file.filename), "wb") as f:  
        f.write(await file.read())
    return {"msg": "File upload success in directory 'Data/'"}


@app.get("/listfiles/{dirpath:path}")
async def listfiles(dirpath: str):
    # return {"file_path": file_path}
    res = os.listdir(dirpath)
    return {"files": res}


@app.get("/cleardir")
async def cleardir():
    for filename in os.listdir(wav_dir):  
        file_path = os.path.join(wav_dir, filename)  
        try:  
            if os.path.isfile(file_path):  
                os.remove(file_path)  
        except Exception as e:  
            print(f"Error: {e}")  
    return {"msg": "Director cleared and empty"}

@app.get("/deletetxt")
async def deletetxt():
    # 指定要删除的文件路径  
    file_path = "./Data/test.txt"  
    # 检查文件是否存在  
    if os.path.isfile(file_path):  
        # 删除文件  
        os.remove(file_path)
        return {"msg": "File remoeved"}
    else:
        return {"msg": "File not exist"}

@app.get("/downloadfile/{filename}")
async def downloadfile(filename: str):
    file_path = os.path.join("res/ptts_syn/res_wavs", filename)
    if os.path.exists(file_path):
        return FileResponse(file_path)
    else:
        return {
            "msg": "File not exis"
        }


@app.get("/wav2label")
async def wav2label():
    report = wav_to_label(wav_dir)
    return report

@app.get("/featsextract")
async def featsextract():
    # 执行Shell脚本
    f = os.popen('./feats_extract.sh')
    return f.read()

@app.get("/trainsambert")
async def trainsambert():
    train_sambert(
        "speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k/basemodel_16k/sambert/config.yaml",
        "training_stage/ptts_feats",
        "training_stage/ptts_sambert_ckpt",
        "speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k/basemodel_16k/sambert/ckpt/checkpoint_2400000.pth"
    )
    return {"msg": "Traing finished"}


@app.get("/modeltransform")
async def modeltransform():
    text_to_wav_trans(
        "./Data/test.txt",
        "res/ptts_syn_one",
        "speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k/resource.zip",
        "training_stage/ptts_sambert_ckpt/ckpt/checkpoint_2402200.pth",
        "speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k//basemodel_16k/hifigan/ckpt/checkpoint_2400000.pth",
        "F7",
        "training_stage/ptts_feats/se/se.npy"
    )

    return {"msg": "Model transform finished"}


# from enum import Enum

# class TargetRate(Enum):
#     rate0 = 1.0
#     rate1 = 0.5
#     rate2 = 0.75   
#     rate3 = 1.25
#     rate4 = 1.5
#     rate5= 1.75
#     rate6 = 2.0

@app.get("/text2wav/{targetrate}")
async def text2wav(targetrate: float=1.0):
    text_to_wav_onnx(
        "./Data/test.txt",
        output_dir,
        "speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k/resource.zip",
        "sambert_onnx/text_encoder.onnx",
        "sambert_onnx/variance_adaptor_dict.pt",
        "sambert_onnx/mel_decoder_dict.pt",
        "sambert_onnx/mel_postnet.onnx",
        "training_stage/ptts_sambert_ckpt/config.yaml",
        "hifigan_onnx/hifigan.onnx",
        "speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k//basemodel_16k/hifigan/config.yaml", 
        targetrate,
        "F7",
        "training_stage/ptts_feats/se/se.npy"
    )

    return {"msg": "Text to wav finished"}



@app.get("/oneclickstart/{targetrate}")
async def oneclickstart(targetrate: float=1.0):
    report = wav_to_label(wav_dir)

    f = os.popen('./feats_extract.sh')

    train_sambert(
        "speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k/basemodel_16k/sambert/config.yaml",
        "training_stage/ptts_feats",
        "training_stage/ptts_sambert_ckpt",
        "speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k/basemodel_16k/sambert/ckpt/checkpoint_2400000.pth"
    )


    text_to_wav_trans(
        "./Data/test.txt",
        "res/ptts_syn_one",
        "speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k/resource.zip",
        "training_stage/ptts_sambert_ckpt/ckpt/checkpoint_2402200.pth",
        "speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k//basemodel_16k/hifigan/ckpt/checkpoint_2400000.pth",
        "F7",
        "training_stage/ptts_feats/se/se.npy"
    )

    text_to_wav_onnx(
        "./Data/test.txt",
        output_dir,
        "speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k/resource.zip",
        "sambert_onnx/text_encoder.onnx",
        "sambert_onnx/variance_adaptor_dict.pt",
        "sambert_onnx/mel_decoder_dict.pt",
        "sambert_onnx/mel_postnet.onnx",
        "training_stage/ptts_sambert_ckpt/config.yaml",
        "hifigan_onnx/hifigan.onnx",
        "speech_personal_sambert-hifigan_nsf_tts_zh-cn_pretrain_16k//basemodel_16k/hifigan/config.yaml", 
        targetrate,
        "F7",
        "training_stage/ptts_feats/se/se.npy"
    )

    return {"msg": "Text to wav finished"}



if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)
