Commit ab1b2790 authored by LiangLiu's avatar LiangLiu Committed by GitHub
Browse files

Deploy server and worker (#284)



* Init deploy: not ok

* Test data_manager & task_manager

* pipeline is no need for worker

* Update worker text_encoder

* deploy: submit task

* add apis

* Test pipelineRunner

* Fix pipeline

* Tidy worker & test PipelineWorker ok

* Tidy code

* Fix multi_stage for wan2.1 t2v & i2v

* api query task, get result & report subtasks failed when workers stop

* Add model list functionality to Pipeline and API

* Add task cancel and task resume  to API

* Add RabbitMQ queue manager

* update local task manager atomic

* support postgreSQL task manager, add lifespan async init

* worker print -> logger

* Add S3 data manager, delete temp objects after finished.

* fix worker

* release fetch queue msg when closed, run stuck worker in another thread, stop worker when process down.

* DiTWorker run with thread & tidy logger print

* Init monitor without test

* fix monitor

* github OAuth and jwt token access & static demo html page

* Add user to task, ok for local task manager & update demo ui

* sql task manager support users

* task list with pages

* merge main fix

* Add proxy for auth request

* support wan audio

* worker ping subtask and ping life, fix rabbitmq async get,

* s3 data manager with async api & tidy monitor config

* fix merge main & update req.txt & fix html view video error

* Fix distributed worker

* LImit user visit freq

* Tidy

* Fix only rank save

* Fix audio input

* Fix worker fetch None

* index.html abs path to rel path

* Fix dist worker stuck

* support publish output video to rtmp & graceful stop running dit step or segment step

* Add VAReader

* Enhance VAReader with torch dist

* Fix audio stream input

* fix merge refractor main, support stream input_audio and output_video

* fix audio read with prev frames & fix end take frames & tidy worker end

* split audio model to 4 workers & fix audio end frame

* fix ping subtask with queue

* Fix audio worker put block & add whep, whip without test ok

* Tidy va recorder & va reader log, thread canel within 30s

* Fix dist worker stuck: broadcast stop signal

* Tidy

* record task active_elapse & subtask status_elapse

* Design prometheus metrics

* Tidy prometheus metrics

* Fix merge main

* send sigint to ffmpeg process

* Fix gstreamer pull audio by whep & Dockerfile for gstreamer & check params when submitting

* Fix merge main

* Query task with more info & va_reader buffer size = 1

* Fix va_recorder

* Add config for prev_frames

* update frontend

* update frontend

* update frontend

* update frontend
merge

* update frontend & partial backend

* Different rank for va_recorder and va_reader

* Fix mem leak: only one rank publish video, other rank should pop gen vids

* fix task category

* va_reader pre-alloc tensor & va_recorder send frames all & fix dist cancel infer

* Fix prev_frame_length

* Tidy

* Tidy

* update frontend & backend

* Fix lint error

* recover some files

* Tidy

* lint code

---------
Co-authored-by: default avatarliuliang1 <liuliang1@sensetime.com>
Co-authored-by: default avatarunknown <qinxinyi@sensetime.com>
parent acacd26f
# For rtc whep, build gstreamer whith whepsrc plugin
FROM registry.ms-sc-01.maoshanwangtech.com/ms-ccr/lightx2v:25080601-cu128-SageSm90 AS gstreamer-base
RUN apt update -y \
&& apt update -y \
&& apt install -y libssl-dev flex bison \
libgtk-3-dev libpango1.0-dev libsoup2.4-dev \
libnice-dev libopus-dev libvpx-dev libx264-dev \
libsrtp2-dev libglib2.0-dev libdrm-dev
RUN cd /opt \
&& wget https://mirrors.tuna.tsinghua.edu.cn/gnu//libiconv/libiconv-1.15.tar.gz \
&& tar zxvf libiconv-1.15.tar.gz \
&& cd libiconv-1.15 \
&& ./configure \
&& make \
&& make install
RUN pip install meson
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable
ENV PATH=/root/.cargo/bin:$PATH
RUN cd /opt \
&& git clone https://github.com/GStreamer/gstreamer.git -b 1.24.12 --depth 1 \
&& cd gstreamer \
&& meson setup builddir \
&& meson compile -C builddir \
&& meson install -C builddir \
&& ldconfig
RUN cd /opt \
&& git clone https://github.com/GStreamer/gst-plugins-rs.git -b gstreamer-1.24.12 --depth 1 \
&& cd gst-plugins-rs \
&& cargo build --package gst-plugin-webrtchttp --release \
&& install -m 644 target/release/libgstwebrtchttp.so $(pkg-config --variable=pluginsdir gstreamer-1.0)/
# Lightx2v deploy image
FROM registry.ms-sc-01.maoshanwangtech.com/ms-ccr/lightx2v:25080601-cu128-SageSm90
RUN mkdir /workspace/lightx2v
WORKDIR /workspace/lightx2v
ENV PYTHONPATH=/workspace/lightx2v
COPY requirements.txt requirements.txt
RUN pip install -r requirements.txt
RUN conda install conda-forge::ffmpeg=8.0.0 -y
RUN rm /usr/bin/ffmpeg && ln -s /opt/conda/bin/ffmpeg /usr/bin/ffmpeg
RUN apt update -y \
&& apt install -y libssl-dev \
libgtk-3-dev libpango1.0-dev libsoup2.4-dev \
libnice-dev libopus-dev libvpx-dev libx264-dev \
libsrtp2-dev libglib2.0-dev libdrm-dev
ENV LBDIR=/usr/local/lib/x86_64-linux-gnu
COPY --from=gstreamer-base /usr/local/bin/gst-* /usr/local/bin/
COPY --from=gstreamer-base $LBDIR $LBDIR
RUN ldconfig
ENV LD_LIBRARY_PATH=$LBDIR:$LD_LIBRARY_PATH
RUN gst-launch-1.0 --version
RUN gst-inspect-1.0 whepsrc
COPY assets assets
COPY configs configs
COPY lightx2v lightx2v
COPY lightx2v_kernel lightx2v_kernel
{
"data":
{
"t2v": {
"wan2.1-1.3B": {
"single_stage": {
"pipeline": {
"inputs": [],
"outputs": ["output_video"]
}
},
"multi_stage": {
"text_encoder": {
"inputs": [],
"outputs": ["text_encoder_output"]
},
"dit": {
"inputs": ["text_encoder_output"],
"outputs": ["latents"]
},
"vae_decoder": {
"inputs": ["latents"],
"outputs": ["output_video"]
}
}
}
},
"i2v": {
"wan2.1-14B-480P": {
"single_stage": {
"pipeline": {
"inputs": ["input_image"],
"outputs": ["output_video"]
}
},
"multi_stage": {
"text_encoder": {
"inputs": ["input_image"],
"outputs": ["text_encoder_output"]
},
"image_encoder": {
"inputs": ["input_image"],
"outputs": ["clip_encoder_output"]
},
"vae_encoder": {
"inputs": ["input_image"],
"outputs": ["vae_encoder_output"]
},
"dit": {
"inputs": [
"clip_encoder_output",
"vae_encoder_output",
"text_encoder_output"
],
"outputs": ["latents"]
},
"vae_decoder": {
"inputs": ["latents"],
"outputs": ["output_video"]
}
}
},
"SekoTalk-Distill": {
"single_stage": {
"pipeline": {
"inputs": ["input_image", "input_audio"],
"outputs": ["output_video"]
}
},
"multi_stage": {
"text_encoder": {
"inputs": ["input_image"],
"outputs": ["text_encoder_output"]
},
"image_encoder": {
"inputs": ["input_image"],
"outputs": ["clip_encoder_output"]
},
"vae_encoder": {
"inputs": ["input_image"],
"outputs": ["vae_encoder_output"]
},
"segment_dit": {
"inputs": [
"input_audio",
"clip_encoder_output",
"vae_encoder_output",
"text_encoder_output"
],
"outputs": ["output_video"]
}
}
}
}
},
"meta": {
"special_types": {
"input_image": "IMAGE",
"input_audio": "AUDIO",
"latents": "TENSOR",
"output_video": "VIDEO"
},
"monitor": {
"subtask_created_timeout": 1800,
"subtask_pending_timeout": 1800,
"subtask_running_timeouts": {
"t2v-wan2.1-1.3B-multi_stage-dit": 300,
"t2v-wan2.1-1.3B-single_stage-pipeline": 300,
"i2v-wan2.1-14B-480P-multi_stage-dit": 600,
"i2v-wan2.1-14B-480P-single_stage-pipeline": 600,
"i2v-SekoTalk-Distill-single_stage-pipeline": 3600,
"i2v-SekoTalk-Distill-multi_stage-segment_dit": 3600
},
"worker_avg_window": 20,
"worker_offline_timeout": 5,
"worker_min_capacity": 20,
"worker_min_cnt": 1,
"worker_max_cnt": 10,
"task_timeout": 3600,
"schedule_ratio_high": 0.25,
"schedule_ratio_low": 0.02,
"ping_timeout": 30,
"user_max_active_tasks": 3,
"user_max_daily_tasks": 100,
"user_visit_frequency": 0.05
}
}
}
import json
import sys
from loguru import logger
class Pipeline:
def __init__(self, pipeline_json_file):
self.pipeline_json_file = pipeline_json_file
x = json.load(open(pipeline_json_file))
self.data = x["data"]
self.meta = x["meta"]
self.inputs = {}
self.outputs = {}
self.temps = {}
self.model_lists = []
self.types = {}
self.queues = set()
self.tidy_pipeline()
def init_dict(self, base, task, model_cls):
if task not in base:
base[task] = {}
if model_cls not in base[task]:
base[task][model_cls] = {}
# tidy each task item eg, ['t2v', 'wan2.1', 'multi_stage']
def tidy_task(self, task, model_cls, stage, v3):
out2worker = {}
out2num = {}
cur_inps = set()
cur_temps = set()
cur_types = {}
for worker_name, worker_item in v3.items():
prevs = []
for inp in worker_item["inputs"]:
cur_types[inp] = self.get_type(inp)
if inp in out2worker:
prevs.append(out2worker[inp])
out2num[inp] -= 1
if out2num[inp] <= 0:
cur_temps.add(inp)
else:
cur_inps.add(inp)
worker_item["previous"] = prevs
for out in worker_item["outputs"]:
cur_types[out] = self.get_type(out)
out2worker[out] = worker_name
if out not in out2num:
out2num[out] = 0
out2num[out] += 1
if "queue" not in worker_item:
worker_item["queue"] = "-".join([task, model_cls, stage, worker_name])
self.queues.add(worker_item["queue"])
cur_outs = [out for out, num in out2num.items() if num > 0]
self.inputs[task][model_cls][stage] = list(cur_inps)
self.outputs[task][model_cls][stage] = cur_outs
self.temps[task][model_cls][stage] = list(cur_temps)
self.types[task][model_cls][stage] = cur_types
# tidy previous dependence workers and queue name
def tidy_pipeline(self):
for task, v1 in self.data.items():
for model_cls, v2 in v1.items():
for stage, v3 in v2.items():
self.init_dict(self.inputs, task, model_cls)
self.init_dict(self.outputs, task, model_cls)
self.init_dict(self.temps, task, model_cls)
self.init_dict(self.types, task, model_cls)
self.tidy_task(task, model_cls, stage, v3)
self.model_lists.append({"task": task, "model_cls": model_cls, "stage": stage})
logger.info(f"pipelines: {json.dumps(self.data, indent=4)}")
logger.info(f"inputs: {self.inputs}")
logger.info(f"outputs: {self.outputs}")
logger.info(f"temps: {self.temps}")
logger.info(f"types: {self.types}")
logger.info(f"model_lists: {self.model_lists}")
logger.info(f"queues: {self.queues}")
def get_item_by_keys(self, keys):
item = self.data
for k in keys:
if k not in item:
raise Exception(f"{keys} are not in {self.pipeline_json_file}!")
item = item[k]
return item
# eg. keys: ['t2v', 'wan2.1', 'multi_stage', 'text_encoder']
def get_worker(self, keys):
return self.get_item_by_keys(keys)
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def get_workers(self, keys):
return self.get_item_by_keys(keys)
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def get_inputs(self, keys):
item = self.inputs
for k in keys:
if k not in item:
raise Exception(f"{keys} are not in inputs!")
item = item[k]
return item
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def get_outputs(self, keys):
item = self.outputs
for k in keys:
if k not in item:
raise Exception(f"{keys} are not in outputs!")
item = item[k]
return item
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def get_temps(self, keys):
item = self.temps
for k in keys:
if k not in item:
raise Exception(f"{keys} are not in temps!")
item = item[k]
return item
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def get_types(self, keys):
item = self.types
for k in keys:
if k not in item:
raise Exception(f"{keys} are not in types!")
item = item[k]
return item
def get_model_lists(self):
return self.model_lists
def get_type(self, name):
return self.meta["special_types"].get(name, "OBJECT")
def get_monitor_config(self):
return self.meta["monitor"]
def get_queues(self):
return self.queues
if __name__ == "__main__":
pipeline = Pipeline(sys.argv[1])
print(pipeline.get_workers(["t2v", "wan2.1", "multi_stage"]))
print(pipeline.get_worker(["i2v", "wan2.1", "multi_stage", "dit"]))
import base64
import io
import os
import time
import traceback
from datetime import datetime
import httpx
import torchaudio
from PIL import Image
from loguru import logger
FMT = "%Y-%m-%d %H:%M:%S"
def current_time():
return datetime.now().timestamp()
def time2str(t):
d = datetime.fromtimestamp(t)
return d.strftime(FMT)
def str2time(s):
d = datetime.strptime(s, FMT)
return d.timestamp()
def try_catch(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception:
logger.error(f"Error in {func.__name__}:")
traceback.print_exc()
return None
return wrapper
def class_try_catch(func):
def wrapper(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except Exception:
logger.error(f"Error in {self.__class__.__name__}.{func.__name__}:")
traceback.print_exc()
return None
return wrapper
def class_try_catch_async(func):
async def wrapper(self, *args, **kwargs):
try:
return await func(self, *args, **kwargs)
except Exception:
logger.error(f"Error in {self.__class__.__name__}.{func.__name__}:")
traceback.print_exc()
return None
return wrapper
def data_name(x, task_id):
if x == "input_image":
x = x + ".png"
elif x == "output_video":
x = x + ".mp4"
return f"{task_id}-{x}"
async def fetch_resource(url, timeout):
logger.info(f"Begin to download resource from url: {url}")
t0 = time.time()
async with httpx.AsyncClient() as client:
async with client.stream("GET", url, timeout=timeout) as response:
response.raise_for_status()
ans_bytes = []
async for chunk in response.aiter_bytes(chunk_size=1024 * 1024):
ans_bytes.append(chunk)
if len(ans_bytes) > 128:
raise Exception(f"url {url} recv data is too big")
content = b"".join(ans_bytes)
logger.info(f"Download url {url} resource cost time: {time.time() - t0} seconds")
return content
async def preload_data(inp, inp_type, typ, val):
try:
if typ == "url":
timeout = int(os.getenv("REQUEST_TIMEOUT", "5"))
data = await fetch_resource(val, timeout=timeout)
elif typ == "base64":
data = base64.b64decode(val)
elif typ == "stream":
# no bytes data need to be saved by data_manager
data = None
else:
raise ValueError(f"cannot read {inp}[{inp_type}] which type is {typ}!")
# check if valid image bytes
if inp_type == "IMAGE":
image = Image.open(io.BytesIO(data))
logger.info(f"load image: {image.size}")
assert image.size[0] > 0 and image.size[1] > 0, "image is empty"
elif inp_type == "AUDIO":
if typ != "stream":
try:
waveform, sample_rate = torchaudio.load(io.BytesIO(data), num_frames=10)
logger.info(f"load audio: {waveform.size()}, {sample_rate}")
assert waveform.size(0) > 0, "audio is empty"
assert sample_rate > 0, "audio sample rate is not valid"
except Exception as e:
logger.warning(f"torchaudio failed to load audio, trying alternative method: {e}")
# 尝试使用其他方法验证音频文件
# 检查文件头是否为有效的音频格式
if len(data) < 4:
raise ValueError("Audio file too short")
# 检查常见的音频文件头
audio_headers = [b"RIFF", b"ID3", b"\xff\xfb", b"\xff\xf3", b"\xff\xf2", b"OggS"]
if not any(data.startswith(header) for header in audio_headers):
logger.warning("Audio file doesn't have recognized header, but continuing...")
logger.info(f"Audio validation passed (alternative method), size: {len(data)} bytes")
else:
raise Exception(f"cannot parse inp_type={inp_type} data")
return data
except Exception as e:
raise ValueError(f"Failed to read {inp}, type={typ}, val={val[:100]}: {e}!")
async def load_inputs(params, raw_inputs, types):
inputs_data = {}
for inp in raw_inputs:
item = params.pop(inp)
bytes_data = await preload_data(inp, types[inp], item["type"], item["data"])
if bytes_data is not None:
inputs_data[inp] = bytes_data
else:
params[inp] = item
return inputs_data
def check_params(params, raw_inputs, raw_outputs, types):
stream_audio = os.getenv("STREAM_AUDIO", "0") == "1"
stream_video = os.getenv("STREAM_VIDEO", "0") == "1"
for x in raw_inputs + raw_outputs:
if x in params and "type" in params[x] and params[x]["type"] == "stream":
if types[x] == "AUDIO":
assert stream_audio, "stream audio is not supported, please set env STREAM_AUDIO=1"
elif types[x] == "VIDEO":
assert stream_video, "stream video is not supported, please set env STREAM_VIDEO=1"
import os
import queue
import signal
import subprocess
import threading
import time
import traceback
import numpy as np
import torch
import torch.distributed as dist
from loguru import logger
class VAReader:
def __init__(
self,
rank: int,
world_size: int,
stream_url: str,
segment_duration: float = 5.0,
sample_rate: int = 16000,
audio_channels: int = 1,
buffer_size: int = 1,
prev_duration: float = 0.3125,
target_rank: int = 0,
):
self.rank = rank
self.world_size = world_size
self.stream_url = stream_url
self.segment_duration = segment_duration
self.sample_rate = sample_rate
self.audio_channels = audio_channels
self.prev_duration = prev_duration
# int16 = 2 bytes
self.chunk_size = int(self.segment_duration * self.sample_rate) * 2
self.prev_size = int(self.prev_duration * self.sample_rate) * 2
self.prev_chunk = None
self.buffer_size = buffer_size
self.audio_queue = queue.Queue(maxsize=self.buffer_size)
self.audio_thread = None
self.ffmpeg_process = None
self.bytes_buffer = bytearray()
self.target_rank = target_rank % self.world_size
self.flag_tensor = torch.tensor([0], dtype=torch.int32).to(device="cuda")
self.audio_tensor = torch.zeros(self.chunk_size, dtype=torch.uint8, device="cuda")
logger.info(f"VAReader initialized for stream: {stream_url} target_rank: {self.target_rank}")
logger.info(f"Audio duration per chunk: {segment_duration}s, sample rate: {sample_rate}Hz")
def start(self):
if self.rank == self.target_rank:
if self.stream_url.startswith("rtmp://"):
self.start_ffmpeg_process_rtmp()
elif self.stream_url.startswith("http"):
self.start_ffmpeg_process_whep()
else:
raise Exception(f"Unsupported stream URL: {self.stream_url}")
self.audio_thread = threading.Thread(target=self.audio_worker, daemon=True)
self.audio_thread.start()
logger.info(f"VAReader {self.rank}/{self.world_size} started successfully")
else:
logger.info(f"VAReader {self.rank}/{self.world_size} wait only")
if self.world_size > 1:
logger.info(f"VAReader {self.rank}/{self.world_size} wait barrier")
dist.barrier()
logger.info(f"VAReader {self.rank}/{self.world_size} end barrier")
def start_ffmpeg_process_rtmp(self):
"""Start ffmpeg process read audio from stream"""
ffmpeg_cmd = [
"/opt/conda/bin/ffmpeg",
"-i",
self.stream_url,
"-vn",
# "-acodec",
# "pcm_s16le",
"-ar",
str(self.sample_rate),
"-ac",
str(self.audio_channels),
"-f",
"s16le",
"-",
]
try:
self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=0)
logger.info(f"FFmpeg audio pull process started with PID: {self.ffmpeg_process.pid}")
logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
except Exception as e:
logger.error(f"Failed to start FFmpeg process: {e}")
raise
def start_ffmpeg_process_whep(self):
"""Start gstream process read audio from stream"""
ffmpeg_cmd = [
"gst-launch-1.0",
"-q",
"whepsrc",
f"whep-endpoint={self.stream_url}",
"video-caps=none",
"!rtpopusdepay",
"!opusdec",
"plc=false",
"!audioconvert",
"!audioresample",
f"!audio/x-raw,format=S16LE,channels={self.audio_channels},rate={self.sample_rate}",
"!fdsink",
"fd=1",
]
try:
self.ffmpeg_process = subprocess.Popen(
ffmpeg_cmd,
stdout=subprocess.PIPE,
# stderr=subprocess.PIPE,
bufsize=0,
)
logger.info(f"FFmpeg audio pull process started with PID: {self.ffmpeg_process.pid}")
logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
except Exception as e:
logger.error(f"Failed to start FFmpeg process: {e}")
raise
def audio_worker(self):
logger.info("Audio pull worker thread started")
try:
while True:
if not self.ffmpeg_process or self.ffmpeg_process.poll() is not None:
logger.warning("FFmpeg process exited, audio worker thread stopped")
break
self.fetch_audio_data()
time.sleep(0.01)
except: # noqa
logger.error(f"Audio pull worker error: {traceback.format_exc()}")
finally:
logger.warning("Audio pull worker thread stopped")
def fetch_audio_data(self):
"""Fetch audio data from ffmpeg process"""
try:
audio_bytes = self.ffmpeg_process.stdout.read(self.chunk_size)
if not audio_bytes:
return
self.bytes_buffer.extend(audio_bytes)
# logger.info(f"Fetch audio data: {len(audio_bytes)} bytes, bytes_buffer: {len(self.bytes_buffer)} bytes")
if len(self.bytes_buffer) >= self.chunk_size:
audio_data = self.bytes_buffer[: self.chunk_size]
self.bytes_buffer = self.bytes_buffer[self.chunk_size :]
# first chunk, read original 81 frames
# for other chunks, read 81 - 5 = 76 frames, concat with previous 5 frames
if self.prev_chunk is None:
logger.info(f"change chunk_size: from {self.chunk_size} to {self.chunk_size - self.prev_size}")
self.chunk_size -= self.prev_size
else:
audio_data = self.prev_chunk + audio_data
self.prev_chunk = audio_data[-self.prev_size :]
try:
self.audio_queue.put_nowait(audio_data)
except queue.Full:
logger.warning(f"Audio queue full:{self.audio_queue.qsize()}, discarded oldest chunk")
self.audio_queue.get_nowait()
self.audio_queue.put_nowait(audio_data)
logger.info(f"Put audio data: {len(audio_data)} bytes, audio_queue: {self.audio_queue.qsize()}, chunk_size:{self.chunk_size}")
except: # noqa
logger.error(f"Fetch audio data error: {traceback.format_exc()}")
def braodcast_audio_data(self, audio_data):
if self.rank == self.target_rank:
if audio_data is None:
self.flag_tensor.fill_(0)
else:
self.flag_tensor.fill_(1)
self.audio_tensor.copy_(torch.frombuffer(bytearray(audio_data), dtype=torch.uint8))
logger.info(f"rank {self.rank} send audio_tensor: {self.audio_tensor.shape}")
dist.broadcast(self.flag_tensor, src=self.target_rank)
if self.flag_tensor.item() == 0:
return None
dist.broadcast(self.audio_tensor, src=self.target_rank)
if self.rank != self.target_rank:
logger.info(f"rank {self.rank} recv audio_tensor: {self.audio_tensor.shape}")
audio_data = self.audio_tensor.cpu().numpy().tobytes()
return audio_data
def bytes_to_ndarray(self, audio_data):
if audio_data is None:
return None
audio_data = np.frombuffer(audio_data, dtype=np.int16)
audio_data = audio_data.astype(np.float32) / 32768.0
logger.info(f"Got segment audio rank={self.rank}: {audio_data.shape} {audio_data.dtype} {audio_data.min()} {audio_data.max()}")
return audio_data
def get_audio_segment(self, timeout: float = 1.0):
audio_data = None
if self.rank == self.target_rank:
try:
audio_data = self.audio_queue.get(timeout=timeout)
except: # noqa
logger.warning(f"Failed to get audio segment: {traceback.format_exc()}")
if self.world_size > 1:
audio_data = self.braodcast_audio_data(audio_data)
audio_data = self.bytes_to_ndarray(audio_data)
return audio_data
def stop(self):
# Stop ffmpeg process
if self.ffmpeg_process:
self.ffmpeg_process.send_signal(signal.SIGINT)
try:
self.ffmpeg_process.wait(timeout=5)
except subprocess.TimeoutExpired:
self.ffmpeg_process.kill()
logger.warning("FFmpeg reader process stopped")
# Wait for threads to finish
if self.audio_thread and self.audio_thread.is_alive():
self.audio_thread.join(timeout=5)
if self.audio_thread.is_alive():
logger.error("Audio pull thread did not stop gracefully")
while self.audio_queue and self.audio_queue.qsize() > 0:
self.audio_queue.get_nowait()
self.audio_queue = None
logger.warning("Audio pull queue cleaned")
def __del__(self):
self.stop()
if __name__ == "__main__":
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
RANK = int(os.environ.get("RANK", 0))
if WORLD_SIZE > 1:
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
logger.info(f"Distributed initialized: rank={RANK}, world_size={WORLD_SIZE}")
reader = VAReader(
RANK,
WORLD_SIZE,
# "rtmp://localhost/live/test_audio",
"https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whep/?app=live&stream=ll_test_audio&eip=10.120.114.76:8000",
segment_duration=1.0,
sample_rate=16000,
audio_channels=1,
prev_duration=1 / 16,
)
reader.start()
fail_count = 0
max_fail_count = 2
try:
while True:
audio_data = reader.get_audio_segment(timeout=2)
if audio_data is not None:
# logger.info(f"Got audio chunk, shape: {audio_data.shape}, range: [{audio_data.min()}, {audio_data.max()}]")
fail_count = 0
else:
fail_count += 1
if fail_count > max_fail_count:
logger.warning("Failed to get audio chunk, stop reader")
reader.stop()
break
time.sleep(0.95)
finally:
reader.stop()
import queue
import signal
import socket
import subprocess
import threading
import time
import traceback
import numpy as np
import torch
import torchaudio as ta
from loguru import logger
class VARecorder:
def __init__(
self,
livestream_url: str,
fps: float = 16.0,
sample_rate: int = 16000,
audio_port: int = 30200,
video_port: int = 30201,
):
self.livestream_url = livestream_url
self.fps = fps
self.sample_rate = sample_rate
self.audio_port = audio_port
self.video_port = video_port
self.width = None
self.height = None
self.stoppable_t = None
# ffmpeg process for mix video and audio data and push to livestream
self.ffmpeg_process = None
# TCP connection objects
self.audio_socket = None
self.video_socket = None
self.audio_conn = None
self.video_conn = None
self.audio_thread = None
self.video_thread = None
# queue for send data to ffmpeg process
self.audio_queue = queue.Queue()
self.video_queue = queue.Queue()
def init_sockets(self):
# TCP socket for send and recv video and audio data
self.video_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.video_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.video_socket.bind(("127.0.0.1", self.video_port))
self.video_socket.listen(1)
self.audio_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.audio_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.audio_socket.bind(("127.0.0.1", self.audio_port))
self.audio_socket.listen(1)
def audio_worker(self):
try:
logger.info("Waiting for ffmpeg to connect to audio socket...")
self.audio_conn, _ = self.audio_socket.accept()
logger.info(f"Audio connection established from {self.audio_conn.getpeername()}")
fail_time, max_fail_time = 0, 10
while True:
try:
if self.audio_queue is None:
break
data = self.audio_queue.get()
if data is None:
logger.info("Audio thread received stop signal")
break
# Convert audio data to 16-bit integer format
audios = np.clip(np.round(data * 32767), -32768, 32767).astype(np.int16)
self.audio_conn.send(audios.tobytes())
fail_time = 0
except: # noqa
logger.error(f"Send audio data error: {traceback.format_exc()}")
fail_time += 1
if fail_time > max_fail_time:
logger.error(f"Audio push worker thread failed {fail_time} times, stopping...")
break
except: # noqa
logger.error(f"Audio push worker thread error: {traceback.format_exc()}")
finally:
logger.info("Audio push worker thread stopped")
def video_worker(self):
try:
logger.info("Waiting for ffmpeg to connect to video socket...")
self.video_conn, _ = self.video_socket.accept()
logger.info(f"Video connection established from {self.video_conn.getpeername()}")
fail_time, max_fail_time = 0, 10
while True:
try:
if self.video_queue is None:
break
data = self.video_queue.get()
if data is None:
logger.info("Video thread received stop signal")
break
# Convert to numpy and scale to [0, 255], convert RGB to BGR for OpenCV/FFmpeg
frames = (data * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
self.video_conn.send(frames.tobytes())
fail_time = 0
except: # noqa
logger.error(f"Send video data error: {traceback.format_exc()}")
fail_time += 1
if fail_time > max_fail_time:
logger.error(f"Video push worker thread failed {fail_time} times, stopping...")
break
except: # noqa
logger.error(f"Video push worker thread error: {traceback.format_exc()}")
finally:
logger.info("Video push worker thread stopped")
def start_ffmpeg_process_rtmp(self):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd = [
"/opt/conda/bin/ffmpeg",
"-re",
"-f",
"s16le",
"-ar",
str(self.sample_rate),
"-ac",
"1",
"-i",
f"tcp://127.0.0.1:{self.audio_port}",
"-f",
"rawvideo",
"-re",
"-pix_fmt",
"rgb24",
"-r",
str(self.fps),
"-s",
f"{self.width}x{self.height}",
"-i",
f"tcp://127.0.0.1:{self.video_port}",
"-ar",
"44100",
"-b:v",
"4M",
"-c:v",
"libx264",
"-preset",
"ultrafast",
"-tune",
"zerolatency",
"-g",
f"{self.fps}",
"-pix_fmt",
"yuv420p",
"-f",
"flv",
self.livestream_url,
"-y",
"-loglevel",
"info",
]
try:
self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd)
logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}")
logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
except Exception as e:
logger.error(f"Failed to start FFmpeg: {e}")
def start_ffmpeg_process_whip(self):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd = [
"/opt/conda/bin/ffmpeg",
"-re",
"-f",
"s16le",
"-ar",
str(self.sample_rate),
"-ac",
"1",
"-i",
f"tcp://127.0.0.1:{self.audio_port}",
"-f",
"rawvideo",
"-re",
"-pix_fmt",
"rgb24",
"-r",
str(self.fps),
"-s",
f"{self.width}x{self.height}",
"-i",
f"tcp://127.0.0.1:{self.video_port}",
"-ar",
"48000",
"-c:a",
"libopus",
"-ac",
"2",
"-b:v",
"4M",
"-c:v",
"libx264",
"-preset",
"ultrafast",
"-tune",
"zerolatency",
"-g",
f"{self.fps}",
"-pix_fmt",
"yuv420p",
"-threads",
"1",
"-bf",
"0",
"-f",
"whip",
self.livestream_url,
"-y",
"-loglevel",
"info",
]
try:
self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd)
logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}")
logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
except Exception as e:
logger.error(f"Failed to start FFmpeg: {e}")
def set_video_size(self, width: int, height: int):
if self.width is not None and self.height is not None:
assert self.width == width and self.height == height, "Video size already set"
return
self.width = width
self.height = height
self.init_sockets()
if self.livestream_url.startswith("rtmp://"):
self.start_ffmpeg_process_rtmp()
elif self.livestream_url.startswith("http"):
self.start_ffmpeg_process_whip()
else:
raise Exception(f"Unsupported livestream URL: {self.livestream_url}")
self.audio_thread = threading.Thread(target=self.audio_worker)
self.video_thread = threading.Thread(target=self.video_worker)
self.audio_thread.start()
self.video_thread.start()
# Publish ComfyUI Image tensor and audio tensor to livestream
def pub_livestream(self, images: torch.Tensor, audios: np.ndarray):
N, height, width, C = images.shape
M = audios.reshape(-1).shape[0]
assert C == 3, "Input must be [N, H, W, C] with C=3"
logger.info(f"Publishing video [{N}x{width}x{height}], audio: [{M}]")
audio_frames = round(M * self.fps / self.sample_rate)
if audio_frames != N:
logger.warning(f"Video and audio frames mismatch, {N} vs {audio_frames}")
self.set_video_size(width, height)
self.audio_queue.put(audios)
self.video_queue.put(images)
logger.info(f"Published {N} frames and {M} audio samples")
self.stoppable_t = time.time() + M / self.sample_rate + 3
def stop(self, wait=True):
if wait and self.stoppable_t:
t = self.stoppable_t - time.time()
if t > 0:
logger.warning(f"Waiting for {t} seconds to stop ...")
time.sleep(t)
self.stoppable_t = None
# Send stop signals to queues
if self.audio_queue:
self.audio_queue.put(None)
if self.video_queue:
self.video_queue.put(None)
# Wait for threads to finish
if self.audio_thread and self.audio_thread.is_alive():
self.audio_thread.join(timeout=5)
if self.audio_thread.is_alive():
logger.warning("Audio push thread did not stop gracefully")
if self.video_thread and self.video_thread.is_alive():
self.video_thread.join(timeout=5)
if self.video_thread.is_alive():
logger.warning("Video push thread did not stop gracefully")
# Close TCP connections, sockets
if self.audio_conn:
self.audio_conn.close()
if self.video_conn:
self.video_conn.close()
if self.audio_socket:
self.audio_socket.close()
if self.video_socket:
self.video_socket.close()
while self.audio_queue and self.audio_queue.qsize() > 0:
self.audio_queue.get_nowait()
while self.video_queue and self.video_queue.qsize() > 0:
self.video_queue.get_nowait()
self.audio_queue = None
self.video_queue = None
logger.warning("Cleaned audio and video queues")
# Stop ffmpeg process
if self.ffmpeg_process:
self.ffmpeg_process.send_signal(signal.SIGINT)
try:
self.ffmpeg_process.wait(timeout=5)
except subprocess.TimeoutExpired:
self.ffmpeg_process.kill()
logger.warning("FFmpeg recorder process stopped")
def __del__(self):
self.stop(wait=False)
def create_simple_video(frames=10, height=480, width=640):
video_data = []
for i in range(frames):
frame = np.zeros((height, width, 3), dtype=np.float32)
stripe_height = height // 8
colors = [
[1.0, 0.0, 0.0], # 红色
[0.0, 1.0, 0.0], # 绿色
[0.0, 0.0, 1.0], # 蓝色
[1.0, 1.0, 0.0], # 黄色
[1.0, 0.0, 1.0], # 洋红
[0.0, 1.0, 1.0], # 青色
[1.0, 1.0, 1.0], # 白色
[0.5, 0.5, 0.5], # 灰色
]
for j, color in enumerate(colors):
start_y = j * stripe_height
end_y = min((j + 1) * stripe_height, height)
frame[start_y:end_y, :] = color
offset = int((i / frames) * width)
frame = np.roll(frame, offset, axis=1)
frame = torch.tensor(frame, dtype=torch.float32)
video_data.append(frame)
return torch.stack(video_data, dim=0)
if __name__ == "__main__":
sample_rate = 16000
fps = 16
width = 640
height = 480
recorder = VARecorder(
# livestream_url="rtmp://localhost/live/test",
livestream_url="https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whip/?app=live&stream=ll_test_video&eip=127.0.0.1:8000",
fps=fps,
sample_rate=sample_rate,
)
audio_path = "/mtc/liuliang1/lightx2v/test_deploy/media_test/test_b_2min.wav"
audio_array, ori_sr = ta.load(audio_path)
audio_array = ta.functional.resample(audio_array.mean(0), orig_freq=ori_sr, new_freq=16000)
audio_array = audio_array.numpy().reshape(-1)
secs = audio_array.shape[0] // sample_rate
interval = 1
for i in range(0, secs, interval):
logger.info(f"{i} / {secs} s")
start = i * sample_rate
end = (i + interval) * sample_rate
cur_audio_array = audio_array[start:end]
logger.info(f"audio: {cur_audio_array.shape} {cur_audio_array.dtype} {cur_audio_array.min()} {cur_audio_array.max()}")
num_frames = int(interval * fps)
images = create_simple_video(num_frames, height, width)
logger.info(f"images: {images.shape} {images.dtype} {images.min()} {images.max()}")
recorder.pub_livestream(images, cur_audio_array)
time.sleep(interval)
recorder.stop()
import io
import json
import torch
from PIL import Image
from lightx2v.deploy.common.utils import class_try_catch_async
class BaseDataManager:
def __init__(self):
pass
async def init(self):
pass
async def close(self):
pass
def to_device(self, data, device):
if isinstance(data, dict):
return {key: self.to_device(value, device) for key, value in data.items()}
elif isinstance(data, list):
return [self.to_device(item, device) for item in data]
elif isinstance(data, torch.Tensor):
return data.to(device)
else:
return data
async def save_bytes(self, bytes_data, filename):
raise NotImplementedError
async def load_bytes(self, filename):
raise NotImplementedError
async def delete_bytes(self, filename):
raise NotImplementedError
async def recurrent_save(self, data, prefix):
if isinstance(data, dict):
return {k: await self.recurrent_save(v, f"{prefix}-{k}") for k, v in data.items()}
elif isinstance(data, list):
return [await self.recurrent_save(v, f"{prefix}-{idx}") for idx, v in enumerate(data)]
elif isinstance(data, torch.Tensor):
save_path = prefix + ".pt"
await self.save_tensor(data, save_path)
return save_path
elif isinstance(data, Image.Image):
save_path = prefix + ".png"
await self.save_image(data, save_path)
return save_path
else:
return data
async def recurrent_load(self, data, device, prefix):
if isinstance(data, dict):
return {k: await self.recurrent_load(v, device, f"{prefix}-{k}") for k, v in data.items()}
elif isinstance(data, list):
return [await self.recurrent_load(v, device, f"{prefix}-{idx}") for idx, v in enumerate(data)]
elif isinstance(data, str) and data == prefix + ".pt":
return await self.load_tensor(data, device)
elif isinstance(data, str) and data == prefix + ".png":
return await self.load_image(data)
else:
return data
async def recurrent_delete(self, data, prefix):
if isinstance(data, dict):
return {k: await self.recurrent_delete(v, f"{prefix}-{k}") for k, v in data.items()}
elif isinstance(data, list):
return [await self.recurrent_delete(v, f"{prefix}-{idx}") for idx, v in enumerate(data)]
elif isinstance(data, str) and data == prefix + ".pt":
await self.delete_bytes(data)
elif isinstance(data, str) and data == prefix + ".png":
await self.delete_bytes(data)
@class_try_catch_async
async def save_object(self, data, filename):
data = await self.recurrent_save(data, filename)
bytes_data = json.dumps(data, ensure_ascii=False).encode("utf-8")
await self.save_bytes(bytes_data, filename)
@class_try_catch_async
async def load_object(self, filename, device):
bytes_data = await self.load_bytes(filename)
data = json.loads(bytes_data.decode("utf-8"))
data = await self.recurrent_load(data, device, filename)
return data
@class_try_catch_async
async def delete_object(self, filename):
bytes_data = await self.load_bytes(filename)
data = json.loads(bytes_data.decode("utf-8"))
await self.recurrent_delete(data, filename)
await self.delete_bytes(filename)
@class_try_catch_async
async def save_tensor(self, data: torch.Tensor, filename):
buffer = io.BytesIO()
torch.save(data.to("cpu"), buffer)
await self.save_bytes(buffer.getvalue(), filename)
@class_try_catch_async
async def load_tensor(self, filename, device):
bytes_data = await self.load_bytes(filename)
buffer = io.BytesIO(bytes_data)
t = torch.load(io.BytesIO(bytes_data))
t = t.to(device)
return t
@class_try_catch_async
async def save_image(self, data: Image.Image, filename):
buffer = io.BytesIO()
data.save(buffer, format="PNG")
await self.save_bytes(buffer.getvalue(), filename)
@class_try_catch_async
async def load_image(self, filename):
bytes_data = await self.load_bytes(filename)
buffer = io.BytesIO(bytes_data)
img = Image.open(buffer).convert("RGB")
return img
def get_delete_func(self, type):
maps = {
"TENSOR": self.delete_bytes,
"IMAGE": self.delete_bytes,
"OBJECT": self.delete_object,
"VIDEO": self.delete_bytes,
}
return maps[type]
# Import data manager implementations
from .local_data_manager import LocalDataManager # noqa
from .s3_data_manager import S3DataManager # noqa
__all__ = ["BaseDataManager", "LocalDataManager", "S3DataManager"]
import asyncio
import os
from loguru import logger
from lightx2v.deploy.common.utils import class_try_catch_async
from lightx2v.deploy.data_manager import BaseDataManager
class LocalDataManager(BaseDataManager):
def __init__(self, local_dir):
self.local_dir = local_dir
self.name = "local"
if not os.path.exists(self.local_dir):
os.makedirs(self.local_dir)
@class_try_catch_async
async def save_bytes(self, bytes_data, filename):
out_path = os.path.join(self.local_dir, filename)
with open(out_path, "wb") as fout:
fout.write(bytes_data)
return True
@class_try_catch_async
async def load_bytes(self, filename):
inp_path = os.path.join(self.local_dir, filename)
with open(inp_path, "rb") as fin:
return fin.read()
@class_try_catch_async
async def delete_bytes(self, filename):
inp_path = os.path.join(self.local_dir, filename)
os.remove(inp_path)
logger.info(f"deleted local file {filename}")
return True
async def test():
import torch
from PIL import Image
m = LocalDataManager("/data/nvme1/liuliang1/lightx2v/local_data")
await m.init()
img = Image.open("/data/nvme1/liuliang1/lightx2v/assets/img_lightx2v.png")
tensor = torch.Tensor([233, 456, 789]).to(dtype=torch.bfloat16, device="cuda:0")
await m.save_image(img, "test_img.png")
print(await m.load_image("test_img.png"))
await m.save_tensor(tensor, "test_tensor.pt")
print(await m.load_tensor("test_tensor.pt", "cuda:0"))
await m.save_object(
{
"images": [img, img],
"tensor": tensor,
"list": [
[2, 0, 5, 5],
{
"1": "hello world",
"2": "world",
"3": img,
"t": tensor,
},
"0609",
],
},
"test_object.json",
)
print(await m.load_object("test_object.json", "cuda:0"))
await m.get_delete_func("OBJECT")("test_object.json")
await m.get_delete_func("TENSOR")("test_tensor.pt")
await m.get_delete_func("IMAGE")("test_img.png")
if __name__ == "__main__":
asyncio.run(test())
import asyncio
import hashlib
import json
import os
import aioboto3
from botocore.client import Config
from loguru import logger
from lightx2v.deploy.common.utils import class_try_catch_async
from lightx2v.deploy.data_manager import BaseDataManager
class S3DataManager(BaseDataManager):
def __init__(self, config_string, max_retries=3):
self.name = "s3"
self.config = json.loads(config_string)
self.max_retries = max_retries
self.bucket_name = self.config["bucket_name"]
self.aws_access_key_id = self.config["aws_access_key_id"]
self.aws_secret_access_key = self.config["aws_secret_access_key"]
self.endpoint_url = self.config["endpoint_url"]
self.base_path = self.config["base_path"]
self.connect_timeout = self.config.get("connect_timeout", 60)
self.read_timeout = self.config.get("read_timeout", 10)
self.write_timeout = self.config.get("write_timeout", 10)
self.session = None
self.s3_client = None
async def init(self):
for i in range(self.max_retries):
try:
logger.info(f"S3DataManager init with config: {self.config} (attempt {i + 1}/{self.max_retries}) ...")
self.session = aioboto3.Session()
self.s3_client = await self.session.client(
"s3",
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
endpoint_url=self.endpoint_url,
config=Config(
signature_version="s3v4",
s3={"payload_signing_enabled": True},
connect_timeout=self.connect_timeout,
read_timeout=self.read_timeout,
parameter_validation=False,
max_pool_connections=50,
),
).__aenter__()
try:
await self.s3_client.head_bucket(Bucket=self.bucket_name)
logger.info(f"check bucket {self.bucket_name} success")
except Exception as e:
logger.info(f"check bucket {self.bucket_name} error: {e}, try to create it...")
await self.s3_client.create_bucket(Bucket=self.bucket_name)
logger.info(f"Successfully init S3 bucket: {self.bucket_name} with timeouts - connect: {self.connect_timeout}s, read: {self.read_timeout}s, write: {self.write_timeout}s")
return
except Exception as e:
logger.warning(f"Failed to connect to S3: {e}")
await asyncio.sleep(1)
async def close(self):
if self.s3_client:
await self.s3_client.__aexit__(None, None, None)
if self.session:
self.session = None
@class_try_catch_async
async def save_bytes(self, bytes_data, filename):
filename = os.path.join(self.base_path, filename)
content_sha256 = hashlib.sha256(bytes_data).hexdigest()
await self.s3_client.put_object(
Bucket=self.bucket_name,
Key=filename,
Body=bytes_data,
ChecksumSHA256=content_sha256,
ContentType="application/octet-stream",
)
return True
@class_try_catch_async
async def load_bytes(self, filename):
filename = os.path.join(self.base_path, filename)
response = await self.s3_client.get_object(Bucket=self.bucket_name, Key=filename)
return await response["Body"].read()
@class_try_catch_async
async def delete_bytes(self, filename):
filename = os.path.join(self.base_path, filename)
await self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename)
logger.info(f"deleted s3 file {filename}")
return True
async def file_exists(self, filename):
filename = os.path.join(self.base_path, filename)
try:
await self.s3_client.head_object(Bucket=self.bucket_name, Key=filename)
return True
except Exception:
return False
async def list_files(self, prefix=""):
prefix = os.path.join(self.base_path, prefix)
response = await self.s3_client.list_objects_v2(Bucket=self.bucket_name, Prefix=prefix)
files = []
if "Contents" in response:
for obj in response["Contents"]:
files.append(obj["Key"])
return files
async def test():
import torch
from PIL import Image
s3_config = {
"aws_access_key_id": "xxx",
"aws_secret_access_key": "xx",
"endpoint_url": "xxx",
"bucket_name": "xxx",
"base_path": "xxx",
"connect_timeout": 10,
"read_timeout": 10,
"write_timeout": 10,
}
m = S3DataManager(json.dumps(s3_config))
await m.init()
img = Image.open("../../../assets/img_lightx2v.png")
tensor = torch.Tensor([233, 456, 789]).to(dtype=torch.bfloat16, device="cuda:0")
await m.save_image(img, "test_img.png")
print(await m.load_image("test_img.png"))
await m.save_tensor(tensor, "test_tensor.pt")
print(await m.load_tensor("test_tensor.pt", "cuda:0"))
await m.save_object(
{
"images": [img, img],
"tensor": tensor,
"list": [
[2, 0, 5, 5],
{
"1": "hello world",
"2": "world",
"3": img,
"t": tensor,
},
"0609",
],
},
"test_object.json",
)
print(await m.load_object("test_object.json", "cuda:0"))
print("all files:", await m.list_files())
await m.get_delete_func("OBJECT")("test_object.json")
await m.get_delete_func("TENSOR")("test_tensor.pt")
await m.get_delete_func("IMAGE")("test_img.png")
print("after delete all files", await m.list_files())
await m.close()
if __name__ == "__main__":
asyncio.run(test())
class BaseQueueManager:
def __init__(self):
pass
async def init(self):
pass
async def close(self):
pass
async def put_subtask(self, subtask):
raise NotImplementedError
async def get_subtasks(self, queue, max_batch, timeout):
raise NotImplementedError
async def pending_num(self, queue):
raise NotImplementedError
# Import queue manager implementations
from .local_queue_manager import LocalQueueManager # noqa
from .rabbitmq_queue_manager import RabbitMQQueueManager # noqa
__all__ = ["BaseQueueManager", "LocalQueueManager", "RabbitMQQueueManager"]
import asyncio
import json
import os
import time
import traceback
from loguru import logger
from lightx2v.deploy.common.utils import class_try_catch_async
from lightx2v.deploy.queue_manager import BaseQueueManager
class LocalQueueManager(BaseQueueManager):
def __init__(self, local_dir):
self.local_dir = local_dir
if not os.path.exists(self.local_dir):
os.makedirs(self.local_dir)
async def get_conn(self):
pass
async def del_conn(self):
pass
async def declare_queue(self, queue):
pass
@class_try_catch_async
async def put_subtask(self, subtask):
out_name = self.get_filename(subtask["queue"])
keys = ["queue", "task_id", "worker_name", "inputs", "outputs", "params"]
msg = json.dumps({k: subtask[k] for k in keys}) + "\n"
logger.info(f"Local published subtask: ({subtask['task_id']}, {subtask['worker_name']}) to {subtask['queue']}")
with open(out_name, "a") as fout:
fout.write(msg)
return True
def read_first_line(self, queue):
out_name = self.get_filename(queue)
if not os.path.exists(out_name):
return None
lines = []
with open(out_name) as fin:
lines = fin.readlines()
if len(lines) <= 0:
return None
subtask = json.loads(lines[0])
msgs = "".join(lines[1:])
fout = open(out_name, "w")
fout.write(msgs)
fout.close()
return subtask
@class_try_catch_async
async def get_subtasks(self, queue, max_batch, timeout):
try:
t0 = time.time()
subtasks = []
while True:
subtask = self.read_first_line(queue)
if subtask:
subtasks.append(subtask)
if len(subtasks) >= max_batch:
return subtasks
else:
continue
else:
if len(subtasks) > 0:
return subtasks
if time.time() - t0 > timeout:
return None
await asyncio.sleep(1)
except asyncio.CancelledError:
logger.warning(f"local queue get_subtasks for {queue} cancelled")
return None
except: # noqa
logger.warning(f"local queue get_subtasks for {queue} failed: {traceback.format_exc()}")
return None
def get_filename(self, queue):
return os.path.join(self.local_dir, f"{queue}.jsonl")
@class_try_catch_async
async def pending_num(self, queue):
out_name = self.get_filename(queue)
if not os.path.exists(out_name):
return 0
lines = []
with open(out_name) as fin:
lines = fin.readlines()
return len(lines)
async def test():
q = LocalQueueManager("/data/nvme1/liuliang1/lightx2v/local_queue")
await q.init()
subtask = {
"task_id": "test-subtask-id",
"queue": "test_queue",
"worker_name": "test_worker",
"inputs": {},
"outputs": {},
"params": {},
}
await q.put_subtask(subtask)
await asyncio.sleep(5)
for i in range(2):
subtask = await q.get_subtasks("test_queue", 3, 5)
print("get subtask:", subtask)
if __name__ == "__main__":
asyncio.run(test())
import asyncio
import json
import traceback
import aio_pika
from loguru import logger
from lightx2v.deploy.common.utils import class_try_catch_async
from lightx2v.deploy.queue_manager import BaseQueueManager
class RabbitMQQueueManager(BaseQueueManager):
def __init__(self, conn_url, max_retries=3):
self.conn_url = conn_url
self.max_retries = max_retries
self.conn = None
self.chan = None
self.queues = set()
async def init(self):
await self.get_conn()
async def close(self):
await self.del_conn()
async def get_conn(self):
if self.chan and self.conn:
return
for i in range(self.max_retries):
try:
logger.info(f"Connect to RabbitMQ (attempt {i + 1}/{self.max_retries}..)")
self.conn = await aio_pika.connect_robust(self.conn_url)
self.chan = await self.conn.channel()
self.queues = set()
await self.chan.set_qos(prefetch_count=10)
logger.info("Successfully connected to RabbitMQ")
return
except Exception as e:
logger.warning(f"Failed to connect to RabbitMQ: {e}")
if i < self.max_retries - 1:
await asyncio.sleep(1)
else:
raise
async def declare_queue(self, queue):
if queue not in self.queues:
await self.get_conn()
await self.chan.declare_queue(queue, durable=True)
self.queues.add(queue)
return await self.chan.get_queue(queue)
@class_try_catch_async
async def put_subtask(self, subtask):
queue = subtask["queue"]
await self.declare_queue(queue)
keys = ["queue", "task_id", "worker_name", "inputs", "outputs", "params"]
msg = json.dumps({k: subtask[k] for k in keys}).encode("utf-8")
message = aio_pika.Message(body=msg, delivery_mode=aio_pika.DeliveryMode.PERSISTENT, content_type="application/json")
await self.chan.default_exchange.publish(message, routing_key=queue)
logger.info(f"Rabbitmq published subtask: ({subtask['task_id']}, {subtask['worker_name']}) to {queue}")
return True
async def get_subtasks(self, queue, max_batch, timeout):
try:
q = await self.declare_queue(queue)
subtasks = []
async with q.iterator() as qiter:
async for message in qiter:
await message.ack()
subtask = json.loads(message.body.decode("utf-8"))
subtasks.append(subtask)
if len(subtasks) >= max_batch:
return subtasks
while True:
message = await q.get(no_ack=False, fail=False)
if message:
await message.ack()
subtask = json.loads(message.body.decode("utf-8"))
subtasks.append(subtask)
if len(subtasks) >= max_batch:
return subtasks
else:
return subtasks
except asyncio.CancelledError:
logger.warning(f"rabbitmq get_subtasks for {queue} cancelled")
return None
except: # noqa
logger.warning(f"rabbitmq get_subtasks for {queue} failed: {traceback.format_exc()}")
return None
@class_try_catch_async
async def pending_num(self, queue):
q = await self.declare_queue(queue)
return q.declaration_result.message_count
async def del_conn(self):
if self.chan:
await self.chan.close()
if self.conn:
await self.conn.close()
async def test():
conn_url = "amqp://username:password@127.0.0.1:5672"
q = RabbitMQQueueManager(conn_url)
await q.init()
subtask = {
"task_id": "test-subtask-id",
"queue": "test_queue",
"worker_name": "test_worker",
"inputs": {},
"outputs": {},
"params": {},
}
await q.put_subtask(subtask)
await asyncio.sleep(5)
for i in range(2):
subtask = await q.get_subtasks("test_queue", 3, 5)
print("get subtask:", subtask)
await q.close()
if __name__ == "__main__":
asyncio.run(test())
This diff is collapsed.
import os
import time
import aiohttp
import jwt
from fastapi import HTTPException
from loguru import logger
class AuthManager:
def __init__(self):
# Worker access token
self.worker_secret_key = os.getenv("WORKER_SECRET_KEY", "worker-secret-key-change-in-production")
# GitHub OAuth
self.github_client_id = os.getenv("GITHUB_CLIENT_ID", "")
self.github_client_secret = os.getenv("GITHUB_CLIENT_SECRET", "")
self.jwt_algorithm = os.getenv("JWT_ALGORITHM", "HS256")
self.jwt_expiration_hours = os.getenv("JWT_EXPIRATION_HOURS", 24)
self.jwt_secret_key = os.getenv("JWT_SECRET_KEY", "your-secret-key-change-in-production")
logger.info(f"AuthManager: GITHUB_CLIENT_ID: {self.github_client_id}")
logger.info(f"AuthManager: GITHUB_CLIENT_SECRET: {self.github_client_secret}")
logger.info(f"AuthManager: JWT_SECRET_KEY: {self.jwt_secret_key}")
logger.info(f"AuthManager: WORKER_SECRET_KEY: {self.worker_secret_key}")
def create_jwt_token(self, data):
data2 = {
"user_id": data["user_id"],
"username": data["username"],
"email": data["email"],
"homepage": data["homepage"],
}
expire = time.time() + (self.jwt_expiration_hours * 3600)
data2.update({"exp": expire})
return jwt.encode(data2, self.jwt_secret_key, algorithm=self.jwt_algorithm)
async def auth_github(self, code):
try:
logger.info(f"GitHub OAuth code: {code}")
token_url = "https://github.com/login/oauth/access_token"
token_data = {"client_id": self.github_client_id, "client_secret": self.github_client_secret, "code": code}
headers = {"Accept": "application/json"}
proxy = os.getenv("auth_https_proxy", None)
if proxy:
logger.info(f"auth_github use proxy: {proxy}")
async with aiohttp.ClientSession() as session:
async with session.post(token_url, data=token_data, headers=headers, proxy=proxy) as response:
response.raise_for_status()
token_info = await response.json()
if "error" in token_info:
raise HTTPException(status_code=400, detail=f"GitHub OAuth error: {token_info['error']}")
access_token = token_info.get("access_token")
if not access_token:
raise HTTPException(status_code=400, detail="Failed to get access token")
user_url = "https://api.github.com/user"
user_headers = {"Authorization": f"token {access_token}", "Accept": "application/vnd.github.v3+json"}
async with aiohttp.ClientSession() as session:
async with session.get(user_url, headers=user_headers, proxy=proxy) as response:
response.raise_for_status()
user_info = await response.json()
return {
"source": "github",
"id": str(user_info["id"]),
"username": user_info["login"],
"email": user_info.get("email", ""),
"homepage": user_info.get("html_url", ""),
"avatar_url": user_info.get("avatar_url", ""),
}
except aiohttp.ClientError as e:
logger.error(f"GitHub API request failed: {e}")
raise HTTPException(status_code=500, detail="Failed to authenticate with GitHub")
except Exception as e:
logger.error(f"Authentication error: {e}")
raise HTTPException(status_code=500, detail="Authentication failed")
def verify_jwt_token(self, token):
try:
payload = jwt.decode(token, self.jwt_secret_key, algorithms=[self.jwt_algorithm])
return payload
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token has expired")
except Exception as e:
logger.error(f"verify_jwt_token error: {e}")
raise HTTPException(status_code=401, detail="Could not validate credentials")
def verify_worker_token(self, token):
return token == self.worker_secret_key
from loguru import logger
from prometheus_client import Counter, Gauge, Summary, generate_latest
from prometheus_client.core import CollectorRegistry
from lightx2v.deploy.task_manager import ActiveStatus, FinishedStatus, TaskStatus
REGISTRY = CollectorRegistry()
class MetricMonitor:
def __init__(self):
self.task_all = Counter("task_all_total", "Total count of all tasks", ["task_type", "model_cls", "stage"], registry=REGISTRY)
self.task_end = Counter("task_end_total", "Total count of ended tasks", ["task_type", "model_cls", "stage", "status"], registry=REGISTRY)
self.task_active = Gauge("task_active_size", "Current count of active tasks", ["task_type", "model_cls", "stage"], registry=REGISTRY)
self.task_elapse = Summary("task_elapse_seconds", "Elapse time of tasks", ["task_type", "model_cls", "stage", "end_status"], registry=REGISTRY)
self.subtask_all = Counter("subtask_all_total", "Total count of all subtasks", ["queue"], registry=REGISTRY)
self.subtask_end = Counter("subtask_end_total", "Total count of ended subtasks", ["queue", "status"], registry=REGISTRY)
self.subtask_active = Gauge("subtask_active_size", "Current count of active subtasks", ["queue", "status"], registry=REGISTRY)
self.subtask_elapse = Summary("subtask_elapse_seconds", "Elapse time of subtasks", ["queue", "elapse_key"], registry=REGISTRY)
def record_task_start(self, task):
self.task_all.labels(task["task_type"], task["model_cls"], task["stage"]).inc()
self.task_active.labels(task["task_type"], task["model_cls"], task["stage"]).inc()
logger.info(f"Metrics task_all + 1, task_active +1")
def record_task_end(self, task, status, elapse):
self.task_end.labels(task["task_type"], task["model_cls"], task["stage"], status.name).inc()
self.task_active.labels(task["task_type"], task["model_cls"], task["stage"]).dec()
self.task_elapse.labels(task["task_type"], task["model_cls"], task["stage"], status.name).observe(elapse)
logger.info(f"Metrics task_end + 1, task_active -1, task_elapse observe {elapse}")
def record_subtask_change(self, subtask, old_status, new_status, elapse_key, elapse):
if old_status in ActiveStatus and new_status in FinishedStatus:
self.subtask_end.labels(subtask["queue"], elapse_key).inc()
logger.info(f"Metrics subtask_end + 1")
if old_status in ActiveStatus:
self.subtask_active.labels(subtask["queue"], old_status.name).dec()
logger.info(f"Metrics subtask_active {old_status.name} -1")
if new_status in ActiveStatus:
self.subtask_active.labels(subtask["queue"], new_status.name).inc()
logger.info(f"Metrics subtask_active {new_status.name} + 1")
if new_status == TaskStatus.CREATED:
self.subtask_all.labels(subtask["queue"]).inc()
logger.info(f"Metrics subtask_all + 1")
if elapse and elapse_key:
self.subtask_elapse.labels(subtask["queue"], elapse_key).observe(elapse)
logger.info(f"Metrics subtask_elapse observe {elapse}")
# restart server, we should recover active tasks in data_manager
def record_task_recover(self, tasks):
for task in tasks:
if task["status"] in ActiveStatus:
self.record_task_start(task)
# restart server, we should recover active tasks in data_manager
def record_subtask_recover(self, subtasks):
for subtask in subtasks:
if subtask["status"] in ActiveStatus:
self.subtask_all.labels(subtask["queue"]).inc()
self.subtask_active.labels(subtask["queue"], subtask["status"].name).inc()
logger.info(f"Metrics subtask_active {subtask['status'].name} + 1")
logger.info(f"Metrics subtask_all + 1")
def get_metrics(self):
return generate_latest(REGISTRY)
import asyncio
import time
from enum import Enum
from loguru import logger
from lightx2v.deploy.common.utils import class_try_catch_async
from lightx2v.deploy.task_manager import TaskStatus
class WorkerStatus(Enum):
FETCHING = 1
FETCHED = 2
DISCONNECT = 3
REPORT = 4
PING = 5
class CostWindow:
def __init__(self, window):
self.window = window
self.costs = []
self.avg = None
def append(self, cost):
self.costs.append(cost)
if len(self.costs) > self.window:
self.costs.pop(0)
self.avg = sum(self.costs) / len(self.costs)
class WorkerClient:
def __init__(self, queue, identity, infer_timeout, offline_timeout, avg_window, ping_timeout):
self.queue = queue
self.identity = identity
self.status = None
self.update_t = time.time()
self.fetched_t = None
self.infer_cost = CostWindow(avg_window)
self.offline_cost = CostWindow(avg_window)
self.infer_timeout = infer_timeout
self.offline_timeout = offline_timeout
self.ping_timeout = ping_timeout
# FETCHING -> FETCHED -> PING * n -> REPORT -> FETCHING
# FETCHING -> DISCONNECT -> FETCHING
def update(self, status: WorkerStatus):
pre_status = self.status
pre_t = self.update_t
self.status = status
self.update_t = time.time()
if status == WorkerStatus.FETCHING:
if pre_status in [WorkerStatus.DISCONNECT, WorkerStatus.REPORT] and pre_t is not None:
cur_cost = self.update_t - pre_t
if cur_cost < self.offline_timeout:
self.offline_cost.append(max(cur_cost, 1))
elif status == WorkerStatus.REPORT:
if self.fetched_t is not None:
cur_cost = self.update_t - self.fetched_t
self.fetched_t = None
if cur_cost < self.infer_timeout:
self.infer_cost.append(max(cur_cost, 1))
elif status == WorkerStatus.FETCHED:
self.fetched_t = time.time()
def check(self):
# infer too long
if self.fetched_t is not None:
elapse = time.time() - self.fetched_t
if self.infer_cost.avg is not None and elapse > self.infer_cost.avg * 5:
logger.warning(f"Worker {self.identity} {self.queue} infer timeout: {elapse:.2f} s")
return False
if elapse > self.infer_timeout:
logger.warning(f"Worker {self.identity} {self.queue} infer timeout2: {elapse:.2f} s")
return False
elapse = time.time() - self.update_t
# no ping too long
if self.status in [WorkerStatus.FETCHED, WorkerStatus.PING]:
if elapse > self.ping_timeout:
logger.warning(f"Worker {self.identity} {self.queue} ping timeout: {elapse:.2f} s")
return False
# offline too long
elif self.status in [WorkerStatus.DISCONNECT, WorkerStatus.REPORT]:
if self.offline_cost.avg is not None and elapse > self.offline_cost.avg * 5:
logger.warning(f"Worker {self.identity} {self.queue} offline timeout: {elapse:.2f} s")
return False
if elapse > self.offline_timeout:
logger.warning(f"Worker {self.identity} {self.queue} offline timeout2: {elapse:.2f} s")
return False
return True
class ServerMonitor:
def __init__(self, model_pipelines, task_manager, queue_manager, interval=1):
self.model_pipelines = model_pipelines
self.task_manager = task_manager
self.queue_manager = queue_manager
self.interval = interval
self.stop = False
self.worker_clients = {}
self.identity_to_queue = {}
self.subtask_run_timeouts = {}
self.all_queues = self.model_pipelines.get_queues()
self.config = self.model_pipelines.get_monitor_config()
for queue in self.all_queues:
self.subtask_run_timeouts[queue] = self.config["subtask_running_timeouts"].get(queue, 60)
self.subtask_created_timeout = self.config["subtask_created_timeout"]
self.subtask_pending_timeout = self.config["subtask_pending_timeout"]
self.worker_avg_window = self.config["worker_avg_window"]
self.worker_offline_timeout = self.config["worker_offline_timeout"]
self.worker_min_capacity = self.config["worker_min_capacity"]
self.worker_min_cnt = self.config["worker_min_cnt"]
self.worker_max_cnt = self.config["worker_max_cnt"]
self.task_timeout = self.config["task_timeout"]
self.schedule_ratio_high = self.config["schedule_ratio_high"]
self.schedule_ratio_low = self.config["schedule_ratio_low"]
self.ping_timeout = self.config["ping_timeout"]
self.user_visits = {} # user_id -> last_visit_t
self.user_max_active_tasks = self.config["user_max_active_tasks"]
self.user_max_daily_tasks = self.config["user_max_daily_tasks"]
self.user_visit_frequency = self.config["user_visit_frequency"]
assert self.worker_avg_window > 0
assert self.worker_offline_timeout > 0
assert self.worker_min_capacity > 0
assert self.worker_min_cnt > 0
assert self.worker_max_cnt > 0
assert self.worker_min_cnt <= self.worker_max_cnt
assert self.task_timeout > 0
assert self.schedule_ratio_high > 0 and self.schedule_ratio_high < 1
assert self.schedule_ratio_low > 0 and self.schedule_ratio_low < 1
assert self.schedule_ratio_high >= self.schedule_ratio_low
assert self.ping_timeout > 0
assert self.user_max_active_tasks > 0
assert self.user_max_daily_tasks > 0
assert self.user_visit_frequency > 0
async def init(self):
while True:
if self.stop:
break
await self.clean_workers()
await self.clean_subtasks()
await asyncio.sleep(self.interval)
logger.info("ServerMonitor stopped")
async def close(self):
self.stop = True
self.model_pipelines = None
self.task_manager = None
self.queue_manager = None
self.worker_clients = None
def init_worker(self, queue, identity):
if queue not in self.worker_clients:
self.worker_clients[queue] = {}
if identity not in self.worker_clients[queue]:
infer_timeout = self.subtask_run_timeouts[queue]
self.worker_clients[queue][identity] = WorkerClient(queue, identity, infer_timeout, self.worker_offline_timeout, self.worker_avg_window, self.ping_timeout)
self.identity_to_queue[identity] = queue
return self.worker_clients[queue][identity]
@class_try_catch_async
async def worker_update(self, queue, identity, status):
if queue is None:
queue = self.identity_to_queue[identity]
worker = self.init_worker(queue, identity)
worker.update(status)
logger.info(f"Worker {identity} {queue} update [{status}]")
async def clean_workers(self):
qs = list(self.worker_clients.keys())
for queue in qs:
idens = list(self.worker_clients[queue].keys())
for identity in idens:
if not self.worker_clients[queue][identity].check():
self.worker_clients[queue].pop(identity)
self.identity_to_queue.pop(identity)
logger.warning(f"Worker {queue} {identity} out of contact removed, remain {self.worker_clients[queue]}")
async def clean_subtasks(self):
created_end_t = time.time() - self.subtask_created_timeout
pending_end_t = time.time() - self.subtask_pending_timeout
ping_end_t = time.time() - self.ping_timeout
fails = set()
created_tasks = await self.task_manager.list_tasks(status=TaskStatus.CREATED, subtasks=True, end_updated_t=created_end_t)
pending_tasks = await self.task_manager.list_tasks(status=TaskStatus.PENDING, subtasks=True, end_updated_t=pending_end_t)
def fmt_subtask(t):
return f"({t['task_id']}, {t['worker_name']}, {t['queue']}, {t['worker_identity']})"
for t in created_tasks + pending_tasks:
if t["task_id"] in fails:
continue
elapse = time.time() - t["update_t"]
logger.warning(f"Subtask {fmt_subtask(t)} CREATED / PENDING timeout: {elapse:.2f} s")
await self.task_manager.finish_subtasks(t["task_id"], TaskStatus.FAILED, worker_name=t["worker_name"], fail_msg=f"CREATED / PENDING timeout: {elapse:.2f} s")
fails.add(t["task_id"])
running_tasks = await self.task_manager.list_tasks(status=TaskStatus.RUNNING, subtasks=True)
for t in running_tasks:
if t["task_id"] in fails:
continue
if t["ping_t"] > 0:
ping_elapse = time.time() - t["ping_t"]
if ping_elapse >= self.ping_timeout:
logger.warning(f"Subtask {fmt_subtask(t)} PING timeout: {ping_elapse:.2f} s")
await self.task_manager.finish_subtasks(t["task_id"], TaskStatus.FAILED, worker_name=t["worker_name"], fail_msg=f"PING timeout: {ping_elapse:.2f} s")
fails.add(t["task_id"])
elapse = time.time() - t["update_t"]
limit = self.subtask_run_timeouts[t["queue"]]
if elapse >= limit:
logger.warning(f"Subtask {fmt_subtask(t)} RUNNING timeout: {elapse:.2f} s")
await self.task_manager.finish_subtasks(t["task_id"], TaskStatus.FAILED, worker_name=t["worker_name"], fail_msg=f"RUNNING timeout: {elapse:.2f} s")
fails.add(t["task_id"])
def get_avg_worker_infer_cost(self, queue):
if queue not in self.worker_clients:
self.worker_clients[queue] = {}
infer_costs = []
for _, client in self.worker_clients[queue].items():
if client.infer_cost.avg is not None:
infer_costs.append(client.infer_cost.avg)
if len(infer_costs) <= 0:
return self.subtask_run_timeouts[queue]
return sum(infer_costs) / len(infer_costs)
@class_try_catch_async
async def check_user_busy(self, user_id, active_new_task=False):
# check if user visit too frequently
cur_t = time.time()
if user_id in self.user_visits:
elapse = cur_t - self.user_visits[user_id]
if elapse <= self.user_visit_frequency:
return f"User {user_id} visit too frequently, {elapse:.2f} s vs {self.user_visit_frequency:.2f} s"
self.user_visits[user_id] = cur_t
if active_new_task:
# check if user has too many active tasks
active_statuses = [TaskStatus.RUNNING, TaskStatus.PENDING, TaskStatus.CREATED]
active_tasks = await self.task_manager.list_tasks(status=active_statuses, user_id=user_id)
if len(active_tasks) >= self.user_max_active_tasks:
return f"User {user_id} has too many active tasks, {len(active_tasks)} vs {self.user_max_active_tasks}"
# check if user has too many daily tasks
daily_statuses = active_statuses + [TaskStatus.SUCCEED, TaskStatus.CANCEL, TaskStatus.FAILED]
daily_tasks = await self.task_manager.list_tasks(status=daily_statuses, user_id=user_id, start_created_t=cur_t - 86400)
if len(daily_tasks) >= self.user_max_daily_tasks:
return f"User {user_id} has too many daily tasks, {len(daily_tasks)} vs {self.user_max_daily_tasks}"
return True
# check if a task can be published to queues
@class_try_catch_async
async def check_queue_busy(self, keys, queues):
wait_time = 0
for queue in queues:
avg_cost = self.get_avg_worker_infer_cost(queue)
worker_cnt = len(self.worker_clients[queue])
subtask_pending = await self.queue_manager.pending_num(queue)
capacity = self.task_timeout * max(worker_cnt, 1) // avg_cost
capacity = max(self.worker_min_capacity, capacity)
if subtask_pending >= capacity:
ss = f"pending={subtask_pending}, capacity={capacity}"
logger.warning(f"Queue {queue} busy, {ss}, task {keys} cannot be publised!")
return None
wait_time += avg_cost * subtask_pending / max(worker_cnt, 1)
return wait_time
@class_try_catch_async
async def cal_metrics(self):
data = {}
target_high = self.task_timeout * self.schedule_ratio_high
target_low = self.task_timeout * self.schedule_ratio_low
for queue in self.all_queues:
avg_cost = self.get_avg_worker_infer_cost(queue)
worker_cnt = len(self.worker_clients[queue])
subtask_pending = await self.queue_manager.pending_num(queue)
data[queue] = {
"avg_cost": avg_cost,
"worker_cnt": worker_cnt,
"subtask_pending": subtask_pending,
"max_worker": 0,
"min_worker": 0,
"need_add_worker": 0,
"need_del_worker": 0,
"del_worker_identities": [],
}
fix_cnt = subtask_pending // max(self.worker_min_capacity, 1)
min_cnt = min(fix_cnt, subtask_pending * avg_cost // target_high)
max_cnt = min(fix_cnt, subtask_pending * avg_cost // target_low)
data[queue]["min_worker"] = max(self.worker_min_cnt, min_cnt)
data[queue]["max_worker"] = max(self.worker_max_cnt, max_cnt)
if worker_cnt < data[queue]["min_worker"]:
data[queue]["need_add_worker"] = data[queue]["min_worker"] - worker_cnt
if subtask_pending == 0 and worker_cnt > data[queue]["max_worker"]:
data[queue]["need_del_worker"] = worker_cnt - data[queue]["max_worker"]
if data[queue]["need_del_worker"] > 0:
for identity, client in self.worker_clients[queue].items():
if client.status in [WorkerStatus.FETCHING, WorkerStatus.DISCONNECT]:
data[queue]["del_worker_identities"].append(identity)
if len(data[queue]["del_worker_identities"]) >= data[queue]["need_del_worker"]:
break
return data
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment