Commit 32fd1c52 authored by LiangLiu's avatar LiangLiu Committed by GitHub
Browse files

deploy update (#309)



deploy update

---------
Co-authored-by: default avatarliuliang1 <liuliang1@sensetime.com>
Co-authored-by: default avatarqinxinyi <qinxinyi@sensetime.com>
parent f6e214bb
......@@ -59,8 +59,10 @@
"outputs": ["output_video"]
}
}
},
"SekoTalk-Distill": {
}
},
"s2v": {
"seko_talk": {
"single_stage": {
"pipeline": {
"inputs": ["input_image", "input_audio"],
......
import asyncio
import json
import os
import sys
from alibabacloud_dypnsapi20170525 import models as dypnsapi_models
from alibabacloud_dypnsapi20170525.client import Client
from alibabacloud_tea_openapi import models as openapi_models
from alibabacloud_tea_util import models as util_models
from loguru import logger
class AlibabaCloudClient:
def __init__(self):
config = openapi_models.Config(
access_key_id=os.getenv("ALIBABA_CLOUD_ACCESS_KEY_ID"),
access_key_secret=os.getenv("ALIBABA_CLOUD_ACCESS_KEY_SECRET"),
https_proxy=os.getenv("auth_https_proxy", None),
)
self.client = Client(config)
self.runtime = util_models.RuntimeOptions()
def check_ok(self, res, prefix):
logger.info(f"{prefix}: {res}")
if not isinstance(res, dict) or "statusCode" not in res or res["statusCode"] != 200:
logger.warning(f"{prefix}: error response: {res}")
return False
if "body" not in res or "Code" not in res["body"] or "Success" not in res["body"]:
logger.warning(f"{prefix}: error body: {res}")
return False
if res["body"]["Code"] != "OK" or res["body"]["Success"] is not True:
logger.warning(f"{prefix}: sms error: {res}")
return False
return True
async def send_sms(self, phone_number):
try:
req = dypnsapi_models.SendSmsVerifyCodeRequest(
phone_number=phone_number,
sign_name="速通互联验证服务",
template_code="100001",
template_param=json.dumps({"code": "##code##", "min": "5"}),
valid_time=300,
)
res = await self.client.send_sms_verify_code_with_options_async(req, self.runtime)
ok = self.check_ok(res.to_map(), "AlibabaCloudClient send sms")
logger.info(f"AlibabaCloudClient send sms for {phone_number}: {ok}")
return ok
except Exception as e:
logger.warning(f"AlibabaCloudClient send sms for {phone_number}: {e}")
return False
async def check_sms(self, phone_number, verify_code):
try:
req = dypnsapi_models.CheckSmsVerifyCodeRequest(
phone_number=phone_number,
verify_code=verify_code,
)
res = await self.client.check_sms_verify_code_with_options_async(req, self.runtime)
ok = self.check_ok(res.to_map(), "AlibabaCloudClient check sms")
logger.info(f"AlibabaCloudClient check sms for {phone_number} with {verify_code}: {ok}")
return ok
except Exception as e:
logger.warning(f"AlibabaCloudClient check sms for {phone_number} with {verify_code}: {e}")
return False
async def test(args):
assert len(args) in [1, 2], "Usage: python aliyun_sms.py <phone_number> [verify_code]"
phone_number = args[0]
client = AlibabaCloudClient()
if len(args) == 1:
await client.send_sms(phone_number)
else:
await client.check_sms(phone_number, args[1])
if __name__ == "__main__":
asyncio.run(test(sys.argv[1:]))
......@@ -143,7 +143,7 @@ class VARecorder:
"-ar",
"44100",
"-b:v",
"4M",
"2M",
"-c:v",
"libx264",
"-preset",
......@@ -199,7 +199,7 @@ class VARecorder:
"-ac",
"2",
"-b:v",
"4M",
"2M",
"-c:v",
"libx264",
"-preset",
......
import io
import json
import os
import torch
from PIL import Image
......@@ -9,7 +10,10 @@ from lightx2v.deploy.common.utils import class_try_catch_async
class BaseDataManager:
def __init__(self):
pass
self.template_images_dir = None
self.template_audios_dir = None
self.template_videos_dir = None
self.template_tasks_dir = None
async def init(self):
pass
......@@ -17,6 +21,12 @@ class BaseDataManager:
async def close(self):
pass
def fmt_path(self, base, filename, abs_path=None):
if abs_path:
return abs_path
else:
return os.path.join(base, filename)
def to_device(self, data, device):
if isinstance(data, dict):
return {key: self.to_device(value, device) for key, value in data.items()}
......@@ -27,15 +37,18 @@ class BaseDataManager:
else:
return data
async def save_bytes(self, bytes_data, filename):
async def save_bytes(self, bytes_data, filename, abs_path=None):
raise NotImplementedError
async def load_bytes(self, filename):
async def load_bytes(self, filename, abs_path=None):
raise NotImplementedError
async def delete_bytes(self, filename):
async def delete_bytes(self, filename, abs_path=None):
raise NotImplementedError
async def presign_url(self, filename, abs_path=None):
return None
async def recurrent_save(self, data, prefix):
if isinstance(data, dict):
return {k: await self.recurrent_save(v, f"{prefix}-{k}") for k, v in data.items()}
......@@ -130,6 +143,60 @@ class BaseDataManager:
}
return maps[type]
def get_template_dir(self, template_type):
if template_type == "audios":
return self.template_audios_dir
elif template_type == "images":
return self.template_images_dir
elif template_type == "videos":
return self.template_videos_dir
elif template_type == "tasks":
return self.template_tasks_dir
else:
raise ValueError(f"Invalid template type: {template_type}")
@class_try_catch_async
async def list_template_files(self, template_type):
template_dir = self.get_template_dir(template_type)
if template_dir is None:
return []
return await self.list_files(base_dir=template_dir)
@class_try_catch_async
async def load_template_file(self, template_type, filename):
template_dir = self.get_template_dir(template_type)
if template_dir is None:
return None
return await self.load_bytes(None, abs_path=os.path.join(template_dir, filename))
@class_try_catch_async
async def template_file_exists(self, template_type, filename):
template_dir = self.get_template_dir(template_type)
if template_dir is None:
return None
return await self.file_exists(None, abs_path=os.path.join(template_dir, filename))
@class_try_catch_async
async def delete_template_file(self, template_type, filename):
template_dir = self.get_template_dir(template_type)
if template_dir is None:
return None
return await self.delete_bytes(None, abs_path=os.path.join(template_dir, filename))
@class_try_catch_async
async def save_template_file(self, template_type, filename, bytes_data):
template_dir = self.get_template_dir(template_type)
if template_dir is None:
return None
return await self.save_bytes(bytes_data, None, abs_path=os.path.join(template_dir, filename))
@class_try_catch_async
async def presign_template_url(self, template_type, filename):
template_dir = self.get_template_dir(template_type)
if template_dir is None:
return None
return await self.presign_url(None, abs_path=os.path.join(template_dir, filename))
# Import data manager implementations
from .local_data_manager import LocalDataManager # noqa
......
......@@ -8,38 +8,58 @@ from lightx2v.deploy.data_manager import BaseDataManager
class LocalDataManager(BaseDataManager):
def __init__(self, local_dir):
def __init__(self, local_dir, template_dir):
super().__init__()
self.local_dir = local_dir
self.name = "local"
if not os.path.exists(self.local_dir):
os.makedirs(self.local_dir)
if template_dir:
self.template_images_dir = os.path.join(template_dir, "images")
self.template_audios_dir = os.path.join(template_dir, "audios")
self.template_videos_dir = os.path.join(template_dir, "videos")
self.template_tasks_dir = os.path.join(template_dir, "tasks")
assert os.path.exists(self.template_images_dir), f"{self.template_images_dir} not exists!"
assert os.path.exists(self.template_audios_dir), f"{self.template_audios_dir} not exists!"
assert os.path.exists(self.template_videos_dir), f"{self.template_videos_dir} not exists!"
assert os.path.exists(self.template_tasks_dir), f"{self.template_tasks_dir} not exists!"
@class_try_catch_async
async def save_bytes(self, bytes_data, filename):
out_path = os.path.join(self.local_dir, filename)
async def save_bytes(self, bytes_data, filename, abs_path=None):
out_path = self.fmt_path(self.local_dir, filename, abs_path)
with open(out_path, "wb") as fout:
fout.write(bytes_data)
return True
@class_try_catch_async
async def load_bytes(self, filename):
inp_path = os.path.join(self.local_dir, filename)
async def load_bytes(self, filename, abs_path=None):
inp_path = self.fmt_path(self.local_dir, filename, abs_path)
with open(inp_path, "rb") as fin:
return fin.read()
@class_try_catch_async
async def delete_bytes(self, filename):
inp_path = os.path.join(self.local_dir, filename)
async def delete_bytes(self, filename, abs_path=None):
inp_path = self.fmt_path(self.local_dir, filename, abs_path)
os.remove(inp_path)
logger.info(f"deleted local file {filename}")
return True
@class_try_catch_async
async def file_exists(self, filename, abs_path=None):
filename = self.fmt_path(self.local_dir, filename, abs_path)
return os.path.exists(filename)
@class_try_catch_async
async def list_files(self, base_dir=None):
prefix = base_dir if base_dir else self.local_dir
return os.listdir(prefix)
async def test():
import torch
from PIL import Image
m = LocalDataManager("/data/nvme1/liuliang1/lightx2v/local_data")
m = LocalDataManager("/data/nvme1/liuliang1/lightx2v/local_data", None)
await m.init()
img = Image.open("/data/nvme1/liuliang1/lightx2v/assets/img_lightx2v.png")
......
......@@ -4,6 +4,7 @@ import json
import os
import aioboto3
import tos
from botocore.client import Config
from loguru import logger
......@@ -12,7 +13,8 @@ from lightx2v.deploy.data_manager import BaseDataManager
class S3DataManager(BaseDataManager):
def __init__(self, config_string, max_retries=3):
def __init__(self, config_string, template_dir, max_retries=3):
super().__init__()
self.name = "s3"
self.config = json.loads(config_string)
self.max_retries = max_retries
......@@ -22,16 +24,36 @@ class S3DataManager(BaseDataManager):
self.endpoint_url = self.config["endpoint_url"]
self.base_path = self.config["base_path"]
self.connect_timeout = self.config.get("connect_timeout", 60)
self.read_timeout = self.config.get("read_timeout", 10)
self.read_timeout = self.config.get("read_timeout", 60)
self.write_timeout = self.config.get("write_timeout", 10)
self.addressing_style = self.config.get("addressing_style", None)
self.region = self.config.get("region", None)
self.session = None
self.s3_client = None
self.presign_client = None
if template_dir:
self.template_images_dir = os.path.join(template_dir, "images")
self.template_audios_dir = os.path.join(template_dir, "audios")
self.template_videos_dir = os.path.join(template_dir, "videos")
self.template_tasks_dir = os.path.join(template_dir, "tasks")
async def init_presign_client(self):
# init tos client for volces.com
if "volces.com" in self.endpoint_url:
self.presign_client = tos.TosClientV2(
self.aws_access_key_id,
self.aws_secret_access_key,
self.endpoint_url.replace("tos-s3-", "tos-"),
self.region,
)
async def init(self):
for i in range(self.max_retries):
try:
logger.info(f"S3DataManager init with config: {self.config} (attempt {i + 1}/{self.max_retries}) ...")
s3_config = {"payload_signing_enabled": True}
if self.addressing_style:
s3_config["addressing_style"] = self.addressing_style
self.session = aioboto3.Session()
self.s3_client = await self.session.client(
"s3",
......@@ -40,7 +62,7 @@ class S3DataManager(BaseDataManager):
endpoint_url=self.endpoint_url,
config=Config(
signature_version="s3v4",
s3={"payload_signing_enabled": True},
s3=s3_config,
connect_timeout=self.connect_timeout,
read_timeout=self.read_timeout,
parameter_validation=False,
......@@ -55,6 +77,7 @@ class S3DataManager(BaseDataManager):
logger.info(f"check bucket {self.bucket_name} error: {e}, try to create it...")
await self.s3_client.create_bucket(Bucket=self.bucket_name)
await self.init_presign_client()
logger.info(f"Successfully init S3 bucket: {self.bucket_name} with timeouts - connect: {self.connect_timeout}s, read: {self.read_timeout}s, write: {self.write_timeout}s")
return
except Exception as e:
......@@ -68,8 +91,8 @@ class S3DataManager(BaseDataManager):
self.session = None
@class_try_catch_async
async def save_bytes(self, bytes_data, filename):
filename = os.path.join(self.base_path, filename)
async def save_bytes(self, bytes_data, filename, abs_path=None):
filename = self.fmt_path(self.base_path, filename, abs_path)
content_sha256 = hashlib.sha256(bytes_data).hexdigest()
await self.s3_client.put_object(
Bucket=self.bucket_name,
......@@ -81,35 +104,47 @@ class S3DataManager(BaseDataManager):
return True
@class_try_catch_async
async def load_bytes(self, filename):
filename = os.path.join(self.base_path, filename)
async def load_bytes(self, filename, abs_path=None):
filename = self.fmt_path(self.base_path, filename, abs_path)
response = await self.s3_client.get_object(Bucket=self.bucket_name, Key=filename)
return await response["Body"].read()
@class_try_catch_async
async def delete_bytes(self, filename):
filename = os.path.join(self.base_path, filename)
async def delete_bytes(self, filename, abs_path=None):
filename = self.fmt_path(self.base_path, filename, abs_path)
await self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename)
logger.info(f"deleted s3 file {filename}")
return True
async def file_exists(self, filename):
filename = os.path.join(self.base_path, filename)
@class_try_catch_async
async def file_exists(self, filename, abs_path=None):
filename = self.fmt_path(self.base_path, filename, abs_path)
try:
await self.s3_client.head_object(Bucket=self.bucket_name, Key=filename)
return True
except Exception:
return False
async def list_files(self, prefix=""):
prefix = os.path.join(self.base_path, prefix)
@class_try_catch_async
async def list_files(self, base_dir=None):
prefix = base_dir if base_dir else self.base_path
response = await self.s3_client.list_objects_v2(Bucket=self.bucket_name, Prefix=prefix)
files = []
if "Contents" in response:
for obj in response["Contents"]:
files.append(obj["Key"])
files.append(obj["Key"].replace(prefix + "/", ""))
return files
@class_try_catch_async
async def presign_url(self, filename, abs_path=None):
filename = self.fmt_path(self.base_path, filename, abs_path)
if self.presign_client:
expires = self.config.get("presign_expires", 24 * 60 * 60)
out = await asyncio.to_thread(self.presign_client.pre_signed_url, tos.HttpMethodType.Http_Method_Get, self.bucket_name, filename, expires)
return out.signed_url
else:
return None
async def test():
import torch
......@@ -126,7 +161,7 @@ async def test():
"write_timeout": 10,
}
m = S3DataManager(json.dumps(s3_config))
m = S3DataManager(json.dumps(s3_config), None)
await m.init()
img = Image.open("../../../assets/img_lightx2v.png")
......
This diff is collapsed.
......@@ -6,6 +6,8 @@ import jwt
from fastapi import HTTPException
from loguru import logger
from lightx2v.deploy.common.aliyun import AlibabaCloudClient
class AuthManager:
def __init__(self):
......@@ -15,12 +17,24 @@ class AuthManager:
# GitHub OAuth
self.github_client_id = os.getenv("GITHUB_CLIENT_ID", "")
self.github_client_secret = os.getenv("GITHUB_CLIENT_SECRET", "")
# Google OAuth
self.google_client_id = os.getenv("GOOGLE_CLIENT_ID", "")
self.google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", "")
self.google_redirect_uri = os.getenv("GOOGLE_REDIRECT_URI", "")
self.jwt_algorithm = os.getenv("JWT_ALGORITHM", "HS256")
self.jwt_expiration_hours = os.getenv("JWT_EXPIRATION_HOURS", 24)
self.jwt_secret_key = os.getenv("JWT_SECRET_KEY", "your-secret-key-change-in-production")
# Aliyun SMS
self.aliyun_client = AlibabaCloudClient()
logger.info(f"AuthManager: GITHUB_CLIENT_ID: {self.github_client_id}")
logger.info(f"AuthManager: GITHUB_CLIENT_SECRET: {self.github_client_secret}")
logger.info(f"AuthManager: GOOGLE_CLIENT_ID: {self.google_client_id}")
logger.info(f"AuthManager: GOOGLE_CLIENT_SECRET: {self.google_client_secret}")
logger.info(f"AuthManager: GOOGLE_REDIRECT_URI: {self.google_redirect_uri}")
logger.info(f"AuthManager: JWT_SECRET_KEY: {self.jwt_secret_key}")
logger.info(f"AuthManager: WORKER_SECRET_KEY: {self.worker_secret_key}")
......@@ -81,6 +95,74 @@ class AuthManager:
logger.error(f"Authentication error: {e}")
raise HTTPException(status_code=500, detail="Authentication failed")
async def auth_google(self, code):
try:
logger.info(f"Google OAuth code: {code}")
token_url = "https://oauth2.googleapis.com/token"
token_data = {
"client_id": self.google_client_id,
"client_secret": self.google_client_secret,
"code": code,
"redirect_uri": self.google_redirect_uri,
"grant_type": "authorization_code",
}
headers = {"Content-Type": "application/x-www-form-urlencoded"}
proxy = os.getenv("auth_https_proxy", None)
if proxy:
logger.info(f"auth_google use proxy: {proxy}")
async with aiohttp.ClientSession() as session:
async with session.post(token_url, data=token_data, headers=headers, proxy=proxy) as response:
response.raise_for_status()
token_info = await response.json()
if "error" in token_info:
raise HTTPException(status_code=400, detail=f"Google OAuth error: {token_info['error']}")
access_token = token_info.get("access_token")
if not access_token:
raise HTTPException(status_code=400, detail="Failed to get access token")
# get user info
user_url = "https://www.googleapis.com/oauth2/v2/userinfo"
user_headers = {"Authorization": f"Bearer {access_token}"}
async with aiohttp.ClientSession() as session:
async with session.get(user_url, headers=user_headers, proxy=proxy) as response:
response.raise_for_status()
user_info = await response.json()
return {
"source": "google",
"id": str(user_info["id"]),
"username": user_info.get("name", user_info.get("email", "")),
"email": user_info.get("email", ""),
"homepage": user_info.get("link", ""),
"avatar_url": user_info.get("picture", ""),
}
except aiohttp.ClientError as e:
logger.error(f"Google API request failed: {e}")
raise HTTPException(status_code=500, detail="Failed to authenticate with Google")
except Exception as e:
logger.error(f"Google authentication error: {e}")
raise HTTPException(status_code=500, detail="Google authentication failed")
async def send_sms(self, phone_number):
return await self.aliyun_client.send_sms(phone_number)
async def check_sms(self, phone_number, verify_code):
ok = await self.aliyun_client.check_sms(phone_number, verify_code)
if not ok:
return None
return {
"source": "phone",
"id": phone_number,
"username": phone_number,
"email": "",
"homepage": "",
"avatar_url": "",
}
def verify_jwt_token(self, token):
try:
payload = jwt.decode(token, self.jwt_secret_key, algorithms=[self.jwt_algorithm])
......
......@@ -62,6 +62,7 @@ class WorkerClient:
self.fetched_t = None
if cur_cost < self.infer_timeout:
self.infer_cost.append(max(cur_cost, 1))
logger.info(f"Worker {self.identity} {self.queue} avg infer cost update: {self.infer_cost.avg:.2f} s")
elif status == WorkerStatus.FETCHED:
self.fetched_t = time.time()
......@@ -95,18 +96,19 @@ class WorkerClient:
class ServerMonitor:
def __init__(self, model_pipelines, task_manager, queue_manager, interval=1):
def __init__(self, model_pipelines, task_manager, queue_manager):
self.model_pipelines = model_pipelines
self.task_manager = task_manager
self.queue_manager = queue_manager
self.interval = interval
self.stop = False
self.worker_clients = {}
self.identity_to_queue = {}
self.subtask_run_timeouts = {}
self.pending_subtasks = {}
self.all_queues = self.model_pipelines.get_queues()
self.config = self.model_pipelines.get_monitor_config()
self.interval = self.config.get("monitor_interval", 30)
self.fetching_timeout = self.config.get("fetching_timeout", 1000)
for queue in self.all_queues:
self.subtask_run_timeouts[queue] = self.config["subtask_running_timeouts"].get(queue, 60)
......@@ -115,11 +117,7 @@ class ServerMonitor:
self.worker_avg_window = self.config["worker_avg_window"]
self.worker_offline_timeout = self.config["worker_offline_timeout"]
self.worker_min_capacity = self.config["worker_min_capacity"]
self.worker_min_cnt = self.config["worker_min_cnt"]
self.worker_max_cnt = self.config["worker_max_cnt"]
self.task_timeout = self.config["task_timeout"]
self.schedule_ratio_high = self.config["schedule_ratio_high"]
self.schedule_ratio_low = self.config["schedule_ratio_low"]
self.ping_timeout = self.config["ping_timeout"]
self.user_visits = {} # user_id -> last_visit_t
......@@ -130,19 +128,16 @@ class ServerMonitor:
assert self.worker_avg_window > 0
assert self.worker_offline_timeout > 0
assert self.worker_min_capacity > 0
assert self.worker_min_cnt > 0
assert self.worker_max_cnt > 0
assert self.worker_min_cnt <= self.worker_max_cnt
assert self.task_timeout > 0
assert self.schedule_ratio_high > 0 and self.schedule_ratio_high < 1
assert self.schedule_ratio_low > 0 and self.schedule_ratio_low < 1
assert self.schedule_ratio_high >= self.schedule_ratio_low
assert self.ping_timeout > 0
assert self.user_max_active_tasks > 0
assert self.user_max_daily_tasks > 0
assert self.user_visit_frequency > 0
async def init(self):
await self.init_pending_subtasks()
async def loop(self):
while True:
if self.stop:
break
......@@ -164,17 +159,15 @@ class ServerMonitor:
if identity not in self.worker_clients[queue]:
infer_timeout = self.subtask_run_timeouts[queue]
self.worker_clients[queue][identity] = WorkerClient(queue, identity, infer_timeout, self.worker_offline_timeout, self.worker_avg_window, self.ping_timeout)
self.identity_to_queue[identity] = queue
return self.worker_clients[queue][identity]
@class_try_catch_async
async def worker_update(self, queue, identity, status):
if queue is None:
queue = self.identity_to_queue[identity]
worker = self.init_worker(queue, identity)
worker.update(status)
logger.info(f"Worker {identity} {queue} update [{status}]")
@class_try_catch_async
async def clean_workers(self):
qs = list(self.worker_clients.keys())
for queue in qs:
......@@ -182,9 +175,9 @@ class ServerMonitor:
for identity in idens:
if not self.worker_clients[queue][identity].check():
self.worker_clients[queue].pop(identity)
self.identity_to_queue.pop(identity)
logger.warning(f"Worker {queue} {identity} out of contact removed, remain {self.worker_clients[queue]}")
@class_try_catch_async
async def clean_subtasks(self):
created_end_t = time.time() - self.subtask_created_timeout
pending_end_t = time.time() - self.subtask_pending_timeout
......@@ -223,7 +216,8 @@ class ServerMonitor:
await self.task_manager.finish_subtasks(t["task_id"], TaskStatus.FAILED, worker_name=t["worker_name"], fail_msg=f"RUNNING timeout: {elapse:.2f} s")
fails.add(t["task_id"])
def get_avg_worker_infer_cost(self, queue):
@class_try_catch_async
async def get_avg_worker_infer_cost(self, queue):
if queue not in self.worker_clients:
self.worker_clients[queue] = {}
infer_costs = []
......@@ -265,8 +259,8 @@ class ServerMonitor:
wait_time = 0
for queue in queues:
avg_cost = self.get_avg_worker_infer_cost(queue)
worker_cnt = len(self.worker_clients[queue])
avg_cost = await self.get_avg_worker_infer_cost(queue)
worker_cnt = await self.get_ready_worker_count(queue)
subtask_pending = await self.queue_manager.pending_num(queue)
capacity = self.task_timeout * max(worker_cnt, 1) // avg_cost
capacity = max(self.worker_min_capacity, capacity)
......@@ -279,42 +273,97 @@ class ServerMonitor:
return wait_time
@class_try_catch_async
async def cal_metrics(self):
data = {}
target_high = self.task_timeout * self.schedule_ratio_high
target_low = self.task_timeout * self.schedule_ratio_low
async def init_pending_subtasks(self):
# query all pending subtasks in task_manager
subtasks = {}
rows = await self.task_manager.list_tasks(status=TaskStatus.PENDING, subtasks=True, sort_by_update_t=True)
for row in rows:
if row["queue"] not in subtasks:
subtasks[row["queue"]] = []
subtasks[row["queue"]].append(row["task_id"])
for queue in self.all_queues:
avg_cost = self.get_avg_worker_infer_cost(queue)
worker_cnt = len(self.worker_clients[queue])
subtask_pending = await self.queue_manager.pending_num(queue)
if queue not in subtasks:
subtasks[queue] = []
data[queue] = {
"avg_cost": avg_cost,
"worker_cnt": worker_cnt,
"subtask_pending": subtask_pending,
"max_worker": 0,
"min_worker": 0,
"need_add_worker": 0,
"need_del_worker": 0,
"del_worker_identities": [],
}
# self.pending_subtasks = {queue: {"consume_count": int, "max_count": int, subtasks: {task_id: order}}
for queue, task_ids in subtasks.items():
pending_num = await self.queue_manager.pending_num(queue)
self.pending_subtasks[queue] = {"consume_count": 0, "max_count": pending_num, "subtasks": {}}
for i, task_id in enumerate(task_ids):
self.pending_subtasks[queue]["subtasks"][task_id] = max(pending_num - i, 1)
logger.info(f"Init pending subtasks: {self.pending_subtasks}")
fix_cnt = subtask_pending // max(self.worker_min_capacity, 1)
min_cnt = min(fix_cnt, subtask_pending * avg_cost // target_high)
max_cnt = min(fix_cnt, subtask_pending * avg_cost // target_low)
data[queue]["min_worker"] = max(self.worker_min_cnt, min_cnt)
data[queue]["max_worker"] = max(self.worker_max_cnt, max_cnt)
if worker_cnt < data[queue]["min_worker"]:
data[queue]["need_add_worker"] = data[queue]["min_worker"] - worker_cnt
if subtask_pending == 0 and worker_cnt > data[queue]["max_worker"]:
data[queue]["need_del_worker"] = worker_cnt - data[queue]["max_worker"]
if data[queue]["need_del_worker"] > 0:
for identity, client in self.worker_clients[queue].items():
if client.status in [WorkerStatus.FETCHING, WorkerStatus.DISCONNECT]:
data[queue]["del_worker_identities"].append(identity)
if len(data[queue]["del_worker_identities"]) >= data[queue]["need_del_worker"]:
break
return data
@class_try_catch_async
async def pending_subtasks_add(self, queue, task_id):
if queue not in self.pending_subtasks:
logger.warning(f"Queue {queue} not found in self.pending_subtasks")
return
max_count = self.pending_subtasks[queue]["max_count"]
self.pending_subtasks[queue]["subtasks"][task_id] = max_count + 1
self.pending_subtasks[queue]["max_count"] = max_count + 1
# logger.warning(f"Pending subtasks {queue} add {task_id}: {self.pending_subtasks[queue]}")
@class_try_catch_async
async def pending_subtasks_sub(self, queue, task_id):
if queue not in self.pending_subtasks:
logger.warning(f"Queue {queue} not found in self.pending_subtasks")
return
self.pending_subtasks[queue]["consume_count"] += 1
if task_id in self.pending_subtasks[queue]["subtasks"]:
self.pending_subtasks[queue]["subtasks"].pop(task_id)
# logger.warning(f"Pending subtasks {queue} sub {task_id}: {self.pending_subtasks[queue]}")
@class_try_catch_async
async def pending_subtasks_get_order(self, queue, task_id):
if queue not in self.pending_subtasks:
logger.warning(f"Queue {queue} not found in self.pending_subtasks")
return None
if task_id not in self.pending_subtasks[queue]["subtasks"]:
logger.warning(f"Task {task_id} not found in self.pending_subtasks[queue]['subtasks']")
return None
order = self.pending_subtasks[queue]["subtasks"][task_id]
consume = self.pending_subtasks[queue]["consume_count"]
real_order = max(order - consume, 1)
# logger.warning(f"Pending subtasks {queue} get order {task_id}: real={real_order} order={order} consume={consume}")
return real_order
@class_try_catch_async
async def get_ready_worker_count(self, queue):
if queue not in self.worker_clients:
self.worker_clients[queue] = {}
return len(self.worker_clients[queue])
@class_try_catch_async
async def format_subtask(self, subtasks):
ret = []
for sub in subtasks:
cur = {
"status": sub["status"].name,
"worker_name": sub["worker_name"],
"fail_msg": None,
"elapses": {},
"estimated_pending_order": None,
"estimated_pending_secs": None,
"estimated_running_secs": None,
"ready_worker_count": None,
}
if sub["status"] in [TaskStatus.PENDING, TaskStatus.RUNNING]:
cur["estimated_running_secs"] = await self.get_avg_worker_infer_cost(sub["queue"])
cur["ready_worker_count"] = await self.get_ready_worker_count(sub["queue"])
if sub["status"] == TaskStatus.PENDING:
order = await self.pending_subtasks_get_order(sub["queue"], sub["task_id"])
worker_count = max(cur["ready_worker_count"], 1e-7)
if order is not None:
cur["estimated_pending_order"] = order
wait_cycle = (order - 1) // worker_count + 1
cur["estimated_pending_secs"] = cur["estimated_running_secs"] * wait_cycle
if isinstance(sub["extra_info"], dict):
if "elapses" in sub["extra_info"]:
cur["elapses"] = sub["extra_info"]["elapses"]
if "start_t" in sub["extra_info"]:
cur["elapses"][f"{cur['status']}-"] = time.time() - sub["extra_info"]["start_t"]
if "fail_msg" in sub["extra_info"]:
cur["fail_msg"] = sub["extra_info"]["fail_msg"]
ret.append(cur)
return ret
import asyncio
import json
import traceback
from loguru import logger
from redis import asyncio as aioredis
from lightx2v.deploy.common.utils import class_try_catch_async
class RedisClient:
def __init__(self, redis_url, retry_times=3):
self.redis_url = redis_url
self.client = None
self.retry_times = retry_times
self.base_key = "lightx2v"
self.init_scriptss()
def init_scriptss(self):
self.script_create_if_not_exists = """
local key = KEYS[1]
local data_json = ARGV[1]
if redis.call('EXISTS', key) == 0 then
local data = cjson.decode(data_json)
for field, value in pairs(data) do
redis.call('HSET', key, field, value)
end
return 1
else
return 0
end
"""
self.script_increment_and_get = """
local key = KEYS[1]
local field = ARGV[1]
local diff = tonumber(ARGV[2])
local new_value = redis.call('HINCRBY', key, field, diff)
return new_value
"""
self.script_correct_pending_info = """
local key = KEYS[1]
local pending_num = tonumber(ARGV[1])
if redis.call('EXISTS', key) ~= 0 then
local consume_count = redis.call('HGET', key, 'consume_count')
local max_count = redis.call('HGET', key, 'max_count')
local redis_pending = tonumber(max_count) - tonumber(consume_count)
if redis_pending > pending_num then
redis.call('HINCRBY', key, 'consume_count', redis_pending - pending_num)
return 'consume_count added ' .. (redis_pending - pending_num)
end
if redis_pending < pending_num then
redis.call('HINCRBY', key, 'max_count', pending_num - redis_pending)
return 'max_count added ' .. (pending_num - redis_pending)
end
return 'pending equal ' .. pending_num .. ' vs ' .. redis_pending
else
return 'key not exists'
end
"""
self.script_list_push = """
local key = KEYS[1]
local value = ARGV[1]
local limit = tonumber(ARGV[2])
redis.call('LPUSH', key, value)
redis.call('LTRIM', key, 0, limit)
return 1
"""
self.script_list_avg = """
local key = KEYS[1]
local limit = tonumber(ARGV[1])
local values = redis.call('LRANGE', key, 0, limit)
local sum = 0.0
local count = 0.0
for _, value in ipairs(values) do
sum = sum + tonumber(value)
count = count + 1
end
if count == 0 then
return "-1"
end
return tostring(sum / count)
"""
async def init(self):
for i in range(self.retry_times):
try:
self.client = aioredis.Redis.from_url(self.redis_url, protocol=3)
ret = await self.client.ping()
logger.info(f"Redis connection initialized, ping: {ret}")
assert ret, "Redis connection failed"
break
except Exception:
logger.warning(f"Redis connection failed, retry {i + 1}/{self.retry_times}: {traceback.format_exc()}")
await asyncio.sleep(1)
def fmt_key(self, key):
return f"{self.base_key}:{key}"
@class_try_catch_async
async def correct_pending_info(self, key, pending_num):
key = self.fmt_key(key)
script = self.client.register_script(self.script_correct_pending_info)
result = await script(keys=[key], args=[pending_num])
logger.warning(f"Redis correct pending info {key} with {pending_num}: {result}")
return result
@class_try_catch_async
async def create_if_not_exists(self, key, value):
key = self.fmt_key(key)
script = self.client.register_script(self.script_create_if_not_exists)
result = await script(keys=[key], args=[json.dumps(value)])
if result == 1:
logger.info(f"Redis key '{key}' created successfully.")
else:
logger.warning(f"Redis key '{key}' already exists, not set.")
@class_try_catch_async
async def increment_and_get(self, key, field, diff):
key = self.fmt_key(key)
script = self.client.register_script(self.script_increment_and_get)
result = await script(keys=[key], args=[field, diff])
return result
@class_try_catch_async
async def hset(self, key, field, value):
key = self.fmt_key(key)
return await self.client.hset(key, field, value)
@class_try_catch_async
async def hget(self, key, field):
key = self.fmt_key(key)
result = await self.client.hget(key, field)
return result
@class_try_catch_async
async def hgetall(self, key):
key = self.fmt_key(key)
result = await self.client.hgetall(key)
return result
@class_try_catch_async
async def hdel(self, key, field):
key = self.fmt_key(key)
return await self.client.hdel(key, field)
@class_try_catch_async
async def hlen(self, key):
key = self.fmt_key(key)
result = await self.client.hlen(key)
return result
@class_try_catch_async
async def set(self, key, value, nx=False):
key = self.fmt_key(key)
result = await self.client.set(key, value, nx=nx)
if result is not True:
logger.warning(f"redis set {key} = {value} failed")
return result
@class_try_catch_async
async def get(self, key):
key = self.fmt_key(key)
result = await self.client.get(key)
return result
@class_try_catch_async
async def delete_key(self, key):
key = self.fmt_key(key)
return await self.client.delete(key)
@class_try_catch_async
async def list_push(self, key, value, limit):
key = self.fmt_key(key)
script = self.client.register_script(self.script_list_push)
result = await script(keys=[key], args=[value, limit])
return result
@class_try_catch_async
async def list_avg(self, key, limit):
key = self.fmt_key(key)
script = self.client.register_script(self.script_list_avg)
result = await script(keys=[key], args=[limit])
return float(result)
async def close(self):
try:
if self.client:
await self.client.aclose()
logger.info("Redis connection closed")
except Exception:
logger.warning(f"Error closing Redis connection: {traceback.format_exc()}")
async def main():
redis_url = "redis://user:password@localhost:6379/1?decode_responses=True&socket_timeout=5"
r = RedisClient(redis_url)
await r.init()
v1 = await r.set("test2", "1", nx=True)
logger.info(f"set test2=1: {v1}, {await r.get('test2')}")
v2 = await r.set("test2", "2", nx=True)
logger.info(f"set test2=2: {v2}, {await r.get('test2')}")
await r.create_if_not_exists("test", {"a": 1, "b": 2})
logger.info(f"create test: {await r.hgetall('test')}")
await r.create_if_not_exists("test", {"a": 2, "b": 3})
logger.info(f"create test again: {await r.hgetall('test')}")
logger.info(f"hlen test: {await r.hlen('test')}")
v = await r.increment_and_get("test", "a", 1)
logger.info(f"a+1: {v}, a={await r.hget('test', 'a')}")
v = await r.increment_and_get("test", "b", 3)
logger.info(f"b+3: {v}, b={await r.hget('test', 'b')}")
await r.hset("test", "a", 233)
logger.info(f"set a=233: a={await r.hget('test', 'a')}")
await r.hset("test", "c", 456)
logger.info(f"set c=456: c={await r.hget('test', 'c')}")
logger.info(f"all: {await r.hgetall('test')}")
logger.info(f"hlen test: {await r.hlen('test')}")
logger.info(f"get unknown key: {await r.hget('test', 'unknown')}")
await r.list_push("test_list", 1, 20)
logger.info(f"list push 1: {await r.list_avg('test_list', 20)}")
await r.list_push("test_list", 2, 20)
logger.info(f"list push 2: {await r.list_avg('test_list', 20)}")
await r.list_push("test_list", 3, 20)
logger.info(f"list push 3: {await r.list_avg('test_list', 20)}")
await r.delete_key("test_list")
logger.info(f"delete test_list: {await r.list_avg('test_list', 20)}")
await r.delete_key("test2")
logger.info(f"delete test2: {await r.get('test2')}")
await r.hdel("test", "a")
logger.info(f"hdel test a: {await r.hgetall('test')}")
await r.delete_key("test")
logger.info(f"delete test: {await r.hgetall('test')}")
logger.info(f"hlen test: {await r.hlen('test')}")
await r.close()
if __name__ == "__main__":
asyncio.run(main())
import asyncio
import json
import time
from loguru import logger
from lightx2v.deploy.common.utils import class_try_catch_async
from lightx2v.deploy.server.monitor import ServerMonitor, WorkerStatus
from lightx2v.deploy.server.redis_client import RedisClient
class RedisServerMonitor(ServerMonitor):
def __init__(self, model_pipelines, task_manager, queue_manager, redis_url):
super().__init__(model_pipelines, task_manager, queue_manager)
self.redis_url = redis_url
self.redis_client = RedisClient(redis_url)
self.last_correct = None
self.correct_interval = 60 * 60 * 24
async def init(self):
await self.redis_client.init()
await self.init_pending_subtasks()
async def loop(self):
while True:
if self.stop:
break
if self.last_correct is None or time.time() - self.last_correct > self.correct_interval:
self.last_correct = time.time()
await self.correct_pending_info()
await self.clean_workers()
await self.clean_subtasks()
await asyncio.sleep(self.interval)
logger.info("RedisServerMonitor stopped")
async def close(self):
await super().close()
await self.redis_client.close()
@class_try_catch_async
async def worker_update(self, queue, identity, status):
status = status.name
key = f"workers:{queue}:workers"
infer_key = f"workers:{queue}:infer_cost"
update_t = time.time()
worker = await self.redis_client.hget(key, identity)
if worker is None:
worker = {"status": "", "fetched_t": 0, "update_t": update_t}
await self.redis_client.hset(key, identity, json.dumps(worker))
else:
worker = json.loads(worker)
pre_status = worker["status"]
pre_fetched_t = float(worker["fetched_t"])
worker["status"] = status
worker["update_t"] = update_t
if status == WorkerStatus.REPORT.name and pre_fetched_t > 0:
cur_cost = update_t - pre_fetched_t
worker["fetched_t"] = 0.0
if cur_cost < self.subtask_run_timeouts[queue]:
await self.redis_client.list_push(infer_key, max(cur_cost, 1), self.worker_avg_window)
logger.info(f"Worker {identity} {queue} avg infer cost update: {cur_cost:.2f} s")
elif status == WorkerStatus.FETCHED.name:
worker["fetched_t"] = update_t
await self.redis_client.hset(key, identity, json.dumps(worker))
logger.info(f"Worker {identity} {queue} update [{status}]")
@class_try_catch_async
async def clean_workers(self):
for queue in self.all_queues:
key = f"workers:{queue}:workers"
workers = await self.redis_client.hgetall(key)
for identity, worker in workers.items():
worker = json.loads(worker)
fetched_t = float(worker["fetched_t"])
update_t = float(worker["update_t"])
status = worker["status"]
# logger.warning(f"{queue} avg infer cost {infer_avg:.2f} s, worker: {worker}")
# infer too long
if fetched_t > 0:
elapse = time.time() - fetched_t
if elapse > self.subtask_run_timeouts[queue]:
logger.warning(f"Worker {identity} {queue} infer timeout2: {elapse:.2f} s")
await self.redis_client.hdel(key, identity)
continue
elapse = time.time() - update_t
# no ping too long
if status in [WorkerStatus.FETCHED.name, WorkerStatus.PING.name]:
if elapse > self.ping_timeout:
logger.warning(f"Worker {identity} {queue} ping timeout: {elapse:.2f} s")
await self.redis_client.hdel(key, identity)
continue
# offline too long
elif status in [WorkerStatus.DISCONNECT.name, WorkerStatus.REPORT.name]:
if elapse > self.worker_offline_timeout:
logger.warning(f"Worker {identity} {queue} offline timeout2: {elapse:.2f} s")
await self.redis_client.hdel(key, identity)
continue
# fetching too long
elif status == WorkerStatus.FETCHING.name:
if elapse > self.fetching_timeout:
logger.warning(f"Worker {identity} {queue} fetching timeout: {elapse:.2f} s")
await self.redis_client.hdel(key, identity)
continue
async def get_ready_worker_count(self, queue):
key = f"workers:{queue}:workers"
worker_count = await self.redis_client.hlen(key)
return worker_count
async def get_avg_worker_infer_cost(self, queue):
infer_key = f"workers:{queue}:infer_cost"
infer_cost = await self.redis_client.list_avg(infer_key, self.worker_avg_window)
if infer_cost < 0:
return self.subtask_run_timeouts[queue]
return infer_cost
async def correct_pending_info(self):
for queue in self.all_queues:
pending_num = await self.queue_manager.pending_num(queue)
await self.redis_client.correct_pending_info(f"pendings:{queue}:info", pending_num)
@class_try_catch_async
async def init_pending_subtasks(self):
await super().init_pending_subtasks()
# save to redis if not exists
for queue, v in self.pending_subtasks.items():
subtasks = v.pop("subtasks", {})
await self.redis_client.create_if_not_exists(f"pendings:{queue}:info", v)
for task_id, order_id in subtasks.items():
await self.redis_client.set(f"pendings:{queue}:subtasks:{task_id}", order_id, nx=True)
self.pending_subtasks = None
logger.info(f"Inited pending subtasks to redis")
@class_try_catch_async
async def pending_subtasks_add(self, queue, task_id):
max_count = await self.redis_client.increment_and_get(f"pendings:{queue}:info", "max_count", 1)
await self.redis_client.set(f"pendings:{queue}:subtasks:{task_id}", max_count)
# logger.warning(f"Redis pending subtasks {queue} add {task_id}: {max_count}")
@class_try_catch_async
async def pending_subtasks_sub(self, queue, task_id):
consume_count = await self.redis_client.increment_and_get(f"pendings:{queue}:info", "consume_count", 1)
await self.redis_client.delete_key(f"pendings:{queue}:subtasks:{task_id}")
# logger.warning(f"Redis pending subtasks {queue} sub {task_id}: {consume_count}")
@class_try_catch_async
async def pending_subtasks_get_order(self, queue, task_id):
order = await self.redis_client.get(f"pendings:{queue}:subtasks:{task_id}")
if order is None:
return None
consume = await self.redis_client.hget(f"pendings:{queue}:info", "consume_count")
if consume is None:
return None
real_order = max(int(order) - int(consume), 1)
# logger.warning(f"Redis pending subtasks {queue} get order {task_id}: real={real_order} order={order} consume={consume}")
return real_order
<svg width="240" height="120" viewBox="0 0 240 120" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M104.224 78.728C101.52 78.728 99.1278 78.26 97.0478 77.324C94.9678 76.3533 93.3212 74.932 92.1078 73.06C90.9292 71.1533 90.3398 68.8133 90.3398 66.04V64.844H95.2798V66.04C95.2798 68.848 96.0945 70.9453 97.7238 72.332C99.3878 73.684 101.555 74.36 104.224 74.36C106.928 74.36 108.991 73.7706 110.412 72.592C111.833 71.3786 112.544 69.8533 112.544 68.016C112.544 66.768 112.215 65.78 111.556 65.052C110.897 64.2893 109.979 63.6653 108.8 63.1799C107.621 62.6946 106.217 62.2613 104.588 61.88L102.3 61.308C100.081 60.788 98.1398 60.1293 96.4758 59.3319C94.8118 58.5 93.5118 57.4253 92.5758 56.1079C91.6398 54.7906 91.1718 53.0746 91.1718 50.9599C91.1718 48.8799 91.6918 47.0946 92.7318 45.6039C93.7718 44.0786 95.2105 42.9173 97.0478 42.1199C98.9198 41.2879 101.104 40.8719 103.6 40.8719C106.096 40.8719 108.332 41.3053 110.308 42.172C112.284 43.0039 113.844 44.2693 114.988 45.9679C116.132 47.6319 116.704 49.7293 116.704 52.2599V54.2879H111.764V52.2599C111.764 50.5959 111.417 49.2613 110.724 48.2559C110.031 47.2159 109.06 46.4533 107.812 45.9679C106.599 45.4826 105.195 45.2399 103.6 45.2399C101.277 45.2399 99.4398 45.7426 98.0878 46.748C96.7705 47.7186 96.1118 49.1053 96.1118 50.9079C96.1118 52.0866 96.4065 53.0573 96.9958 53.8199C97.5852 54.5479 98.4172 55.1546 99.4918 55.6399C100.601 56.1253 101.936 56.5413 103.496 56.8879L105.784 57.4599C108.037 57.9453 110.031 58.604 111.764 59.4359C113.532 60.2333 114.919 61.308 115.924 62.66C116.964 64.0119 117.484 65.7626 117.484 67.912C117.484 70.096 116.929 72.0026 115.82 73.632C114.745 75.2266 113.203 76.4746 111.192 77.3759C109.216 78.2773 106.893 78.728 104.224 78.728Z" fill="black"/>
<path d="M134.037 78.728C131.472 78.728 129.218 78.1906 127.277 77.116C125.336 76.0066 123.828 74.464 122.753 72.4879C121.678 70.5119 121.141 68.2066 121.141 65.572V64.948C121.141 62.2786 121.678 59.9559 122.753 57.9799C123.828 56.0039 125.318 54.4786 127.225 53.4039C129.132 52.2946 131.333 51.74 133.829 51.74C136.256 51.74 138.388 52.2773 140.225 53.3519C142.097 54.3919 143.553 55.8826 144.593 57.8239C145.633 59.7306 146.153 61.9839 146.153 64.584V66.5079H125.925C125.994 69.004 126.809 70.9799 128.369 72.4359C129.929 73.8573 131.853 74.568 134.141 74.568C136.256 74.568 137.868 74.1 138.977 73.164C140.086 72.1933 140.918 71.0666 141.473 69.784L145.477 71.76C144.992 72.8 144.281 73.8573 143.345 74.9319C142.444 76.0066 141.248 76.908 139.757 77.636C138.266 78.364 136.36 78.728 134.037 78.728ZM125.977 62.764H141.317C141.178 60.6146 140.433 58.9506 139.081 57.7719C137.729 56.5586 135.961 55.9519 133.777 55.9519C131.628 55.9519 129.86 56.5586 128.473 57.7719C127.121 58.9506 126.289 60.6146 125.977 62.764Z" fill="black"/>
<path d="M151.399 78V41.5999H156.131V62.712H156.859L167.727 52.4679H174.175L160.655 64.896L174.591 78H168.247L156.859 67.132H156.131V78H151.399Z" fill="black"/>
<path d="M188.559 78.728C185.993 78.728 183.722 78.1906 181.746 77.116C179.771 76.0413 178.228 74.5333 177.118 72.592C176.009 70.616 175.454 68.2933 175.454 65.624V64.896C175.454 62.1919 176.009 59.8693 177.118 57.9279C178.228 55.9519 179.771 54.4266 181.746 53.3519C183.722 52.2773 185.993 51.74 188.559 51.74C191.124 51.74 193.395 52.2773 195.37 53.3519C197.346 54.4266 198.889 55.9519 199.998 57.9279C201.143 59.8693 201.715 62.1919 201.715 64.896V65.624C201.715 68.2933 201.143 70.616 199.998 72.592C198.889 74.5333 197.346 76.0413 195.37 77.116C193.395 78.1906 191.124 78.728 188.559 78.728ZM188.559 74.5159C191.089 74.5159 193.117 73.7186 194.643 72.124C196.202 70.4946 196.982 68.276 196.982 65.468V65C196.982 62.2266 196.202 60.0426 194.643 58.448C193.117 56.8186 191.089 56.004 188.559 56.004C186.062 56.004 184.035 56.8186 182.475 58.448C180.949 60.0426 180.187 62.2266 180.187 65V65.468C180.187 68.276 180.949 70.4946 182.475 72.124C184.035 73.7186 186.062 74.5159 188.559 74.5159Z" fill="black"/>
<path d="M65.0876 62.7798L54.4894 67.4984L51.9751 68.6177L39.4047 74.2143C39.1896 74.3101 39.0411 74.5091 39.0073 74.7376L39 74.8374V89.3181C39.0002 89.8114 39.5079 90.141 39.9586 89.9405L65.6421 78.5054L65.73 78.4588C65.9252 78.335 66.0469 78.1186 66.0469 77.8829V74.8829V72.6442V63.4022C66.0468 62.9089 65.5384 62.5791 65.0876 62.7798Z" fill="black"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M59.2473 53.6442L56.7331 52.5249L45.2273 47.4024L43.1823 46.4919L39.9586 45.0567C39.5079 44.8562 39.0001 45.1858 39 45.6791V60.1598C39 60.4291 39.1588 60.6727 39.4047 60.7822L51.9755 66.3792L64.2559 60.9116C66.0592 60.1087 68.0915 61.4286 68.0915 63.4025V73.5547L71.3151 74.99C71.7658 75.1905 72.2736 74.8608 72.2737 74.3676V59.8869C72.2737 59.6177 72.1156 59.3733 71.8696 59.2638L59.2473 53.6442ZM43.1823 56.5981L43.1883 56.7814C43.1903 56.8114 43.1926 56.8412 43.1955 56.8708C43.1984 56.9008 43.2017 56.9305 43.2056 56.9602C43.1903 56.8423 43.1823 56.7215 43.1823 56.5981Z" fill="black"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M72.2734 30.6824V45.1631C72.2734 45.3987 72.1518 45.6152 71.9566 45.7389L71.8687 45.7855L59.2463 51.4055L45.2266 45.1633V42.1176L45.2339 42.0177C45.2676 41.7892 45.4162 41.5903 45.6313 41.4945L71.3142 30.06C71.7649 29.8593 72.2733 30.189 72.2734 30.6824Z" fill="black"/>
</svg>
<svg width="240" height="120" viewBox="0 0 240 120" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M104.224 78.728C101.52 78.728 99.1278 78.26 97.0478 77.324C94.9678 76.3533 93.3212 74.932 92.1078 73.06C90.9292 71.1533 90.3398 68.8133 90.3398 66.04V64.844H95.2798V66.04C95.2798 68.848 96.0945 70.9453 97.7238 72.332C99.3878 73.684 101.555 74.36 104.224 74.36C106.928 74.36 108.991 73.7706 110.412 72.592C111.833 71.3786 112.544 69.8533 112.544 68.016C112.544 66.768 112.215 65.78 111.556 65.052C110.897 64.2893 109.979 63.6653 108.8 63.1799C107.621 62.6946 106.217 62.2613 104.588 61.88L102.3 61.308C100.081 60.788 98.1398 60.1293 96.4758 59.3319C94.8118 58.4999 93.5118 57.4253 92.5758 56.1079C91.6398 54.7906 91.1718 53.0746 91.1718 50.9599C91.1718 48.8799 91.6918 47.0946 92.7318 45.6039C93.7718 44.0786 95.2105 42.9173 97.0478 42.1199C98.9198 41.2879 101.104 40.8719 103.6 40.8719C106.096 40.8719 108.332 41.3053 110.308 42.172C112.284 43.0039 113.844 44.2693 114.988 45.9679C116.132 47.6319 116.704 49.7293 116.704 52.2599V54.2879H111.764V52.2599C111.764 50.5959 111.417 49.2613 110.724 48.2559C110.031 47.2159 109.06 46.4533 107.812 45.9679C106.599 45.4826 105.195 45.2399 103.6 45.2399C101.277 45.2399 99.4398 45.7426 98.0878 46.748C96.7705 47.7186 96.1118 49.1053 96.1118 50.9079C96.1118 52.0866 96.4065 53.0573 96.9958 53.8199C97.5852 54.5479 98.4172 55.1546 99.4918 55.6399C100.601 56.1253 101.936 56.5413 103.496 56.8879L105.784 57.4599C108.037 57.9453 110.031 58.6039 111.764 59.4359C113.532 60.2333 114.919 61.308 115.924 62.66C116.964 64.012 117.484 65.7626 117.484 67.912C117.484 70.096 116.929 72.0026 115.82 73.632C114.745 75.2266 113.203 76.4746 111.192 77.3759C109.216 78.2773 106.893 78.728 104.224 78.728Z" fill="white"/>
<path d="M134.037 78.728C131.472 78.728 129.218 78.1906 127.277 77.116C125.336 76.0066 123.828 74.4639 122.753 72.4879C121.678 70.5119 121.141 68.2066 121.141 65.572V64.948C121.141 62.2786 121.678 59.956 122.753 57.9799C123.828 56.0039 125.318 54.4786 127.225 53.4039C129.132 52.2946 131.333 51.74 133.829 51.74C136.256 51.74 138.388 52.2773 140.225 53.3519C142.097 54.3919 143.553 55.8826 144.593 57.8239C145.633 59.7306 146.153 61.984 146.153 64.584V66.5079H125.925C125.994 69.004 126.809 70.9799 128.369 72.4359C129.929 73.8573 131.853 74.568 134.141 74.568C136.256 74.568 137.868 74.1 138.977 73.164C140.086 72.1933 140.918 71.0666 141.473 69.784L145.477 71.76C144.992 72.8 144.281 73.8573 143.345 74.9319C142.444 76.0066 141.248 76.908 139.757 77.636C138.266 78.364 136.36 78.728 134.037 78.728ZM125.977 62.764H141.317C141.178 60.6146 140.433 58.9506 139.081 57.7719C137.729 56.5586 135.961 55.9519 133.777 55.9519C131.628 55.9519 129.86 56.5586 128.473 57.7719C127.121 58.9506 126.289 60.6146 125.977 62.764Z" fill="white"/>
<path d="M151.399 78V41.5999H156.131V62.712H156.859L167.727 52.4679H174.175L160.655 64.896L174.591 78H168.247L156.859 67.132H156.131V78H151.399Z" fill="white"/>
<path d="M188.559 78.728C185.993 78.728 183.722 78.1906 181.746 77.116C179.77 76.0413 178.228 74.5333 177.118 72.592C176.009 70.616 175.454 68.2933 175.454 65.624V64.896C175.454 62.192 176.009 59.8693 177.118 57.9279C178.228 55.9519 179.77 54.4266 181.746 53.3519C183.722 52.2773 185.993 51.74 188.559 51.74C191.124 51.74 193.394 52.2773 195.37 53.3519C197.346 54.4266 198.889 55.9519 199.998 57.9279C201.142 59.8693 201.715 62.192 201.715 64.896V65.624C201.715 68.2933 201.142 70.616 199.998 72.592C198.889 74.5333 197.346 76.0413 195.37 77.116C193.394 78.1906 191.124 78.728 188.559 78.728ZM188.559 74.5159C191.089 74.5159 193.117 73.7186 194.643 72.124C196.203 70.4946 196.982 68.276 196.982 65.468V65C196.982 62.2266 196.203 60.0426 194.643 58.448C193.117 56.8186 191.089 56.004 188.559 56.004C186.063 56.004 184.035 56.8186 182.475 58.448C180.949 60.0426 180.187 62.2266 180.187 65V65.468C180.187 68.276 180.949 70.4946 182.475 72.124C184.035 73.7186 186.063 74.5159 188.559 74.5159Z" fill="white"/>
<path d="M65.0876 62.7798L54.4894 67.4984L51.9751 68.6177L39.4047 74.2143C39.1896 74.3101 39.0411 74.5091 39.0073 74.7376L39 74.8374V89.3181C39.0002 89.8114 39.5079 90.141 39.9586 89.9405L65.6421 78.5054L65.73 78.4588C65.9252 78.335 66.0469 78.1186 66.0469 77.8829V74.8829V72.6442V63.4022C66.0468 62.9089 65.5384 62.5791 65.0876 62.7798Z" fill="white"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M59.2473 53.6442L56.7331 52.5249L45.2273 47.4024L43.1823 46.4919L39.9586 45.0567C39.5079 44.8562 39.0001 45.1858 39 45.6791V60.1598C39 60.4291 39.1588 60.6727 39.4047 60.7822L51.9755 66.3792L64.2559 60.9116C66.0592 60.1087 68.0915 61.4286 68.0915 63.4025V73.5547L71.3151 74.99C71.7658 75.1905 72.2736 74.8608 72.2737 74.3676V59.8869C72.2737 59.6177 72.1156 59.3733 71.8696 59.2638L59.2473 53.6442ZM43.1823 56.5981L43.1883 56.7814C43.1903 56.8114 43.1926 56.8412 43.1955 56.8708C43.1984 56.9008 43.2017 56.9305 43.2056 56.9602C43.1903 56.8423 43.1823 56.7215 43.1823 56.5981Z" fill="white"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M72.2734 30.6824V45.1631C72.2734 45.3987 72.1518 45.6152 71.9566 45.7389L71.8687 45.7855L59.2463 51.4055L45.2266 45.1633V42.1176L45.2339 42.0177C45.2676 41.7892 45.4162 41.5903 45.6313 41.4945L71.3142 30.06C71.7649 29.8593 72.2733 30.189 72.2734 30.6824Z" fill="white"/>
</svg>
This diff is collapsed.
......@@ -4,7 +4,7 @@ from re import T
from loguru import logger
from lightx2v.deploy.common.utils import current_time, data_name
from lightx2v.deploy.common.utils import class_try_catch, current_time, data_name
class TaskStatus(Enum):
......@@ -63,6 +63,9 @@ class BaseTaskManager:
async def resume_task(self, task_id, all_subtask=False, user_id=None):
raise NotImplementedError
async def delete_task(self, task_id, user_id=None):
raise NotImplementedError
def fmt_dict(self, data):
for k in ["status"]:
if k in data:
......@@ -74,7 +77,7 @@ class BaseTaskManager:
data[k] = TaskStatus[data[k]]
async def create_user(self, user_info):
assert user_info["source"] == "github", f"do not support {user_info['source']} user!"
assert user_info["source"] in ["github", "google", "phone"], f"do not support {user_info['source']} user!"
cur_t = current_time()
user_id = f"{user_info['source']}_{user_info['id']}"
data = {
......@@ -112,7 +115,8 @@ class BaseTaskManager:
"outputs": {x: data_name(x, task_id) for x in outputs},
"user_id": user_id,
}
self.mark_task_start(task)
records = []
self.mark_task_start(records, task)
subtasks = []
for worker_name, worker_item in workers.items():
subtasks.append(
......@@ -134,25 +138,23 @@ class BaseTaskManager:
"infer_cost": -1.0,
}
)
self.mark_subtask_change(subtasks[-1], None, TaskStatus.CREATED)
self.mark_subtask_change(records, subtasks[-1], None, TaskStatus.CREATED)
ret = await self.insert_task(task, subtasks)
# if insert error
if not ret:
self.mark_task_end(task, TaskStatus.FAILED)
for sub in subtasks:
self.mark_subtask_change(sub, sub["status"], TaskStatus.FAILED)
assert ret, f"create task {task_id} failed"
self.metrics_commit(records)
return task_id
async def mark_server_restart(self):
if self.metrics_monitor:
tasks = await self.list_tasks(status=ActiveStatus)
subtasks = await self.list_tasks(status=ActiveStatus, subtasks=True)
logger.warning(f"Mark system restart, {len(tasks)} tasks, {len(subtasks)} subtasks")
self.metrics_monitor.record_task_recover(tasks)
self.metrics_monitor.record_subtask_recover(subtasks)
def mark_task_start(self, task):
pass
# only for start server with active tasks
# if self.metrics_monitor:
# tasks = await self.list_tasks(status=ActiveStatus)
# subtasks = await self.list_tasks(status=ActiveStatus, subtasks=True)
# logger.warning(f"Mark system restart, {len(tasks)} tasks, {len(subtasks)} subtasks")
# self.metrics_monitor.record_task_recover(tasks)
# self.metrics_monitor.record_subtask_recover(subtasks)
def mark_task_start(self, records, task):
t = current_time()
if not isinstance(task["extra_info"], dict):
task["extra_info"] = {}
......@@ -161,9 +163,14 @@ class BaseTaskManager:
task["extra_info"]["start_t"] = t
logger.info(f"Task {task['task_id']} active start")
if self.metrics_monitor:
self.metrics_monitor.record_task_start(task)
records.append(
[
self.metrics_monitor.record_task_start,
[task],
]
)
def mark_task_end(self, task, end_status):
def mark_task_end(self, records, task, end_status):
if "start_t" not in task["extra_info"]:
logger.warning(f"Task {task} has no start time")
else:
......@@ -173,9 +180,14 @@ class BaseTaskManager:
logger.info(f"Task {task['task_id']} active end with [{end_status}], elapse: {elapse}")
if self.metrics_monitor:
self.metrics_monitor.record_task_end(task, end_status, elapse)
def mark_subtask_change(self, subtask, old_status, new_status, fail_msg=None):
records.append(
[
self.metrics_monitor.record_task_end,
[task, end_status, elapse],
]
)
def mark_subtask_change(self, records, subtask, old_status, new_status, fail_msg=None):
t = current_time()
if not isinstance(subtask["extra_info"], dict):
subtask["extra_info"] = {}
......@@ -211,7 +223,17 @@ class BaseTaskManager:
)
if self.metrics_monitor:
self.metrics_monitor.record_subtask_change(subtask, old_status, new_status, elapse_key, elapse)
records.append(
[
self.metrics_monitor.record_subtask_change,
[subtask, old_status, new_status, elapse_key, elapse],
]
)
@class_try_catch
def metrics_commit(self, records):
for func, args in records:
func(*args)
# Import task manager implementations
......
......@@ -93,10 +93,18 @@ class LocalTaskManager(BaseTaskManager):
continue
if "end_ping_t" in kwargs and kwargs["end_ping_t"] < task["ping_t"]:
continue
# 如果不是查询子任务,则添加子任务信息到任务中
if not kwargs.get("subtasks", False):
task["subtasks"] = info.get("subtasks", [])
tasks.append(task)
if "count" in kwargs:
return len(tasks)
tasks = sorted(tasks, key=lambda x: x["create_t"], reverse=True)
sort_key = "update_t" if kwargs.get("sort_by_update_t", False) else "create_t"
tasks = sorted(tasks, key=lambda x: x[sort_key], reverse=True)
if "offset" in kwargs:
tasks = tasks[kwargs["offset"] :]
if "limit" in kwargs:
......@@ -109,6 +117,7 @@ class LocalTaskManager(BaseTaskManager):
@class_try_catch_async
async def next_subtasks(self, task_id):
records = []
task, subtasks = self.load(task_id)
if task["status"] not in ActiveStatus:
return []
......@@ -125,7 +134,7 @@ class LocalTaskManager(BaseTaskManager):
dep_ok = False
break
if dep_ok:
self.mark_subtask_change(sub, sub["status"], TaskStatus.PENDING)
self.mark_subtask_change(records, sub, sub["status"], TaskStatus.PENDING)
sub["params"] = task["params"]
sub["status"] = TaskStatus.PENDING
sub["update_t"] = current_time()
......@@ -134,10 +143,12 @@ class LocalTaskManager(BaseTaskManager):
task["status"] = TaskStatus.PENDING
task["update_t"] = current_time()
self.save(task, subtasks)
self.metrics_commit(records)
return nexts
@class_try_catch_async
async def run_subtasks(self, cands, worker_identity):
records = []
valids = []
for cand in cands:
task_id = cand["task_id"]
......@@ -147,7 +158,7 @@ class LocalTaskManager(BaseTaskManager):
continue
for sub in subtasks:
if sub["worker_name"] == worker_name:
self.mark_subtask_change(sub, sub["status"], TaskStatus.RUNNING)
self.mark_subtask_change(records, sub, sub["status"], TaskStatus.RUNNING)
sub["status"] = TaskStatus.RUNNING
sub["worker_identity"] = worker_identity
sub["update_t"] = current_time()
......@@ -157,6 +168,7 @@ class LocalTaskManager(BaseTaskManager):
self.save(task, subtasks)
valids.append(cand)
break
self.metrics_commit(records)
return valids
@class_try_catch_async
......@@ -173,6 +185,7 @@ class LocalTaskManager(BaseTaskManager):
@class_try_catch_async
async def finish_subtasks(self, task_id, status, worker_identity=None, worker_name=None, fail_msg=None, should_running=False):
records = []
task, subtasks = self.load(task_id)
subs = subtasks
......@@ -190,12 +203,13 @@ class LocalTaskManager(BaseTaskManager):
if should_running and sub["status"] != TaskStatus.RUNNING:
print(f"task {task_id} is not running, skip finish subtask: {sub}")
continue
self.mark_subtask_change(sub, sub["status"], status, fail_msg=fail_msg)
self.mark_subtask_change(records, sub, sub["status"], status, fail_msg=fail_msg)
sub["status"] = status
sub["update_t"] = current_time()
if task["status"] == TaskStatus.CANCEL:
self.save(task, subtasks)
self.metrics_commit(records)
return TaskStatus.CANCEL
running_subs = []
......@@ -209,47 +223,53 @@ class LocalTaskManager(BaseTaskManager):
# some subtask failed, we should fail all other subtasks
if failed_sub:
if task["status"] != TaskStatus.FAILED:
self.mark_task_end(task, TaskStatus.FAILED)
self.mark_task_end(records, task, TaskStatus.FAILED)
task["status"] = TaskStatus.FAILED
task["update_t"] = current_time()
for sub in running_subs:
self.mark_subtask_change(sub, sub["status"], TaskStatus.FAILED, fail_msg="other subtask failed")
self.mark_subtask_change(records, sub, sub["status"], TaskStatus.FAILED, fail_msg="other subtask failed")
sub["status"] = TaskStatus.FAILED
sub["update_t"] = current_time()
self.save(task, subtasks)
self.metrics_commit(records)
return TaskStatus.FAILED
# all subtasks finished and all succeed
elif len(running_subs) == 0:
if task["status"] != TaskStatus.SUCCEED:
self.mark_task_end(task, TaskStatus.SUCCEED)
self.mark_task_end(records, task, TaskStatus.SUCCEED)
task["status"] = TaskStatus.SUCCEED
task["update_t"] = current_time()
self.save(task, subtasks)
self.metrics_commit(records)
return TaskStatus.SUCCEED
self.save(task, subtasks)
self.metrics_commit(records)
return None
@class_try_catch_async
async def cancel_task(self, task_id, user_id=None):
records = []
task, subtasks = self.load(task_id, user_id)
if task["status"] not in ActiveStatus:
return f"Task {task_id} is not in active status (current status: {task['status']}). Only tasks with status CREATED, PENDING, or RUNNING can be cancelled."
for sub in subtasks:
if sub["status"] not in FinishedStatus:
self.mark_subtask_change(sub, sub["status"], TaskStatus.CANCEL)
self.mark_subtask_change(records, sub, sub["status"], TaskStatus.CANCEL)
sub["status"] = TaskStatus.CANCEL
sub["update_t"] = current_time()
self.mark_task_end(task, TaskStatus.CANCEL)
self.mark_task_end(records, task, TaskStatus.CANCEL)
task["status"] = TaskStatus.CANCEL
task["update_t"] = current_time()
self.save(task, subtasks)
self.metrics_commit(records)
return True
@class_try_catch_async
async def resume_task(self, task_id, all_subtask=False, user_id=None):
records = []
task, subtasks = self.load(task_id, user_id)
# the task is not finished
if task["status"] not in FinishedStatus:
......@@ -259,14 +279,27 @@ class LocalTaskManager(BaseTaskManager):
return False
for sub in subtasks:
if all_subtask or sub["status"] != TaskStatus.SUCCEED:
self.mark_subtask_change(sub, None, TaskStatus.CREATED)
self.mark_subtask_change(records, sub, None, TaskStatus.CREATED)
sub["status"] = TaskStatus.CREATED
sub["update_t"] = current_time()
sub["ping_t"] = 0.0
self.mark_task_start(task)
self.mark_task_start(records, task)
task["status"] = TaskStatus.CREATED
task["update_t"] = current_time()
self.save(task, subtasks)
self.metrics_commit(records)
return True
@class_try_catch_async
async def delete_task(self, task_id, user_id=None):
task = self.load(task_id, user_id, only_task=True)
# only allow to delete finished tasks
if task["status"] not in FinishedStatus:
return False
# delete task file
task_file = self.get_task_filename(task_id)
if os.path.exists(task_file):
os.remove(task_file)
return True
@class_try_catch_async
......
......@@ -9,8 +9,6 @@ from loguru import logger
from lightx2v.deploy.common.utils import class_try_catch_async
from lightx2v.deploy.task_manager import ActiveStatus, BaseTaskManager, FinishedStatus, TaskStatus
ASYNC_LOCK = asyncio.Lock()
class PostgresSQLTaskManager(BaseTaskManager):
def __init__(self, db_url, metrics_monitor=None):
......@@ -198,6 +196,15 @@ class PostgresSQLTaskManager(BaseTaskManager):
subtasks.append(sub)
return task, subtasks
def check_update_valid(self, ret, prefix, query, params):
if ret.startswith("UPDATE "):
count = int(ret.split(" ")[1])
assert count > 0, f"{prefix}: no row changed: {query} {params}"
return count
else:
logger.warning(f"parse postsql update ret failed: {ret}")
return 0
async def update_task(self, conn, task_id, **kwargs):
query = f"UPDATE {self.table_tasks} SET "
conds = ["update_t = $1"]
......@@ -211,10 +218,19 @@ class PostgresSQLTaskManager(BaseTaskManager):
param_idx += 1
conds.append(f"extra_info = ${param_idx}")
params.append(json.dumps(kwargs["extra_info"], ensure_ascii=False))
query += " ,".join(conds)
query += f" WHERE task_id = ${param_idx + 1}"
limit_conds = [f"task_id = ${param_idx + 1}"]
param_idx += 1
params.append(task_id)
await conn.execute(query, *params)
if "src_status" in kwargs:
param_idx += 1
limit_conds.append(f"status = ${param_idx}")
params.append(kwargs["src_status"].name)
query += " ,".join(conds) + " WHERE " + " AND ".join(limit_conds)
ret = await conn.execute(query, *params)
return self.check_update_valid(ret, "update_task", query, params)
async def update_subtask(self, conn, task_id, worker_name, **kwargs):
query = f"UPDATE {self.table_subtasks} SET "
......@@ -249,10 +265,19 @@ class PostgresSQLTaskManager(BaseTaskManager):
param_idx += 1
conds.append(f"extra_info = ${param_idx}")
params.append(json.dumps(kwargs["extra_info"], ensure_ascii=False))
query += " ,".join(conds)
query += f" WHERE task_id = ${param_idx + 1} AND worker_name = ${param_idx + 2}"
limit_conds = [f"task_id = ${param_idx + 1}", f"worker_name = ${param_idx + 2}"]
param_idx += 2
params.extend([task_id, worker_name])
await conn.execute(query, *params)
if "src_status" in kwargs:
param_idx += 1
limit_conds.append(f"status = ${param_idx}")
params.append(kwargs["src_status"].name)
query += " ,".join(conds) + " WHERE " + " AND ".join(limit_conds)
ret = await conn.execute(query, *params)
return self.check_update_valid(ret, "update_subtask", query, params)
@class_try_catch_async
async def insert_task(self, task, subtasks):
......@@ -384,7 +409,8 @@ class PostgresSQLTaskManager(BaseTaskManager):
query += " WHERE " + " AND ".join(conds)
if not count:
query += " ORDER BY create_t DESC"
sort_key = "update_t" if kwargs.get("sort_by_update_t", False) else "create_t"
query += f" ORDER BY {sort_key} DESC"
if "limit" in kwargs:
param_idx += 1
......@@ -399,10 +425,26 @@ class PostgresSQLTaskManager(BaseTaskManager):
rows = await conn.fetch(query, *params)
if count:
return rows[0]["count"]
# query subtasks with task
subtasks = {}
if not kwargs.get("subtasks", False):
subtask_query = f"SELECT {self.table_subtasks}.* FROM ({query}) AS t \
JOIN {self.table_subtasks} ON t.task_id = {self.table_subtasks}.task_id"
subtask_rows = await conn.fetch(subtask_query, *params)
for row in subtask_rows:
sub = dict(row)
self.parse_dict(sub)
if sub["task_id"] not in subtasks:
subtasks[sub["task_id"]] = []
subtasks[sub["task_id"]].append(sub)
tasks = []
for row in rows:
task = dict(row)
self.parse_dict(task)
if not kwargs.get("subtasks", False):
task["subtasks"] = subtasks.get(task["task_id"], [])
tasks.append(task)
return tasks
except: # noqa
......@@ -425,8 +467,8 @@ class PostgresSQLTaskManager(BaseTaskManager):
@class_try_catch_async
async def next_subtasks(self, task_id):
conn = await self.get_conn()
records = []
try:
await ASYNC_LOCK.acquire()
async with conn.transaction(isolation="read_uncommitted"):
task, subtasks = await self.load(conn, task_id)
if task["status"] not in ActiveStatus:
......@@ -445,24 +487,31 @@ class PostgresSQLTaskManager(BaseTaskManager):
break
if dep_ok:
sub["params"] = task["params"]
self.mark_subtask_change(sub, sub["status"], TaskStatus.PENDING)
await self.update_subtask(conn, task_id, sub["worker_name"], status=TaskStatus.PENDING, extra_info=sub["extra_info"])
self.mark_subtask_change(records, sub, sub["status"], TaskStatus.PENDING)
await self.update_subtask(
conn,
task_id,
sub["worker_name"],
status=TaskStatus.PENDING,
extra_info=sub["extra_info"],
src_status=sub["status"],
)
nexts.append(sub)
if len(nexts) > 0:
await self.update_task(conn, task_id, status=TaskStatus.PENDING)
self.metrics_commit(records)
return nexts
except: # noqa
logger.error(f"next_subtasks error: {traceback.format_exc()}")
return None
finally:
ASYNC_LOCK.release()
await self.release_conn(conn)
@class_try_catch_async
async def run_subtasks(self, cands, worker_identity):
conn = await self.get_conn()
records = []
try:
await ASYNC_LOCK.acquire()
async with conn.transaction(isolation="read_uncommitted"):
valids = []
for cand in cands:
......@@ -473,24 +522,32 @@ class PostgresSQLTaskManager(BaseTaskManager):
if task["status"] in [TaskStatus.SUCCEED, TaskStatus.FAILED, TaskStatus.CANCEL]:
continue
self.mark_subtask_change(subs[0], subs[0]["status"], TaskStatus.RUNNING)
await self.update_subtask(conn, task_id, worker_name, status=TaskStatus.RUNNING, worker_identity=worker_identity, ping_t=True, extra_info=subs[0]["extra_info"])
self.mark_subtask_change(records, subs[0], subs[0]["status"], TaskStatus.RUNNING)
await self.update_subtask(
conn,
task_id,
worker_name,
status=TaskStatus.RUNNING,
worker_identity=worker_identity,
ping_t=True,
extra_info=subs[0]["extra_info"],
src_status=subs[0]["status"],
)
await self.update_task(conn, task_id, status=TaskStatus.RUNNING)
valids.append(cand)
break
self.metrics_commit(records)
return valids
except: # noqa
logger.error(f"run_subtasks error: {traceback.format_exc()}")
return []
finally:
ASYNC_LOCK.release()
await self.release_conn(conn)
@class_try_catch_async
async def ping_subtask(self, task_id, worker_name, worker_identity):
conn = await self.get_conn()
try:
await ASYNC_LOCK.acquire()
async with conn.transaction(isolation="read_uncommitted"):
task, subtasks = await self.load(conn, task_id)
for sub in subtasks:
......@@ -504,14 +561,13 @@ class PostgresSQLTaskManager(BaseTaskManager):
logger.error(f"ping_subtask error: {traceback.format_exc()}")
return False
finally:
ASYNC_LOCK.release()
await self.release_conn(conn)
@class_try_catch_async
async def finish_subtasks(self, task_id, status, worker_identity=None, worker_name=None, fail_msg=None, should_running=False):
conn = await self.get_conn()
records = []
try:
await ASYNC_LOCK.acquire()
async with conn.transaction(isolation="read_uncommitted"):
task, subtasks = await self.load(conn, task_id)
subs = subtasks
......@@ -529,11 +585,19 @@ class PostgresSQLTaskManager(BaseTaskManager):
if should_running and sub["status"] != TaskStatus.RUNNING:
logger.warning(f"task {task_id} is not running, skip finish subtask: {sub}")
continue
self.mark_subtask_change(sub, sub["status"], status, fail_msg=fail_msg)
await self.update_subtask(conn, task_id, sub["worker_name"], status=status, extra_info=sub["extra_info"])
self.mark_subtask_change(records, sub, sub["status"], status, fail_msg=fail_msg)
await self.update_subtask(
conn,
task_id,
sub["worker_name"],
status=status,
extra_info=sub["extra_info"],
src_status=sub["status"],
)
sub["status"] = status
if task["status"] == TaskStatus.CANCEL:
self.metrics_commit(records)
return TaskStatus.CANCEL
running_subs = []
......@@ -547,57 +611,93 @@ class PostgresSQLTaskManager(BaseTaskManager):
# some subtask failed, we should fail all other subtasks
if failed_sub:
if task["status"] != TaskStatus.FAILED:
self.mark_task_end(task, TaskStatus.FAILED)
await self.update_task(conn, task_id, status=TaskStatus.FAILED, extra_info=task["extra_info"])
self.mark_task_end(records, task, TaskStatus.FAILED)
await self.update_task(
conn,
task_id,
status=TaskStatus.FAILED,
extra_info=task["extra_info"],
src_status=task["status"],
)
for sub in running_subs:
self.mark_subtask_change(sub, sub["status"], TaskStatus.FAILED, fail_msg="other subtask failed")
await self.update_subtask(conn, task_id, sub["worker_name"], status=TaskStatus.FAILED, extra_info=sub["extra_info"])
self.mark_subtask_change(records, sub, sub["status"], TaskStatus.FAILED, fail_msg="other subtask failed")
await self.update_subtask(
conn,
task_id,
sub["worker_name"],
status=TaskStatus.FAILED,
extra_info=sub["extra_info"],
src_status=sub["status"],
)
self.metrics_commit(records)
return TaskStatus.FAILED
# all subtasks finished and all succeed
elif len(running_subs) == 0:
if task["status"] != TaskStatus.SUCCEED:
self.mark_task_end(task, TaskStatus.SUCCEED)
await self.update_task(conn, task_id, status=TaskStatus.SUCCEED, extra_info=task["extra_info"])
self.mark_task_end(records, task, TaskStatus.SUCCEED)
await self.update_task(
conn,
task_id,
status=TaskStatus.SUCCEED,
extra_info=task["extra_info"],
src_status=task["status"],
)
self.metrics_commit(records)
return TaskStatus.SUCCEED
self.metrics_commit(records)
return None
except: # noqa
logger.error(f"finish_subtasks error: {traceback.format_exc()}")
return None
finally:
ASYNC_LOCK.release()
await self.release_conn(conn)
@class_try_catch_async
async def cancel_task(self, task_id, user_id=None):
conn = await self.get_conn()
records = []
try:
await ASYNC_LOCK.acquire()
async with conn.transaction(isolation="read_uncommitted"):
task, subtasks = await self.load(conn, task_id, user_id)
if task["status"] not in ActiveStatus:
return f"Task {task_id} is not in active status (current status: {task['status']}). Only tasks with status CREATED, PENDING, or RUNNING can be cancelled."
return f"Task {task_id} is not in active status (current status: {task['status']}). \
Only tasks with status CREATED, PENDING, or RUNNING can be cancelled."
for sub in subtasks:
if sub["status"] not in FinishedStatus:
self.mark_subtask_change(sub, sub["status"], TaskStatus.CANCEL)
await self.update_subtask(conn, task_id, sub["worker_name"], status=TaskStatus.CANCEL, extra_info=sub["extra_info"])
self.mark_task_end(task, TaskStatus.CANCEL)
await self.update_task(conn, task_id, status=TaskStatus.CANCEL, extra_info=task["extra_info"])
self.mark_subtask_change(records, sub, sub["status"], TaskStatus.CANCEL)
await self.update_subtask(
conn,
task_id,
sub["worker_name"],
status=TaskStatus.CANCEL,
extra_info=sub["extra_info"],
src_status=sub["status"],
)
self.mark_task_end(records, task, TaskStatus.CANCEL)
await self.update_task(
conn,
task_id,
status=TaskStatus.CANCEL,
extra_info=task["extra_info"],
src_status=task["status"],
)
self.metrics_commit(records)
return True
except: # noqa
logger.error(f"cancel_task error: {traceback.format_exc()}")
return "unknown cancel error"
finally:
ASYNC_LOCK.release()
await self.release_conn(conn)
@class_try_catch_async
async def resume_task(self, task_id, all_subtask=False, user_id=None):
conn = await self.get_conn()
records = []
try:
await ASYNC_LOCK.acquire()
async with conn.transaction(isolation="read_uncommitted"):
task, subtasks = await self.load(conn, task_id, user_id)
# the task is not finished
......@@ -609,17 +709,55 @@ class PostgresSQLTaskManager(BaseTaskManager):
for sub in subtasks:
if all_subtask or sub["status"] != TaskStatus.SUCCEED:
self.mark_subtask_change(sub, None, TaskStatus.CREATED)
await self.update_subtask(conn, task_id, sub["worker_name"], status=TaskStatus.CREATED, reset_ping_t=True, extra_info=sub["extra_info"])
self.mark_task_start(task)
await self.update_task(conn, task_id, status=TaskStatus.CREATED, extra_info=task["extra_info"])
self.mark_subtask_change(records, sub, None, TaskStatus.CREATED)
await self.update_subtask(
conn,
task_id,
sub["worker_name"],
status=TaskStatus.CREATED,
reset_ping_t=True,
extra_info=sub["extra_info"],
src_status=sub["status"],
)
self.mark_task_start(records, task)
await self.update_task(
conn,
task_id,
status=TaskStatus.CREATED,
extra_info=task["extra_info"],
src_status=task["status"],
)
self.metrics_commit(records)
return True
except: # noqa
logger.error(f"resume_task error: {traceback.format_exc()}")
return False
finally:
ASYNC_LOCK.release()
await self.release_conn(conn)
@class_try_catch_async
async def delete_task(self, task_id, user_id=None):
conn = await self.get_conn()
try:
async with conn.transaction(isolation="read_uncommitted"):
task = await self.load(conn, task_id, user_id, only_task=True)
# only allow to delete finished tasks
if task["status"] not in FinishedStatus:
logger.warning(f"Cannot delete task {task_id} with status {task['status']}, only finished tasks can be deleted")
return False
# delete subtasks & task record
await conn.execute(f"DELETE FROM {self.table_subtasks} WHERE task_id = $1", task_id)
await conn.execute(f"DELETE FROM {self.table_tasks} WHERE task_id = $1", task_id)
logger.info(f"Task {task_id} and its subtasks deleted successfully")
return True
except: # noqa
logger.error(f"delete_task error: {traceback.format_exc()}")
return False
finally:
await self.release_conn(conn)
@class_try_catch_async
......
......@@ -201,13 +201,15 @@ async def sync_subtask():
async def main(args):
if args.model_name == "":
args.model_name = args.model_cls
worker_keys = [args.task, args.model_name, args.stage, args.worker]
if args.task_name == "":
args.task_name = args.task
worker_keys = [args.task_name, args.model_name, args.stage, args.worker]
data_manager = None
if args.data_url.startswith("/"):
data_manager = LocalDataManager(args.data_url)
data_manager = LocalDataManager(args.data_url, None)
elif args.data_url.startswith("{"):
data_manager = S3DataManager(args.data_url)
data_manager = S3DataManager(args.data_url, None)
else:
raise NotImplementedError
await data_manager.init()
......@@ -300,6 +302,7 @@ if __name__ == "__main__":
dft_data_url = os.path.join(base_dir, "local_data")
parser.add_argument("--task", type=str, required=True)
parser.add_argument("--task_name", type=str, default="")
parser.add_argument("--model_cls", type=str, required=True)
parser.add_argument("--model_name", type=str, default="")
parser.add_argument("--stage", type=str, required=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment