import os import asyncio import io import traceback from fastapi import FastAPI, Request, Response, File, UploadFile, Form from fastapi.responses import JSONResponse from contextlib import asynccontextmanager from fastapi.middleware.cors import CORSMiddleware import uvicorn import argparse import json import time import soundfile as sf from typing import List, Optional, Union from loguru import logger logger.add("logs/api_server_v2.log", rotation="10 MB", retention=10, level="DEBUG", enqueue=True) from indextts.infer_vllm_v2 import IndexTTS2 tts = None @asynccontextmanager async def lifespan(app: FastAPI): global tts tts = IndexTTS2( model_dir=args.model_dir, is_fp16=args.is_fp16, gpu_memory_utilization=args.gpu_memory_utilization, qwenemo_gpu_memory_utilization=args.qwenemo_gpu_memory_utilization, ) yield app = FastAPI(lifespan=lifespan) # Add CORS middleware configuration app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allows all origins, change in production for security allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/health") async def health_check(): """Health check endpoint""" if tts is None: return JSONResponse( status_code=503, content={ "status": "unhealthy", "message": "TTS model not initialized" } ) return JSONResponse( status_code=200, content={ "status": "healthy", "message": "Service is running", "timestamp": time.time() } ) @app.post("/tts_url", responses={ 200: {"content": {"application/octet-stream": {}}}, 500: {"content": {"application/json": {}}} }) async def tts_api_url(request: Request): try: data = await request.json() emo_control_method = data.get("emo_control_method", 0) text = data["text"] spk_audio_path = data["spk_audio_path"] emo_ref_path = data.get("emo_ref_path", None) emo_weight = data.get("emo_weight", 1.0) emo_vec = data.get("emo_vec", [0] * 8) emo_text = data.get("emo_text", None) emo_random = data.get("emo_random", False) max_text_tokens_per_sentence = data.get("max_text_tokens_per_sentence", 120) global tts if type(emo_control_method) is not int: emo_control_method = emo_control_method.value if emo_control_method == 0: emo_ref_path = None emo_weight = 1.0 if emo_control_method == 1: emo_weight = emo_weight if emo_control_method == 2: vec = emo_vec vec_sum = sum(vec) if vec_sum > 1.5: return JSONResponse( status_code=500, content={ "status": "error", "error": "情感向量之和不能超过1.5,请调整后重试。" } ) else: vec = None # logger.info(f"Emo control mode:{emo_control_method}, vec:{vec}") sr, wav = await tts.infer(spk_audio_prompt=spk_audio_path, text=text, output_path=None, emo_audio_prompt=emo_ref_path, emo_alpha=emo_weight, emo_vector=vec, use_emo_text=(emo_control_method==3), emo_text=emo_text,use_random=emo_random, max_text_tokens_per_sentence=int(max_text_tokens_per_sentence)) with io.BytesIO() as wav_buffer: sf.write(wav_buffer, wav, sr, format='WAV') wav_bytes = wav_buffer.getvalue() return Response(content=wav_bytes, media_type="audio/wav") except Exception as ex: tb_str = ''.join(traceback.format_exception(type(ex), ex, ex.__traceback__)) return JSONResponse( status_code=500, content={ "status": "error", "error": str(tb_str) } ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=6006) parser.add_argument("--model_dir", type=str, default="checkpoints/IndexTTS-2-vLLM", help="Model checkpoints directory") parser.add_argument("--is_fp16", action="store_true", default=False, help="Fp16 infer") parser.add_argument("--gpu_memory_utilization", type=float, default=0.25) parser.add_argument("--qwenemo_gpu_memory_utilization", type=float, default=0.10) parser.add_argument("--verbose", action="store_true", default=False, help="Enable verbose mode") args = parser.parse_args() if not os.path.exists("outputs"): os.makedirs("outputs") uvicorn.run(app=app, host=args.host, port=args.port)