Commit ab1b2790 authored by LiangLiu's avatar LiangLiu Committed by GitHub
Browse files

Deploy server and worker (#284)



* Init deploy: not ok

* Test data_manager & task_manager

* pipeline is no need for worker

* Update worker text_encoder

* deploy: submit task

* add apis

* Test pipelineRunner

* Fix pipeline

* Tidy worker & test PipelineWorker ok

* Tidy code

* Fix multi_stage for wan2.1 t2v & i2v

* api query task, get result & report subtasks failed when workers stop

* Add model list functionality to Pipeline and API

* Add task cancel and task resume  to API

* Add RabbitMQ queue manager

* update local task manager atomic

* support postgreSQL task manager, add lifespan async init

* worker print -> logger

* Add S3 data manager, delete temp objects after finished.

* fix worker

* release fetch queue msg when closed, run stuck worker in another thread, stop worker when process down.

* DiTWorker run with thread & tidy logger print

* Init monitor without test

* fix monitor

* github OAuth and jwt token access & static demo html page

* Add user to task, ok for local task manager & update demo ui

* sql task manager support users

* task list with pages

* merge main fix

* Add proxy for auth request

* support wan audio

* worker ping subtask and ping life, fix rabbitmq async get,

* s3 data manager with async api & tidy monitor config

* fix merge main & update req.txt & fix html view video error

* Fix distributed worker

* LImit user visit freq

* Tidy

* Fix only rank save

* Fix audio input

* Fix worker fetch None

* index.html abs path to rel path

* Fix dist worker stuck

* support publish output video to rtmp & graceful stop running dit step or segment step

* Add VAReader

* Enhance VAReader with torch dist

* Fix audio stream input

* fix merge refractor main, support stream input_audio and output_video

* fix audio read with prev frames & fix end take frames & tidy worker end

* split audio model to 4 workers & fix audio end frame

* fix ping subtask with queue

* Fix audio worker put block & add whep, whip without test ok

* Tidy va recorder & va reader log, thread canel within 30s

* Fix dist worker stuck: broadcast stop signal

* Tidy

* record task active_elapse & subtask status_elapse

* Design prometheus metrics

* Tidy prometheus metrics

* Fix merge main

* send sigint to ffmpeg process

* Fix gstreamer pull audio by whep & Dockerfile for gstreamer & check params when submitting

* Fix merge main

* Query task with more info & va_reader buffer size = 1

* Fix va_recorder

* Add config for prev_frames

* update frontend

* update frontend

* update frontend

* update frontend
merge

* update frontend & partial backend

* Different rank for va_recorder and va_reader

* Fix mem leak: only one rank publish video, other rank should pop gen vids

* fix task category

* va_reader pre-alloc tensor & va_recorder send frames all & fix dist cancel infer

* Fix prev_frame_length

* Tidy

* Tidy

* update frontend & backend

* Fix lint error

* recover some files

* Tidy

* lint code

---------
Co-authored-by: default avatarliuliang1 <liuliang1@sensetime.com>
Co-authored-by: default avatarunknown <qinxinyi@sensetime.com>
parent acacd26f
# For rtc whep, build gstreamer whith whepsrc plugin
FROM registry.ms-sc-01.maoshanwangtech.com/ms-ccr/lightx2v:25080601-cu128-SageSm90 AS gstreamer-base
RUN apt update -y \
&& apt update -y \
&& apt install -y libssl-dev flex bison \
libgtk-3-dev libpango1.0-dev libsoup2.4-dev \
libnice-dev libopus-dev libvpx-dev libx264-dev \
libsrtp2-dev libglib2.0-dev libdrm-dev
RUN cd /opt \
&& wget https://mirrors.tuna.tsinghua.edu.cn/gnu//libiconv/libiconv-1.15.tar.gz \
&& tar zxvf libiconv-1.15.tar.gz \
&& cd libiconv-1.15 \
&& ./configure \
&& make \
&& make install
RUN pip install meson
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable
ENV PATH=/root/.cargo/bin:$PATH
RUN cd /opt \
&& git clone https://github.com/GStreamer/gstreamer.git -b 1.24.12 --depth 1 \
&& cd gstreamer \
&& meson setup builddir \
&& meson compile -C builddir \
&& meson install -C builddir \
&& ldconfig
RUN cd /opt \
&& git clone https://github.com/GStreamer/gst-plugins-rs.git -b gstreamer-1.24.12 --depth 1 \
&& cd gst-plugins-rs \
&& cargo build --package gst-plugin-webrtchttp --release \
&& install -m 644 target/release/libgstwebrtchttp.so $(pkg-config --variable=pluginsdir gstreamer-1.0)/
# Lightx2v deploy image
FROM registry.ms-sc-01.maoshanwangtech.com/ms-ccr/lightx2v:25080601-cu128-SageSm90
RUN mkdir /workspace/lightx2v
WORKDIR /workspace/lightx2v
ENV PYTHONPATH=/workspace/lightx2v
COPY requirements.txt requirements.txt
RUN pip install -r requirements.txt
RUN conda install conda-forge::ffmpeg=8.0.0 -y
RUN rm /usr/bin/ffmpeg && ln -s /opt/conda/bin/ffmpeg /usr/bin/ffmpeg
RUN apt update -y \
&& apt install -y libssl-dev \
libgtk-3-dev libpango1.0-dev libsoup2.4-dev \
libnice-dev libopus-dev libvpx-dev libx264-dev \
libsrtp2-dev libglib2.0-dev libdrm-dev
ENV LBDIR=/usr/local/lib/x86_64-linux-gnu
COPY --from=gstreamer-base /usr/local/bin/gst-* /usr/local/bin/
COPY --from=gstreamer-base $LBDIR $LBDIR
RUN ldconfig
ENV LD_LIBRARY_PATH=$LBDIR:$LD_LIBRARY_PATH
RUN gst-launch-1.0 --version
RUN gst-inspect-1.0 whepsrc
COPY assets assets
COPY configs configs
COPY lightx2v lightx2v
COPY lightx2v_kernel lightx2v_kernel
{
"data":
{
"t2v": {
"wan2.1-1.3B": {
"single_stage": {
"pipeline": {
"inputs": [],
"outputs": ["output_video"]
}
},
"multi_stage": {
"text_encoder": {
"inputs": [],
"outputs": ["text_encoder_output"]
},
"dit": {
"inputs": ["text_encoder_output"],
"outputs": ["latents"]
},
"vae_decoder": {
"inputs": ["latents"],
"outputs": ["output_video"]
}
}
}
},
"i2v": {
"wan2.1-14B-480P": {
"single_stage": {
"pipeline": {
"inputs": ["input_image"],
"outputs": ["output_video"]
}
},
"multi_stage": {
"text_encoder": {
"inputs": ["input_image"],
"outputs": ["text_encoder_output"]
},
"image_encoder": {
"inputs": ["input_image"],
"outputs": ["clip_encoder_output"]
},
"vae_encoder": {
"inputs": ["input_image"],
"outputs": ["vae_encoder_output"]
},
"dit": {
"inputs": [
"clip_encoder_output",
"vae_encoder_output",
"text_encoder_output"
],
"outputs": ["latents"]
},
"vae_decoder": {
"inputs": ["latents"],
"outputs": ["output_video"]
}
}
},
"SekoTalk-Distill": {
"single_stage": {
"pipeline": {
"inputs": ["input_image", "input_audio"],
"outputs": ["output_video"]
}
},
"multi_stage": {
"text_encoder": {
"inputs": ["input_image"],
"outputs": ["text_encoder_output"]
},
"image_encoder": {
"inputs": ["input_image"],
"outputs": ["clip_encoder_output"]
},
"vae_encoder": {
"inputs": ["input_image"],
"outputs": ["vae_encoder_output"]
},
"segment_dit": {
"inputs": [
"input_audio",
"clip_encoder_output",
"vae_encoder_output",
"text_encoder_output"
],
"outputs": ["output_video"]
}
}
}
}
},
"meta": {
"special_types": {
"input_image": "IMAGE",
"input_audio": "AUDIO",
"latents": "TENSOR",
"output_video": "VIDEO"
},
"monitor": {
"subtask_created_timeout": 1800,
"subtask_pending_timeout": 1800,
"subtask_running_timeouts": {
"t2v-wan2.1-1.3B-multi_stage-dit": 300,
"t2v-wan2.1-1.3B-single_stage-pipeline": 300,
"i2v-wan2.1-14B-480P-multi_stage-dit": 600,
"i2v-wan2.1-14B-480P-single_stage-pipeline": 600,
"i2v-SekoTalk-Distill-single_stage-pipeline": 3600,
"i2v-SekoTalk-Distill-multi_stage-segment_dit": 3600
},
"worker_avg_window": 20,
"worker_offline_timeout": 5,
"worker_min_capacity": 20,
"worker_min_cnt": 1,
"worker_max_cnt": 10,
"task_timeout": 3600,
"schedule_ratio_high": 0.25,
"schedule_ratio_low": 0.02,
"ping_timeout": 30,
"user_max_active_tasks": 3,
"user_max_daily_tasks": 100,
"user_visit_frequency": 0.05
}
}
}
import json
import sys
from loguru import logger
class Pipeline:
def __init__(self, pipeline_json_file):
self.pipeline_json_file = pipeline_json_file
x = json.load(open(pipeline_json_file))
self.data = x["data"]
self.meta = x["meta"]
self.inputs = {}
self.outputs = {}
self.temps = {}
self.model_lists = []
self.types = {}
self.queues = set()
self.tidy_pipeline()
def init_dict(self, base, task, model_cls):
if task not in base:
base[task] = {}
if model_cls not in base[task]:
base[task][model_cls] = {}
# tidy each task item eg, ['t2v', 'wan2.1', 'multi_stage']
def tidy_task(self, task, model_cls, stage, v3):
out2worker = {}
out2num = {}
cur_inps = set()
cur_temps = set()
cur_types = {}
for worker_name, worker_item in v3.items():
prevs = []
for inp in worker_item["inputs"]:
cur_types[inp] = self.get_type(inp)
if inp in out2worker:
prevs.append(out2worker[inp])
out2num[inp] -= 1
if out2num[inp] <= 0:
cur_temps.add(inp)
else:
cur_inps.add(inp)
worker_item["previous"] = prevs
for out in worker_item["outputs"]:
cur_types[out] = self.get_type(out)
out2worker[out] = worker_name
if out not in out2num:
out2num[out] = 0
out2num[out] += 1
if "queue" not in worker_item:
worker_item["queue"] = "-".join([task, model_cls, stage, worker_name])
self.queues.add(worker_item["queue"])
cur_outs = [out for out, num in out2num.items() if num > 0]
self.inputs[task][model_cls][stage] = list(cur_inps)
self.outputs[task][model_cls][stage] = cur_outs
self.temps[task][model_cls][stage] = list(cur_temps)
self.types[task][model_cls][stage] = cur_types
# tidy previous dependence workers and queue name
def tidy_pipeline(self):
for task, v1 in self.data.items():
for model_cls, v2 in v1.items():
for stage, v3 in v2.items():
self.init_dict(self.inputs, task, model_cls)
self.init_dict(self.outputs, task, model_cls)
self.init_dict(self.temps, task, model_cls)
self.init_dict(self.types, task, model_cls)
self.tidy_task(task, model_cls, stage, v3)
self.model_lists.append({"task": task, "model_cls": model_cls, "stage": stage})
logger.info(f"pipelines: {json.dumps(self.data, indent=4)}")
logger.info(f"inputs: {self.inputs}")
logger.info(f"outputs: {self.outputs}")
logger.info(f"temps: {self.temps}")
logger.info(f"types: {self.types}")
logger.info(f"model_lists: {self.model_lists}")
logger.info(f"queues: {self.queues}")
def get_item_by_keys(self, keys):
item = self.data
for k in keys:
if k not in item:
raise Exception(f"{keys} are not in {self.pipeline_json_file}!")
item = item[k]
return item
# eg. keys: ['t2v', 'wan2.1', 'multi_stage', 'text_encoder']
def get_worker(self, keys):
return self.get_item_by_keys(keys)
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def get_workers(self, keys):
return self.get_item_by_keys(keys)
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def get_inputs(self, keys):
item = self.inputs
for k in keys:
if k not in item:
raise Exception(f"{keys} are not in inputs!")
item = item[k]
return item
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def get_outputs(self, keys):
item = self.outputs
for k in keys:
if k not in item:
raise Exception(f"{keys} are not in outputs!")
item = item[k]
return item
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def get_temps(self, keys):
item = self.temps
for k in keys:
if k not in item:
raise Exception(f"{keys} are not in temps!")
item = item[k]
return item
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def get_types(self, keys):
item = self.types
for k in keys:
if k not in item:
raise Exception(f"{keys} are not in types!")
item = item[k]
return item
def get_model_lists(self):
return self.model_lists
def get_type(self, name):
return self.meta["special_types"].get(name, "OBJECT")
def get_monitor_config(self):
return self.meta["monitor"]
def get_queues(self):
return self.queues
if __name__ == "__main__":
pipeline = Pipeline(sys.argv[1])
print(pipeline.get_workers(["t2v", "wan2.1", "multi_stage"]))
print(pipeline.get_worker(["i2v", "wan2.1", "multi_stage", "dit"]))
import base64
import io
import os
import time
import traceback
from datetime import datetime
import httpx
import torchaudio
from PIL import Image
from loguru import logger
FMT = "%Y-%m-%d %H:%M:%S"
def current_time():
return datetime.now().timestamp()
def time2str(t):
d = datetime.fromtimestamp(t)
return d.strftime(FMT)
def str2time(s):
d = datetime.strptime(s, FMT)
return d.timestamp()
def try_catch(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception:
logger.error(f"Error in {func.__name__}:")
traceback.print_exc()
return None
return wrapper
def class_try_catch(func):
def wrapper(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except Exception:
logger.error(f"Error in {self.__class__.__name__}.{func.__name__}:")
traceback.print_exc()
return None
return wrapper
def class_try_catch_async(func):
async def wrapper(self, *args, **kwargs):
try:
return await func(self, *args, **kwargs)
except Exception:
logger.error(f"Error in {self.__class__.__name__}.{func.__name__}:")
traceback.print_exc()
return None
return wrapper
def data_name(x, task_id):
if x == "input_image":
x = x + ".png"
elif x == "output_video":
x = x + ".mp4"
return f"{task_id}-{x}"
async def fetch_resource(url, timeout):
logger.info(f"Begin to download resource from url: {url}")
t0 = time.time()
async with httpx.AsyncClient() as client:
async with client.stream("GET", url, timeout=timeout) as response:
response.raise_for_status()
ans_bytes = []
async for chunk in response.aiter_bytes(chunk_size=1024 * 1024):
ans_bytes.append(chunk)
if len(ans_bytes) > 128:
raise Exception(f"url {url} recv data is too big")
content = b"".join(ans_bytes)
logger.info(f"Download url {url} resource cost time: {time.time() - t0} seconds")
return content
async def preload_data(inp, inp_type, typ, val):
try:
if typ == "url":
timeout = int(os.getenv("REQUEST_TIMEOUT", "5"))
data = await fetch_resource(val, timeout=timeout)
elif typ == "base64":
data = base64.b64decode(val)
elif typ == "stream":
# no bytes data need to be saved by data_manager
data = None
else:
raise ValueError(f"cannot read {inp}[{inp_type}] which type is {typ}!")
# check if valid image bytes
if inp_type == "IMAGE":
image = Image.open(io.BytesIO(data))
logger.info(f"load image: {image.size}")
assert image.size[0] > 0 and image.size[1] > 0, "image is empty"
elif inp_type == "AUDIO":
if typ != "stream":
try:
waveform, sample_rate = torchaudio.load(io.BytesIO(data), num_frames=10)
logger.info(f"load audio: {waveform.size()}, {sample_rate}")
assert waveform.size(0) > 0, "audio is empty"
assert sample_rate > 0, "audio sample rate is not valid"
except Exception as e:
logger.warning(f"torchaudio failed to load audio, trying alternative method: {e}")
# 尝试使用其他方法验证音频文件
# 检查文件头是否为有效的音频格式
if len(data) < 4:
raise ValueError("Audio file too short")
# 检查常见的音频文件头
audio_headers = [b"RIFF", b"ID3", b"\xff\xfb", b"\xff\xf3", b"\xff\xf2", b"OggS"]
if not any(data.startswith(header) for header in audio_headers):
logger.warning("Audio file doesn't have recognized header, but continuing...")
logger.info(f"Audio validation passed (alternative method), size: {len(data)} bytes")
else:
raise Exception(f"cannot parse inp_type={inp_type} data")
return data
except Exception as e:
raise ValueError(f"Failed to read {inp}, type={typ}, val={val[:100]}: {e}!")
async def load_inputs(params, raw_inputs, types):
inputs_data = {}
for inp in raw_inputs:
item = params.pop(inp)
bytes_data = await preload_data(inp, types[inp], item["type"], item["data"])
if bytes_data is not None:
inputs_data[inp] = bytes_data
else:
params[inp] = item
return inputs_data
def check_params(params, raw_inputs, raw_outputs, types):
stream_audio = os.getenv("STREAM_AUDIO", "0") == "1"
stream_video = os.getenv("STREAM_VIDEO", "0") == "1"
for x in raw_inputs + raw_outputs:
if x in params and "type" in params[x] and params[x]["type"] == "stream":
if types[x] == "AUDIO":
assert stream_audio, "stream audio is not supported, please set env STREAM_AUDIO=1"
elif types[x] == "VIDEO":
assert stream_video, "stream video is not supported, please set env STREAM_VIDEO=1"
import os
import queue
import signal
import subprocess
import threading
import time
import traceback
import numpy as np
import torch
import torch.distributed as dist
from loguru import logger
class VAReader:
def __init__(
self,
rank: int,
world_size: int,
stream_url: str,
segment_duration: float = 5.0,
sample_rate: int = 16000,
audio_channels: int = 1,
buffer_size: int = 1,
prev_duration: float = 0.3125,
target_rank: int = 0,
):
self.rank = rank
self.world_size = world_size
self.stream_url = stream_url
self.segment_duration = segment_duration
self.sample_rate = sample_rate
self.audio_channels = audio_channels
self.prev_duration = prev_duration
# int16 = 2 bytes
self.chunk_size = int(self.segment_duration * self.sample_rate) * 2
self.prev_size = int(self.prev_duration * self.sample_rate) * 2
self.prev_chunk = None
self.buffer_size = buffer_size
self.audio_queue = queue.Queue(maxsize=self.buffer_size)
self.audio_thread = None
self.ffmpeg_process = None
self.bytes_buffer = bytearray()
self.target_rank = target_rank % self.world_size
self.flag_tensor = torch.tensor([0], dtype=torch.int32).to(device="cuda")
self.audio_tensor = torch.zeros(self.chunk_size, dtype=torch.uint8, device="cuda")
logger.info(f"VAReader initialized for stream: {stream_url} target_rank: {self.target_rank}")
logger.info(f"Audio duration per chunk: {segment_duration}s, sample rate: {sample_rate}Hz")
def start(self):
if self.rank == self.target_rank:
if self.stream_url.startswith("rtmp://"):
self.start_ffmpeg_process_rtmp()
elif self.stream_url.startswith("http"):
self.start_ffmpeg_process_whep()
else:
raise Exception(f"Unsupported stream URL: {self.stream_url}")
self.audio_thread = threading.Thread(target=self.audio_worker, daemon=True)
self.audio_thread.start()
logger.info(f"VAReader {self.rank}/{self.world_size} started successfully")
else:
logger.info(f"VAReader {self.rank}/{self.world_size} wait only")
if self.world_size > 1:
logger.info(f"VAReader {self.rank}/{self.world_size} wait barrier")
dist.barrier()
logger.info(f"VAReader {self.rank}/{self.world_size} end barrier")
def start_ffmpeg_process_rtmp(self):
"""Start ffmpeg process read audio from stream"""
ffmpeg_cmd = [
"/opt/conda/bin/ffmpeg",
"-i",
self.stream_url,
"-vn",
# "-acodec",
# "pcm_s16le",
"-ar",
str(self.sample_rate),
"-ac",
str(self.audio_channels),
"-f",
"s16le",
"-",
]
try:
self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=0)
logger.info(f"FFmpeg audio pull process started with PID: {self.ffmpeg_process.pid}")
logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
except Exception as e:
logger.error(f"Failed to start FFmpeg process: {e}")
raise
def start_ffmpeg_process_whep(self):
"""Start gstream process read audio from stream"""
ffmpeg_cmd = [
"gst-launch-1.0",
"-q",
"whepsrc",
f"whep-endpoint={self.stream_url}",
"video-caps=none",
"!rtpopusdepay",
"!opusdec",
"plc=false",
"!audioconvert",
"!audioresample",
f"!audio/x-raw,format=S16LE,channels={self.audio_channels},rate={self.sample_rate}",
"!fdsink",
"fd=1",
]
try:
self.ffmpeg_process = subprocess.Popen(
ffmpeg_cmd,
stdout=subprocess.PIPE,
# stderr=subprocess.PIPE,
bufsize=0,
)
logger.info(f"FFmpeg audio pull process started with PID: {self.ffmpeg_process.pid}")
logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
except Exception as e:
logger.error(f"Failed to start FFmpeg process: {e}")
raise
def audio_worker(self):
logger.info("Audio pull worker thread started")
try:
while True:
if not self.ffmpeg_process or self.ffmpeg_process.poll() is not None:
logger.warning("FFmpeg process exited, audio worker thread stopped")
break
self.fetch_audio_data()
time.sleep(0.01)
except: # noqa
logger.error(f"Audio pull worker error: {traceback.format_exc()}")
finally:
logger.warning("Audio pull worker thread stopped")
def fetch_audio_data(self):
"""Fetch audio data from ffmpeg process"""
try:
audio_bytes = self.ffmpeg_process.stdout.read(self.chunk_size)
if not audio_bytes:
return
self.bytes_buffer.extend(audio_bytes)
# logger.info(f"Fetch audio data: {len(audio_bytes)} bytes, bytes_buffer: {len(self.bytes_buffer)} bytes")
if len(self.bytes_buffer) >= self.chunk_size:
audio_data = self.bytes_buffer[: self.chunk_size]
self.bytes_buffer = self.bytes_buffer[self.chunk_size :]
# first chunk, read original 81 frames
# for other chunks, read 81 - 5 = 76 frames, concat with previous 5 frames
if self.prev_chunk is None:
logger.info(f"change chunk_size: from {self.chunk_size} to {self.chunk_size - self.prev_size}")
self.chunk_size -= self.prev_size
else:
audio_data = self.prev_chunk + audio_data
self.prev_chunk = audio_data[-self.prev_size :]
try:
self.audio_queue.put_nowait(audio_data)
except queue.Full:
logger.warning(f"Audio queue full:{self.audio_queue.qsize()}, discarded oldest chunk")
self.audio_queue.get_nowait()
self.audio_queue.put_nowait(audio_data)
logger.info(f"Put audio data: {len(audio_data)} bytes, audio_queue: {self.audio_queue.qsize()}, chunk_size:{self.chunk_size}")
except: # noqa
logger.error(f"Fetch audio data error: {traceback.format_exc()}")
def braodcast_audio_data(self, audio_data):
if self.rank == self.target_rank:
if audio_data is None:
self.flag_tensor.fill_(0)
else:
self.flag_tensor.fill_(1)
self.audio_tensor.copy_(torch.frombuffer(bytearray(audio_data), dtype=torch.uint8))
logger.info(f"rank {self.rank} send audio_tensor: {self.audio_tensor.shape}")
dist.broadcast(self.flag_tensor, src=self.target_rank)
if self.flag_tensor.item() == 0:
return None
dist.broadcast(self.audio_tensor, src=self.target_rank)
if self.rank != self.target_rank:
logger.info(f"rank {self.rank} recv audio_tensor: {self.audio_tensor.shape}")
audio_data = self.audio_tensor.cpu().numpy().tobytes()
return audio_data
def bytes_to_ndarray(self, audio_data):
if audio_data is None:
return None
audio_data = np.frombuffer(audio_data, dtype=np.int16)
audio_data = audio_data.astype(np.float32) / 32768.0
logger.info(f"Got segment audio rank={self.rank}: {audio_data.shape} {audio_data.dtype} {audio_data.min()} {audio_data.max()}")
return audio_data
def get_audio_segment(self, timeout: float = 1.0):
audio_data = None
if self.rank == self.target_rank:
try:
audio_data = self.audio_queue.get(timeout=timeout)
except: # noqa
logger.warning(f"Failed to get audio segment: {traceback.format_exc()}")
if self.world_size > 1:
audio_data = self.braodcast_audio_data(audio_data)
audio_data = self.bytes_to_ndarray(audio_data)
return audio_data
def stop(self):
# Stop ffmpeg process
if self.ffmpeg_process:
self.ffmpeg_process.send_signal(signal.SIGINT)
try:
self.ffmpeg_process.wait(timeout=5)
except subprocess.TimeoutExpired:
self.ffmpeg_process.kill()
logger.warning("FFmpeg reader process stopped")
# Wait for threads to finish
if self.audio_thread and self.audio_thread.is_alive():
self.audio_thread.join(timeout=5)
if self.audio_thread.is_alive():
logger.error("Audio pull thread did not stop gracefully")
while self.audio_queue and self.audio_queue.qsize() > 0:
self.audio_queue.get_nowait()
self.audio_queue = None
logger.warning("Audio pull queue cleaned")
def __del__(self):
self.stop()
if __name__ == "__main__":
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
RANK = int(os.environ.get("RANK", 0))
if WORLD_SIZE > 1:
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
logger.info(f"Distributed initialized: rank={RANK}, world_size={WORLD_SIZE}")
reader = VAReader(
RANK,
WORLD_SIZE,
# "rtmp://localhost/live/test_audio",
"https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whep/?app=live&stream=ll_test_audio&eip=10.120.114.76:8000",
segment_duration=1.0,
sample_rate=16000,
audio_channels=1,
prev_duration=1 / 16,
)
reader.start()
fail_count = 0
max_fail_count = 2
try:
while True:
audio_data = reader.get_audio_segment(timeout=2)
if audio_data is not None:
# logger.info(f"Got audio chunk, shape: {audio_data.shape}, range: [{audio_data.min()}, {audio_data.max()}]")
fail_count = 0
else:
fail_count += 1
if fail_count > max_fail_count:
logger.warning("Failed to get audio chunk, stop reader")
reader.stop()
break
time.sleep(0.95)
finally:
reader.stop()
import queue
import signal
import socket
import subprocess
import threading
import time
import traceback
import numpy as np
import torch
import torchaudio as ta
from loguru import logger
class VARecorder:
def __init__(
self,
livestream_url: str,
fps: float = 16.0,
sample_rate: int = 16000,
audio_port: int = 30200,
video_port: int = 30201,
):
self.livestream_url = livestream_url
self.fps = fps
self.sample_rate = sample_rate
self.audio_port = audio_port
self.video_port = video_port
self.width = None
self.height = None
self.stoppable_t = None
# ffmpeg process for mix video and audio data and push to livestream
self.ffmpeg_process = None
# TCP connection objects
self.audio_socket = None
self.video_socket = None
self.audio_conn = None
self.video_conn = None
self.audio_thread = None
self.video_thread = None
# queue for send data to ffmpeg process
self.audio_queue = queue.Queue()
self.video_queue = queue.Queue()
def init_sockets(self):
# TCP socket for send and recv video and audio data
self.video_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.video_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.video_socket.bind(("127.0.0.1", self.video_port))
self.video_socket.listen(1)
self.audio_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.audio_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.audio_socket.bind(("127.0.0.1", self.audio_port))
self.audio_socket.listen(1)
def audio_worker(self):
try:
logger.info("Waiting for ffmpeg to connect to audio socket...")
self.audio_conn, _ = self.audio_socket.accept()
logger.info(f"Audio connection established from {self.audio_conn.getpeername()}")
fail_time, max_fail_time = 0, 10
while True:
try:
if self.audio_queue is None:
break
data = self.audio_queue.get()
if data is None:
logger.info("Audio thread received stop signal")
break
# Convert audio data to 16-bit integer format
audios = np.clip(np.round(data * 32767), -32768, 32767).astype(np.int16)
self.audio_conn.send(audios.tobytes())
fail_time = 0
except: # noqa
logger.error(f"Send audio data error: {traceback.format_exc()}")
fail_time += 1
if fail_time > max_fail_time:
logger.error(f"Audio push worker thread failed {fail_time} times, stopping...")
break
except: # noqa
logger.error(f"Audio push worker thread error: {traceback.format_exc()}")
finally:
logger.info("Audio push worker thread stopped")
def video_worker(self):
try:
logger.info("Waiting for ffmpeg to connect to video socket...")
self.video_conn, _ = self.video_socket.accept()
logger.info(f"Video connection established from {self.video_conn.getpeername()}")
fail_time, max_fail_time = 0, 10
while True:
try:
if self.video_queue is None:
break
data = self.video_queue.get()
if data is None:
logger.info("Video thread received stop signal")
break
# Convert to numpy and scale to [0, 255], convert RGB to BGR for OpenCV/FFmpeg
frames = (data * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
self.video_conn.send(frames.tobytes())
fail_time = 0
except: # noqa
logger.error(f"Send video data error: {traceback.format_exc()}")
fail_time += 1
if fail_time > max_fail_time:
logger.error(f"Video push worker thread failed {fail_time} times, stopping...")
break
except: # noqa
logger.error(f"Video push worker thread error: {traceback.format_exc()}")
finally:
logger.info("Video push worker thread stopped")
def start_ffmpeg_process_rtmp(self):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd = [
"/opt/conda/bin/ffmpeg",
"-re",
"-f",
"s16le",
"-ar",
str(self.sample_rate),
"-ac",
"1",
"-i",
f"tcp://127.0.0.1:{self.audio_port}",
"-f",
"rawvideo",
"-re",
"-pix_fmt",
"rgb24",
"-r",
str(self.fps),
"-s",
f"{self.width}x{self.height}",
"-i",
f"tcp://127.0.0.1:{self.video_port}",
"-ar",
"44100",
"-b:v",
"4M",
"-c:v",
"libx264",
"-preset",
"ultrafast",
"-tune",
"zerolatency",
"-g",
f"{self.fps}",
"-pix_fmt",
"yuv420p",
"-f",
"flv",
self.livestream_url,
"-y",
"-loglevel",
"info",
]
try:
self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd)
logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}")
logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
except Exception as e:
logger.error(f"Failed to start FFmpeg: {e}")
def start_ffmpeg_process_whip(self):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd = [
"/opt/conda/bin/ffmpeg",
"-re",
"-f",
"s16le",
"-ar",
str(self.sample_rate),
"-ac",
"1",
"-i",
f"tcp://127.0.0.1:{self.audio_port}",
"-f",
"rawvideo",
"-re",
"-pix_fmt",
"rgb24",
"-r",
str(self.fps),
"-s",
f"{self.width}x{self.height}",
"-i",
f"tcp://127.0.0.1:{self.video_port}",
"-ar",
"48000",
"-c:a",
"libopus",
"-ac",
"2",
"-b:v",
"4M",
"-c:v",
"libx264",
"-preset",
"ultrafast",
"-tune",
"zerolatency",
"-g",
f"{self.fps}",
"-pix_fmt",
"yuv420p",
"-threads",
"1",
"-bf",
"0",
"-f",
"whip",
self.livestream_url,
"-y",
"-loglevel",
"info",
]
try:
self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd)
logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}")
logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
except Exception as e:
logger.error(f"Failed to start FFmpeg: {e}")
def set_video_size(self, width: int, height: int):
if self.width is not None and self.height is not None:
assert self.width == width and self.height == height, "Video size already set"
return
self.width = width
self.height = height
self.init_sockets()
if self.livestream_url.startswith("rtmp://"):
self.start_ffmpeg_process_rtmp()
elif self.livestream_url.startswith("http"):
self.start_ffmpeg_process_whip()
else:
raise Exception(f"Unsupported livestream URL: {self.livestream_url}")
self.audio_thread = threading.Thread(target=self.audio_worker)
self.video_thread = threading.Thread(target=self.video_worker)
self.audio_thread.start()
self.video_thread.start()
# Publish ComfyUI Image tensor and audio tensor to livestream
def pub_livestream(self, images: torch.Tensor, audios: np.ndarray):
N, height, width, C = images.shape
M = audios.reshape(-1).shape[0]
assert C == 3, "Input must be [N, H, W, C] with C=3"
logger.info(f"Publishing video [{N}x{width}x{height}], audio: [{M}]")
audio_frames = round(M * self.fps / self.sample_rate)
if audio_frames != N:
logger.warning(f"Video and audio frames mismatch, {N} vs {audio_frames}")
self.set_video_size(width, height)
self.audio_queue.put(audios)
self.video_queue.put(images)
logger.info(f"Published {N} frames and {M} audio samples")
self.stoppable_t = time.time() + M / self.sample_rate + 3
def stop(self, wait=True):
if wait and self.stoppable_t:
t = self.stoppable_t - time.time()
if t > 0:
logger.warning(f"Waiting for {t} seconds to stop ...")
time.sleep(t)
self.stoppable_t = None
# Send stop signals to queues
if self.audio_queue:
self.audio_queue.put(None)
if self.video_queue:
self.video_queue.put(None)
# Wait for threads to finish
if self.audio_thread and self.audio_thread.is_alive():
self.audio_thread.join(timeout=5)
if self.audio_thread.is_alive():
logger.warning("Audio push thread did not stop gracefully")
if self.video_thread and self.video_thread.is_alive():
self.video_thread.join(timeout=5)
if self.video_thread.is_alive():
logger.warning("Video push thread did not stop gracefully")
# Close TCP connections, sockets
if self.audio_conn:
self.audio_conn.close()
if self.video_conn:
self.video_conn.close()
if self.audio_socket:
self.audio_socket.close()
if self.video_socket:
self.video_socket.close()
while self.audio_queue and self.audio_queue.qsize() > 0:
self.audio_queue.get_nowait()
while self.video_queue and self.video_queue.qsize() > 0:
self.video_queue.get_nowait()
self.audio_queue = None
self.video_queue = None
logger.warning("Cleaned audio and video queues")
# Stop ffmpeg process
if self.ffmpeg_process:
self.ffmpeg_process.send_signal(signal.SIGINT)
try:
self.ffmpeg_process.wait(timeout=5)
except subprocess.TimeoutExpired:
self.ffmpeg_process.kill()
logger.warning("FFmpeg recorder process stopped")
def __del__(self):
self.stop(wait=False)
def create_simple_video(frames=10, height=480, width=640):
video_data = []
for i in range(frames):
frame = np.zeros((height, width, 3), dtype=np.float32)
stripe_height = height // 8
colors = [
[1.0, 0.0, 0.0], # 红色
[0.0, 1.0, 0.0], # 绿色
[0.0, 0.0, 1.0], # 蓝色
[1.0, 1.0, 0.0], # 黄色
[1.0, 0.0, 1.0], # 洋红
[0.0, 1.0, 1.0], # 青色
[1.0, 1.0, 1.0], # 白色
[0.5, 0.5, 0.5], # 灰色
]
for j, color in enumerate(colors):
start_y = j * stripe_height
end_y = min((j + 1) * stripe_height, height)
frame[start_y:end_y, :] = color
offset = int((i / frames) * width)
frame = np.roll(frame, offset, axis=1)
frame = torch.tensor(frame, dtype=torch.float32)
video_data.append(frame)
return torch.stack(video_data, dim=0)
if __name__ == "__main__":
sample_rate = 16000
fps = 16
width = 640
height = 480
recorder = VARecorder(
# livestream_url="rtmp://localhost/live/test",
livestream_url="https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whip/?app=live&stream=ll_test_video&eip=127.0.0.1:8000",
fps=fps,
sample_rate=sample_rate,
)
audio_path = "/mtc/liuliang1/lightx2v/test_deploy/media_test/test_b_2min.wav"
audio_array, ori_sr = ta.load(audio_path)
audio_array = ta.functional.resample(audio_array.mean(0), orig_freq=ori_sr, new_freq=16000)
audio_array = audio_array.numpy().reshape(-1)
secs = audio_array.shape[0] // sample_rate
interval = 1
for i in range(0, secs, interval):
logger.info(f"{i} / {secs} s")
start = i * sample_rate
end = (i + interval) * sample_rate
cur_audio_array = audio_array[start:end]
logger.info(f"audio: {cur_audio_array.shape} {cur_audio_array.dtype} {cur_audio_array.min()} {cur_audio_array.max()}")
num_frames = int(interval * fps)
images = create_simple_video(num_frames, height, width)
logger.info(f"images: {images.shape} {images.dtype} {images.min()} {images.max()}")
recorder.pub_livestream(images, cur_audio_array)
time.sleep(interval)
recorder.stop()
import io
import json
import torch
from PIL import Image
from lightx2v.deploy.common.utils import class_try_catch_async
class BaseDataManager:
def __init__(self):
pass
async def init(self):
pass
async def close(self):
pass
def to_device(self, data, device):
if isinstance(data, dict):
return {key: self.to_device(value, device) for key, value in data.items()}
elif isinstance(data, list):
return [self.to_device(item, device) for item in data]
elif isinstance(data, torch.Tensor):
return data.to(device)
else:
return data
async def save_bytes(self, bytes_data, filename):
raise NotImplementedError
async def load_bytes(self, filename):
raise NotImplementedError
async def delete_bytes(self, filename):
raise NotImplementedError
async def recurrent_save(self, data, prefix):
if isinstance(data, dict):
return {k: await self.recurrent_save(v, f"{prefix}-{k}") for k, v in data.items()}
elif isinstance(data, list):
return [await self.recurrent_save(v, f"{prefix}-{idx}") for idx, v in enumerate(data)]
elif isinstance(data, torch.Tensor):
save_path = prefix + ".pt"
await self.save_tensor(data, save_path)
return save_path
elif isinstance(data, Image.Image):
save_path = prefix + ".png"
await self.save_image(data, save_path)
return save_path
else:
return data
async def recurrent_load(self, data, device, prefix):
if isinstance(data, dict):
return {k: await self.recurrent_load(v, device, f"{prefix}-{k}") for k, v in data.items()}
elif isinstance(data, list):
return [await self.recurrent_load(v, device, f"{prefix}-{idx}") for idx, v in enumerate(data)]
elif isinstance(data, str) and data == prefix + ".pt":
return await self.load_tensor(data, device)
elif isinstance(data, str) and data == prefix + ".png":
return await self.load_image(data)
else:
return data
async def recurrent_delete(self, data, prefix):
if isinstance(data, dict):
return {k: await self.recurrent_delete(v, f"{prefix}-{k}") for k, v in data.items()}
elif isinstance(data, list):
return [await self.recurrent_delete(v, f"{prefix}-{idx}") for idx, v in enumerate(data)]
elif isinstance(data, str) and data == prefix + ".pt":
await self.delete_bytes(data)
elif isinstance(data, str) and data == prefix + ".png":
await self.delete_bytes(data)
@class_try_catch_async
async def save_object(self, data, filename):
data = await self.recurrent_save(data, filename)
bytes_data = json.dumps(data, ensure_ascii=False).encode("utf-8")
await self.save_bytes(bytes_data, filename)
@class_try_catch_async
async def load_object(self, filename, device):
bytes_data = await self.load_bytes(filename)
data = json.loads(bytes_data.decode("utf-8"))
data = await self.recurrent_load(data, device, filename)
return data
@class_try_catch_async
async def delete_object(self, filename):
bytes_data = await self.load_bytes(filename)
data = json.loads(bytes_data.decode("utf-8"))
await self.recurrent_delete(data, filename)
await self.delete_bytes(filename)
@class_try_catch_async
async def save_tensor(self, data: torch.Tensor, filename):
buffer = io.BytesIO()
torch.save(data.to("cpu"), buffer)
await self.save_bytes(buffer.getvalue(), filename)
@class_try_catch_async
async def load_tensor(self, filename, device):
bytes_data = await self.load_bytes(filename)
buffer = io.BytesIO(bytes_data)
t = torch.load(io.BytesIO(bytes_data))
t = t.to(device)
return t
@class_try_catch_async
async def save_image(self, data: Image.Image, filename):
buffer = io.BytesIO()
data.save(buffer, format="PNG")
await self.save_bytes(buffer.getvalue(), filename)
@class_try_catch_async
async def load_image(self, filename):
bytes_data = await self.load_bytes(filename)
buffer = io.BytesIO(bytes_data)
img = Image.open(buffer).convert("RGB")
return img
def get_delete_func(self, type):
maps = {
"TENSOR": self.delete_bytes,
"IMAGE": self.delete_bytes,
"OBJECT": self.delete_object,
"VIDEO": self.delete_bytes,
}
return maps[type]
# Import data manager implementations
from .local_data_manager import LocalDataManager # noqa
from .s3_data_manager import S3DataManager # noqa
__all__ = ["BaseDataManager", "LocalDataManager", "S3DataManager"]
import asyncio
import os
from loguru import logger
from lightx2v.deploy.common.utils import class_try_catch_async
from lightx2v.deploy.data_manager import BaseDataManager
class LocalDataManager(BaseDataManager):
def __init__(self, local_dir):
self.local_dir = local_dir
self.name = "local"
if not os.path.exists(self.local_dir):
os.makedirs(self.local_dir)
@class_try_catch_async
async def save_bytes(self, bytes_data, filename):
out_path = os.path.join(self.local_dir, filename)
with open(out_path, "wb") as fout:
fout.write(bytes_data)
return True
@class_try_catch_async
async def load_bytes(self, filename):
inp_path = os.path.join(self.local_dir, filename)
with open(inp_path, "rb") as fin:
return fin.read()
@class_try_catch_async
async def delete_bytes(self, filename):
inp_path = os.path.join(self.local_dir, filename)
os.remove(inp_path)
logger.info(f"deleted local file {filename}")
return True
async def test():
import torch
from PIL import Image
m = LocalDataManager("/data/nvme1/liuliang1/lightx2v/local_data")
await m.init()
img = Image.open("/data/nvme1/liuliang1/lightx2v/assets/img_lightx2v.png")
tensor = torch.Tensor([233, 456, 789]).to(dtype=torch.bfloat16, device="cuda:0")
await m.save_image(img, "test_img.png")
print(await m.load_image("test_img.png"))
await m.save_tensor(tensor, "test_tensor.pt")
print(await m.load_tensor("test_tensor.pt", "cuda:0"))
await m.save_object(
{
"images": [img, img],
"tensor": tensor,
"list": [
[2, 0, 5, 5],
{
"1": "hello world",
"2": "world",
"3": img,
"t": tensor,
},
"0609",
],
},
"test_object.json",
)
print(await m.load_object("test_object.json", "cuda:0"))
await m.get_delete_func("OBJECT")("test_object.json")
await m.get_delete_func("TENSOR")("test_tensor.pt")
await m.get_delete_func("IMAGE")("test_img.png")
if __name__ == "__main__":
asyncio.run(test())
import asyncio
import hashlib
import json
import os
import aioboto3
from botocore.client import Config
from loguru import logger
from lightx2v.deploy.common.utils import class_try_catch_async
from lightx2v.deploy.data_manager import BaseDataManager
class S3DataManager(BaseDataManager):
def __init__(self, config_string, max_retries=3):
self.name = "s3"
self.config = json.loads(config_string)
self.max_retries = max_retries
self.bucket_name = self.config["bucket_name"]
self.aws_access_key_id = self.config["aws_access_key_id"]
self.aws_secret_access_key = self.config["aws_secret_access_key"]
self.endpoint_url = self.config["endpoint_url"]
self.base_path = self.config["base_path"]
self.connect_timeout = self.config.get("connect_timeout", 60)
self.read_timeout = self.config.get("read_timeout", 10)
self.write_timeout = self.config.get("write_timeout", 10)
self.session = None
self.s3_client = None
async def init(self):
for i in range(self.max_retries):
try:
logger.info(f"S3DataManager init with config: {self.config} (attempt {i + 1}/{self.max_retries}) ...")
self.session = aioboto3.Session()
self.s3_client = await self.session.client(
"s3",
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
endpoint_url=self.endpoint_url,
config=Config(
signature_version="s3v4",
s3={"payload_signing_enabled": True},
connect_timeout=self.connect_timeout,
read_timeout=self.read_timeout,
parameter_validation=False,
max_pool_connections=50,
),
).__aenter__()
try:
await self.s3_client.head_bucket(Bucket=self.bucket_name)
logger.info(f"check bucket {self.bucket_name} success")
except Exception as e:
logger.info(f"check bucket {self.bucket_name} error: {e}, try to create it...")
await self.s3_client.create_bucket(Bucket=self.bucket_name)
logger.info(f"Successfully init S3 bucket: {self.bucket_name} with timeouts - connect: {self.connect_timeout}s, read: {self.read_timeout}s, write: {self.write_timeout}s")
return
except Exception as e:
logger.warning(f"Failed to connect to S3: {e}")
await asyncio.sleep(1)
async def close(self):
if self.s3_client:
await self.s3_client.__aexit__(None, None, None)
if self.session:
self.session = None
@class_try_catch_async
async def save_bytes(self, bytes_data, filename):
filename = os.path.join(self.base_path, filename)
content_sha256 = hashlib.sha256(bytes_data).hexdigest()
await self.s3_client.put_object(
Bucket=self.bucket_name,
Key=filename,
Body=bytes_data,
ChecksumSHA256=content_sha256,
ContentType="application/octet-stream",
)
return True
@class_try_catch_async
async def load_bytes(self, filename):
filename = os.path.join(self.base_path, filename)
response = await self.s3_client.get_object(Bucket=self.bucket_name, Key=filename)
return await response["Body"].read()
@class_try_catch_async
async def delete_bytes(self, filename):
filename = os.path.join(self.base_path, filename)
await self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename)
logger.info(f"deleted s3 file {filename}")
return True
async def file_exists(self, filename):
filename = os.path.join(self.base_path, filename)
try:
await self.s3_client.head_object(Bucket=self.bucket_name, Key=filename)
return True
except Exception:
return False
async def list_files(self, prefix=""):
prefix = os.path.join(self.base_path, prefix)
response = await self.s3_client.list_objects_v2(Bucket=self.bucket_name, Prefix=prefix)
files = []
if "Contents" in response:
for obj in response["Contents"]:
files.append(obj["Key"])
return files
async def test():
import torch
from PIL import Image
s3_config = {
"aws_access_key_id": "xxx",
"aws_secret_access_key": "xx",
"endpoint_url": "xxx",
"bucket_name": "xxx",
"base_path": "xxx",
"connect_timeout": 10,
"read_timeout": 10,
"write_timeout": 10,
}
m = S3DataManager(json.dumps(s3_config))
await m.init()
img = Image.open("../../../assets/img_lightx2v.png")
tensor = torch.Tensor([233, 456, 789]).to(dtype=torch.bfloat16, device="cuda:0")
await m.save_image(img, "test_img.png")
print(await m.load_image("test_img.png"))
await m.save_tensor(tensor, "test_tensor.pt")
print(await m.load_tensor("test_tensor.pt", "cuda:0"))
await m.save_object(
{
"images": [img, img],
"tensor": tensor,
"list": [
[2, 0, 5, 5],
{
"1": "hello world",
"2": "world",
"3": img,
"t": tensor,
},
"0609",
],
},
"test_object.json",
)
print(await m.load_object("test_object.json", "cuda:0"))
print("all files:", await m.list_files())
await m.get_delete_func("OBJECT")("test_object.json")
await m.get_delete_func("TENSOR")("test_tensor.pt")
await m.get_delete_func("IMAGE")("test_img.png")
print("after delete all files", await m.list_files())
await m.close()
if __name__ == "__main__":
asyncio.run(test())
class BaseQueueManager:
def __init__(self):
pass
async def init(self):
pass
async def close(self):
pass
async def put_subtask(self, subtask):
raise NotImplementedError
async def get_subtasks(self, queue, max_batch, timeout):
raise NotImplementedError
async def pending_num(self, queue):
raise NotImplementedError
# Import queue manager implementations
from .local_queue_manager import LocalQueueManager # noqa
from .rabbitmq_queue_manager import RabbitMQQueueManager # noqa
__all__ = ["BaseQueueManager", "LocalQueueManager", "RabbitMQQueueManager"]
import asyncio
import json
import os
import time
import traceback
from loguru import logger
from lightx2v.deploy.common.utils import class_try_catch_async
from lightx2v.deploy.queue_manager import BaseQueueManager
class LocalQueueManager(BaseQueueManager):
def __init__(self, local_dir):
self.local_dir = local_dir
if not os.path.exists(self.local_dir):
os.makedirs(self.local_dir)
async def get_conn(self):
pass
async def del_conn(self):
pass
async def declare_queue(self, queue):
pass
@class_try_catch_async
async def put_subtask(self, subtask):
out_name = self.get_filename(subtask["queue"])
keys = ["queue", "task_id", "worker_name", "inputs", "outputs", "params"]
msg = json.dumps({k: subtask[k] for k in keys}) + "\n"
logger.info(f"Local published subtask: ({subtask['task_id']}, {subtask['worker_name']}) to {subtask['queue']}")
with open(out_name, "a") as fout:
fout.write(msg)
return True
def read_first_line(self, queue):
out_name = self.get_filename(queue)
if not os.path.exists(out_name):
return None
lines = []
with open(out_name) as fin:
lines = fin.readlines()
if len(lines) <= 0:
return None
subtask = json.loads(lines[0])
msgs = "".join(lines[1:])
fout = open(out_name, "w")
fout.write(msgs)
fout.close()
return subtask
@class_try_catch_async
async def get_subtasks(self, queue, max_batch, timeout):
try:
t0 = time.time()
subtasks = []
while True:
subtask = self.read_first_line(queue)
if subtask:
subtasks.append(subtask)
if len(subtasks) >= max_batch:
return subtasks
else:
continue
else:
if len(subtasks) > 0:
return subtasks
if time.time() - t0 > timeout:
return None
await asyncio.sleep(1)
except asyncio.CancelledError:
logger.warning(f"local queue get_subtasks for {queue} cancelled")
return None
except: # noqa
logger.warning(f"local queue get_subtasks for {queue} failed: {traceback.format_exc()}")
return None
def get_filename(self, queue):
return os.path.join(self.local_dir, f"{queue}.jsonl")
@class_try_catch_async
async def pending_num(self, queue):
out_name = self.get_filename(queue)
if not os.path.exists(out_name):
return 0
lines = []
with open(out_name) as fin:
lines = fin.readlines()
return len(lines)
async def test():
q = LocalQueueManager("/data/nvme1/liuliang1/lightx2v/local_queue")
await q.init()
subtask = {
"task_id": "test-subtask-id",
"queue": "test_queue",
"worker_name": "test_worker",
"inputs": {},
"outputs": {},
"params": {},
}
await q.put_subtask(subtask)
await asyncio.sleep(5)
for i in range(2):
subtask = await q.get_subtasks("test_queue", 3, 5)
print("get subtask:", subtask)
if __name__ == "__main__":
asyncio.run(test())
import asyncio
import json
import traceback
import aio_pika
from loguru import logger
from lightx2v.deploy.common.utils import class_try_catch_async
from lightx2v.deploy.queue_manager import BaseQueueManager
class RabbitMQQueueManager(BaseQueueManager):
def __init__(self, conn_url, max_retries=3):
self.conn_url = conn_url
self.max_retries = max_retries
self.conn = None
self.chan = None
self.queues = set()
async def init(self):
await self.get_conn()
async def close(self):
await self.del_conn()
async def get_conn(self):
if self.chan and self.conn:
return
for i in range(self.max_retries):
try:
logger.info(f"Connect to RabbitMQ (attempt {i + 1}/{self.max_retries}..)")
self.conn = await aio_pika.connect_robust(self.conn_url)
self.chan = await self.conn.channel()
self.queues = set()
await self.chan.set_qos(prefetch_count=10)
logger.info("Successfully connected to RabbitMQ")
return
except Exception as e:
logger.warning(f"Failed to connect to RabbitMQ: {e}")
if i < self.max_retries - 1:
await asyncio.sleep(1)
else:
raise
async def declare_queue(self, queue):
if queue not in self.queues:
await self.get_conn()
await self.chan.declare_queue(queue, durable=True)
self.queues.add(queue)
return await self.chan.get_queue(queue)
@class_try_catch_async
async def put_subtask(self, subtask):
queue = subtask["queue"]
await self.declare_queue(queue)
keys = ["queue", "task_id", "worker_name", "inputs", "outputs", "params"]
msg = json.dumps({k: subtask[k] for k in keys}).encode("utf-8")
message = aio_pika.Message(body=msg, delivery_mode=aio_pika.DeliveryMode.PERSISTENT, content_type="application/json")
await self.chan.default_exchange.publish(message, routing_key=queue)
logger.info(f"Rabbitmq published subtask: ({subtask['task_id']}, {subtask['worker_name']}) to {queue}")
return True
async def get_subtasks(self, queue, max_batch, timeout):
try:
q = await self.declare_queue(queue)
subtasks = []
async with q.iterator() as qiter:
async for message in qiter:
await message.ack()
subtask = json.loads(message.body.decode("utf-8"))
subtasks.append(subtask)
if len(subtasks) >= max_batch:
return subtasks
while True:
message = await q.get(no_ack=False, fail=False)
if message:
await message.ack()
subtask = json.loads(message.body.decode("utf-8"))
subtasks.append(subtask)
if len(subtasks) >= max_batch:
return subtasks
else:
return subtasks
except asyncio.CancelledError:
logger.warning(f"rabbitmq get_subtasks for {queue} cancelled")
return None
except: # noqa
logger.warning(f"rabbitmq get_subtasks for {queue} failed: {traceback.format_exc()}")
return None
@class_try_catch_async
async def pending_num(self, queue):
q = await self.declare_queue(queue)
return q.declaration_result.message_count
async def del_conn(self):
if self.chan:
await self.chan.close()
if self.conn:
await self.conn.close()
async def test():
conn_url = "amqp://username:password@127.0.0.1:5672"
q = RabbitMQQueueManager(conn_url)
await q.init()
subtask = {
"task_id": "test-subtask-id",
"queue": "test_queue",
"worker_name": "test_worker",
"inputs": {},
"outputs": {},
"params": {},
}
await q.put_subtask(subtask)
await asyncio.sleep(5)
for i in range(2):
subtask = await q.get_subtasks("test_queue", 3, 5)
print("get subtask:", subtask)
await q.close()
if __name__ == "__main__":
asyncio.run(test())
import argparse
import asyncio
import mimetypes
import os
import traceback
from contextlib import asynccontextmanager
import uvicorn
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, JSONResponse, Response
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.staticfiles import StaticFiles
from loguru import logger
from lightx2v.deploy.common.pipeline import Pipeline
from lightx2v.deploy.common.utils import check_params, data_name, load_inputs
from lightx2v.deploy.data_manager import LocalDataManager, S3DataManager
from lightx2v.deploy.queue_manager import LocalQueueManager, RabbitMQQueueManager
from lightx2v.deploy.server.auth import AuthManager
from lightx2v.deploy.server.metrics import MetricMonitor
from lightx2v.deploy.server.monitor import ServerMonitor, WorkerStatus
from lightx2v.deploy.task_manager import LocalTaskManager, PostgresSQLTaskManager, TaskStatus
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.service_utils import ProcessManager
# =========================
# FastAPI Related Code
# =========================
model_pipelines = None
task_manager = None
data_manager = None
queue_manager = None
server_monitor = None
auth_manager = None
metrics_monitor = MetricMonitor()
@asynccontextmanager
async def lifespan(app: FastAPI):
await task_manager.init()
await task_manager.mark_server_restart()
await data_manager.init()
await queue_manager.init()
asyncio.create_task(server_monitor.init())
yield
await server_monitor.close()
await queue_manager.close()
await data_manager.close()
await task_manager.close()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
logger.error(f"HTTP Exception: {exc.status_code} - {exc.detail} for {request.url}")
return JSONResponse(status_code=exc.status_code, content={"message": exc.detail})
static_dir = os.path.join(os.path.dirname(__file__), "static")
app.mount("/static", StaticFiles(directory=static_dir), name="static")
security = HTTPBearer()
async def verify_user_access(credentials: HTTPAuthorizationCredentials = Depends(security)):
token = credentials.credentials
payload = auth_manager.verify_jwt_token(token)
user_id = payload.get("user_id", None)
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user")
user = await task_manager.query_user(user_id)
# logger.info(f"Verfiy user access: {payload}")
if user is None or user["user_id"] != user_id:
raise HTTPException(status_code=401, detail="Invalid user")
return user
async def verify_user_access_from_query(request: Request):
"""从查询参数中验证用户访问权限"""
# 首先尝试从 Authorization 头部获取 token
auth_header = request.headers.get("Authorization")
token = None
if auth_header and auth_header.startswith("Bearer "):
token = auth_header[7:] # 移除 "Bearer " 前缀
else:
# 如果没有 Authorization 头部,尝试从查询参数获取
token = request.query_params.get("token")
payload = auth_manager.verify_jwt_token(token)
user_id = payload.get("user_id", None)
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user")
user = await task_manager.query_user(user_id)
if user is None or user["user_id"] != user_id:
raise HTTPException(status_code=401, detail="Invalid user")
return user
async def verify_worker_access(credentials: HTTPAuthorizationCredentials = Depends(security)):
token = credentials.credentials
if not auth_manager.verify_worker_token(token):
raise HTTPException(status_code=403, detail="Invalid worker token")
return True
def error_response(e, code):
return JSONResponse({"message": f"error: {e}!"}, status_code=code)
@app.get("/", response_class=HTMLResponse)
async def root():
with open(os.path.join(static_dir, "index.html"), "r", encoding="utf-8") as f:
return HTMLResponse(content=f.read())
@app.get("/auth/login/github")
async def github_auth(request: Request):
client_id = auth_manager.github_client_id
redirect_uri = f"{request.base_url}"
auth_url = f"https://github.com/login/oauth/authorize?client_id={client_id}&redirect_uri={redirect_uri}"
return {"auth_url": auth_url}
@app.get("/auth/callback/github")
async def github_callback(request: Request):
try:
code = request.query_params.get("code")
if not code:
return error_response("Missing authorization code", 400)
user_info = await auth_manager.auth_github(code)
user_id = await task_manager.create_user(user_info)
user_info["user_id"] = user_id
access_token = auth_manager.create_jwt_token(user_info)
logger.info(f"GitHub callback: user_info: {user_info}, access_token: {access_token}")
return {"access_token": access_token, "user_info": user_info}
except Exception as e:
traceback.print_exc()
return error_response(str(e), 500)
async def prepare_subtasks(task_id):
# schedule next subtasks and pend, put to message queue
subtasks = await task_manager.next_subtasks(task_id)
for sub in subtasks:
logger.info(f"Prepare ready subtask: ({task_id}, {sub['worker_name']})")
r = await queue_manager.put_subtask(sub)
assert r, "put subtask to queue error"
@app.get("/api/v1/model/list")
async def api_v1_model_list(user=Depends(verify_user_access)):
try:
msg = await server_monitor.check_user_busy(user["user_id"])
if msg is not True:
return error_response(msg, 400)
return {"models": model_pipelines.get_model_lists()}
except Exception as e:
traceback.print_exc()
return error_response(str(e), 500)
@app.post("/api/v1/task/submit")
async def api_v1_task_submit(request: Request, user=Depends(verify_user_access)):
task_id = None
try:
msg = await server_monitor.check_user_busy(user["user_id"], active_new_task=True)
if msg is not True:
return error_response(msg, 400)
params = await request.json()
keys = [params.pop("task"), params.pop("model_cls"), params.pop("stage")]
assert len(params["prompt"]) > 0, "valid prompt is required"
# get worker infos, model input names
workers = model_pipelines.get_workers(keys)
inputs = model_pipelines.get_inputs(keys)
outputs = model_pipelines.get_outputs(keys)
types = model_pipelines.get_types(keys)
check_params(params, inputs, outputs, types)
# check if task can be published to queues
queues = [v["queue"] for v in workers.values()]
wait_time = await server_monitor.check_queue_busy(keys, queues)
if wait_time is None:
return error_response(f"Queue busy, please try again later", 500)
# process multimodal inputs data
inputs_data = await load_inputs(params, inputs, types)
# init task
task_id = await task_manager.create_task(keys, workers, params, inputs, outputs, user["user_id"])
logger.info(f"Submit task: {task_id} {params}")
# save multimodal inputs data
for inp, data in inputs_data.items():
await data_manager.save_bytes(data, data_name(inp, task_id))
await prepare_subtasks(task_id)
return {"task_id": task_id, "workers": workers, "params": params, "wait_time": wait_time}
except Exception as e:
traceback.print_exc()
if task_id:
await task_manager.finish_subtasks(task_id, TaskStatus.FAILED, fail_msg=f"submit failed: {e}")
return error_response(str(e), 500)
@app.get("/api/v1/task/query")
async def api_v1_task_query(request: Request, user=Depends(verify_user_access)):
try:
msg = await server_monitor.check_user_busy(user["user_id"])
if msg is not True:
return error_response(msg, 400)
task_id = request.query_params["task_id"]
task, subtasks = await task_manager.query_task(task_id, user["user_id"], only_task=False)
if task is None:
return error_response(f"Task {task_id} not found", 404)
for sub in subtasks:
sub["status"] = sub["status"].name
task["subtasks"] = subtasks
task["status"] = task["status"].name
return task
except Exception as e:
traceback.print_exc()
return error_response(str(e), 500)
@app.get("/api/v1/task/list")
async def api_v1_task_list(request: Request, user=Depends(verify_user_access)):
try:
user_id = user["user_id"]
msg = await server_monitor.check_user_busy(user_id)
if msg is not True:
return error_response(msg, 400)
page = int(request.query_params.get("page", 1))
page_size = int(request.query_params.get("page_size", 10))
assert page > 0 and page_size > 0, "page and page_size must be greater than 0"
status_filter = request.query_params.get("status", None)
query_params = {"user_id": user_id}
if status_filter and status_filter != "ALL":
query_params["status"] = TaskStatus[status_filter.upper()]
total_tasks = await task_manager.list_tasks(count=True, **query_params)
total_pages = (total_tasks + page_size - 1) // page_size
page_info = {"page": page, "page_size": page_size, "total": total_tasks, "total_pages": total_pages}
if page > total_pages:
return {"tasks": [], "pagination": page_info}
query_params["offset"] = (page - 1) * page_size
query_params["limit"] = page_size
tasks = await task_manager.list_tasks(**query_params)
for task in tasks:
task["status"] = task["status"].name
return {"tasks": tasks, "pagination": page_info}
except Exception as e:
traceback.print_exc()
return error_response(str(e), 500)
@app.get("/api/v1/task/result")
async def api_v1_task_result(request: Request, user=Depends(verify_user_access_from_query)):
try:
msg = await server_monitor.check_user_busy(user["user_id"])
if msg is not True:
return error_response(msg, 400)
name = request.query_params["name"]
task_id = request.query_params["task_id"]
task = await task_manager.query_task(task_id, user_id=user["user_id"])
if task is None:
return error_response(f"Task {task_id} not found", 404)
if task["status"] != TaskStatus.SUCCEED:
return error_response(f"Task {task_id} not succeed", 400)
assert name in task["outputs"], f"Output {name} not found in task {task_id}"
if name in task["params"]:
return error_response(f"Output {name} is a stream", 400)
data = await data_manager.load_bytes(task["outputs"][name])
# set correct Content-Type
content_type, _ = mimetypes.guess_type(name)
if content_type is None:
content_type = "application/octet-stream"
headers = {"Content-Disposition": f'attachment; filename="{name}"'}
return Response(content=data, media_type=content_type, headers=headers)
except Exception as e:
traceback.print_exc()
return error_response(str(e), 500)
@app.get("/api/v1/task/input")
async def api_v1_task_input(request: Request, user=Depends(verify_user_access_from_query)):
try:
# msg = await server_monitor.check_user_busy(user['user_id'])
# if msg is not True:
# return error_response(msg, 400)
name = request.query_params["name"]
task_id = request.query_params["task_id"]
task = await task_manager.query_task(task_id, user_id=user["user_id"])
if task is None:
return error_response(f"Task {task_id} not found", 404)
if name not in task["inputs"]:
return error_response(f"Input {name} not found in task {task_id}", 404)
if name in task["params"]:
return error_response(f"Input {name} is a stream", 400)
data = await data_manager.load_bytes(task["inputs"][name])
# set correct Content-Type
content_type, _ = mimetypes.guess_type(name)
if content_type is None:
content_type = "application/octet-stream"
headers = {"Content-Disposition": f'attachment; filename="{name}"'}
return Response(content=data, media_type=content_type, headers=headers)
except Exception as e:
traceback.print_exc()
return error_response(str(e), 500)
@app.get("/api/v1/task/thumbnails")
async def api_v1_task_thumbnails(request: Request, user=Depends(verify_user_access)):
"""一次性获取所有任务的缩略图"""
try:
user_id = user["user_id"]
msg = await server_monitor.check_user_busy(user_id)
if msg is not True:
return error_response(msg, 400)
# 获取所有任务
tasks = await task_manager.list_tasks(user_id=user_id, limit=1000) # 限制最多1000个任务
logger.info(f"获取到 {len(tasks)} 个任务")
# 转换任务状态为字符串
for task in tasks:
task["status"] = task["status"].name
thumbnails = {}
for task in tasks:
task_id = task["task_id"]
if task.get("inputs"):
# 查找输入中的图片文件
image_inputs = []
for key, value in task["inputs"].items():
if key.lower().find("image") != -1 or str(value).lower().endswith((".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp")):
image_inputs.append(key)
if image_inputs:
# 使用第一个图片作为缩略图
first_image_key = image_inputs[0]
try:
# 获取图片数据
image_data = await data_manager.load_bytes(task["inputs"][first_image_key])
# 转换为base64
import base64
image_base64 = base64.b64encode(image_data).decode("utf-8")
# 根据文件扩展名确定MIME类型
content_type, _ = mimetypes.guess_type(first_image_key)
if content_type is None:
content_type = "image/jpeg" # 默认类型
# 构建data URL
data_url = f"data:{content_type};base64,{image_base64}"
thumbnails[task_id] = data_url
except Exception as e:
logger.warning(f"Failed to load thumbnail for task {task_id}: {e}")
# 如果加载失败,不添加到结果中
result = {"thumbnails": thumbnails, "total_tasks": len(tasks), "total_thumbnails": len(thumbnails)}
return result
except Exception as e:
traceback.print_exc()
return error_response(str(e), 500)
@app.get("/api/v1/task/cancel")
async def api_v1_task_cancel(request: Request, user=Depends(verify_user_access)):
try:
msg = await server_monitor.check_user_busy(user["user_id"])
if msg is not True:
return error_response(msg, 400)
task_id = request.query_params["task_id"]
ret = await task_manager.cancel_task(task_id, user_id=user["user_id"])
logger.warning(f"Task {task_id} cancelled: {ret}")
if ret is True:
return {"msg": "Task cancelled successfully"}
else:
return error_response({"error": f"Task {task_id} cancel failed: {ret}"}, 400)
except Exception as e:
traceback.print_exc()
return error_response(str(e), 500)
@app.get("/api/v1/task/resume")
async def api_v1_task_resume(request: Request, user=Depends(verify_user_access)):
try:
msg = await server_monitor.check_user_busy(user["user_id"], active_new_task=True)
if msg is not True:
return error_response(msg, 400)
task_id = request.query_params["task_id"]
ret = await task_manager.resume_task(task_id, user_id=user["user_id"], all_subtask=True)
if ret:
await prepare_subtasks(task_id)
return {"msg": "ok"}
else:
return error_response(f"Task {task_id} resume failed", 400)
except Exception as e:
traceback.print_exc()
return error_response(str(e), 500)
@app.post("/api/v1/worker/fetch")
async def api_v1_worker_fetch(request: Request, valid=Depends(verify_worker_access)):
try:
params = await request.json()
logger.info(f"Worker fetching: {params}")
keys = params.pop("worker_keys")
identity = params.pop("worker_identity")
max_batch = params.get("max_batch", 1)
timeout = params.get("timeout", 5)
# check client disconnected
async def check_client(request, fetch_task, identity, queue):
while True:
msg = await request.receive()
if msg["type"] == "http.disconnect":
logger.warning(f"Worker {identity} {queue} disconnected, req: {request.client}, {msg}")
fetch_task.cancel()
await server_monitor.worker_update(queue, identity, WorkerStatus.DISCONNECT)
return
await asyncio.sleep(1)
# get worker info
worker = model_pipelines.get_worker(keys)
await server_monitor.worker_update(worker["queue"], identity, WorkerStatus.FETCHING)
fetch_task = asyncio.create_task(queue_manager.get_subtasks(worker["queue"], max_batch, timeout))
check_task = asyncio.create_task(check_client(request, fetch_task, identity, worker["queue"]))
try:
subtasks = await asyncio.wait_for(fetch_task, timeout=timeout)
except asyncio.TimeoutError:
subtasks = []
fetch_task.cancel()
check_task.cancel()
subtasks = [] if subtasks is None else subtasks
valid_subtasks = await task_manager.run_subtasks(subtasks, identity)
valids = [sub["task_id"] for sub in valid_subtasks]
if len(valid_subtasks) > 0:
await server_monitor.worker_update(worker["queue"], identity, WorkerStatus.FETCHED)
logger.info(f"Worker {identity} {keys} {request.client} fetched {valids}")
else:
await server_monitor.worker_update(worker["queue"], identity, WorkerStatus.DISCONNECT)
return {"subtasks": valid_subtasks}
except Exception as e:
traceback.print_exc()
return error_response(str(e), 500)
@app.post("/api/v1/worker/report")
async def api_v1_worker_report(request: Request, valid=Depends(verify_worker_access)):
try:
params = await request.json()
logger.info(f"{params}")
task_id = params.pop("task_id")
worker_name = params.pop("worker_name")
status = TaskStatus[params.pop("status")]
identity = params.pop("worker_identity")
queue = params.pop("queue")
fail_msg = params.pop("fail_msg", None)
await server_monitor.worker_update(queue, identity, WorkerStatus.REPORT)
ret = await task_manager.finish_subtasks(task_id, status, worker_identity=identity, worker_name=worker_name, fail_msg=fail_msg, should_running=True)
# not all subtasks finished, prepare new ready subtasks
if ret not in [TaskStatus.SUCCEED, TaskStatus.FAILED]:
await prepare_subtasks(task_id)
# all subtasks succeed, delete temp data
elif ret == TaskStatus.SUCCEED:
logger.info(f"Task {task_id} succeed")
task = await task_manager.query_task(task_id)
keys = [task["task_type"], task["model_cls"], task["stage"]]
temps = model_pipelines.get_temps(keys)
for temp in temps:
type = model_pipelines.get_type(temp)
name = data_name(temp, task_id)
await data_manager.get_delete_func(type)(name)
elif ret == TaskStatus.FAILED:
logger.warning(f"Task {task_id} failed")
return {"msg": "ok"}
except Exception as e:
traceback.print_exc()
return error_response(str(e), 500)
@app.post("/api/v1/worker/ping/subtask")
async def api_v1_worker_ping_subtask(request: Request, valid=Depends(verify_worker_access)):
try:
params = await request.json()
logger.info(f"{params}")
task_id = params.pop("task_id")
worker_name = params.pop("worker_name")
identity = params.pop("worker_identity")
queue = params.pop("queue")
task = await task_manager.query_task(task_id)
if task["status"] != TaskStatus.RUNNING:
return {"msg": "delete"}
assert await task_manager.ping_subtask(task_id, worker_name, identity)
await server_monitor.worker_update(queue, identity, WorkerStatus.PING)
return {"msg": "ok"}
except Exception as e:
traceback.print_exc()
return error_response(str(e), 500)
@app.post("/api/v1/worker/ping/life")
async def api_v1_worker_ping_life(request: Request, valid=Depends(verify_worker_access)):
try:
params = await request.json()
logger.info(f"{params}")
identity = params.pop("worker_identity")
keys = params.pop("worker_keys")
worker = model_pipelines.get_worker(keys)
# worker lost, init it again
queue = server_monitor.identity_to_queue.get(identity, None)
if queue is None:
queue = worker["queue"]
logger.warning(f"worker {identity} lost, refetching it")
await server_monitor.worker_update(queue, identity, WorkerStatus.FETCHING)
else:
assert queue == worker["queue"], f"worker {identity} queue not matched: {queue} vs {worker['queue']}"
metrics = await server_monitor.cal_metrics()
ret = {"queue": queue, "metrics": metrics[queue]}
if identity in metrics[queue]["del_worker_identities"]:
ret["msg"] = "delete"
else:
ret["msg"] = "ok"
return ret
except Exception as e:
traceback.print_exc()
return error_response(str(e), 500)
@app.get("/metrics")
async def api_v1_monitor_metrics():
try:
return Response(content=metrics_monitor.get_metrics(), media_type="text/plain")
except Exception as e:
traceback.print_exc()
return error_response(str(e), 500)
# Template API endpoints
@app.get("/api/v1/template/{template_type}/{filename}")
async def api_v1_template(template_type: str, filename: str):
"""获取模板文件"""
try:
import os
template_dir = os.path.join(os.path.dirname(__file__), "..", "template")
file_path = os.path.join(template_dir, template_type, filename)
# 安全检查:确保文件在template目录内
if not os.path.exists(file_path) or not file_path.startswith(template_dir):
return error_response(f"Template file not found", 404)
with open(file_path, "rb") as f:
data = f.read()
# 根据文件类型设置媒体类型
if template_type == "images":
if filename.lower().endswith(".png"):
media_type = "image/png"
elif filename.lower().endswith((".jpg", ".jpeg")):
media_type = "image/jpeg"
else:
media_type = "application/octet-stream"
elif template_type == "audios":
if filename.lower().endswith(".mp3"):
media_type = "audio/mpeg"
elif filename.lower().endswith(".wav"):
media_type = "audio/wav"
else:
media_type = "application/octet-stream"
else:
media_type = "application/octet-stream"
return Response(content=data, media_type=media_type)
except Exception as e:
traceback.print_exc()
return error_response(str(e), 500)
@app.get("/api/v1/template/list")
async def api_v1_template_list():
"""获取模板文件列表"""
try:
import glob
import os
template_dir = os.path.join(os.path.dirname(__file__), "..", "template")
templates = {"images": [], "audios": []}
# 获取图片模板
image_dir = os.path.join(template_dir, "images")
if os.path.exists(image_dir):
for file_path in glob.glob(os.path.join(image_dir, "*")):
if os.path.isfile(file_path):
filename = os.path.basename(file_path)
templates["images"].append({"filename": filename, "url": f"/api/v1/template/images/{filename}"})
# 获取音频模板
audio_dir = os.path.join(template_dir, "audios")
if os.path.exists(audio_dir):
for file_path in glob.glob(os.path.join(audio_dir, "*")):
if os.path.isfile(file_path):
filename = os.path.basename(file_path)
templates["audios"].append({"filename": filename, "url": f"/api/v1/template/audios/{filename}"})
return {"templates": templates}
except Exception as e:
traceback.print_exc()
return error_response(str(e), 500)
# =========================
# Main Entry
# =========================
if __name__ == "__main__":
ProcessManager.register_signal_handler()
parser = argparse.ArgumentParser()
cur_dir = os.path.dirname(os.path.abspath(__file__))
base_dir = os.path.abspath(os.path.join(cur_dir, "../../.."))
dft_pipeline_json = os.path.join(base_dir, "configs/model_pipeline.json")
dft_task_url = os.path.join(base_dir, "local_task")
dft_data_url = os.path.join(base_dir, "local_data")
dft_queue_url = os.path.join(base_dir, "local_queue")
parser.add_argument("--pipeline_json", type=str, default=dft_pipeline_json)
parser.add_argument("--task_url", type=str, default=dft_task_url)
parser.add_argument("--data_url", type=str, default=dft_data_url)
parser.add_argument("--queue_url", type=str, default=dft_queue_url)
parser.add_argument("--ip", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=8080)
args = parser.parse_args()
logger.info(f"args: {args}")
with ProfilingContext("Init Server Cost"):
model_pipelines = Pipeline(args.pipeline_json)
auth_manager = AuthManager()
if args.task_url.startswith("/"):
task_manager = LocalTaskManager(args.task_url, metrics_monitor)
elif args.task_url.startswith("postgresql://"):
task_manager = PostgresSQLTaskManager(args.task_url, metrics_monitor)
else:
raise NotImplementedError
if args.data_url.startswith("/"):
data_manager = LocalDataManager(args.data_url)
elif args.data_url.startswith("{"):
data_manager = S3DataManager(args.data_url)
else:
raise NotImplementedError
if args.queue_url.startswith("/"):
queue_manager = LocalQueueManager(args.queue_url)
elif args.queue_url.startswith("amqp://"):
queue_manager = RabbitMQQueueManager(args.queue_url)
else:
raise NotImplementedError
server_monitor = ServerMonitor(model_pipelines, task_manager, queue_manager)
uvicorn.run(app, host=args.ip, port=args.port, reload=False, workers=1)
import os
import time
import aiohttp
import jwt
from fastapi import HTTPException
from loguru import logger
class AuthManager:
def __init__(self):
# Worker access token
self.worker_secret_key = os.getenv("WORKER_SECRET_KEY", "worker-secret-key-change-in-production")
# GitHub OAuth
self.github_client_id = os.getenv("GITHUB_CLIENT_ID", "")
self.github_client_secret = os.getenv("GITHUB_CLIENT_SECRET", "")
self.jwt_algorithm = os.getenv("JWT_ALGORITHM", "HS256")
self.jwt_expiration_hours = os.getenv("JWT_EXPIRATION_HOURS", 24)
self.jwt_secret_key = os.getenv("JWT_SECRET_KEY", "your-secret-key-change-in-production")
logger.info(f"AuthManager: GITHUB_CLIENT_ID: {self.github_client_id}")
logger.info(f"AuthManager: GITHUB_CLIENT_SECRET: {self.github_client_secret}")
logger.info(f"AuthManager: JWT_SECRET_KEY: {self.jwt_secret_key}")
logger.info(f"AuthManager: WORKER_SECRET_KEY: {self.worker_secret_key}")
def create_jwt_token(self, data):
data2 = {
"user_id": data["user_id"],
"username": data["username"],
"email": data["email"],
"homepage": data["homepage"],
}
expire = time.time() + (self.jwt_expiration_hours * 3600)
data2.update({"exp": expire})
return jwt.encode(data2, self.jwt_secret_key, algorithm=self.jwt_algorithm)
async def auth_github(self, code):
try:
logger.info(f"GitHub OAuth code: {code}")
token_url = "https://github.com/login/oauth/access_token"
token_data = {"client_id": self.github_client_id, "client_secret": self.github_client_secret, "code": code}
headers = {"Accept": "application/json"}
proxy = os.getenv("auth_https_proxy", None)
if proxy:
logger.info(f"auth_github use proxy: {proxy}")
async with aiohttp.ClientSession() as session:
async with session.post(token_url, data=token_data, headers=headers, proxy=proxy) as response:
response.raise_for_status()
token_info = await response.json()
if "error" in token_info:
raise HTTPException(status_code=400, detail=f"GitHub OAuth error: {token_info['error']}")
access_token = token_info.get("access_token")
if not access_token:
raise HTTPException(status_code=400, detail="Failed to get access token")
user_url = "https://api.github.com/user"
user_headers = {"Authorization": f"token {access_token}", "Accept": "application/vnd.github.v3+json"}
async with aiohttp.ClientSession() as session:
async with session.get(user_url, headers=user_headers, proxy=proxy) as response:
response.raise_for_status()
user_info = await response.json()
return {
"source": "github",
"id": str(user_info["id"]),
"username": user_info["login"],
"email": user_info.get("email", ""),
"homepage": user_info.get("html_url", ""),
"avatar_url": user_info.get("avatar_url", ""),
}
except aiohttp.ClientError as e:
logger.error(f"GitHub API request failed: {e}")
raise HTTPException(status_code=500, detail="Failed to authenticate with GitHub")
except Exception as e:
logger.error(f"Authentication error: {e}")
raise HTTPException(status_code=500, detail="Authentication failed")
def verify_jwt_token(self, token):
try:
payload = jwt.decode(token, self.jwt_secret_key, algorithms=[self.jwt_algorithm])
return payload
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token has expired")
except Exception as e:
logger.error(f"verify_jwt_token error: {e}")
raise HTTPException(status_code=401, detail="Could not validate credentials")
def verify_worker_token(self, token):
return token == self.worker_secret_key
from loguru import logger
from prometheus_client import Counter, Gauge, Summary, generate_latest
from prometheus_client.core import CollectorRegistry
from lightx2v.deploy.task_manager import ActiveStatus, FinishedStatus, TaskStatus
REGISTRY = CollectorRegistry()
class MetricMonitor:
def __init__(self):
self.task_all = Counter("task_all_total", "Total count of all tasks", ["task_type", "model_cls", "stage"], registry=REGISTRY)
self.task_end = Counter("task_end_total", "Total count of ended tasks", ["task_type", "model_cls", "stage", "status"], registry=REGISTRY)
self.task_active = Gauge("task_active_size", "Current count of active tasks", ["task_type", "model_cls", "stage"], registry=REGISTRY)
self.task_elapse = Summary("task_elapse_seconds", "Elapse time of tasks", ["task_type", "model_cls", "stage", "end_status"], registry=REGISTRY)
self.subtask_all = Counter("subtask_all_total", "Total count of all subtasks", ["queue"], registry=REGISTRY)
self.subtask_end = Counter("subtask_end_total", "Total count of ended subtasks", ["queue", "status"], registry=REGISTRY)
self.subtask_active = Gauge("subtask_active_size", "Current count of active subtasks", ["queue", "status"], registry=REGISTRY)
self.subtask_elapse = Summary("subtask_elapse_seconds", "Elapse time of subtasks", ["queue", "elapse_key"], registry=REGISTRY)
def record_task_start(self, task):
self.task_all.labels(task["task_type"], task["model_cls"], task["stage"]).inc()
self.task_active.labels(task["task_type"], task["model_cls"], task["stage"]).inc()
logger.info(f"Metrics task_all + 1, task_active +1")
def record_task_end(self, task, status, elapse):
self.task_end.labels(task["task_type"], task["model_cls"], task["stage"], status.name).inc()
self.task_active.labels(task["task_type"], task["model_cls"], task["stage"]).dec()
self.task_elapse.labels(task["task_type"], task["model_cls"], task["stage"], status.name).observe(elapse)
logger.info(f"Metrics task_end + 1, task_active -1, task_elapse observe {elapse}")
def record_subtask_change(self, subtask, old_status, new_status, elapse_key, elapse):
if old_status in ActiveStatus and new_status in FinishedStatus:
self.subtask_end.labels(subtask["queue"], elapse_key).inc()
logger.info(f"Metrics subtask_end + 1")
if old_status in ActiveStatus:
self.subtask_active.labels(subtask["queue"], old_status.name).dec()
logger.info(f"Metrics subtask_active {old_status.name} -1")
if new_status in ActiveStatus:
self.subtask_active.labels(subtask["queue"], new_status.name).inc()
logger.info(f"Metrics subtask_active {new_status.name} + 1")
if new_status == TaskStatus.CREATED:
self.subtask_all.labels(subtask["queue"]).inc()
logger.info(f"Metrics subtask_all + 1")
if elapse and elapse_key:
self.subtask_elapse.labels(subtask["queue"], elapse_key).observe(elapse)
logger.info(f"Metrics subtask_elapse observe {elapse}")
# restart server, we should recover active tasks in data_manager
def record_task_recover(self, tasks):
for task in tasks:
if task["status"] in ActiveStatus:
self.record_task_start(task)
# restart server, we should recover active tasks in data_manager
def record_subtask_recover(self, subtasks):
for subtask in subtasks:
if subtask["status"] in ActiveStatus:
self.subtask_all.labels(subtask["queue"]).inc()
self.subtask_active.labels(subtask["queue"], subtask["status"].name).inc()
logger.info(f"Metrics subtask_active {subtask['status'].name} + 1")
logger.info(f"Metrics subtask_all + 1")
def get_metrics(self):
return generate_latest(REGISTRY)
import asyncio
import time
from enum import Enum
from loguru import logger
from lightx2v.deploy.common.utils import class_try_catch_async
from lightx2v.deploy.task_manager import TaskStatus
class WorkerStatus(Enum):
FETCHING = 1
FETCHED = 2
DISCONNECT = 3
REPORT = 4
PING = 5
class CostWindow:
def __init__(self, window):
self.window = window
self.costs = []
self.avg = None
def append(self, cost):
self.costs.append(cost)
if len(self.costs) > self.window:
self.costs.pop(0)
self.avg = sum(self.costs) / len(self.costs)
class WorkerClient:
def __init__(self, queue, identity, infer_timeout, offline_timeout, avg_window, ping_timeout):
self.queue = queue
self.identity = identity
self.status = None
self.update_t = time.time()
self.fetched_t = None
self.infer_cost = CostWindow(avg_window)
self.offline_cost = CostWindow(avg_window)
self.infer_timeout = infer_timeout
self.offline_timeout = offline_timeout
self.ping_timeout = ping_timeout
# FETCHING -> FETCHED -> PING * n -> REPORT -> FETCHING
# FETCHING -> DISCONNECT -> FETCHING
def update(self, status: WorkerStatus):
pre_status = self.status
pre_t = self.update_t
self.status = status
self.update_t = time.time()
if status == WorkerStatus.FETCHING:
if pre_status in [WorkerStatus.DISCONNECT, WorkerStatus.REPORT] and pre_t is not None:
cur_cost = self.update_t - pre_t
if cur_cost < self.offline_timeout:
self.offline_cost.append(max(cur_cost, 1))
elif status == WorkerStatus.REPORT:
if self.fetched_t is not None:
cur_cost = self.update_t - self.fetched_t
self.fetched_t = None
if cur_cost < self.infer_timeout:
self.infer_cost.append(max(cur_cost, 1))
elif status == WorkerStatus.FETCHED:
self.fetched_t = time.time()
def check(self):
# infer too long
if self.fetched_t is not None:
elapse = time.time() - self.fetched_t
if self.infer_cost.avg is not None and elapse > self.infer_cost.avg * 5:
logger.warning(f"Worker {self.identity} {self.queue} infer timeout: {elapse:.2f} s")
return False
if elapse > self.infer_timeout:
logger.warning(f"Worker {self.identity} {self.queue} infer timeout2: {elapse:.2f} s")
return False
elapse = time.time() - self.update_t
# no ping too long
if self.status in [WorkerStatus.FETCHED, WorkerStatus.PING]:
if elapse > self.ping_timeout:
logger.warning(f"Worker {self.identity} {self.queue} ping timeout: {elapse:.2f} s")
return False
# offline too long
elif self.status in [WorkerStatus.DISCONNECT, WorkerStatus.REPORT]:
if self.offline_cost.avg is not None and elapse > self.offline_cost.avg * 5:
logger.warning(f"Worker {self.identity} {self.queue} offline timeout: {elapse:.2f} s")
return False
if elapse > self.offline_timeout:
logger.warning(f"Worker {self.identity} {self.queue} offline timeout2: {elapse:.2f} s")
return False
return True
class ServerMonitor:
def __init__(self, model_pipelines, task_manager, queue_manager, interval=1):
self.model_pipelines = model_pipelines
self.task_manager = task_manager
self.queue_manager = queue_manager
self.interval = interval
self.stop = False
self.worker_clients = {}
self.identity_to_queue = {}
self.subtask_run_timeouts = {}
self.all_queues = self.model_pipelines.get_queues()
self.config = self.model_pipelines.get_monitor_config()
for queue in self.all_queues:
self.subtask_run_timeouts[queue] = self.config["subtask_running_timeouts"].get(queue, 60)
self.subtask_created_timeout = self.config["subtask_created_timeout"]
self.subtask_pending_timeout = self.config["subtask_pending_timeout"]
self.worker_avg_window = self.config["worker_avg_window"]
self.worker_offline_timeout = self.config["worker_offline_timeout"]
self.worker_min_capacity = self.config["worker_min_capacity"]
self.worker_min_cnt = self.config["worker_min_cnt"]
self.worker_max_cnt = self.config["worker_max_cnt"]
self.task_timeout = self.config["task_timeout"]
self.schedule_ratio_high = self.config["schedule_ratio_high"]
self.schedule_ratio_low = self.config["schedule_ratio_low"]
self.ping_timeout = self.config["ping_timeout"]
self.user_visits = {} # user_id -> last_visit_t
self.user_max_active_tasks = self.config["user_max_active_tasks"]
self.user_max_daily_tasks = self.config["user_max_daily_tasks"]
self.user_visit_frequency = self.config["user_visit_frequency"]
assert self.worker_avg_window > 0
assert self.worker_offline_timeout > 0
assert self.worker_min_capacity > 0
assert self.worker_min_cnt > 0
assert self.worker_max_cnt > 0
assert self.worker_min_cnt <= self.worker_max_cnt
assert self.task_timeout > 0
assert self.schedule_ratio_high > 0 and self.schedule_ratio_high < 1
assert self.schedule_ratio_low > 0 and self.schedule_ratio_low < 1
assert self.schedule_ratio_high >= self.schedule_ratio_low
assert self.ping_timeout > 0
assert self.user_max_active_tasks > 0
assert self.user_max_daily_tasks > 0
assert self.user_visit_frequency > 0
async def init(self):
while True:
if self.stop:
break
await self.clean_workers()
await self.clean_subtasks()
await asyncio.sleep(self.interval)
logger.info("ServerMonitor stopped")
async def close(self):
self.stop = True
self.model_pipelines = None
self.task_manager = None
self.queue_manager = None
self.worker_clients = None
def init_worker(self, queue, identity):
if queue not in self.worker_clients:
self.worker_clients[queue] = {}
if identity not in self.worker_clients[queue]:
infer_timeout = self.subtask_run_timeouts[queue]
self.worker_clients[queue][identity] = WorkerClient(queue, identity, infer_timeout, self.worker_offline_timeout, self.worker_avg_window, self.ping_timeout)
self.identity_to_queue[identity] = queue
return self.worker_clients[queue][identity]
@class_try_catch_async
async def worker_update(self, queue, identity, status):
if queue is None:
queue = self.identity_to_queue[identity]
worker = self.init_worker(queue, identity)
worker.update(status)
logger.info(f"Worker {identity} {queue} update [{status}]")
async def clean_workers(self):
qs = list(self.worker_clients.keys())
for queue in qs:
idens = list(self.worker_clients[queue].keys())
for identity in idens:
if not self.worker_clients[queue][identity].check():
self.worker_clients[queue].pop(identity)
self.identity_to_queue.pop(identity)
logger.warning(f"Worker {queue} {identity} out of contact removed, remain {self.worker_clients[queue]}")
async def clean_subtasks(self):
created_end_t = time.time() - self.subtask_created_timeout
pending_end_t = time.time() - self.subtask_pending_timeout
ping_end_t = time.time() - self.ping_timeout
fails = set()
created_tasks = await self.task_manager.list_tasks(status=TaskStatus.CREATED, subtasks=True, end_updated_t=created_end_t)
pending_tasks = await self.task_manager.list_tasks(status=TaskStatus.PENDING, subtasks=True, end_updated_t=pending_end_t)
def fmt_subtask(t):
return f"({t['task_id']}, {t['worker_name']}, {t['queue']}, {t['worker_identity']})"
for t in created_tasks + pending_tasks:
if t["task_id"] in fails:
continue
elapse = time.time() - t["update_t"]
logger.warning(f"Subtask {fmt_subtask(t)} CREATED / PENDING timeout: {elapse:.2f} s")
await self.task_manager.finish_subtasks(t["task_id"], TaskStatus.FAILED, worker_name=t["worker_name"], fail_msg=f"CREATED / PENDING timeout: {elapse:.2f} s")
fails.add(t["task_id"])
running_tasks = await self.task_manager.list_tasks(status=TaskStatus.RUNNING, subtasks=True)
for t in running_tasks:
if t["task_id"] in fails:
continue
if t["ping_t"] > 0:
ping_elapse = time.time() - t["ping_t"]
if ping_elapse >= self.ping_timeout:
logger.warning(f"Subtask {fmt_subtask(t)} PING timeout: {ping_elapse:.2f} s")
await self.task_manager.finish_subtasks(t["task_id"], TaskStatus.FAILED, worker_name=t["worker_name"], fail_msg=f"PING timeout: {ping_elapse:.2f} s")
fails.add(t["task_id"])
elapse = time.time() - t["update_t"]
limit = self.subtask_run_timeouts[t["queue"]]
if elapse >= limit:
logger.warning(f"Subtask {fmt_subtask(t)} RUNNING timeout: {elapse:.2f} s")
await self.task_manager.finish_subtasks(t["task_id"], TaskStatus.FAILED, worker_name=t["worker_name"], fail_msg=f"RUNNING timeout: {elapse:.2f} s")
fails.add(t["task_id"])
def get_avg_worker_infer_cost(self, queue):
if queue not in self.worker_clients:
self.worker_clients[queue] = {}
infer_costs = []
for _, client in self.worker_clients[queue].items():
if client.infer_cost.avg is not None:
infer_costs.append(client.infer_cost.avg)
if len(infer_costs) <= 0:
return self.subtask_run_timeouts[queue]
return sum(infer_costs) / len(infer_costs)
@class_try_catch_async
async def check_user_busy(self, user_id, active_new_task=False):
# check if user visit too frequently
cur_t = time.time()
if user_id in self.user_visits:
elapse = cur_t - self.user_visits[user_id]
if elapse <= self.user_visit_frequency:
return f"User {user_id} visit too frequently, {elapse:.2f} s vs {self.user_visit_frequency:.2f} s"
self.user_visits[user_id] = cur_t
if active_new_task:
# check if user has too many active tasks
active_statuses = [TaskStatus.RUNNING, TaskStatus.PENDING, TaskStatus.CREATED]
active_tasks = await self.task_manager.list_tasks(status=active_statuses, user_id=user_id)
if len(active_tasks) >= self.user_max_active_tasks:
return f"User {user_id} has too many active tasks, {len(active_tasks)} vs {self.user_max_active_tasks}"
# check if user has too many daily tasks
daily_statuses = active_statuses + [TaskStatus.SUCCEED, TaskStatus.CANCEL, TaskStatus.FAILED]
daily_tasks = await self.task_manager.list_tasks(status=daily_statuses, user_id=user_id, start_created_t=cur_t - 86400)
if len(daily_tasks) >= self.user_max_daily_tasks:
return f"User {user_id} has too many daily tasks, {len(daily_tasks)} vs {self.user_max_daily_tasks}"
return True
# check if a task can be published to queues
@class_try_catch_async
async def check_queue_busy(self, keys, queues):
wait_time = 0
for queue in queues:
avg_cost = self.get_avg_worker_infer_cost(queue)
worker_cnt = len(self.worker_clients[queue])
subtask_pending = await self.queue_manager.pending_num(queue)
capacity = self.task_timeout * max(worker_cnt, 1) // avg_cost
capacity = max(self.worker_min_capacity, capacity)
if subtask_pending >= capacity:
ss = f"pending={subtask_pending}, capacity={capacity}"
logger.warning(f"Queue {queue} busy, {ss}, task {keys} cannot be publised!")
return None
wait_time += avg_cost * subtask_pending / max(worker_cnt, 1)
return wait_time
@class_try_catch_async
async def cal_metrics(self):
data = {}
target_high = self.task_timeout * self.schedule_ratio_high
target_low = self.task_timeout * self.schedule_ratio_low
for queue in self.all_queues:
avg_cost = self.get_avg_worker_infer_cost(queue)
worker_cnt = len(self.worker_clients[queue])
subtask_pending = await self.queue_manager.pending_num(queue)
data[queue] = {
"avg_cost": avg_cost,
"worker_cnt": worker_cnt,
"subtask_pending": subtask_pending,
"max_worker": 0,
"min_worker": 0,
"need_add_worker": 0,
"need_del_worker": 0,
"del_worker_identities": [],
}
fix_cnt = subtask_pending // max(self.worker_min_capacity, 1)
min_cnt = min(fix_cnt, subtask_pending * avg_cost // target_high)
max_cnt = min(fix_cnt, subtask_pending * avg_cost // target_low)
data[queue]["min_worker"] = max(self.worker_min_cnt, min_cnt)
data[queue]["max_worker"] = max(self.worker_max_cnt, max_cnt)
if worker_cnt < data[queue]["min_worker"]:
data[queue]["need_add_worker"] = data[queue]["min_worker"] - worker_cnt
if subtask_pending == 0 and worker_cnt > data[queue]["max_worker"]:
data[queue]["need_del_worker"] = worker_cnt - data[queue]["max_worker"]
if data[queue]["need_del_worker"] > 0:
for identity, client in self.worker_clients[queue].items():
if client.status in [WorkerStatus.FETCHING, WorkerStatus.DISCONNECT]:
data[queue]["del_worker_identities"].append(identity)
if len(data[queue]["del_worker_identities"]) >= data[queue]["need_del_worker"]:
break
return data
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>LightX2V 文生视频服务</title>
<script src="https://cdn.tailwindcss.com"></script>
<!-- 主要图标库 -->
<link href="https://cdn.bootcdn.net/ajax/libs/font-awesome/6.4.0/css/all.min.css" rel="stylesheet">
<!-- 备用图标库 -->
<link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css" rel="stylesheet" media="print" onload="this.media='all'">
<!-- 本地备用图标(如果CDN都失败) -->
<style>
/* 备用图标样式,防止CDN失败时图标不显示 */
.icon-fallback {
display: inline-block;
width: 1em;
height: 1em;
background-color: currentColor;
border-radius: 50%;
}
.icon-fallback.small {
width: 0.75em;
height: 0.75em;
}
.icon-fallback.large {
width: 1.5em;
height: 1.5em;
}
.icon-fallback.xl {
width: 2em;
height: 2em;
}
/* 登录页面样式 */
.login-container {
min-height: 100%;
min-width: 100%;
background: linear-gradient(135deg, #0b0a20 0%, #1b1240 50%, #0f0e22 100%);
position: relative;
overflow: hidden;
display: flex;
align-items: center;
justify-content: center;
}
.login-container::before {
content: '';
position: absolute;
top: 0;
left: 0;
right: 0;
bottom: 0;
background:
radial-gradient(circle at 20% 80%, rgba(154, 114, 255, 0.1) 0%, transparent 50%),
radial-gradient(circle at 80% 20%, rgba(183, 139, 255, 0.1) 0%, transparent 50%),
radial-gradient(circle at 40% 40%, rgba(124, 106, 255, 0.05) 0%, transparent 50%);
animation: backgroundShift 20s ease-in-out infinite;
}
/* 登录页面样式 */
.main-container {
min-height: 100%;
min-width: 100%;
background: linear-gradient(135deg, #0b0a20 0%, #1b1240 50%, #0f0e22 100%);
overflow: hidden;
display: flex;
}
@keyframes backgroundShift {
0%, 100% { opacity: 1; }
50% { opacity: 0.8; }
}
.login-card {
background: rgba(27, 18, 64, 0.8);
backdrop-filter: blur(20px);
border: 1px solid rgba(154, 114, 255, 0.2);
border-radius: 24px;
box-shadow:
0 20px 40px rgba(0, 0, 0, 0.3),
0 0 40px rgba(154, 114, 255, 0.1),
inset 0 1px 0 rgba(255, 255, 255, 0.1);
position: relative;
overflow: hidden;
transition: all 0.3s ease;
max-width: 500px;
width: 100%;
}
.login-card::before {
content: '';
position: absolute;
top: 0;
left: -100%;
width: 100%;
height: 100%;
background: linear-gradient(90deg, transparent, rgba(154, 114, 255, 0.1), transparent);
transition: left 0.5s ease;
}
.login-card:hover::before {
left: 100%;
}
.login-card:hover {
transform: translateY(-5px);
box-shadow:
0 25px 50px rgba(0, 0, 0, 0.4),
0 0 60px rgba(154, 114, 255, 0.2),
inset 0 1px 0 rgba(255, 255, 255, 0.15);
}
.login-logo {
background: linear-gradient(135deg, #9a72ff, #b78bff, #7c6aff);
-webkit-background-clip: text;
background-clip: text;
-webkit-text-fill-color: transparent;
font-size: 3rem;
font-weight: 700;
margin-bottom: 1rem;
animation: logoGlow 3s ease-in-out infinite alternate;
}
@keyframes logoGlow {
0% {
filter: drop-shadow(0 0 10px rgba(154, 114, 255, 0.5));
}
100% {
filter: drop-shadow(0 0 20px rgba(154, 114, 255, 0.8));
}
}
.login-subtitle {
color: rgba(255, 255, 255, 0.7);
font-size: 1.1rem;
margin-bottom: 2rem;
font-weight: 300;
}
.btn-github {
background: linear-gradient(135deg, #d2c1ff, #a88bff, #8e88ff);
border: 1px solid rgba(154, 114, 255, 0.3);
font-weight: 500;
font-size: 16px;
letter-spacing: 0.2px;
font-family: 'Inter', sans-serif;
padding: 20px 30px;
border-radius: 14px;
position: relative;
overflow: hidden;
text-decoration: none;
box-shadow: 0 10px 30px rgba(140, 110, 255, 0.4);
transition: transform 0.5s ease, box-shadow 0.15s ease;
}
.btn-github::before {
content: '';
position: absolute;
top: 0;
left: -100%;
width: 100%;
height: 100%;
background: linear-gradient(90deg, transparent, rgba(255, 255, 255, 0.1), transparent);
transition: left 0.5s ease;
}
.btn-github:hover::before {
left: 100%;
}
.btn-github:hover {
background: linear-gradient(135deg, #c1a5ff, #8b5cf6);
border-color: rgba(168, 85, 247, 0.5);
transform: translateY(-2px);
box-shadow: 0 8px 20px rgba(168, 85, 247, 0.3);
}
.btn-github:active {
transform: translateY(0);
background: linear-gradient(135deg, #c1a5ff, #8b5cf6);
box-shadow: 0 2px 8px rgba(124, 58, 237, 0.4);
}
.btn-github:disabled {
opacity: 0.6;
cursor: not-allowed;
transform: none;
}
.btn-github:disabled:hover {
transform: none;
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2);
}
.floating-particles {
position: absolute;
top: 0;
left: 0;
width: 100%;
height: 100%;
overflow: hidden;
pointer-events: none;
}
.particle {
position: absolute;
width: 4px;
height: 4px;
background: rgba(193, 169, 255, 0.6);
border-radius: 50%;
animation: floatParticle 15s linear infinite;
}
.particle:nth-child(1) { left: 10%; animation-delay: 0s; }
.particle:nth-child(2) { left: 20%; animation-delay: 12s; }
.particle:nth-child(3) { left: 30%; animation-delay: 10s; }
.particle:nth-child(4) { left: 40%; animation-delay: 6s; }
.particle:nth-child(5) { left: 50%; animation-delay: 8s; }
.particle:nth-child(6) { left: 60%; animation-delay: 14s; }
.particle:nth-child(7) { left: 70%; animation-delay: 16s; }
.particle:nth-child(8) { left: 80%; animation-delay: 2s; }
.particle:nth-child(9) { left: 90%; animation-delay: 4s; }
@keyframes floatParticle {
0% {
transform: translateY(100vh) scale(0);
opacity: 0;
}
10% {
opacity: 1;
}
90% {
opacity: 1;
}
100% {
transform: translateY(-100px) scale(1);
opacity: 0;
}
}
.login-features {
margin-top: 2rem;
padding-top: 2rem;
padding-left: 10rem;
border-top: 1px solid rgba(154, 114, 255, 0.2);
}
.feature-item {
display: flex;
align-items: center;
margin-bottom: 1rem;
color: rgba(255, 255, 255, 0.8);
font-size: 0.9rem;
}
.feature-icon {
width: 20px;
height: 20px;
background: linear-gradient(135deg, #9a72ff, #b78bff);
border-radius: 50%;
margin-right: 12px;
display: flex;
align-items: center;
justify-content: center;
font-size: 10px;
color: white;
}
/* 登录页面进入动画 */
.login-container {
animation: fadeInUp 0.8s ease-out;
}
.login-card {
animation: slideInUp 0.6s ease-out 0.2s both;
}
.login-logo {
animation: logoGlow 3s ease-in-out infinite alternate, fadeInScale 0.8s ease-out 0.4s both;
}
.login-subtitle {
animation: fadeIn 0.8s ease-out 0.6s both;
}
.btn-github {
animation: fadeInUp 0.8s ease-out 0.8s both;
}
.login-features {
animation: fadeIn 0.8s ease-out 1s both;
}
.feature-item {
animation: slideInLeft 0.6s ease-out both;
}
.feature-item:nth-child(1) { animation-delay: 1.2s; }
.feature-item:nth-child(2) { animation-delay: 1.4s; }
.feature-item:nth-child(3) { animation-delay: 1.6s; }
.feature-item:nth-child(4) { animation-delay: 1.8s; }
.feature-item:nth-child(5) { animation-delay: 2.0s; }
.feature-item:nth-child(6) { animation-delay: 2.2s; }
@keyframes fadeInUp {
from {
opacity: 0;
transform: translateY(30px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
@keyframes slideInUp {
from {
opacity: 0;
transform: translateY(50px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
@keyframes fadeInScale {
from {
opacity: 0;
transform: scale(0.8);
}
to {
opacity: 1;
transform: scale(1);
}
}
@keyframes fadeIn {
from {
opacity: 0;
}
to {
opacity: 1;
}
}
@keyframes slideInLeft {
from {
opacity: 0;
transform: translateX(-20px);
}
to {
opacity: 1;
transform: translateX(0);
}
}
/* 响应式设计 */
@media (max-width: 768px) {
.login-logo {
font-size: 2.5rem;
}
.login-subtitle {
font-size: 1rem;
}
.btn-github {
padding: 14px 28px;
font-size: 1rem;
}
.login-card {
margin: 20px;
border-radius: 20px;
}
}
@media (max-width: 480px) {
.login-logo {
font-size: 2rem;
}
.login-subtitle {
font-size: 0.9rem;
}
.btn-github {
padding: 12px 24px;
font-size: 0.95rem;
}
.login-card .card-body {
padding: 2rem !important;
}
}
</style>
<script>
tailwind.config = {
theme: {
extend: {
colors: {
primary: '#9a72ff', // 主紫色
secondary: '#1b1240', // 深紫色背景
accent: '#b78bff', // 亮紫色强调
dark: '#0b0a20', // 深色背景
'dark-light': '#0f0e22', // 稍亮的深色
'laser-purple': '#9a72ff', // 激光紫色
'neon-purple': '#b78bff', // 霓虹紫色
'electric-purple': '#7c6aff', // 电光紫色
},
fontFamily: {
inter: ['Inter', 'sans-serif'],
},
boxShadow: {
'neon': '0 0 10px rgba(154, 114, 255, 0.5), 0 0 20px rgba(154, 114, 255, 0.3)',
'neon-lg': '0 0 15px rgba(154, 114, 255, 0.7), 0 0 30px rgba(154, 114, 255, 0.5)',
'laser': '0 0 20px rgba(154, 114, 255, 0.8), 0 0 40px rgba(154, 114, 255, 0.6), 0 0 60px rgba(154, 114, 255, 0.4), 0 0 80px rgba(154, 114, 255, 0.2)',
'laser-intense': '0 0 25px rgba(154, 114, 255, 0.9), 0 0 50px rgba(154, 114, 255, 0.7), 0 0 75px rgba(154, 114, 255, 0.5), 0 0 100px rgba(154, 114, 255, 0.3), 0 0 125px rgba(154, 114, 255, 0.1)',
'electric': '0 0 15px rgba(124, 106, 255, 0.8), 0 0 30px rgba(124, 106, 255, 0.6), 0 0 45px rgba(124, 106, 255, 0.4)'
}
}
}
}
</script>
<style type="text/tailwindcss">
[v-cloak] { display: none; }
/* 确保html和body能够正确填充 */
html {
height: 100%;
}
/* 确保body和app容器能够正确填充 */
body {
margin: 0;
padding: 0;
width: 125vw;
height: 125vh;
overflow-x: hidden;
overflow-y: auto;
/* 整体缩放80% */
transform: scale(0.8);
transform-origin: top left;
}
@layer utilities {
.content-auto {
content-visibility: auto;
}
.text-gradient {
background-clip: text;
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
}
/* 新增渐变图标颜色类 */
.text-gradient-icon {
background: linear-gradient(135deg, #d2c1ff, #a88bff, #8e88ff);
-webkit-background-clip: text;
background-clip: text;
-webkit-text-fill-color: transparent;
}
.bg-grid {
background-size: 40px 40px;
background-image:
linear-gradient(to right, rgba(210, 193, 255, 0.15) 1px, transparent 1px),
linear-gradient(to bottom, rgba(210, 193, 255, 0.15) 1px, transparent 1px);
}
.scrollbar-thin {
scrollbar-width: thin;
}
.scrollbar-thin::-webkit-scrollbar {
width: 4px;
}
.scrollbar-thin::-webkit-scrollbar-thumb {
background-color: rgba(210, 193, 255, 0.6);
border-radius: 2px;
}
/* 历史任务区域滚动条样式 - 与主内容区域保持一致 */
.history-tasks-scroll::-webkit-scrollbar {
width: 8px !important;
}
.history-tasks-scroll::-webkit-scrollbar-track {
background: rgba(27, 18, 64, 0.3) !important;
border-radius: 4px;
}
.history-tasks-scroll::-webkit-scrollbar-thumb {
background: linear-gradient(135deg, rgba(210, 193, 255, 0.8), rgba(168, 139, 255, 0.8)) !important;
border-radius: 4px;
border: 1px solid rgba(210, 193, 255, 0.3);
}
.history-tasks-scroll::-webkit-scrollbar-thumb:hover {
background: linear-gradient(135deg, rgba(210, 193, 255, 1), rgba(168, 139, 255, 1)) !important;
}
/* 确保历史任务区域可以正常滚动 */
.history-tasks-scroll {
scroll-behavior: smooth;
-webkit-overflow-scrolling: touch;
/* 移除max-height限制,让flex-1占据所有可用空间 */
}
/* 主内容区域滚动条样式 */
.content-area::-webkit-scrollbar {
width: 8px;
}
.content-area::-webkit-scrollbar-track {
background: rgba(27, 18, 64, 0.3);
border-radius: 4px;
}
.content-area::-webkit-scrollbar-thumb {
background: linear-gradient(135deg, rgba(210, 193, 255, 0.8), rgba(168, 139, 255, 0.8));
border-radius: 4px;
border: 1px solid rgba(210, 193, 255, 0.3);
}
.content-area::-webkit-scrollbar-thumb:hover {
background: linear-gradient(135deg, rgba(210, 193, 255, 1), rgba(168, 139, 255, 1));
}
/* 确保内容可以正常滚动 */
.content-area {
scroll-behavior: smooth;
-webkit-overflow-scrolling: touch;
}
.animate-pulse-slow {
animation: pulse 3s cubic-bezier(0.4, 0, 0.6, 0.5) infinite;
}
.animate-float {
animation: float 6s ease-in-out infinite;
}
.animate-laser-glow {
animation: laserGlow 2s ease-in-out infinite alternate;
}
.animate-electric-pulse {
animation: electricPulse 1.5s ease-in-out infinite;
}
.animate-neon-flicker {
animation: neonFlicker 3s ease-in-out infinite;
}
@keyframes float {
0% { transform: translateY(0px); }
50% { transform: translateY(-10px); }
100% { transform: translateY(0px); }
}
@keyframes laserGlow {
0% {
box-shadow: 0 0 10px rgba(210, 193, 255, 0.8), 0 0 40px rgba(210, 193, 255, 0.6), 0 0 60px rgba(210, 193, 255, 0.4);
filter: brightness(0.8) saturate(0.7);
}
100% {
box-shadow: 0 0 20px rgba(210, 193, 255, 1), 0 0 60px rgba(210, 193, 255, 0.8), 0 0 90px rgba(210, 193, 255, 0.6);
filter: brightness(1) saturate(1);
}
}
@keyframes electricPulse {
0%, 100% {
box-shadow: 0 0 15px rgba(142, 136, 255, 0.8), 0 0 30px rgba(142, 136, 255, 0.6);
transform: scale(1);
}
50% {
box-shadow: 0 0 25px rgba(142, 136, 255, 1), 0 0 50px rgba(142, 136, 255, 0.8), 0 0 75px rgba(142, 136, 255, 0.4);
transform: scale(1.02);
}
}
@keyframes neonFlicker {
0%, 100% {
box-shadow: 0 0 20px rgba(183, 139, 255, 0.8), 0 0 40px rgba(183, 139, 255, 0.6);
opacity: 1;
}
25% {
box-shadow: 0 0 15px rgba(183, 139, 255, 0.6), 0 0 30px rgba(183, 139, 255, 0.4);
opacity: 0.8;
}
75% {
box-shadow: 0 0 25px rgba(183, 139, 255, 1), 0 0 50px rgba(183, 139, 255, 0.8);
opacity: 1.1;
}
}
.bg-laser-gradient {
background: linear-gradient(135deg, #d2c1ff 0%, #a88bff 25%, #8e88ff 50%, #d2c1ff 75%, #a88bff 100%);
background-size: 200% 200%;
animation: gradientShift 3s ease-in-out infinite;
}
@keyframes gradientShift {
0%, 100% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
}
.text-laser-glow {
text-shadow: 0 0 10px rgba(154, 114, 255, 0.8), 0 0 20px rgba(154, 114, 255, 0.6), 0 0 30px rgba(154, 114, 255, 0.4);
}
.border-laser {
border-color: #d2c1ff;
box-shadow: 0 0 15px rgba(154, 114, 255, 0.6), inset 0 0 15px rgba(154, 114, 255, 0.1);
}
.btn-primary{
padding: 15px 25px;
border-radius: 14px;
font-weight: 500;
font-size: 14px;
letter-spacing: 0.2px;
font-family: 'Inter', sans-serif;
background: linear-gradient(135deg, #d2c1ff, #a88bff, #8e88ff);
border: 0;
text-decoration: none;
box-shadow: 0 10px 30px rgba(140, 110, 255, 0.4);
transition: transform 0.15s ease, box-shadow 0.15s ease;
}
.btn-primary:hover{
transform: translateY(-1px);
box-shadow: 0 14px 40px rgba(140, 110, 255, 0.55);
}
/* 修复布局问题 */
.task-type-btn {
padding: 0.75rem 1rem;
font-size: 0.875rem;
font-weight: 500;
transition-property: color, background-color;
transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
transition-duration: 150ms;
}
.task-type-btn:hover {
background-color: rgba(154, 114, 255, 0.1);
}
.model-selection {
display: flex;
flex-wrap: wrap;
gap: 0.5rem;
}
.upload-section {
display: grid;
grid-template-columns: repeat(1, minmax(0, 1fr));
gap: 1.5rem;
margin-bottom: 1.5rem;
}
@media (min-width: 768px) {
.upload-section {
grid-template-columns: repeat(2, minmax(0, 1fr));
}
}
.upload-area {
position: relative;
border: 2px dashed rgba(154, 114, 255, 0.4);
border-radius: 0.75rem;
padding: 1.5rem;
text-align: center;
transition-property: all;
transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
transition-duration: 150ms;
cursor: pointer;
background-color: rgba(27, 18, 64, 0.1);
min-height: 250px; /* 确保上传区域有最小高度,防止预览时高度收缩 */
}
.upload-area:hover {
border-color: rgba(154, 114, 255, 0.7);
box-shadow: 0 0 20px rgba(154, 114, 255, 0.8), 0 0 40px rgba(154, 114, 255, 0.6), 0 0 60px rgba(154, 114, 255, 0.4), 0 0 80px rgba(154, 114, 255, 0.2);
}
.upload-icon {
margin: 0 auto;
width: 4rem;
height: 4rem;
background-color: rgba(154, 114, 255, 0.2);
border-radius: 9999px;
display: flex;
align-items: center;
justify-content: center;
margin-bottom: 1rem;
transition-property: all;
transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
transition-duration: 150ms;
}
.upload-area:hover .upload-icon {
background-color: rgba(154, 114, 255, 0.3);
}
/* 图片预览占据整个上传区域 */
.image-preview {
position: absolute;
top: 0;
left: 0;
width: 100%;
height: 100%;
overflow: hidden;
z-index: 10;
display: flex;
align-items: center;
justify-content: center;
cursor: pointer;
}
.image-preview img {
height: 100%;
width: auto;
max-width: 100%;
display: block;
margin: 0 auto;
object-fit: contain;
transition: all 0.3s ease;
}
/* 音频预览占据整个上传区域 */
.audio-preview {
position: absolute;
top: 0;
left: 0;
width: 100%;
height: 100%;
border-radius: 0.75rem;
overflow: hidden;
z-index: 10;
display: flex;
align-items: center;
justify-content: center;
background-color: rgba(154, 114, 255, 0.1);
border: 2px solid rgba(154, 114, 255, 0.3);
cursor: pointer;
}
.audio-preview audio {
width: 90%;
height: 60px;
max-height: 80%;
border-radius: 0.5rem;
background-color: rgba(27, 18, 64, 0.3);
display: block;
}
/* 确保音频控件在容器中正确显示 */
.audio-preview audio::-webkit-media-controls {
background-color: rgba(27, 18, 64, 0.5);
border-radius: 0.5rem;
}
/* 上传内容样式 */
.upload-content {
width: 100%;
height: 100%;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
}
.btn-close {
position: absolute;
top: 0.5rem;
right: 0.5rem;
background-color: #ef4444;
color: white;
border-radius: 9999px;
width: 1.5rem;
height: 1.5rem;
display: flex;
align-items: center;
justify-content: center;
font-size: 0.75rem;
cursor: pointer;
z-index: 20;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.3);
}
/* 确保flexbox布局正确 */
#app {
display: flex;
width: 100%;
height: 100%;
}
.bg-linear-dark {
background-color: linear-gradient(135deg, #0b0a20 0%, #1b1240 50%, #0f0e22 100%);
}
aside {
flex-shrink: 0;
width: 280px; /* 默认展开宽度 */
min-width: 3rem; /* 最小宽度 */
max-width: 500px; /* 最大宽度 */
background-color: linear-gradient(135deg, #0b0a20 0%, #1b1240 50%, #0f0e22 100%);
border-right: 1px solid rgba(154, 114, 255, 0.4);
display: flex;
flex-direction: column;
transition-property: all;
transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
transition-duration: 300ms;
z-index: 10;
position: relative;
}
/* 拖拽调整器 */
.resize-handle {
position: absolute;
top: 0;
right: 0;
width: 4px;
height: 100%;
background: transparent;
cursor: col-resize;
z-index: 20;
transition: background-color 0.2s ease;
}
.resize-handle:hover {
background: rgba(154, 114, 255, 0.5);
}
.resize-handle:active {
background: rgba(154, 114, 255, 0.8);
}
/* 拖拽时的视觉反馈 */
.resizing {
user-select: none;
pointer-events: none;
}
.resizing * {
pointer-events: none;
}
main {
flex: 1;
display: flex;
flex-direction: column;
min-width: 0;
width: calc(100% - 280px); /* 主内容区域占据剩余宽度,适应展开的侧边栏 */
height: 100%;
}
/* 内容区域全屏显示 */
.content-area {
flex: 1;
overflow-y: auto;
background-color: #0b0a20;
padding: 2rem;
width: 100%;
min-height: 0; /* 确保flex子元素可以收缩 */
}
/* 任务创建面板全屏 */
#task-creator {
max-width: none;
width: 80%;
padding: 0 1rem;
}
/* 任务详情面板全屏 */
.task-detail-panel {
max-width: none;
width: 80%;
padding: 0 0rem;
}
/* 上传区域全屏布局 */
.upload-section {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(300px, 1fr));
gap: 2rem;
margin-bottom: 2rem;
width: 100%;
}
/* 任务类型选择全屏 */
.task-type-selection {
width: 100%;
margin-bottom: 2rem;
}
.task-type-buttons {
display: flex;
width: 100%;
border-bottom: 1px solid rgba(154, 114, 255, 0.3);
}
.task-type-btn {
flex: 1;
padding: 1rem 1.5rem;
font-size: 1rem;
font-weight: 500;
transition-property: color, background-color;
transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
transition-duration: 150ms;
text-align: center;
}
/* 模型选择全屏 */
.model-selection {
display: flex;
flex-wrap: wrap;
gap: 1rem;
width: 100%;
justify-content: flex-start;
}
/* 提示词输入全屏 */
.prompt-input-section {
width: 100%;
margin-bottom: 2rem;
}
.prompt-textarea {
width: 100%;
min-height: 150px;
resize: vertical;
}
/* 侧边栏折叠样式 */
.sidebar-collapsed {
width: 3rem !important;
}
/* 侧边栏展开样式 */
aside:not(.sidebar-collapsed) {
width: 280px !important;
}
.sidebar-collapsed .sidebar-content {
display: none !important;
}
.sidebar-collapsed .resize-handle {
display: none !important;
}
.sidebar-collapsed .user-info-section {
display: none !important;
}
.sidebar-collapsed .sidebar-header {
justify-content: center;
padding: 1rem 0.5rem;
}
.sidebar-collapsed .sidebar-header h1 {
display: none;
}
/* 展开状态下显示所有内容 */
aside:not(.sidebar-collapsed) .sidebar-content {
display: flex !important;
}
aside:not(.sidebar-collapsed) .resize-handle {
display: block !important;
}
aside:not(.sidebar-collapsed) .sidebar-header {
justify-content: space-between;
padding: 1rem;
}
aside:not(.sidebar-collapsed) .sidebar-header h1 {
display: flex;
}
aside:not(.sidebar-collapsed) .sidebar-header .toggle-btn {
display: flex !important;
align-items: center;
justify-content: center;
}
aside:not(.sidebar-collapsed) .user-info-section {
display: block !important;
}
.sidebar-collapsed .sidebar-header .toggle-btn {
display: flex !important;
align-items: center;
justify-content: center;
width: 2rem;
height: 2rem;
border-radius: 0.375rem;
background-color: rgba(154, 114, 255, 0.1);
border: 1px solid rgba(154, 114, 255, 0.3);
margin: 0 auto;
}
/* 当侧边栏折叠时,主内容区域调整 */
.sidebar-collapsed + main {
width: calc(100% - 3rem);
}
/* 当侧边栏展开时,主内容区域调整 */
aside:not(.sidebar-collapsed) + main {
width: calc(100% - 280px);
}
/* 响应式设计 */
@media (max-width: 1200px) {
aside:not(.sidebar-collapsed) {
width: 250px !important;
}
.sidebar-collapsed + main {
width: calc(100% - 3rem);
}
aside:not(.sidebar-collapsed) + main {
width: calc(100% - 250px);
}
}
@media (max-width: 768px) {
aside:not(.sidebar-collapsed) {
width: 200px !important;
}
.sidebar-collapsed + main {
width: calc(100% - 3rem);
}
aside:not(.sidebar-collapsed) + main {
width: calc(100% - 200px);
}
.upload-section {
grid-template-columns: 1fr;
}
}
/* 修复任务项样式 */
.task-item {
padding: 0.75rem;
border-radius: 0.5rem;
cursor: pointer;
transition-property: all;
transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
transition-duration: 200ms;
}
.task-item:hover {
background-color: rgba(154, 114, 255, 0.15);
box-shadow: 0 0 20px rgba(154, 114, 255, 0.8), 0 0 40px rgba(154, 114, 255, 0.6), 0 0 60px rgba(154, 114, 255, 0.4), 0 0 80px rgba(154, 114, 255, 0.2);
}
/* 修复状态指示器 */
.status-indicator {
width: 0.75rem;
height: 0.75rem;
border-radius: 9999px;
box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -2px rgba(0, 0, 0, 0.05);
}
/* 修复按钮样式 */
.btn-primary {
padding: 12px 22px;
border-radius: 14px;
font-weight: 700;
letter-spacing: 0.2px;
color: #0c0920;
background: linear-gradient(135deg, #d2c1ff, #a88bff, #8e88ff);
border: 0;
text-decoration: none;
box-shadow: 0 10px 30px rgba(140, 110, 255, 0.4);
transition: transform 0.15s ease, box-shadow 0.15s ease;
cursor: pointer;
display: inline-block;
}
.btn-primary:hover {
transform: translateY(-1px);
box-shadow: 0 14px 40px rgba(140, 110, 255, 0.55);
}
/* 修复模型按钮样式 */
.model-btn {
padding: 0.5rem 1rem;
border-radius: 0.5rem;
font-size: 0.875rem;
transition-property: all;
transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
transition-duration: 150ms;
cursor: pointer;
border: 1px solid;
}
.model-btn.active {
background-color: rgba(154, 114, 255, 0.2);
border-color: rgba(154, 114, 255, 0.4);
box-shadow: 0 0 20px rgba(154, 114, 255, 0.8), 0 0 40px rgba(154, 114, 255, 0.6), 0 0 60px rgba(154, 114, 255, 0.4), 0 0 80px rgba(154, 114, 255, 0.2);
animation: electricPulse 1.5s ease-in-out infinite;
}
/* 确保内容区域正确滚动 */
.content-scroll {
flex: 1;
overflow-y: auto;
}
/* 任务进行中面板样式 */
.task-running-panel .animate-pulse-slow {
animation: pulse 3s cubic-bezier(0.4, 0, 0.6, 0.5) infinite;
}
/* 任务失败面板样式 */
.task-failed-panel .bg-red-500\/10 {
background-color: rgba(239, 68, 68, 0.1);
}
.task-detail-panel video {
width: 100%;
height: 100%;
object-fit: cover;
}
/* 素材预览样式 */
.material-preview {
display: flex;
flex-wrap: wrap;
gap: 0.75rem;
}
.material-preview img {
border-radius: 0.5rem;
transition: all 0.2s ease;
}
.material-preview img:hover {
transform: scale(1.05);
box-shadow: 0 0 20px rgba(154, 114, 255, 0.6);
}
/* 任务状态指示器增强 */
.status-indicator {
position: relative;
}
.status-indicator::after {
content: '';
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
width: 0.25rem;
height: 0.25rem;
background-color: currentColor;
border-radius: 50%;
opacity: 0.8;
}
/* 任务面板切换动画 */
.task-panel-enter-active,
.task-panel-leave-active {
transition: all 0.3s ease;
}
.task-panel-enter-from {
opacity: 0;
transform: translateY(20px);
}
.task-panel-leave-to {
opacity: 0;
transform: translateY(-20px);
}
/* 响应式任务面板 */
@media (max-width: 768px) {
.task-detail-panel {
padding: 0 0.5rem;
}
}
/* 提示消息动画 */
.animate-slide-down {
animation: slideDown 0.3s ease-out;
}
@keyframes slideDown {
0% {
opacity: 0;
transform: translate(-50%, -100%);
}
100% {
opacity: 1;
transform: translate(-50%, 0);
}
}
/* 提示消息样式 - 统一浅色透明背景 */
.alert {
backdrop-filter: blur(15px);
background: rgba(255, 255, 255, 0.15);
border-radius: 0.75rem;
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1);
color: #333;
}
}
</style>
</head>
<body
class="bg-dark text-gray-100 font-inter"
>
<div id="app">
<!-- 登录页面 -->
<div v-if="!isLoggedIn" class="login-container">
<!-- 浮动粒子背景 -->
<div class="floating-particles">
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
</div>
<div class="login-card">
<div class="card-body text-center p-5">
<!-- Logo和标题 -->
<div class="mb-4">
<div class="login-logo">
<i class="fas fa-film me-3"></i>
LightX2V
</div>
<p class="login-subtitle">一个强大的视频生成平台</p>
</div>
<!-- 登录按钮 -->
<button @click="loginWithGitHub" class="btn btn-github btn-lg w-100 mb-4" :disabled="loading">
<i class="fab fa-github me-2"></i>
{{ loading ? '登录中...' : '使用GitHub登录' }}
</button>
<!-- 功能特性 -->
<div class="login-features">
<div class="feature-item">
<div class="feature-icon">🎭</div>
<span>电影级数字人视频</span>
</div>
<div class="feature-item">
<div class="feature-icon"></div>
<span>20倍生成提速</span>
</div>
<div class="feature-item">
<div class="feature-icon">💰</div>
<span>超低成本生成</span>
</div>
<div class="feature-item">
<div class="feature-icon">🎯</div>
<span>精准口型对齐</span>
</div>
<div class="feature-item">
<div class="feature-icon">📱</div>
<span>分钟级视频时长</span>
</div>
<div class="feature-item">
<div class="feature-icon">🎨</div>
<span>多场景应用</span>
</div>
</div>
</div>
</div>
</div>
<!-- 主应用页面 -->
<div v-else class="main-container">
<!-- 浮动粒子背景 -->
<div class="floating-particles">
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
<div class="particle"></div>
</div>
<!-- 侧边栏 - 历史任务 -->
<aside class="w-64 bg-linear-dark border-r border-laser-purple/40 flex flex-col transition-all duration-300 ease-in-out z-10" ref="sidebar" :class="{ 'sidebar-collapsed': sidebarCollapsed }">
<!-- 拖拽调整器 -->
<div class="resize-handle" @mousedown="startResize"></div>
<div class="p-4 border-b border-laser-purple/40 sidebar-header">
<div class="flex items-center justify-between">
<h1 class="text-xl font-bold flex items-center">
<i class="fas fa-video text-gradient-icon mr-2"></i>
<span>LightX2V</span>
</h1>
<button
@click="toggleSidebar"
class="text-gray-400 hover:text-gradient-icon transition-colors flex-shrink-0 toggle-btn"
title="展开/折叠侧边栏">
<i class="fas fa-bars"></i>
</button>
</div>
</div>
<div class="sidebar-content flex flex-col flex-1 min-h-0">
<div class="p-3 border-b border-laser-purple/40">
<button
@click="showTaskCreator"
class="w-full btn-primary py-2 rounded-lg flex items-center justify-center transition-all duration-200 font-medium text-sm">
<i class="fas fa-plus mr-2"></i>
新建任务
</button>
<div class="relative mt-3">
<input
v-model="searchQuery"
class="w-full bg-dark-light border border-laser-purple/30 rounded-lg py-2 pl-10 pr-4 text-sm focus:outline-none focus:ring-2 focus:ring-laser-purple/50 transition-all focus:border-laser focus:shadow-laser"
placeholder="搜索"
type="text"
/>
<i class="fas fa-search absolute left-3 top-1/2 transform -translate-y-1/2 text-gray-400"></i>
</div>
<!-- 状态过滤和刷新 -->
<div class="mt-3">
<div class="flex flex-wrap gap-1 justify-between items-center">
<div class="flex flex-wrap gap-1">
<button
@click="statusFilter = 'ALL'"
class="px-2 py-1 text-xs rounded transition-all"
:class="statusFilter === 'ALL' ? 'bg-dark-light bg-laser-purple/40' : 'bg-dark-light text-gray-400 hover:bg-laser-purple/20'"
>
全部
</button>
<button
@click="statusFilter = 'SUCCEED'"
class="px-2 py-1 text-xs rounded transition-all"
:class="statusFilter === 'SUCCEED' ? 'bg-green-500/30 text-green-400' : 'bg-dark-light text-gray-400 hover:bg-green-500/20'"
>
成功
</button>
<button
@click="statusFilter = 'RUNNING'"
class="px-2 py-1 text-xs rounded transition-all"
:class="statusFilter === 'RUNNING' ? 'bg-yellow-500/30 text-yellow-400' : 'bg-dark-light text-gray-400 hover:bg-yellow-500/20'"
>
进行中
</button>
<button
@click="statusFilter = 'FAILED'"
class="px-2 py-1 text-xs rounded transition-all"
:class="statusFilter === 'FAILED' ? 'bg-red-500/30 text-red-400' : 'bg-dark-light text-gray-400 hover:bg-red-500/20'"
>
失败
</button>
</div>
<button
@click="refreshTasks"
class="text-gray-400 hover:text-gradient-icon transition-colors flex-shrink-0"
title="刷新任务列表"
>
<i class="fas fa-sync-alt"></i>
</button>
</div>
</div>
</div>
<div class="flex-1 overflow-y-auto history-tasks-scroll p-2 min-h-0">
<div class="text-xs uppercase text-gray-400 font-semibold mb-2 px-3">
历史任务
</div>
<!-- 历史任务列表 -->
<div class="space-y-1" id="history-tasks">
<div v-if="filteredTasks.length === 0" class="flex-col items-center justify-center py-12 text-center">
<p class="text-gray-400 text-sm">暂无历史任务</p>
<p class="text-gray-500 text-xs mt-1">开始创建你的第一个AI视频吧</p>
</div>
<!-- 任务项 -->
<div
v-for="task in filteredTasks"
:key="task.task_id"
class="task-item p-2 rounded-lg cursor-pointer hover:bg-laser-purple/15 hover:shadow-laser transition-all duration-200"
:class="getTaskItemClass(task.status)"
@click="viewTaskDetail(task)"
>
<div class="flex items-start gap-3 mb-2">
<div class="w-16 h-12 bg-dark-light rounded overflow-hidden flex-shrink-0">
<template v-for="(thumbnailInfo, index) in [getVideoThumbnailInfo(task.task_id, 'output_video')]" :key="index">
<img
v-if="thumbnailInfo.hasThumbnail"
:src="thumbnailInfo.url"
alt="任务预览"
class="w-full h-full object-cover"
@error="handleThumbnailError"
/>
<div v-else class="w-full h-full bg-laser-purple/20 flex items-center justify-center">
<i class="fas fa-video text-gradient-icon text-xl"></i>
</div>
</template>
</div>
<div class="flex-1 min-w-0">
<div class="flex justify-between items-start mb-1 gap-2">
<h3 class="font-medium text-sm truncate max-w-[calc(100%-2rem)]">{{ task.params.prompt || '无标题任务' }}</h3>
<div
:class="getStatusIndicatorClass(task.status)"
:title="getTaskStatusDisplay(task.status)"
class="flex-shrink-0"
></div>
</div>
<p class="text-xs text-gray-400 mb-2 line-clamp-1">
{{ getTaskTypeName(task) }} | {{ getRelativeTime(task.create_t) }}
</p>
<div class="flex items-center justify-between text-xs gap-2">
<span class="text-gray-500 truncate max-w-[calc(100%-4rem)]">{{ task.model_cls }}</span>
<span :class="getTaskStatusColor(task.status)" class="font-medium flex-shrink-0">
{{ getTaskStatusDisplay(task.status) }}
</span>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
<!-- 用户信息区域 - 任务栏底部分区 -->
<div class="mt-auto p-3 border-t border-laser-purple/40 user-info-section">
<div class="flex items-center space-x-3">
<!-- 用户头像 -->
<div v-if="currentUser.avatar_url" class="w-10 h-10 rounded-full border border-laser-purple/40 overflow-hidden flex-shrink-0">
<img
:src="currentUser.avatar_url"
alt="用户头像"
class="w-full h-full object-cover"
/>
</div>
<!-- 默认用户图标 -->
<div v-else class="w-10 h-10 rounded-full border border-laser-purple/40 bg-laser-purple/20 flex items-center justify-center flex-shrink-0">
<i class="fas fa-user text-gradient-icon"></i>
</div>
<!-- 用户信息 -->
<div class="flex-1 min-w-0">
<div class="text-sm font-medium text-gray-100 truncate">
{{ currentUser.username }}
</div>
<div class="text-xs text-gray-400 truncate">
{{ currentUser.email }}
</div>
</div>
<!-- 退出按钮 -->
<button @click="logout" class="text-gray-400 hover:text-gradient-icon transition-colors flex-shrink-0" title="退出登录">
<i class="fas fa-sign-out-alt"></i>
</button>
</div>
</div>
</aside>
<!-- 主内容区 -->
<main class="flex-1 flex flex-col overflow-hidden">
<!-- 内容区域 -->
<div class="flex-1 overflow-y-auto bg-dark p-6 content-area">
<!-- 模板选择浮窗 -->
<div v-cloak>
<div v-if="showImageTemplates || showAudioTemplates"
class="fixed inset-0 bg-black/50 z-50 flex items-center justify-center"
@click="showImageTemplates = false; showAudioTemplates = false">
<div class="bg-secondary rounded-xl p-6 max-w-4xl w-full mx-4 max-h-[80vh] overflow-hidden"
@click.stop>
<!-- 浮窗头部 -->
<div class="flex items-center justify-between mb-4">
<h3 class="text-lg font-medium text-white">
<i v-if="showImageTemplates" class="fas fa-image text-gradient-icon mr-2"></i>
<i v-if="showAudioTemplates" class="fas fa-music text-gradient-icon mr-2"></i>
{{ showImageTemplates ? '选择图片模板' : '选择音频模板' }}
</h3>
<button @click="showImageTemplates = false; showAudioTemplates = false"
class="text-gray-400 hover:text-white transition-colors">
<i class="fas fa-times text-xl"></i>
</button>
</div>
<!-- 图片模板网格 -->
<div v-if="showImageTemplates" class="overflow-y-auto max-h-[50vh]">
<div v-if="imageTemplates.length > 0" class="grid grid-cols-4 gap-4">
<div v-for="template in imageTemplates" :key="template.filename"
@click="selectImageTemplate(template)"
class="relative group cursor-pointer rounded-lg overflow-hidden border border-gray-700 hover:border-laser-purple/50 transition-all">
<img :src="template.url" :alt="template.filename"
class="w-full h-32 object-cover">
<div class="absolute inset-0 bg-black/50 opacity-0 group-hover:opacity-100 transition-opacity flex items-center justify-center">
<i class="fas fa-check text-white text-2xl"></i>
</div>
<div class="absolute bottom-0 left-0 right-0 bg-black/80 text-white text-xs p-2">
<div class="truncate">{{ template.filename }}</div>
</div>
</div>
</div>
<div v-else class="flex flex-col items-center justify-center py-12 text-center">
<div class="w-16 h-16 bg-laser-purple/20 rounded-full flex items-center justify-center mb-4">
<i class="fas fa-image text-gradient-icon text-2xl"></i>
</div>
<p class="text-gray-400 text-lg mb-2">目前暂无图片模板</p>
</div>
</div>
<!-- 音频模板列表 -->
<div v-if="showAudioTemplates" class="overflow-y-auto max-h-[50vh]">
<div v-if="audioTemplates.length > 0" class="space-y-3">
<div v-for="template in audioTemplates" :key="template.filename"
@click="selectAudioTemplate(template)"
class="flex items-center gap-4 p-4 rounded-lg border border-gray-700 hover:border-laser-purple/50 transition-all cursor-pointer bg-dark-light/50">
<div class="w-12 h-12 bg-laser-purple/20 rounded-lg flex items-center justify-center">
<i class="fas fa-music text-gradient-icon text-xl"></i>
</div>
<div class="flex-1">
<div class="text-white font-medium">{{ template.filename }}</div>
<div class="text-gray-400 text-sm">音频模板</div>
</div>
<button @click.stop="previewAudioTemplate(template)"
class="px-3 py-2 bg-laser-purple/20 hover:bg-laser-purple/30 text-gradient-icon rounded-lg transition-all">
<i class="fas fa-play mr-2"></i>
试听
</button>
</div>
</div>
<div v-else class="flex flex-col items-center justify-center py-12 text-center">
<div class="w-16 h-16 bg-laser-purple/20 rounded-full flex items-center justify-center mb-4">
<i class="fas fa-music text-gradient-icon text-2xl"></i>
</div>
<p class="text-gray-400 text-lg mb-2">目前暂无音频模板</p>
</div>
</div>
</div>
</div>
</div>
<!-- 任务创建面板 -->
<div v-if="showCreator" class="max-w-4xl mx-auto" id="task-creator">
<!-- 任务类型选择 -->
<div class="mb-8 task-type-selection">
<div class="flex border-b border-laser-purple/30 task-type-buttons">
<button
v-for="taskType in availableTaskTypes"
:key="taskType"
@click="selectTask(taskType)"
class="task-type-btn"
:class="getTaskTypeBtnClass(taskType)"
>
<i :class="getTaskTypeIcon(taskType)" class="mr-2"></i>
{{ getTaskTypeName(taskType) }}
</button>
</div>
</div>
<!-- 模型选择 -->
<div v-if="selectedTaskId" class="mb-6">
<label class="block text-sm text-gray-400 mb-2">选择模型</label>
<div class="model-selection">
<button
v-for="model in availableModelClasses"
:key="model"
@click="selectModel(model)"
class="model-btn px-4 py-2 rounded-lg text-sm transition-all"
:class="getModelBtnClass(model)"
>
<i v-if="model === getCurrentForm().model_cls" class="fas fa-star text-yellow-400 mr-1"></i>
{{ model }}
</button>
</div>
</div>
<!-- 上传区域 -->
<div v-if="selectedTaskId === 'i2v' || selectedTaskId === 'digital_human'" class="upload-section">
<!-- 上传图片 -->
<div v-if="selectedTaskId === 'i2v' || selectedTaskId === 'digital_human'" class="upload-area" @click="triggerImageUpload">
<!-- 默认上传界面 -->
<div v-if="!getCurrentImagePreview()" class="upload-content">
<div class="upload-icon">
<i class="fas fa-image text-gradient-icon text-xl"></i>
</div>
<p class="text-xs text-gray-400 mb-4">支持JPG、PNG格式,大小不超过10MB</p>
<div class="flex gap-2">
<button class="btn-primary px-4 py-1.5 rounded-lg transition-all flex-1">上传图片</button>
<button @click.stop="showImageTemplates = !showImageTemplates"
class="px-4 py-1.5 rounded-lg bg-laser-purple/20 hover:bg-laser-purple/30 text-gradient-icon border border-laser-purple/40 rounded-lg transition-all">
<i class="fas fa-images mr-1"></i>
模板
</button>
</div>
</div>
<!-- 图片预览 -->
<div v-if="getCurrentImagePreview()" class="image-preview group">
<img :src="getCurrentImagePreview()" alt="预览图片" class="w-full h-full object-cover rounded-lg transition-all duration-300 group-hover:brightness-50">
<!-- 悬停时显示的操作按钮,位置在中下方 -->
<div class="absolute inset-x-0 bottom-4 flex items-center justify-center opacity-0 group-hover:opacity-100 transition-opacity duration-300">
<div class="flex space-x-3">
<button
@click.stop="triggerImageUpload"
class="w-12 h-12 flex items-center justify-center bg-white/15 text-white p-3 rounded-full transition-all duration-200 hover:scale-110 shadow-lg"
title="重新上传">
<i class="fas fa-upload text-lg"></i>
</button>
<button
@click.stop="removeImage"
class="w-12 h-12 flex items-center justify-center bg-white/15 text-white p-3 rounded-full transition-all duration-200 hover:scale-110 shadow-lg"
title="删除图片">
<i class="fas fa-trash text-lg"></i>
</button>
</div>
</div>
</div>
<input
type="file"
ref="imageInput"
@change="handleImageUpload"
accept="image/*"
style="display: none;">
</div>
<!-- 上传音频 -->
<div v-if="selectedTaskId === 'digital_human'" class="upload-area" @click="triggerAudioUpload">
<!-- 默认上传界面 -->
<div v-if="!getCurrentAudioPreview()" class="upload-content">
<div class="upload-icon">
<i class="fas fa-microphone text-gradient-icon text-xl"></i>
</div>
<p class="text-xs text-gray-400 mb-4">支持MP4、WAV格式,最长支持120s</p>
<div class="flex gap-2">
<button class="btn-primary px-4 py-1.5 rounded-lg transition-all flex-1">上传音频</button>
<button @click.stop="showAudioTemplates = !showAudioTemplates"
class="px-4 py-1.5 rounded-lg bg-laser-purple/20 hover:bg-laser-purple/30 text-gradient-icon border border-laser-purple/40 rounded-lg transition-all">
<i class="fas fa-music mr-1"></i>
模板
</button>
</div>
</div>
<!-- 音频预览 -->
<div v-if="getCurrentAudioPreview()" class="audio-preview group" @click.stop>
<audio controls class="w-full h-full">
<source :src="getCurrentAudioPreview()" :type="getAudioMimeType()">
</audio>
<!-- 悬停时显示的操作按钮,位置在中下方 -->
<div class="absolute inset-x-0 bottom-4 flex items-center justify-center opacity-0 group-hover:opacity-100 transition-opacity duration-300 bg-black/20">
<div class="flex space-x-3">
<button
@click.stop="triggerAudioUpload"
class="w-12 h-12 flex items-center justify-center bg-white/15 text-white p-3 rounded-full transition-all duration-200 hover:scale-110 shadow-lg"
title="重新上传">
<i class="fas fa-upload text-lg"></i>
</button>
<button
@click.stop="removeAudio"
class="w-12 h-12 flex items-center justify-center bg-white/15 text-white p-3 rounded-full transition-all duration-200 hover:scale-110 shadow-lg"
title="删除音频">
<i class="fas fa-trash text-lg"></i>
</button>
</div>
</div>
</div>
<input
type="file"
ref="audioInput"
@change="handleAudioUpload"
accept="audio/*"
style="display: none;">
</div>
</div>
<!-- 提示词输入 -->
<div class="mb-6 prompt-input-section">
<div class="flex justify-between items-center mb-2">
<label class="block text-sm text-gray-400">提示词</label>
<div class="flex space-x-2">
<button @click="showPromptTemplates" class="text-xs text-gray-400 hover:text-gradient-icon transition-colors hover:text-gradient-icon" title="提示词模板">
<i class="fas fa-magic"></i>
</button>
<button @click="showPromptHistory" class="text-xs text-gray-400 hover:text-gradient-icon transition-colors hover:text-gradient-icon" title="历史记录">
<i class="fas fa-history"></i>
</button>
</div>
</div>
<!-- 提示词模板选择 -->
<div v-if="showTemplates" class="mb-4 p-4 bg-linear-dark/30 rounded-lg">
<h4 class="text-sm font-medium mb-3 text-gradient-icon">选择提示词模板</h4>
<div class="grid grid-cols-1 md:grid-cols-2 gap-3">
<button
v-for="template in getPromptTemplates(selectedTaskId)"
:key="template.id"
@click="selectPromptTemplate(template)"
class="p-3 text-left bg-dark-light rounded-lg hover:bg-laser-purple/20 transition-all border border-transparent hover:border-laser-purple/40"
>
<div class="font-medium text-sm mb-1">{{ template.title }}</div>
<div class="text-xs text-gray-400 line-clamp-2">{{ template.prompt }}</div>
</button>
</div>
<button @click="showTemplates = false" class="mt-3 text-xs text-gray-400 hover:text-gradient-icon">
<i class="fas fa-times mr-1"></i>关闭模板
</button>
</div>
<!-- 提示词历史记录 -->
<div v-if="showHistory" class="mb-4 p-4 bg-secondary/30 rounded-lg">
<div class="flex justify-between items-center mb-3">
<h4 class="text-sm font-medium text-gradient-icon">提示词历史记录</h4>
<button @click="clearPromptHistory" class="text-xs text-red-400 hover:text-red-300 transition-colors" title="清空历史记录">
<i class="fas fa-trash"></i>
</button>
</div>
<div v-if="getPromptHistory().length === 0" class="text-center py-4 text-gray-400 text-sm">
暂无历史记录
</div>
<div v-else class="space-y-2 max-h-40 overflow-y-auto">
<button
v-for="(history, index) in getPromptHistory()"
:key="index"
@click="selectPromptHistory(history)"
class="w-full p-3 text-left bg-dark-light rounded-lg hover:bg-laser-purple/20 transition-all border border-transparent hover:border-laser-purple/40"
>
<div class="text-xs text-gray-300 line-clamp-2">{{ history }}</div>
</button>
</div>
<button @click="showHistory = false" class="mt-3 text-xs text-gray-400 hover:text-gradient-icon">
<i class="fas fa-times mr-1"></i>关闭历史
</button>
</div>
<div class="relative">
<textarea
v-model="getCurrentForm().prompt"
class="w-full bg-dark-light border border-laser-purple/40 rounded-lg p-4 pr-16 text-sm min-h-[120px] focus:outline-none focus:ring-2 focus:ring-laser-purple/60 transition-all resize-none scrollbar-thin focus:border-laser focus:shadow-laser prompt-textarea"
:placeholder="getPromptPlaceholder()"
rows="3"
required
></textarea>
</div>
<!-- 高级配置选项 -->
<!-- <div class="mt-4 grid grid-cols-1 md:grid-cols-2 gap-4">
<div>
<label class="block text-sm text-gray-400 mb-2">种子值</label>
<input
v-model="getCurrentForm().seed"
type="number"
class="w-full bg-dark-light border border-laser-purple/40 rounded-lg px-3 py-2 text-sm focus:outline-none focus:ring-2 focus:ring-laser-purple/60 transition-all"
placeholder="随机种子值"
/>
</div> -->
<!-- <div>
<label class="block text-sm text-gray-400 mb-2">推理阶段</label>
<select
v-model="getCurrentForm().stage"
class="w-full bg-dark-light border border-laser-purple/40 rounded-lg px-3 py-2 text-sm focus:outline-none focus:ring-2 focus:ring-laser-purple/60 transition-all"
>
<option value="single_stage">单阶段</option>
<option value="multi_stage">多阶段</option>
<option value="original">原始阶段</option>
</select>
</div> -->
</div>
<div class="flex justify-between items-center mt-4">
<p class="text-xs text-gray-500">最多支持500个字符</p>
<div class="flex space-x-2">
<button @click="clearPrompt" class="text-xs px-3 py-1 rounded transition-all">
<i class="fas fa-sync-alt mr-1"></i>
清空
</button>
<button @click="submitTask" :disabled="submitting" class="btn-primary px-4 py-1 rounded transition-all">
<i class="fas fa-play mr-1"></i>
{{ submitting ? '生成中...' : '生成视频' }}
</button>
</div>
</div>
</div>
<!-- 任务详情显示面板 -->
<div v-if="selectedTask && !showCreator" class="max-w-4xl mx-auto task-detail-panel">
<div class="mb-6">
<!-- 输出视频 -->
<div v-if="selectedTask.status === 'SUCCEED' && selectedTask.outputs && Object.keys(selectedTask.outputs).length > 0" class="bg-secondary/30 rounded-xl p-4 mb-6 task-completed-panel">
<h4 class="text-sm font-medium mb-3 flex items-center">
<i class="fas fa-video text-gradient-icon mr-2"></i>
输出结果
</h4>
<div class="space-y-3">
<div v-for="(output, key) in selectedTask.outputs" :key="key" class="flex items-center justify-between bg-dark-light rounded-lg p-3">
<div class="flex items-center">
<i class="fas fa-file-video text-gradient-icon mr-2"></i>
<span class="text-sm">{{ output }}</span>
<span v-if="selectedTaskFiles.outputs[key] && selectedTaskFiles.outputs[key].error"
class="ml-2 text-xs text-red-400">
<i class="fas fa-exclamation-triangle"></i> 加载失败
</span>
</div>
<div class="flex space-x-2">
<button v-if="selectedTaskFiles.outputs[key] && selectedTaskFiles.outputs[key].url"
@click="downloadFile(selectedTaskFiles.outputs[key])"
class="text-xs btn-primary px-2 py-1 rounded">
<i class="fas fa-download mr-1"></i>下载
</button>
<div v-else-if="!selectedTaskFiles.outputs[key]"
class="text-xs text-gray-400 px-2 py-1">
<i class="fas fa-spinner fa-spin mr-1"></i>加载中
</div>
</div>
</div>
</div>
<!-- 视频预览 -->
<div class="w-full">
<div class="relative w-full h-[400px] flex items-center justify-center rounded-xl bg-black border border-laser-purple/50 overflow-hidden">
<div data-test-id="dragable-content" data-is-dragging="false" draggable="true" class="w-full h-full flex items-center justify-center">
<div class="w-full flexitems-center justify-center">
<video
v-if="selectedTaskFiles.outputs.output_video && selectedTaskFiles.outputs.output_video.url"
class="object-contain bg-black"
style="object-fit: contain; height: 400px; visibility: visible;"
controls
preload="metadata"
:src="selectedTaskFiles.outputs.output_video.url">
您的浏览器不支持视频播放。
</video>
<div v-else-if="selectedTaskFiles.outputs.output_video && selectedTaskFiles.outputs.output_video.error"
class="w-full h-full flex items-center justify-center bg-red-900/20">
<div class="text-center">
<i class="fas fa-exclamation-triangle text-red-400 text-2xl mb-2"></i>
<p class="text-red-400 text-sm">视频加载失败</p>
</div>
</div>
<div v-else
class="w-full h-full flex items-center justify-center bg-gray-700">
<div class="text-center">
<i class="fas fa-spinner fa-spin text-gray-400 text-2xl mb-2"></i>
<p class="text-gray-400 text-sm">加载视频中...</p>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
<!-- 任务进行中信息 -->
<div v-if="['CREATED', 'PENDING', 'RUNNING'].includes(selectedTask.status) && !showCreator" class="max-w-4xl mx-auto task-running-panel">
<div class="text-center py-8">
<div class="w-24 h-24 mx-auto bg-laser-purple/20 rounded-full flex items-center justify-center mb-6 animate-pulse-slow">
<i class="fas fa-spinner fa-spin text-gradient-icon text-3xl"></i>
</div>
<h3 class="text-xl font-medium mb-2">视频生成中</h3>
<p class="text-gray-400 mb-6">AI正在努力生成您的视频,请稍候...</p>
</div>
</div>
<!-- 任务失败信息 -->
<div v-if="selectedTask && selectedTask.status === 'FAILED' && !showCreator" class="max-w-4xl mx-auto task-failed-panel">
<div class="text-center py-8">
<div class="w-24 h-24 mx-auto bg-red-500/10 rounded-full flex items-center justify-center mb-6">
<i class="fas fa-exclamation-triangle text-red-500 text-3xl"></i>
</div>
<h3 class="text-xl font-medium mb-2">视频生成失败</h3>
<p class="text-gray-400 mb-4 max-w-md mx-auto">
很抱歉,您的视频生成任务未能完成。这可能是由于资源限制或参数设置问题导致的。
</p>
</div>
</div>
<!-- 任务取消信息 -->
<div v-if="selectedTask && selectedTask.status === 'CANCEL' && !showCreator" class="max-w-4xl mx-auto task-cancelled-panel">
<div class="text-center py-8">
<div class="w-24 h-24 mx-auto bg-yellow-500/10 rounded-full flex items-center justify-center mb-6">
<i class="fas fa-ban text-yellow-500 text-3xl"></i>
</div>
<h3 class="text-xl font-medium mb-2">任务已取消</h3>
<p class="text-gray-400 mb-4 max-w-md mx-auto">
此任务已被取消,您可以重新生成或查看之前上传的素材。
</p>
</div>
</div>
<!-- 任务操作 -->
<div class="flex justify-center space-x-3 p-4">
<button v-if="['CREATED', 'PENDING', 'RUNNING'].includes(selectedTask.status)"
@click="cancelTask(selectedTask.task_id)"
class="px-4 py-2 btn-primary rounded-lg text-sm transition-all">
<i class="fas fa-times mr-2"></i>
取消任务
</button>
<button v-if="['SUCCEED', 'FAILED', 'CANCEL'].includes(selectedTask.status)"
@click="resumeTask(selectedTask.task_id)"
class="px-4 py-2 btn-primary rounded-lg text-sm transition-all">
<i class="fas fa-redo mr-2"></i>
重新生成
</button>
</div>
<!-- 任务状态显示 -->
<div class="bg-secondary/30 rounded-xl p-6 mb-6">
<h4 class="text-sm font-medium mb-3 flex items-center">
<i class="fas fa-info-circle text-gradient-icon mr-2"></i>
任务信息
</h4>
<ul class="space-y-2 text-sm">
<li class="flex justify-between">
<span class="text-gray-400">任务ID</span>
<span>{{ selectedTask.task_id }}</span>
</li>
<li class="flex justify-between">
<span class="text-gray-400">任务类型</span>
<span>{{ selectedTask.task_type }}</span>
</li>
<li class="flex justify-between">
<span class="text-gray-400">模型名称</span>
<span class="text-gradient-icon">{{ selectedTask.model_cls }}</span>
</li>
<li class="flex justify-between">
<span class="text-gray-400">创建时间</span>
<span>{{ formatTime(selectedTask.create_t) }}</span>
</li>
<li class="flex justify-between">
<span class="text-gray-400">状态</span>
<span :class="getStatusTextClass(selectedTask.status)">{{ selectedTask.status }}</span>
</li>
</ul>
</div>
<!-- 提示词 -->
<div class="bg-secondary/30 rounded-xl p-4 mb-6">
<h4 class="text-sm font-medium mb-3 flex items-center">
<i class="fas fa-file-alt text-gradient-icon mr-2"></i>
提示词
</h4>
<div class="bg-dark-light rounded-lg p-4 text-sm text-gray-300">
<p>{{ selectedTask.params.prompt || '无提示词' }}</p>
</div>
</div>
<div v-if="selectedTask.inputs && Object.keys(selectedTask.inputs).length" class="bg-secondary/30 rounded-xl p-4 mb-6">
<h4 class="text-sm font-medium mb-3 flex items-center">
<i class="fas fa-upload text-gradient-icon mr-2"></i>
上传素材
<span v-if="loadingTaskFiles" class="ml-2 text-xs text-gray-400">(加载中...)</span>
</h4>
<div class="space-y-3">
<template v-for="(input, key) in selectedTask.inputs" :key="key">
<div class="flex items-center gap-3">
<template v-if="key.includes('image')">
<i class="fas fa-image text-gradient-icon text-xl"></i>
<div class="flex items-center gap-2">
<span v-if="!selectedTaskFiles.inputs[key] || !selectedTaskFiles.inputs[key].url"
class="text-gray-400 text-sm">{{ typeof input === 'string' ? input : '图片' }}</span>
<div class="flex items-center gap-2 relative group">
<img v-if="selectedTaskFiles.inputs[key] && selectedTaskFiles.inputs[key].url"
:src="selectedTaskFiles.inputs[key].url"
:alt="input"
class="w-20 h-20 object-cover rounded bg-dark-light border border-gray-700">
<div v-else-if="selectedTaskFiles.inputs[key] && selectedTaskFiles.inputs[key].error"
class="w-20 h-20 rounded bg-red-900/20 border border-red-500/30 flex items-center justify-center">
<i class="fas fa-exclamation-triangle text-red-400"></i>
</div>
<div v-else
class="w-20 h-20 rounded bg-gray-700 border border-gray-600 flex items-center justify-center">
<i class="fas fa-spinner fa-spin text-gray-400"></i>
</div>
<button v-if="selectedTaskFiles.inputs[key] && selectedTaskFiles.inputs[key].url"
@click="downloadFile(selectedTaskFiles.inputs[key])"
class="text-xs px-2 py-1 rounded">
<i class="fas fa-download mr-1 text-white opacity-30 hover:opacity-100 transition-opacity"></i>
</button>
</div>
</div>
</template>
<template v-else-if="key.includes('audio')">
<i class="fas fa-microphone text-gradient-icon text-xl"></i>
<div class="flex items-center gap-2">
<span v-if="!selectedTaskFiles.inputs[key] || !selectedTaskFiles.inputs[key].url"
class="text-gray-400 text-sm">{{ typeof input === 'string' ? input : '音频文件' }}</span>
<div class="flex items-center gap-2">
<audio v-if="selectedTaskFiles.inputs[key] && selectedTaskFiles.inputs[key].url"
:src="selectedTaskFiles.inputs[key].url"
controls
class="h-8 text-gradient-icon">
您的浏览器不支持音频播放
</audio>
<div v-else-if="selectedTaskFiles.inputs[key] && selectedTaskFiles.inputs[key].error"
class="h-8 px-3 rounded bg-red-900/20 border border-red-500/30 flex items-center">
<i class="fas fa-exclamation-triangle text-red-400 text-xs"></i>
</div>
<div v-else
class="h-8 px-3 rounded bg-gray-700 border border-gray-600 flex items-center">
<i class="fas fa-spinner fa-spin text-gray-400 text-xs"></i>
</div>
<button v-if="selectedTaskFiles.inputs[key] && selectedTaskFiles.inputs[key].url"
@click="downloadFile(selectedTaskFiles.inputs[key])"
class="text-xs px-2 py-1 rounded">
<i class="fas fa-download mr-1 text-white opacity-30 hover:opacity-100 transition-opacity"></i>
</button>
</div>
</div>
</div>
</template>
</div>
</template>
</div>
</div>
</div>
</div>
</div>
<!-- 加载指示器 -->
<div v-if="loading" class="loading position-fixed top-50 start-50 translate-middle show">
<div class="spinner-border text-gradient-icon" role="status">
<span class="visually-hidden">加载中...</span>
</div>
</div>
<!-- 增强的提示消息系统 -->
<div v-cloak>
<div v-if="alert.show"
class="fixed top-3 left-1/2 transform -translate-x-1/2 z-50 max-w-[16rem] w-full px-1"
:class="getAlertClass(alert.type)">
<div class="alert flex items-center p-2 rounded-md shadow-md border-l-4 transition-all duration-300 ease-out"
:class="getAlertBorderClass(alert.type)">
<div class="flex-shrink-0 mr-1">
<i :class="getAlertIcon(alert.type)" class="text-base"></i>
</div>
<div class="flex-1">
<p class="text-xs font-medium" :class="getAlertTextClass(alert.type)">
{{ alert.message }}
</p>
</div>
<div class="flex-shrink-0 ml-1">
<button @click="alert.show = false"
class="text-gray-400 hover:text-gray-600 transition-colors">
<i class="fas fa-times text-xs"></i>
</button>
</div>
</div>
</div>
</div>
</div>
<script src="https://unpkg.com/vue@3/dist/vue.global.js"></script>
<script>
// 检测Font Awesome图标是否加载成功
function checkFontAwesome() {
const testIcon = document.createElement('i');
testIcon.className = 'fas fa-check';
testIcon.style.position = 'absolute';
testIcon.style.left = '-9999px';
document.body.appendChild(testIcon);
const computedStyle = window.getComputedStyle(testIcon, ':before');
const content = computedStyle.getPropertyValue('content');
document.body.removeChild(testIcon);
// 如果图标没有正确显示,使用备用方案
if (!content || content === 'none' || content === 'normal') {
console.warn('Font Awesome 图标加载失败,使用备用方案');
replaceIconsWithFallback();
}
}
// 使用备用图标替换失败的图标
function replaceIconsWithFallback() {
const iconMap = {
'fas fa-video': '🎥',
'fas fa-plus': '',
'fas fa-search': '🔍',
'fas fa-clock': '',
'fas fa-bars': '',
'fas fa-sign-out-alt': '🚪',
'fas fa-star': '',
'fas fa-cloud-upload-alt': '☁️',
'fas fa-microphone': '🎤',
'fas fa-magic': '',
'fas fa-history': '📚',
'fas fa-times': '✖️',
'fas fa-trash': '🗑️',
'fas fa-sync-alt': '🔄',
'fas fa-play': '▶️',
'fas fa-share-alt': '📤',
'fas fa-download': '⬇️',
'fas fa-info-circle': 'ℹ️',
'fas fa-file-alt': '📄',
'fas fa-file-video': '🎬',
'fas fa-eye': '👁️',
'fas fa-exclamation-triangle': '⚠️',
'fas fa-lightbulb': '💡',
'fas fa-check': '',
'fas fa-user': '👤',
'fas fa-image': '🖼️',
'fas fa-font': '🔤',
'fas fa-spinner': '🌀',
'fas fa-check-circle': '',
'fas fa-hourglass-half': '',
'fas fa-ban': '🚫',
'fas fa-question-circle': '',
'fas fa-times-circle': '',
'fas fa-music': '🎵',
'fas fa-tags': '🏷️',
'fas fa-chart-bar': '📊',
'fas fa-redo': '🔄',
'fas fa-pause': '⏸️'
};
// 替换所有图标
Object.keys(iconMap).forEach(iconClass => {
const icons = document.querySelectorAll(`.${iconClass.replace(/\s+/g, '.')}`);
icons.forEach(icon => {
icon.innerHTML = iconMap[iconClass];
icon.className = icon.className.replace(/fas fa-[a-z-]+/g, '');
});
});
}
// 页面加载完成后检查图标
document.addEventListener('DOMContentLoaded', function() {
setTimeout(checkFontAwesome, 1000); // 延迟1秒检查,确保CDN加载完成
});
</script>
<script>
const { createApp, ref, computed, onMounted, watch } = Vue;
createApp({
setup() {
// 响应式数据
const loading = ref(false);
const alert = ref({ show: false, message: '', type: 'info' });
const submitting = ref(false);
const showCreator = ref(true);
const searchQuery = ref('');
const generatingThumbnails = ref(false);
const sidebarCollapsed = ref(false);
const thumbnailCache = ref(new Map());
const thumbnailCacheLoaded = ref(false);
const imageTemplates = ref([]);
const audioTemplates = ref([]);
const showImageTemplates = ref(false);
const showAudioTemplates = ref(false);
const currentUser = ref({});
const models = ref([]);
const tasks = ref([]);
const isLoggedIn = ref(false);
const selectedTaskId = ref(null);
const selectedTask = ref(null);
const selectedTaskFiles = ref({ inputs: {}, outputs: {} }); // 存储任务的输入输出文件
const loadingTaskFiles = ref(false); // 加载任务文件的状态
const statusFilter = ref('ALL');
const pagination = ref(null);
const currentPage = ref(1);
const pageSize = ref(10);
// 为三个任务类型分别创建独立的表单
const t2vForm = ref({
task: 't2v',
model_cls: '',
stage: 'single_stage',
prompt: '',
seed: 42
});
const i2vForm = ref({
task: 'i2v',
model_cls: '',
stage: 'multi_stage',
imageFile: null,
prompt: '',
seed: 42
});
const digitalHumanForm = ref({
task: 'digital_human',
model_cls: '',
stage: 'single_stage',
imageFile: null,
audioFile: null,
prompt: '',
seed: 42
});
// 根据当前选择的任务类型获取对应的表单
const getCurrentForm = () => {
switch (selectedTaskId.value) {
case 't2v':
return t2vForm.value;
case 'i2v':
return i2vForm.value;
case 'digital_human':
return digitalHumanForm.value;
default:
return t2vForm.value;
}
};
// 为每个任务类型创建独立的预览变量
const i2vImagePreview = ref(null);
const digitalHumanImagePreview = ref(null);
const digitalHumanAudioPreview = ref(null);
// 根据当前任务类型获取对应的预览变量
const getCurrentImagePreview = () => {
switch (selectedTaskId.value) {
case 't2v':
return null;
case 'i2v':
return i2vImagePreview.value;
case 'digital_human':
return digitalHumanImagePreview.value;
default:
return null;
}
};
const getCurrentAudioPreview = () => {
switch (selectedTaskId.value) {
case 't2v':
return null
case 'i2v':
return null
case 'digital_human':
return digitalHumanAudioPreview.value;
default:
return null;
}
};
const setCurrentImagePreview = (value) => {
switch (selectedTaskId.value) {
case 't2v':
break;
case 'i2v':
i2vImagePreview.value = value;
break;
case 'digital_human':
digitalHumanImagePreview.value = value;
break;
}
};
const setCurrentAudioPreview = (value) => {
switch (selectedTaskId.value) {
case 't2v':
break;
case 'i2v':
break;
case 'digital_human':
digitalHumanAudioPreview.value = value;
break;
}
};
// 提示词模板相关
const showTemplates = ref(false);
const showHistory = ref(false);
// 计算属性
const availableTaskTypes = computed(() => {
const types = [...new Set(models.value.map(m => m.task))];
// 重新排序,确保数字人在最左边
const orderedTypes = [];
// 检查是否有包含audio或seko的i2v模型,如果有则添加digital_human类型
const hasDigitalHumanModels = models.value.some(m =>
m.task === 'i2v' && (m.model_cls.toLowerCase().includes('audio') || m.model_cls.toLowerCase().includes('seko'))
);
// 优先添加数字人(如果存在相关模型)
if (hasDigitalHumanModels) {
orderedTypes.push('digital_human');
}
// 然后添加其他类型
types.forEach(type => {
if (type !== 'digital_human') {
orderedTypes.push(type);
}
});
return orderedTypes;
});
const availableModelClasses = computed(() => {
if (!selectedTaskId.value) return [];
// 如果是数字人任务类型,显示包含audio或seko的i2v模型
if (selectedTaskId.value === 'digital_human') {
return [...new Set(models.value
.filter(m => m.task === 'i2v' && (m.model_cls.toLowerCase().includes('audio') || m.model_cls.toLowerCase().includes('seko')))
.map(m => m.model_cls))];
}
// 如果是i2v任务类型,剔除包含audio或seko的模型
if (selectedTaskId.value === 'i2v') {
return [...new Set(models.value
.filter(m => m.task === 'i2v' && !m.model_cls.toLowerCase().includes('audio') && !m.model_cls.toLowerCase().includes('seko'))
.map(m => m.model_cls))];
}
// 其他任务类型正常处理
return [...new Set(models.value
.filter(m => m.task === selectedTaskId.value)
.map(m => m.model_cls))];
});
const filteredTasks = computed(() => {
let filtered = tasks.value;
// 状态过滤
if (statusFilter.value !== 'ALL') {
filtered = filtered.filter(task => task.status === statusFilter.value);
}
// 搜索过滤
if (searchQuery.value) {
filtered = filtered.filter(task =>
task.params.prompt?.toLowerCase().includes(searchQuery.value.toLowerCase()) ||
task.task_id.toLowerCase().includes(searchQuery.value.toLowerCase()) ||
task.task_type.toLowerCase().includes(searchQuery.value.toLowerCase())
);
}
// 按时间排序,最新的任务在前面
filtered = filtered.sort((a, b) => {
const timeA = parseInt(a.create_t) || 0;
const timeB = parseInt(b.create_t) || 0;
return timeB - timeA; // 降序排列,最新的在前
});
return filtered;
});
// 方法
const showAlert = (message, type = 'info') => {
alert.value = { show: true, message, type };
setTimeout(() => {
alert.value.show = false;
}, 5000);
};
const setLoading = (value) => {
loading.value = value;
};
const apiCall = async (endpoint, options = {}) => {
const url = `${endpoint}`;
const headers = {
'Content-Type': 'application/json',
...options.headers
};
if (localStorage.getItem('accessToken')) {
headers['Authorization'] = `Bearer ${localStorage.getItem('accessToken')}`;
}
const response = await fetch(url, {
...options,
headers
});
if (response.status === 401) {
logout();
throw new Error('认证失败,请重新登录'); }
if (response.status === 400) {
const error = await response.json();
showAlert(error.message, 'danger');
throw new Error(error.message);
}
// 添加50ms延迟,防止触发服务端频率限制
await new Promise(resolve => setTimeout(resolve, 50));
return response;
};
const loginWithGitHub = async () => {
try {
setLoading(true);
const response = await fetch('./auth/login/github');
const data = await response.json();
window.location.href = data.auth_url;
} catch (error) {
showAlert('获取GitHub认证URL失败', 'danger');
} finally {
setLoading(false);
}
};
const handleGitHubCallback = async (code) => {
try {
setLoading(true);
const response = await fetch(`./auth/callback/github?code=${code}`);
if (response.ok) {
const data = await response.json();
console.log(data);
localStorage.setItem('accessToken', data.access_token);
localStorage.setItem('currentUser', JSON.stringify(data.user_info));
currentUser.value = data.user_info;
isLoggedIn.value = true;
} else {
const error = await response.json();
showAlert(`登录失败: ${error.detail}`, 'danger');
}
window.location.href = '/';
} catch (error) {
showAlert('登录过程中发生错误', 'danger');
console.error(error);
} finally {
setLoading(false);
}
};
const logout = () => {
localStorage.removeItem('accessToken');
localStorage.removeItem('currentUser');
currentUser.value = {};
isLoggedIn.value = false;
models.value = [];
tasks.value = [];
showAlert('已退出登录', 'info');
};
const loadModels = async () => {
try {
console.log('开始加载模型列表...');
const response = await apiRequest('./api/v1/model/list');
if (response && response.ok) {
const data = await response.json();
console.log('模型列表数据:', data);
models.value = data.models || [];
console.log('设置后的models.value:', models.value);
} else if (response) {
console.error('模型列表API响应失败:', response);
showAlert('加载模型列表失败', 'danger');
}
// 如果response为null,说明是认证错误,apiRequest已经处理了
} catch (error) {
console.error('加载模型失败:', error);
showAlert(`加载模型失败: ${error.message}`, 'danger');
}
};
// 加载模板文件
const loadTemplates = async () => {
try {
const response = await apiCall('./api/v1/template/list');
if (response.ok) {
const data = await response.json();
imageTemplates.value = data.templates.images || [];
audioTemplates.value = data.templates.audios || [];
} else {
console.warn('加载模板失败');
}
} catch (error) {
console.warn('加载模板失败:', error);
}
};
// 选择图片模板
const selectImageTemplate = async (template) => {
try {
const response = await fetch(template.url);
if (response.ok) {
const blob = await response.blob();
const file = new File([blob], template.filename, { type: blob.type });
if (selectedTaskId.value === 'i2v') {
i2vForm.value.imageFile = file;
} else if (selectedTaskId.value === 'digital_human') {
digitalHumanForm.value.imageFile = file;
}
// 创建预览
const reader = new FileReader();
reader.onload = (e) => {
setCurrentImagePreview(e.target.result);
};
reader.readAsDataURL(file);
showImageTemplates.value = false;
showAlert('图片模板已选择', 'success');
} else {
showAlert('加载图片模板失败', 'danger');
}
} catch (error) {
showAlert(`加载图片模板失败: ${error.message}`, 'danger');
}
};
// 选择音频模板
const selectAudioTemplate = async (template) => {
try {
const response = await fetch(template.url);
if (response.ok) {
const blob = await response.blob();
const file = new File([blob], template.filename, { type: blob.type });
digitalHumanForm.value.audioFile = file;
// 创建预览
const reader = new FileReader();
reader.onload = (e) => {
setCurrentAudioPreview(e.target.result);
};
reader.readAsDataURL(file);
showAudioTemplates.value = false;
showAlert('音频模板已选择', 'success');
} else {
showAlert('加载音频模板失败', 'danger');
}
} catch (error) {
showAlert(`加载音频模板失败: ${error.message}`, 'danger');
}
};
// 预览音频模板
const previewAudioTemplate = (template) => {
const audio = new Audio(template.url);
audio.play().catch(error => {
console.error('音频播放失败:', error);
showAlert('音频播放失败', 'danger');
});
};
const handleImageUpload = (event) => {
const file = event.target.files[0];
if (file) {
if (selectedTaskId.value === 'i2v') {
i2vForm.value.imageFile = file;
} else if (selectedTaskId.value === 'digital_human') {
digitalHumanForm.value.imageFile = file;
}
const reader = new FileReader();
reader.onload = (e) => {
setCurrentImagePreview(e.target.result);
};
reader.readAsDataURL(file);
} else {
// 用户取消了选择,保持原有图片不变
// 不做任何操作
}
};
const selectTask = (taskType) => {
selectedTaskId.value = taskType;
// 根据任务类型恢复对应的预览
if (taskType === 'i2v' && i2vForm.value.imageFile) {
// 恢复图片预览
const reader = new FileReader();
reader.onload = (e) => {
setCurrentImagePreview(e.target.result);
};
reader.readAsDataURL(i2vForm.value.imageFile);
} else if (taskType === 'digital_human') {
// 恢复数字人任务的图片和音频预览
if (digitalHumanForm.value.imageFile) {
const reader = new FileReader();
reader.onload = (e) => {
setCurrentImagePreview(e.target.result);
};
reader.readAsDataURL(digitalHumanForm.value.imageFile);
}
if (digitalHumanForm.value.audioFile) {
const reader = new FileReader();
reader.onload = (e) => {
setCurrentAudioPreview(e.target.result);
};
reader.readAsDataURL(digitalHumanForm.value.audioFile);
}
}
// 如果当前表单没有选择模型,自动选择第一个可用的模型
const currentForm = getCurrentForm();
if (!currentForm.model_cls) {
const availableModels = models.value.filter(m => m.task === taskType);
if (availableModels.length > 0) {
const firstModel = availableModels[0];
currentForm.model_cls = firstModel.model_cls;
currentForm.stage = firstModel.stage;
}
}
};
const selectModel = (model) => {
getCurrentForm().model_cls = model;
};
const triggerImageUpload = () => {
document.querySelector('input[type="file"][accept="image/*"]').click();
};
const triggerAudioUpload = () => {
document.querySelector('input[type="file"][accept="audio/*"]').click();
};
const removeImage = () => {
setCurrentImagePreview(null);
if (selectedTaskId.value === 'i2v') {
i2vForm.value.imageFile = null;
} else if (selectedTaskId.value === 'digital_human') {
digitalHumanForm.value.imageFile = null;
}
// 重置文件输入框,确保可以重新选择相同文件
const imageInput = document.querySelector('input[type="file"][accept="image/*"]');
if (imageInput) {
imageInput.value = '';
}
};
const removeAudio = () => {
setCurrentAudioPreview(null);
digitalHumanForm.value.audioFile = null;
console.log('音频已移除');
// 重置音频文件输入框,确保可以重新选择相同文件
const audioInput = document.querySelector('input[type="file"][accept="audio/*"]');
if (audioInput) {
audioInput.value = '';
}
};
const getAudioMimeType = () => {
if (digitalHumanForm.value.audioFile) {
console.log('音频文件类型:', digitalHumanForm.value.audioFile.type);
return digitalHumanForm.value.audioFile.type;
}
console.log('使用默认音频类型: audio/mpeg');
return 'audio/mpeg'; // 默认类型
};
const handleAudioUpload = (event) => {
const file = event.target.files[0];
if (file) {
digitalHumanForm.value.audioFile = file;
const reader = new FileReader();
reader.onload = (e) => {
setCurrentAudioPreview(e.target.result);
console.log('音频预览已设置:', e.target.result);
};
reader.readAsDataURL(file);
} else {
setCurrentAudioPreview(null);
}
};
const submitTask = async () => {
try {
const currentForm = getCurrentForm();
// 表单验证
if (!selectedTaskId.value) {
showAlert('请选择任务类型', 'warning');
return;
}
if (!currentForm.model_cls) {
showAlert('请选择模型', 'warning');
return;
}
if (!currentForm.prompt || currentForm.prompt.trim().length === 0) {
showAlert('请输入提示词', 'warning');
return;
}
if (currentForm.prompt.length > 500) {
showAlert('提示词长度不能超过500个字符', 'warning');
return;
}
if (selectedTaskId.value === 'i2v' && !currentForm.imageFile) {
showAlert('图生视频任务需要上传参考图片', 'warning');
return;
}
if (selectedTaskId.value === 'digital_human' && !currentForm.imageFile) {
showAlert('数字人任务需要上传参考图片', 'warning');
return;
}
if (selectedTaskId.value === 'digital_human' && !currentForm.audioFile) {
showAlert('数字人任务需要上传音频文件', 'warning');
return;
}
setLoading(true);
submitting.value = true;
// 确定实际提交的任务类型
let actualTaskType = selectedTaskId.value;
if (selectedTaskId.value === 'digital_human') {
actualTaskType = 'i2v'; // 数字人任务实际提交为i2v
}
var formData = {
task: actualTaskType,
model_cls: currentForm.model_cls,
stage: currentForm.stage,
prompt: currentForm.prompt.trim(),
seed: currentForm.seed || Math.floor(Math.random() * 1000000)
};
if (currentForm.model_cls.startsWith('wan2.1')) {
formData.negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
}
if (selectedTaskId.value === 'i2v' && currentForm.imageFile) {
const base64 = await fileToBase64(currentForm.imageFile);
formData.input_image = {
type: 'base64',
data: base64
};
}
if (selectedTaskId.value === 'digital_human') {
if (currentForm.imageFile) {
const base64 = await fileToBase64(currentForm.imageFile);
formData.input_image = {
type: 'base64',
data: base64
};
}
if (currentForm.audioFile) {
const base64 = await fileToBase64(currentForm.audioFile);
formData.input_audio = {
type: 'base64',
data: base64
};
formData.negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
}
}
const response = await apiRequest('./api/v1/task/submit', {
method: 'POST',
body: JSON.stringify(formData)
});
if (response && response.ok) {
const result = await response.json();
showAlert(`任务提交成功!任务ID: ${result.task_id}`, 'success');
// 保存提示词到历史记录
addPromptToHistory(currentForm.prompt);
await refreshTasks();
showCreator.value = true;
// 重新选择数字人任务(如果可用)
if (availableTaskTypes.value.includes('digital_human')) {
selectTask('digital_human');
}
// 重置所有表单
t2vForm.value = {
task: 't2v',
model_cls: '',
stage: 'single_stage',
prompt: '',
seed: Math.floor(Math.random() * 1000000)
};
i2vForm.value = {
task: 'i2v',
model_cls: '',
stage: 'multi_stage',
imageFile: null,
prompt: '',
seed: 42
};
digitalHumanForm.value = {
task: 'digital_human',
model_cls: '',
stage: 'single_stage',
imageFile: null,
audioFile: null,
prompt: '',
seed: Math.floor(Math.random() * 1000000)
};
// 重置所有预览
i2vImagePreview.value = null;
digitalHumanImagePreview.value = null;
digitalHumanAudioPreview.value = null;
} else {
const error = await response.json();
showAlert(`任务提交失败: ${error.message},${error.detail}`, 'danger');
}
} catch (error) {
showAlert(`提交任务失败: ${error.message}`, 'danger');
} finally {
submitting.value = false;
setLoading(false);
}
};
const fileToBase64 = (file) => {
return new Promise((resolve, reject) => {
const reader = new FileReader();
reader.readAsDataURL(file);
reader.onload = () => {
const base64 = reader.result.split(',')[1];
resolve(base64);
};
reader.onerror = error => reject(error);
});
};
const formatTime = (timestamp) => {
if (!timestamp) return '';
const date = new Date(timestamp * 1000);
return date.toLocaleString('zh-CN');
};
const preloadInputImages = async (tasks) => {
// 为所有任务预加载输入图片
for (const task of tasks) {
if (task.inputs) {
// 查找输入中的图片文件
const imageInputs = Object.keys(task.inputs).filter(key =>
key.includes('image') ||
task.inputs[key].toString().toLowerCase().match(/\.(jpg|jpeg|png|gif|bmp|webp)$/)
);
// 预加载第一个输入图片
if (imageInputs.length > 0) {
const firstImageKey = imageInputs[0];
try {
const imageUrl = getTaskInputUrl(task.task_id, firstImageKey);
// 创建Image对象预加载输入图片
const img = new Image();
img.src = imageUrl;
// 监听加载完成事件
img.onload = () => {
console.log(`Input image preloaded for task ${task.task_id}: ${firstImageKey}`);
};
img.onerror = () => {
console.warn(`Failed to preload input image for task ${task.task_id}: ${firstImageKey}`);
};
} catch (error) {
console.warn(`Failed to preload input image for task ${task.task_id}:`, error);
}
}
}
}
};
const preloadThumbnailCache = async (tasks) => {
setTimeout(async () => {
try {
const response = await apiRequest('./api/v1/task/thumbnails');
if (response && response.ok) {
const data = await response.json();
const thumbnails = data.thumbnails || {};
for (const [taskId, thumbnailUrl] of Object.entries(thumbnails)) {
thumbnailCache.value.set(taskId, thumbnailUrl);
}
thumbnailCacheLoaded.value = true;
console.log(`缩略图一次性加载完成,共缓存${Object.keys(thumbnails).length}个任务`);
} else {
// 如果API调用失败,回退到原来的逐个加载方式
await preloadThumbnailsIndividually(tasks.slice(0, 30));
}
} catch (error) {
// 如果API调用异常,回退到原来的逐个加载方式
await preloadThumbnailsIndividually(tasks.slice(0, 30));
}
}, 100); // 延迟100ms开始,让页面先渲染
};
const preloadThumbnailsIndividually = async (tasksToCache) => {
// 使用串行加载避免过快访问
for (let i = 0; i < tasksToCache.length; i++) {
const task = tasksToCache[i];
if (task.inputs) {
// 查找输入中的图片文件
const imageInputs = Object.keys(task.inputs).filter(key =>
key.includes('image') ||
task.inputs[key].toString().toLowerCase().match(/\.(jpg|jpeg|png|gif|bmp|webp)$/)
);
if (imageInputs.length > 0) {
const firstImageKey = imageInputs[0];
try {
const imageUrl = getTaskInputUrl(task.task_id, firstImageKey);
// 使用重试机制加载图片
const success = await loadImageWithRetry(task.task_id, imageUrl, 3);
if (success) {
thumbnailCache.value.set(task.task_id, imageUrl);
}
} catch (error) {
console.warn(`缩略图缓存错误 ${task.task_id}:`, error);
}
}
}
// 添加延迟避免过快访问
if (i < tasksToCache.length - 1) {
await new Promise(resolve => setTimeout(resolve, 200)); // 200ms延迟
}
}
thumbnailCacheLoaded.value = true;
};
const loadImageWithRetry = async (taskId, imageUrl, maxRetries = 3) => {
for (let attempt = 1; attempt <= maxRetries; attempt++) {
try {
const success = await new Promise((resolve) => {
const img = new Image();
img.onload = () => resolve(true);
img.onerror = () => resolve(false);
// 设置超时
const timeout = setTimeout(() => {
resolve(false);
}, 10000); // 10秒超时
img.onload = () => {
clearTimeout(timeout);
resolve(true);
};
img.onerror = () => {
clearTimeout(timeout);
resolve(false);
};
img.src = imageUrl;
});
if (success) {
return true;
} else if (attempt < maxRetries) {
await new Promise(resolve => setTimeout(resolve, 3000 * attempt));
}
} catch (error) {
if (attempt < maxRetries) {
await new Promise(resolve => setTimeout(resolve, 3000 * attempt));
}
}
}
return false;
};
const refreshTasks = async () => {
try {
const params = new URLSearchParams({
page: currentPage.value.toString(),
page_size: pageSize.value.toString()
});
if (statusFilter.value !== 'ALL') {
params.append('status', statusFilter.value);
}
const response = await apiRequest(`./api/v1/task/list?${params.toString()}`);
if (response && response.ok) {
const data = await response.json();
tasks.value = data.tasks || [];
pagination.value = data.pagination || null;
if (!thumbnailCacheLoaded.value) {
await preloadThumbnailCache(tasks.value);
} else {
await preloadInputImages(tasks.value);
}
} else if (response) {
showAlert('刷新任务列表失败', 'danger');
}
// 如果response为null,说明是认证错误,apiRequest已经处理了
} catch (error) {
showAlert(`刷新任务列表失败: ${error.message}`, 'danger');
}
};
const getStatusBadgeClass = (status) => {
const statusMap = {
'SUCCEED': 'bg-success',
'FAILED': 'bg-danger',
'RUNNING': 'bg-warning',
'PENDING': 'bg-secondary',
'CREATED': 'bg-secondary'
};
return statusMap[status] || 'bg-secondary';
};
const downloadSingleResult = async (taskId, key, outputPath) => {
try {
setLoading(true);
const response = await apiCall(`./api/v1/task/result?task_id=${taskId}&name=${key}`);
if (response.ok) {
const blob = await response.blob();
const url = window.URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
// 使用原始文件名,如果没有则使用outputPath
const filename = key || outputPath || `result_${taskId}`;
a.download = filename;
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
window.URL.revokeObjectURL(url);
showAlert('文件下载成功', 'success');
} else {
showAlert('获取结果失败', 'danger');
}
} catch (error) {
showAlert(`下载结果失败: ${error.message}`, 'danger');
} finally {
setLoading(false);
}
};
const viewSingleResult = async (taskId, key) => {
try {
setLoading(true);
const response = await apiCall(`./api/v1/task/result?task_id=${taskId}&name=${key}`);
if (response.ok) {
const blob = await response.blob();
const videoBlob = new Blob([blob], { type: 'video/mp4' });
const url = window.URL.createObjectURL(videoBlob);
window.open(url, '_blank');
} else {
showAlert('获取结果失败', 'danger');
}
} catch (error) {
showAlert(`查看结果失败: ${error.message}`, 'danger');
} finally {
setLoading(false);
}
};
const getVideoUrl = (taskId, key) => {
const token = localStorage.getItem('accessToken');
if (token) {
return `./api/v1/task/result?task_id=${taskId}&name=${key}&token=${encodeURIComponent(token)}`;
}
return `./api/v1/task/result?task_id=${taskId}&name=${key}`;
};
const cancelTask = async (taskId) => {
try {
const response = await apiRequest(`./api/v1/task/cancel?task_id=${taskId}`);
if (response && response.ok) {
showAlert('任务取消成功', 'success');
await refreshTasks();
} else if (response) {
const error = await response.json();
showAlert(`取消任务失败: ${error.message}`, 'danger');
}
// 如果response为null,说明是认证错误,apiRequest已经处理了
} catch (error) {
showAlert(`取消任务失败: ${error.message}`, 'danger');
}
};
const resumeTask = async (taskId) => {
try {
const response = await apiRequest(`./api/v1/task/resume?task_id=${taskId}`);
if (response && response.ok) {
showAlert('任务重试成功', 'success');
await refreshTasks();
} else if (response) {
const error = await response.json();
showAlert(`重试任务失败: ${error.message}`, 'danger');
}
// 如果response为null,说明是认证错误,apiRequest已经处理了
} catch (error) {
showAlert(`重试任务失败: ${error.message}`, 'danger');
}
};
const loadTaskFiles = async (taskId) => {
try {
loadingTaskFiles.value = true;
// 通过API获取任务详情
const response = await apiRequest(`./api/v1/task/query?task_id=${taskId}`);
if (!response || !response.ok) {
showAlert('获取任务详情失败', 'danger');
return;
}
const task = await response.json();
if (!task) {
showAlert('任务不存在', 'danger');
return;
}
const files = { inputs: {}, outputs: {} };
// 获取输入文件(所有状态的任务都需要)
if (task.inputs) {
for (const [key, inputPath] of Object.entries(task.inputs)) {
try {
const response = await apiRequest(`./api/v1/task/input?task_id=${taskId}&name=${key}`);
if (response && response.ok) {
const blob = await response.blob();
files.inputs[key] = {
name: inputPath, // 使用原始文件名而不是key
path: inputPath,
blob: blob,
url: URL.createObjectURL(blob)
};
}
} catch (error) {
console.error(`Failed to load input ${key}:`, error);
files.inputs[key] = {
name: inputPath, // 使用原始文件名而不是key
path: inputPath,
error: true
};
}
}
}
// 只对成功完成的任务获取输出文件
if (task.status === 'SUCCEED' && task.outputs) {
for (const [key, outputPath] of Object.entries(task.outputs)) {
try {
const response = await apiRequest(`./api/v1/task/result?task_id=${taskId}&name=${key}`);
if (response && response.ok) {
const blob = await response.blob();
files.outputs[key] = {
name: outputPath, // 使用原始文件名而不是key
path: outputPath,
blob: blob,
url: URL.createObjectURL(blob)
};
}
} catch (error) {
console.error(`Failed to load output ${key}:`, error);
files.outputs[key] = {
name: outputPath, // 使用原始文件名而不是key
path: outputPath,
error: true
};
}
}
}
selectedTaskFiles.value = files;
} catch (error) {
console.error('Failed to load task files:', error);
showAlert('加载任务文件失败', 'danger');
} finally {
loadingTaskFiles.value = false;
}
};
const viewTaskDetail = async (task) => {
// 清理之前的文件缓存
clearTaskFiles();
selectedTask.value = task;
selectedTaskId.value = task.task_type;
showCreator.value = false;
// 一次性加载所有任务文件
await loadTaskFiles(task.task_id);
};
const downloadFile = (fileInfo) => {
if (!fileInfo || !fileInfo.blob) {
showAlert('文件不可用', 'danger');
return;
}
try {
const url = URL.createObjectURL(fileInfo.blob);
const a = document.createElement('a');
a.href = url;
a.download = fileInfo.name || 'download';
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
} catch (error) {
console.error('Download failed:', error);
showAlert('下载失败', 'danger');
}
};
const viewFile = (fileInfo) => {
if (!fileInfo || !fileInfo.url) {
showAlert('文件不可用', 'danger');
return;
}
// 在新窗口中打开文件
window.open(fileInfo.url, '_blank');
};
const clearTaskFiles = () => {
// 清理 URL 对象,释放内存
Object.values(selectedTaskFiles.value.inputs).forEach(file => {
if (file.url) {
URL.revokeObjectURL(file.url);
}
});
Object.values(selectedTaskFiles.value.outputs).forEach(file => {
if (file.url) {
URL.revokeObjectURL(file.url);
}
});
selectedTaskFiles.value = { inputs: {}, outputs: {} };
};
const showTaskCreator = () => {
showCreator.value = true;
selectedTask.value = null;
// clearTaskFiles(); // 清空文件缓存
selectedTaskId.value = 'digital_human'; // 默认选择数字人任务
};
const toggleSidebar = () => {
sidebarCollapsed.value = !sidebarCollapsed.value;
};
const clearPrompt = () => {
getCurrentForm().prompt = '';
};
const getTaskItemClass = (status) => {
if (status === 'SUCCEED') return 'bg-laser-purple/15 border border-laser-purple/30';
if (status === 'RUNNING') return 'bg-laser-purple/15 border border-laser-purple/30';
if (status === 'FAILED') return 'bg-red-500/15 border border-red-500/30';
return 'bg-dark-light border border-gray-700';
};
const getStatusIndicatorClass = (status) => {
const base = 'inline-block w-2 aspect-square rounded-full shrink-0 align-middle';
if (status === 'SUCCEED')
return `${base} bg-gradient-to-r from-emerald-200 to-green-300 shadow-md shadow-emerald-300/30`;
if (status === 'RUNNING')
return `${base} bg-gradient-to-r from-amber-200 to-yellow-300 shadow-md shadow-amber-300/30 animate-pulse`;
if (status === 'FAILED')
return `${base} bg-gradient-to-r from-red-200 to-pink-300 shadow-md shadow-red-300/30`;
return `${base} bg-gradient-to-r from-gray-200 to-gray-300 shadow-md shadow-gray-300/30`;
};
const getTaskTypeBtnClass = (taskType) => {
if (selectedTaskId.value === taskType) {
return 'text-gradient-icon border-b-2 border-laser-purple';
}
return 'text-gray-400 hover:text-gradient-icon';
};
const getModelBtnClass = (model) => {
if (getCurrentForm().model_cls === model) {
return 'bg-laser-purple/20 border border-laser-purple/40 active shadow-laser animate-electric-pulse';
}
return 'bg-dark-light border border-gray-700 hover:bg-laser-purple/15 hover:border-laser-purple/40 transition-all hover:shadow-laser';
};
const getTaskTypeIcon = (taskType) => {
const iconMap = {
't2v': 'fas fa-font',
'i2v': 'fas fa-image',
'digital_human': 'fas fa-user'
};
return iconMap[taskType] || 'fas fa-video';
};
const getTaskTypeName = (task) => {
// 如果传入的是字符串,直接返回映射
if (typeof task === 'string') {
const nameMap = {
't2v': '文生视频',
'i2v': '图生视频',
'digital_human': '数字人'
};
return nameMap[task] || task;
}
// 如果传入的是任务对象,根据模型类型判断
if (task && task.model_cls) {
const modelCls = task.model_cls.toLowerCase();
// 检查是否是数字人模型(包含audio或seko)
if (modelCls.includes('audio') || modelCls.includes('seko')) {
return '数字人';
}
// 根据task_type判断
const nameMap = {
't2v': '文生视频',
'i2v': '图生视频',
'digital_human': '数字人'
};
return nameMap[task.task_type] || task.task_type;
}
// 默认返回task_type
return task.task_type || '未知';
};
const getPromptPlaceholder = () => {
if (selectedTaskId.value === 't2v') {
return '请输入视频生成提示词,描述视频内容、风格、场景等...';
} else if (selectedTaskId.value === 'i2v') {
return '请输入视频生成提示词,描述基于图片的视频内容、动作要求等...';
} else if (selectedTaskId.value === 'digital_human') {
return '请输入视频生成提示词,描述数字人形象、背景风格、动作要求等...';
}
return '请输入视频生成提示词...';
};
const getStatusTextClass = (status) => {
if (status === 'SUCCEED') return 'text-emerald-400';
if (status === 'RUNNING') return 'text-amber-400';
if (status === 'FAILED') return 'text-red-400';
return 'text-gray-400';
};
const getImagePreview = (base64Data) => {
if (!base64Data) return '';
return `data:image/jpeg;base64,${base64Data}`;
};
const getTaskInputUrl = (taskId, key) => {
const token = localStorage.getItem('accessToken');
if (token) {
return `./api/v1/task/input?task_id=${taskId}&name=${key}&token=${encodeURIComponent(token)}`;
}
return `./api/v1/task/input?task_id=${taskId}&name=${key}`;
};
const getTaskInputImage = (task) => {
if (!task || !task.inputs) return null;
const imageInputs = Object.keys(task.inputs).filter(key =>
key.includes('image') ||
task.inputs[key].toString().toLowerCase().match(/\.(jpg|jpeg|png|gif|bmp|webp)$/)
);
if (imageInputs.length > 0) {
const firstImageKey = imageInputs[0];
return getTaskInputUrl(task.task_id, firstImageKey);
}
return null;
};
const getVideoThumbnail = (taskId, name) => {
if (thumbnailCache.value.has(taskId)) {
return thumbnailCache.value.get(taskId);
}
const task = tasks.value.find(t => t.task_id === taskId);
if (task) {
const inputImageUrl = getTaskInputImage(task);
if (inputImageUrl) {
return inputImageUrl;
} else {
console.log(`任务 ${taskId} 没有输入图片`);
}
} else {
console.log(`未找到任务: ${taskId}`);
}
// 如果没有输入图片,返回空字符串,让handleThumbnailError处理
return '';
};
const getVideoThumbnailInfo = (taskId, name) => {
// 首先检查缓存
if (thumbnailCache.value.has(taskId)) {
return {
url: thumbnailCache.value.get(taskId),
hasThumbnail: true
};
}
// 如果缓存中没有,异步加载并更新缓存
loadThumbnailAsync(taskId, name);
// 返回空,让模板显示默认图标
return {
url: '',
hasThumbnail: false
};
};
const loadThumbnailAsync = async (taskId, name) => {
const task = tasks.value.find(t => t.task_id === taskId);
if (!task || !task.inputs) return;
// 查找输入中的图片文件
const imageInputs = Object.keys(task.inputs).filter(key =>
key.includes('image') ||
task.inputs[key].toString().toLowerCase().match(/\.(jpg|jpeg|png|gif|bmp|webp)$/)
);
if (imageInputs.length > 0) {
const firstImageKey = imageInputs[0];
try {
const imageUrl = getTaskInputUrl(taskId, firstImageKey);
// 使用重试机制加载图片
const success = await loadImageWithRetry(taskId, imageUrl, 3);
if (success) {
thumbnailCache.value.set(taskId, imageUrl);
}
} catch (error) {
console.warn(`缩略图异步加载错误 ${taskId}:`, error);
}
}
};
const handleThumbnailError = (event) => {
// 当输入图片加载失败时,显示默认图标
const img = event.target;
const parent = img.parentElement;
parent.innerHTML = '<div class="w-full h-full bg-laser-purple/20 flex items-center justify-center"><i class="fas fa-video text-gradient-icon text-xl"></i></div>';
};
const handleImageError = (event) => {
// 当图片加载失败时,隐藏图片,显示文件名
const img = event.target;
img.style.display = 'none';
// 文件名已经显示,不需要额外处理
};
const handleImageLoad = (event) => {
// 当图片加载成功时,显示图片和下载按钮,隐藏文件名
const img = event.target;
img.style.display = 'block';
// 显示下载按钮
const downloadBtn = img.parentElement.querySelector('button');
if (downloadBtn) {
downloadBtn.style.display = 'block';
}
// 隐藏文件名span
const span = img.parentElement.parentElement.querySelector('span');
if (span) {
span.style.display = 'none';
}
};
const handleAudioError = (event) => {
// 当音频加载失败时,隐藏音频控件和下载按钮,显示文件名
const audio = event.target;
audio.style.display = 'none';
// 隐藏下载按钮
const downloadBtn = audio.parentElement.querySelector('button');
if (downloadBtn) {
downloadBtn.style.display = 'none';
}
// 文件名已经显示,不需要额外处理
};
const handleAudioLoad = (event) => {
// 当音频加载成功时,显示音频控件和下载按钮,隐藏文件名
const audio = event.target;
audio.style.display = 'block';
// 显示下载按钮
const downloadBtn = audio.parentElement.querySelector('button');
if (downloadBtn) {
downloadBtn.style.display = 'block';
}
// 隐藏文件名span
const span = audio.parentElement.parentElement.querySelector('span');
if (span) {
span.style.display = 'none';
}
};
const downloadTaskInput = async (taskId, inputName, fileName) => {
try {
const url = getTaskInputUrl(taskId, inputName);
const response = await apiRequest(url);
if (!response || !response.ok) {
throw new Error(`下载失败: ${response ? response.status : '认证失败'}`);
}
const blob = await response.blob();
const downloadUrl = window.URL.createObjectURL(blob);
// 创建下载链接
const link = document.createElement('a');
link.href = downloadUrl;
// 使用原始文件名,如果没有则使用inputName
const filename = inputName || fileName || `input_${taskId}`;
link.download = filename;
document.body.appendChild(link);
link.click();
// 清理
document.body.removeChild(link);
window.URL.revokeObjectURL(downloadUrl);
showAlert('文件下载成功', 'success');
} catch (error) {
console.error('下载失败:', error);
showAlert(`下载失败: ${error.message}`, 'danger');
}
};
const initModelAndTasks = async () => {
await loadModels();
await refreshTasks();
};
// 任务状态管理
const getTaskStatusDisplay = (status) => {
const statusMap = {
'CREATED': '创建',
'PENDING': '等待',
'RUNNING': '进行',
'SUCCEED': '完成',
'FAILED': '失败',
'CANCEL': '取消'
};
return statusMap[status] || status;
};
const getTaskStatusColor = (status) => {
const colorMap = {
'CREATED': 'text-blue-400',
'PENDING': 'text-yellow-400',
'RUNNING': 'text-amber-400',
'SUCCEED': 'text-emerald-400',
'FAILED': 'text-red-400',
'CANCEL': 'text-gray-400'
};
return colorMap[status] || 'text-gray-400';
};
const getTaskStatusIcon = (status) => {
const iconMap = {
'CREATED': 'fas fa-clock',
'PENDING': 'fas fa-hourglass-half',
'RUNNING': 'fas fa-spinner fa-spin',
'SUCCEED': 'fas fa-check-circle',
'FAILED': 'fas fa-exclamation-triangle',
'CANCEL': 'fas fa-ban'
};
return iconMap[status] || 'fas fa-question-circle';
};
// 任务时间格式化
const getTaskDuration = (startTime, endTime) => {
if (!startTime || !endTime) return '未知';
const start = new Date(startTime * 1000);
const end = new Date(endTime * 1000);
const diff = end - start;
const minutes = Math.floor(diff / 60000);
const seconds = Math.floor((diff % 60000) / 1000);
return `${minutes}${seconds}秒`;
};
// 相对时间格式化
const getRelativeTime = (timestamp) => {
if (!timestamp) return '未知';
const now = new Date();
const time = new Date(timestamp * 1000);
const diff = now - time;
const minutes = Math.floor(diff / 60000);
const hours = Math.floor(diff / 3600000);
const days = Math.floor(diff / 86400000);
const months = Math.floor(diff / 2592000000); // 30天
const years = Math.floor(diff / 31536000000);
if (years > 0) {
return years === 1 ? '一年前' : `${years}年前`;
} else if (months > 0) {
return months === 1 ? '一个月前' : `${months}个月前`;
} else if (days > 0) {
return days === 1 ? '一天前' : `${days}天前`;
} else if (hours > 0) {
return hours === 1 ? '一小时前' : `${hours}小时前`;
} else if (minutes > 0) {
return minutes === 1 ? '一分钟前' : `${minutes}分钟前`;
} else {
return '刚刚';
}
};
// 任务历史记录管理
const getTaskHistory = () => {
return tasks.value.filter(task =>
['SUCCEED', 'FAILED', 'CANCEL'].includes(task.status)
);
};
const getActiveTasks = () => {
return tasks.value.filter(task =>
['CREATED', 'PENDING', 'RUNNING'].includes(task.status)
);
};
// 任务搜索和过滤增强
const searchTasks = (query) => {
if (!query) return tasks.value;
return tasks.value.filter(task => {
const searchText = [
task.task_id,
task.task_type,
task.model_cls,
task.params?.prompt || '',
getTaskStatusDisplay(task.status)
].join(' ').toLowerCase();
return searchText.includes(query.toLowerCase());
});
};
const filterTasksByStatus = (status) => {
if (status === 'ALL') return tasks.value;
return tasks.value.filter(task => task.status === status);
};
const filterTasksByType = (type) => {
if (!type) return tasks.value;
return tasks.value.filter(task => task.task_type === type);
};
// 提示消息样式管理
const getAlertClass = (type) => {
const classMap = {
'success': 'animate-slide-down',
'warning': 'animate-slide-down',
'danger': 'animate-slide-down',
'info': 'animate-slide-down'
};
return classMap[type] || 'animate-slide-down';
};
const getAlertBorderClass = (type) => {
const borderMap = {
'success': 'border-green-500',
'warning': 'border-yellow-500',
'danger': 'border-red-500',
'info': 'border-blue-500'
};
return borderMap[type] || 'border-gray-500';
};
const getAlertTextClass = (type) => {
// 字体为灰色偏白色
const textMap = {
'success': 'text-gray-100 bg-white-500',
'warning': 'text-gray-100 bg-white-500',
'danger': 'text-gray-100 bg-white-500',
'info': 'text-gray-100 bg-white-500'
};
return textMap[type] || 'text-gray-100 bg-white-500';
};
const getAlertIcon = (type) => {
const iconMap = {
'success': 'fas fa-check-circle text-green-400',
'warning': 'fas fa-exclamation-triangle text-yellow-400',
'danger': 'fas fa-times-circle text-red-400',
'info': 'fas fa-info-circle text-blue-400'
};
return iconMap[type] || 'fas fa-info-circle text-gray-400';
};
// 监听器 - 监听任务类型变化
watch(() => selectedTaskId.value, () => {
const currentForm = getCurrentForm();
// 只有当当前表单没有选择模型时,才自动选择第一个可用的模型
if (!currentForm.model_cls) {
let availableModels;
// 如果是数字人任务,从i2v模型中筛选包含audio或seko的模型
if (selectedTaskId.value === 'digital_human') {
availableModels = models.value.filter(m =>
m.task === 'i2v' && (m.model_cls.toLowerCase().includes('audio') || m.model_cls.toLowerCase().includes('seko'))
);
} else if (selectedTaskId.value === 'i2v') {
// 如果是i2v任务,排除包含audio或seko的模型
availableModels = models.value.filter(m =>
m.task === 'i2v' && !m.model_cls.toLowerCase().includes('audio') && !m.model_cls.toLowerCase().includes('seko')
);
} else {
availableModels = models.value.filter(m => m.task === selectedTaskId.value);
}
if (availableModels.length > 0) {
const firstModel = availableModels[0];
currentForm.model_cls = firstModel.model_cls;
currentForm.stage = firstModel.stage;
}
}
// 注意:这里不需要重置预览,因为我们要保持每个任务的独立性
// 预览会在 selectTask 函数中根据文件状态恢复
});
watch(() => getCurrentForm().model_cls, () => {
const currentForm = getCurrentForm();
// 只有当当前表单没有选择阶段时,才自动选择第一个可用的阶段
if (!currentForm.stage) {
let availableStages;
// 如果是数字人任务,从i2v模型中筛选
if (selectedTaskId.value === 'digital_human') {
availableStages = models.value
.filter(m => m.task === 'i2v' && m.model_cls === currentForm.model_cls)
.map(m => m.stage);
} else if (selectedTaskId.value === 'i2v') {
// 如果是i2v任务,排除包含audio或seko的模型
availableStages = models.value
.filter(m => m.task === 'i2v' && m.model_cls === currentForm.model_cls && !m.model_cls.toLowerCase().includes('audio') && !m.model_cls.toLowerCase().includes('seko'))
.map(m => m.stage);
} else {
availableStages = models.value
.filter(m => m.task === selectedTaskId.value && m.model_cls === currentForm.model_cls)
.map(m => m.stage);
}
if (availableStages.length > 0) {
currentForm.stage = availableStages[0];
}
}
});
// 生命周期
onMounted(async () => {
// 检查是否已登录
const savedToken = localStorage.getItem('accessToken');
const savedUser = localStorage.getItem('currentUser');
if (savedToken && savedUser) {
// 验证token是否仍然有效
const isValid = await validateToken(savedToken);
if (isValid) {
currentUser.value = JSON.parse(savedUser);
isLoggedIn.value = true;
} else {
// Token无效,清除本地存储
logout();
showAlert('登录已过期,请重新登录', 'warning');
}
} else {
// 检查是否是GitHub回调
const urlParams = new URLSearchParams(window.location.search);
const code = urlParams.get('code');
if (code) {
handleGitHubCallback(code);
}
}
// 无论是否登录都要加载模型数据
await initModelAndTasks();
loadPromptHistory();
loadTemplates();
// 等待模型数据加载完成后再检查任务类型
if (availableTaskTypes.value.includes('digital_human')) {
selectTask('digital_human');
}
console.log('当前用户:', currentUser.value);
console.log('可用模型:', models.value);
console.log('任务列表:', tasks.value);
});
// 提示词模板管理
const promptTemplates = {
'digital_human': [
{
id: 'dh_1',
title: '商务演讲',
prompt: '数字人进行商务演讲,表情自然,手势得体,背景为现代化的会议室,整体风格专业商务。'
},
{
id: 'dh_2',
title: '产品介绍',
prompt: '数字人介绍产品特点,语气亲切,动作自然,背景为产品展示区,突出产品的科技感和实用性。'
}
],
't2v': [
{
id: 't2v_1',
title: '自然风景',
prompt: '一个宁静的山谷,阳光透过云层洒在绿色的草地上,远处有雪山,近处有清澈的溪流,画面温暖自然,充满生机。'
},
{
id: 't2v_2',
title: '城市夜景',
prompt: '繁华的城市夜景,霓虹灯闪烁,高楼大厦林立,车流如织,天空中有星星点缀,营造出都市的繁华氛围。'
},
{
id: 't2v_3',
title: '科技未来',
prompt: '未来科技城市,飞行汽车穿梭,全息投影随处可见,建筑具有流线型设计,充满科技感和未来感。'
}
],
'i2v': [
{
id: 'i2v_1',
title: '人物动作',
prompt: '基于参考图片,让角色做出自然的行走动作,保持原有的服装和风格,背景可以适当变化。'
},
{
id: 'i2v_2',
title: '场景转换',
prompt: '保持参考图片中的人物形象,将背景转换为不同的季节或环境,如从室内到户外,从白天到夜晚。'
}
]
};
const getPromptTemplates = (taskType) => {
return promptTemplates[taskType] || [];
};
const showPromptTemplates = () => {
if (!selectedTaskId.value) {
showAlert('请先选择任务类型', 'warning');
return;
}
showTemplates.value = !showTemplates.value;
};
const showPromptHistory = () => {
showHistory.value = !showHistory.value;
};
const selectPromptTemplate = (template) => {
getCurrentForm().prompt = template.prompt;
showTemplates.value = false;
showAlert(`已应用模板: ${template.title}`, 'success');
};
// 提示词历史记录管理
const promptHistory = ref([]);
const getPromptHistory = () => {
return promptHistory.value.slice(-10); // 只显示最近10条
};
const addPromptToHistory = (prompt) => {
if (!prompt || prompt.trim().length === 0) return;
// 避免重复添加
const trimmedPrompt = prompt.trim();
if (promptHistory.value.includes(trimmedPrompt)) {
// 将已存在的提示词移到最前面
promptHistory.value = promptHistory.value.filter(p => p !== trimmedPrompt);
}
promptHistory.value.push(trimmedPrompt);
// 限制历史记录数量
if (promptHistory.value.length > 50) {
promptHistory.value = promptHistory.value.slice(-50);
}
// 保存到本地存储
localStorage.setItem('promptHistory', JSON.stringify(promptHistory.value));
};
const selectPromptHistory = (prompt) => {
getCurrentForm().prompt = prompt;
showHistory.value = false;
showAlert('已应用历史提示词', 'success');
};
const clearPromptHistory = () => {
promptHistory.value = [];
localStorage.removeItem('promptHistory');
showAlert('提示词历史已清空', 'info');
};
// 加载提示词历史记录
const loadPromptHistory = () => {
try {
const saved = localStorage.getItem('promptHistory');
if (saved) {
promptHistory.value = JSON.parse(saved);
}
} catch (error) {
console.warn('加载提示词历史记录失败:', error);
}
};
const getAuthHeaders = () => {
const headers = {
'Content-Type': 'application/json'
};
const token = localStorage.getItem('accessToken');
if (token) {
headers['Authorization'] = `Bearer ${token}`;
console.log('使用Token进行认证:', token.substring(0, 20) + '...');
} else {
console.warn('没有找到accessToken');
}
return headers;
};
// 验证token是否有效
const validateToken = async (token) => {
try {
const response = await fetch('./api/v1/model/list', {
method: 'GET',
headers: {
'Authorization': `Bearer ${token}`,
'Content-Type': 'application/json'
}
});
return response.ok;
} catch (error) {
console.error('Token validation failed:', error);
return false;
}
};
// 增强的API请求函数,自动处理认证错误
const apiRequest = async (url, options = {}) => {
const headers = getAuthHeaders();
try {
const response = await fetch(url, {
...options,
headers: {
...headers,
...options.headers
}
});
// 检查是否是认证错误
if (response.status === 401 || response.status === 403) {
// Token无效,清除本地存储并跳转到登录页
logout();
showAlert('登录已过期,请重新登录', 'warning');
return null;
}
return response;
} catch (error) {
console.error('API request failed:', error);
showAlert('网络请求失败', 'danger');
return null;
}
};
// 侧边栏拖拽调整功能
const sidebar = ref(null);
let isResizing = false;
let startX = 0;
let startWidth = 0;
const startResize = (e) => {
e.preventDefault();
// 在小屏幕时禁用拖拽调整
const windowWidth = window.innerWidth;
if (windowWidth <= 1200) {
return;
}
isResizing = true;
startX = e.clientX;
startWidth = sidebar.value.offsetWidth;
document.body.classList.add('resizing');
document.addEventListener('mousemove', handleResize);
document.addEventListener('mouseup', stopResize);
};
const handleResize = (e) => {
if (!isResizing) return;
// 在小屏幕时停止拖拽调整
const windowWidth = window.innerWidth;
if (windowWidth <= 1200) {
stopResize();
return;
}
const deltaX = e.clientX - startX;
const newWidth = startWidth + deltaX;
const minWidth = 200;
const maxWidth = 500;
if (newWidth >= minWidth && newWidth <= maxWidth) {
sidebar.value.style.width = newWidth + 'px';
// 同时调整主内容区域宽度
const mainContent = document.querySelector('main');
if (mainContent) {
mainContent.style.width = `calc(100% - ${newWidth}px)`;
}
}
};
const stopResize = () => {
isResizing = false;
document.body.classList.remove('resizing');
document.removeEventListener('mousemove', handleResize);
document.removeEventListener('mouseup', stopResize);
// 保存当前宽度到localStorage
if (sidebar.value) {
localStorage.setItem('sidebarWidth', sidebar.value.offsetWidth);
}
};
// 应用响应式侧边栏宽度
const applyResponsiveWidth = () => {
if (!sidebar.value) return;
const windowWidth = window.innerWidth;
let sidebarWidth;
if (windowWidth <= 768) {
sidebarWidth = '200px';
} else if (windowWidth <= 1200) {
sidebarWidth = '250px';
} else {
// 大屏幕时使用保存的宽度或默认宽度
const savedWidth = localStorage.getItem('sidebarWidth');
if (savedWidth) {
const width = parseInt(savedWidth);
if (width >= 200 && width <= 500) {
sidebarWidth = width + 'px';
} else {
sidebarWidth = '256px'; // 默认 w-64
}
} else {
sidebarWidth = '256px'; // 默认 w-64
}
}
sidebar.value.style.width = sidebarWidth;
const mainContent = document.querySelector('main');
if (mainContent) {
mainContent.style.width = `calc(100% - ${sidebarWidth})`;
}
};
// 恢复保存的侧边栏宽度
onMounted(() => {
applyResponsiveWidth();
// 监听窗口大小变化
window.addEventListener('resize', applyResponsiveWidth);
});
return {
isLoggedIn,
loading,
loginWithGitHub,
submitting,
showCreator,
searchQuery,
currentUser,
models,
tasks,
alert,
t2vForm,
i2vForm,
digitalHumanForm,
getCurrentForm,
i2vImagePreview,
digitalHumanImagePreview,
digitalHumanAudioPreview,
getCurrentImagePreview,
getCurrentAudioPreview,
availableTaskTypes,
availableModelClasses,
filteredTasks,
selectedTaskId,
selectedTask,
selectedTaskFiles,
loadingTaskFiles,
statusFilter,
pagination,
currentPage,
pageSize,
showAlert,
setLoading,
apiCall,
logout,
loadModels,
generatingThumbnails,
sidebarCollapsed,
thumbnailCache,
thumbnailCacheLoaded,
loadTaskFiles,
downloadFile,
viewFile,
handleImageUpload,
selectTask,
selectModel,
triggerImageUpload,
triggerAudioUpload,
removeImage,
removeAudio,
handleAudioUpload,
loadTemplates,
selectImageTemplate,
selectAudioTemplate,
previewAudioTemplate,
imageTemplates,
audioTemplates,
showImageTemplates,
showAudioTemplates,
showTemplates,
showHistory,
submitTask,
fileToBase64,
formatTime,
refreshTasks,
preloadThumbnailCache,
preloadThumbnailsIndividually,
loadImageWithRetry,
getStatusBadgeClass,
downloadSingleResult,
viewSingleResult,
cancelTask,
resumeTask,
viewTaskDetail,
showTaskCreator,
toggleSidebar,
clearPrompt,
getTaskItemClass,
getStatusIndicatorClass,
getTaskTypeBtnClass,
getModelBtnClass,
getTaskTypeIcon,
getTaskTypeName,
getPromptPlaceholder,
getStatusTextClass,
getImagePreview,
getTaskInputUrl,
getVideoThumbnail,
getVideoThumbnailInfo,
loadThumbnailAsync,
handleThumbnailError,
handleImageError,
handleImageLoad,
handleAudioError,
handleAudioLoad,
downloadTaskInput,
getVideoUrl,
getTaskStatusDisplay,
getTaskStatusColor,
getTaskStatusIcon,
getTaskDuration,
getRelativeTime,
getTaskHistory,
getActiveTasks,
searchTasks,
filterTasksByStatus,
filterTasksByType,
getAlertClass,
getAlertBorderClass,
getAlertTextClass,
getAlertIcon,
getPromptTemplates,
showPromptTemplates,
showPromptHistory,
selectPromptTemplate,
promptHistory,
getPromptHistory,
addPromptToHistory,
selectPromptHistory,
clearPromptHistory,
loadPromptHistory,
getAudioMimeType,
getAuthHeaders,
sidebar,
startResize
};
}
}).mount('#app');
</script>
</body>
</html>
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment