Commit ab1b2790 authored by LiangLiu's avatar LiangLiu Committed by GitHub
Browse files

Deploy server and worker (#284)



* Init deploy: not ok

* Test data_manager & task_manager

* pipeline is no need for worker

* Update worker text_encoder

* deploy: submit task

* add apis

* Test pipelineRunner

* Fix pipeline

* Tidy worker & test PipelineWorker ok

* Tidy code

* Fix multi_stage for wan2.1 t2v & i2v

* api query task, get result & report subtasks failed when workers stop

* Add model list functionality to Pipeline and API

* Add task cancel and task resume  to API

* Add RabbitMQ queue manager

* update local task manager atomic

* support postgreSQL task manager, add lifespan async init

* worker print -> logger

* Add S3 data manager, delete temp objects after finished.

* fix worker

* release fetch queue msg when closed, run stuck worker in another thread, stop worker when process down.

* DiTWorker run with thread & tidy logger print

* Init monitor without test

* fix monitor

* github OAuth and jwt token access & static demo html page

* Add user to task, ok for local task manager & update demo ui

* sql task manager support users

* task list with pages

* merge main fix

* Add proxy for auth request

* support wan audio

* worker ping subtask and ping life, fix rabbitmq async get,

* s3 data manager with async api & tidy monitor config

* fix merge main & update req.txt & fix html view video error

* Fix distributed worker

* LImit user visit freq

* Tidy

* Fix only rank save

* Fix audio input

* Fix worker fetch None

* index.html abs path to rel path

* Fix dist worker stuck

* support publish output video to rtmp & graceful stop running dit step or segment step

* Add VAReader

* Enhance VAReader with torch dist

* Fix audio stream input

* fix merge refractor main, support stream input_audio and output_video

* fix audio read with prev frames & fix end take frames & tidy worker end

* split audio model to 4 workers & fix audio end frame

* fix ping subtask with queue

* Fix audio worker put block & add whep, whip without test ok

* Tidy va recorder & va reader log, thread canel within 30s

* Fix dist worker stuck: broadcast stop signal

* Tidy

* record task active_elapse & subtask status_elapse

* Design prometheus metrics

* Tidy prometheus metrics

* Fix merge main

* send sigint to ffmpeg process

* Fix gstreamer pull audio by whep & Dockerfile for gstreamer & check params when submitting

* Fix merge main

* Query task with more info & va_reader buffer size = 1

* Fix va_recorder

* Add config for prev_frames

* update frontend

* update frontend

* update frontend

* update frontend
merge

* update frontend & partial backend

* Different rank for va_recorder and va_reader

* Fix mem leak: only one rank publish video, other rank should pop gen vids

* fix task category

* va_reader pre-alloc tensor & va_recorder send frames all & fix dist cancel infer

* Fix prev_frame_length

* Tidy

* Tidy

* update frontend & backend

* Fix lint error

* recover some files

* Tidy

* lint code

---------
Co-authored-by: default avatarliuliang1 <liuliang1@sensetime.com>
Co-authored-by: default avatarunknown <qinxinyi@sensetime.com>
parent acacd26f
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>LightX2V 文生视频服务</title>
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" rel="stylesheet">
<link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css" rel="stylesheet">
<style>
.login-container {
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
}
.main-container {
min-height: 100vh;
background-color: #f8f9fa;
}
.navbar-brand {
font-weight: bold;
color: #667eea !important;
}
.card {
border: none;
border-radius: 15px;
box-shadow: 0 0 20px rgba(0,0,0,0.1);
}
.btn-github {
background-color: #24292e;
border-color: #24292e;
color: white;
}
.btn-github:hover {
background-color: #1a1e22;
border-color: #1a1e22;
color: white;
}
.task-card {
transition: transform 0.2s;
}
.task-card:hover {
transform: translateY(-2px);
}
.status-badge {
font-size: 0.8em;
}
.user-avatar {
width: 32px;
height: 32px;
border-radius: 50%;
}
.loading {
display: none;
}
.loading.show {
display: block;
}
.fade-enter-active, .fade-leave-active {
transition: opacity 0.3s;
}
.fade-enter-from, .fade-leave-to {
opacity: 0;
}
.task-card {
border-left: 4px solid #667eea;
transition: all 0.3s ease;
}
.task-card:hover {
border-left-color: #764ba2;
box-shadow: 0 4px 12px rgba(0,0,0,0.15);
}
.task-detail-section {
background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%);
border-radius: 8px;
margin-top: 15px;
}
.prompt-text {
word-break: break-word;
line-height: 1.6;
}
.output-badge {
font-size: 0.75em;
margin-right: 5px;
}
.btn-group .btn {
margin-left: 5px;
}
.btn-group .btn:first-child {
margin-left: 0;
}
.status-badge i {
font-size: 0.6em;
}
.card-title {
font-weight: 600;
}
.text-truncate {
max-width: 200px;
display: inline-block;
vertical-align: middle;
}
.task-card .card-body {
padding: 1.25rem;
}
.badge {
font-weight: 500;
}
.badge i {
font-size: 0.8em;
}
.task-tags .badge {
font-size: 0.75em;
padding: 0.35em 0.65em;
}
.status-badge {
font-size: 0.8em;
padding: 0.4em 0.8em;
}
.flex-column .btn-group {
width: 100%;
}
.flex-column .btn-group .btn {
flex: 1;
}
/* 分页组件样式 */
.pagination .page-link {
color: #667eea;
border-color: #dee2e6;
}
.pagination .page-item.active .page-link {
background-color: #667eea;
border-color: #667eea;
color: white;
}
.pagination .page-item.disabled .page-link {
color: #6c757d;
background-color: #fff;
border-color: #dee2e6;
}
.pagination .page-link:hover {
color: #764ba2;
background-color: #e9ecef;
border-color: #dee2e6;
}
.pagination .page-item.active .page-link:hover {
background-color: #764ba2;
border-color: #764ba2;
color: white;
}
</style>
</head>
<body>
<div id="app">
<!-- 登录页面 -->
<div v-if="!isLoggedIn" class="login-container">
<div class="container">
<div class="row justify-content-center">
<div class="col-md-6 col-lg-4">
<div class="card">
<div class="card-body text-center p-5">
<h2 class="mb-4">
<i class="fas fa-brain text-primary"></i>
LightX2V
</h2>
<p class="text-muted mb-4">文生视频服务</p>
<button @click="loginWithGitHub" class="btn btn-github btn-lg w-100" :disabled="loading">
<i class="fab fa-github me-2"></i>
{{ loading ? '登录中...' : '使用GitHub登录' }}
</button>
</div>
</div>
</div>
</div>
</div>
</div>
<!-- 主应用页面 -->
<div v-else class="main-container">
<!-- 导航栏 -->
<nav class="navbar navbar-expand-lg navbar-light bg-white shadow-sm">
<div class="container">
<a class="navbar-brand" href="/">
<i class="fas fa-brain me-2"></i>
LightX2V
</a>
<div class="navbar-nav ms-auto">
<div class="nav-item dropdown">
<a class="nav-link dropdown-toggle d-flex align-items-center" href="#" id="userDropdown" role="button" data-bs-toggle="dropdown">
<img :src="currentUser.avatar_url" alt="Avatar" class="user-avatar me-2">
<span>{{ currentUser.username }}</span>
</a>
<ul class="dropdown-menu">
<li><a class="dropdown-item" href="#" @click="logout">退出登录</a></li>
</ul>
</div>
</div>
</div>
</nav>
<!-- 主要内容 -->
<div class="container mt-4">
<!-- 模型列表 -->
<div class="row mb-4">
<div class="col-12">
<div class="card">
<div class="card-header">
<h5 class="mb-0">
<i class="fas fa-list me-2"></i>
可用模型
</h5>
</div>
<div class="card-body">
<div class="row">
<div v-for="model in models" :key="`${model.task}-${model.model_cls}-${model.stage}`" class="col-md-4 mb-3">
<div class="card task-card h-100">
<div class="card-body">
<h6 class="card-title">任务类型:{{ model.task }}</h6>
<p class="card-text text-muted">模型名称:{{ model.model_cls }}</p>
<small class="text-muted">推理模式: {{ model.stage }}</small>
</div>
</div>
</div>
<div v-if="models.length === 0" class="col-12 text-center text-muted">
暂无可用模型
</div>
</div>
</div>
</div>
</div>
</div>
<!-- 任务提交 -->
<div class="row mb-4">
<div class="col-12">
<div class="card">
<div class="card-header">
<h5 class="mb-0">
<i class="fas fa-plus me-2"></i>
提交新任务
</h5>
</div>
<div class="card-body">
<form @submit.prevent="submitTask">
<div class="row">
<div class="col-md-3 mb-3">
<label for="taskType" class="form-label">任务类型</label>
<select class="form-select" v-model="taskForm.task" required>
<option value="">选择任务类型</option>
<option v-for="taskType in availableTaskTypes" :key="taskType" :value="taskType">
{{ taskType }}
</option>
</select>
</div>
<div class="col-md-3 mb-3">
<label for="modelClass" class="form-label">模型名称</label>
<select class="form-select" v-model="taskForm.model_cls" required>
<option value="">选择模型名称</option>
<option v-for="modelClass in availableModelClasses" :key="modelClass" :value="modelClass">
{{ modelClass }}
</option>
</select>
</div>
<div class="col-md-3 mb-3">
<label for="stage" class="form-label">推理模式</label>
<select class="form-select" v-model="taskForm.stage" required>
<option value="">选择推理模式</option>
<option v-for="stage in availableStages" :key="stage" :value="stage">
{{ stage }}
</option>
</select>
</div>
<div class="col-md-3 mb-3">
<label for="prompt" class="form-label">提示词</label>
<input type="text" class="form-control" v-model="taskForm.prompt" placeholder="请输入提示词" required>
</div>
<div class="col-md-3 mb-3">
<label for="seed" class="form-label">随机种子</label>
<input type="number" class="form-control" v-model="taskForm.seed" placeholder="42" required>
</div>
<div class="col-md-3 mb-3" v-if="taskForm.task === 'i2v'">
<label for="inputFile" class="form-label">输入图片</label>
<input type="file" class="form-control" @change="handleImageUpload" accept="image/*" required>
<div v-if="imagePreview" class="mt-2">
<img :src="imagePreview" alt="预览图片" class="img-thumbnail" style="max-width: 100px; max-height: 100px;">
</div>
</div>
<div class="col-md-3 mb-3" v-if="isAudioModel">
<label for="audioFile" class="form-label">输入音频</label>
<input type="file" class="form-control" @change="handleAudioUpload" accept="audio/*" required>
<div v-if="audioPreview" class="mt-2">
<audio controls style="max-width: 100%; max-height: 60px;">
<source :src="audioPreview" type="audio/mpeg">
您的浏览器不支持音频播放
</audio>
</div>
</div>
</div>
<button type="submit" class="btn btn-primary" :disabled="submitting">
<i class="fas fa-paper-plane me-2"></i>
{{ submitting ? '提交中...' : '提交任务' }}
</button>
</form>
</div>
</div>
</div>
</div>
<!-- 任务列表 -->
<div class="row">
<div class="col-12">
<div class="card">
<div class="card-header d-flex justify-content-between align-items-center">
<h5 class="mb-0">
<i class="fas fa-tasks me-2"></i>
任务列表
</h5>
<div class="d-flex align-items-center">
<!-- 状态过滤器 -->
<select v-model="statusFilter" @change="refreshTasks" class="form-select form-select-sm me-2" style="width: auto;">
<option value="ALL" selected>ALL</option>
<option value="CREATED">CREATED</option>
<option value="PENDING">PENDING</option>
<option value="RUNNING">RUNNING</option>
<option value="SUCCEED">SUCCEED</option>
<option value="FAILED">FAILED</option>
<option value="CANCEL">CANCEL</option>
</select>
<button class="btn btn-outline-primary btn-sm" @click="refreshTasks" :disabled="loading">
<i class="fas fa-sync-alt me-1"></i>
刷新
</button>
</div>
</div>
<div class="card-body">
<div v-if="tasks.length === 0" class="text-center text-muted">
暂无任务记录
</div>
<div v-else>
<div v-for="task in tasks" :key="task.task_id" class="card task-card mb-3">
<div class="card-body">
<!-- 任务摘要信息 -->
<div class="row align-items-center">
<div class="col-md-8">
<div class="d-flex justify-content-between align-items-center mb-3">
<div class="d-flex align-items-center">
<h6 class="card-title mb-0 text-primary me-3">
<i class="fas fa-hashtag me-1"></i>
{{ task.task_id }}
</h6>
<div class="task-tags">
<span class="badge bg-primary me-1">
<i class="fas fa-tasks me-1"></i>
{{ task.task_type }}
</span>
<span class="badge bg-info me-1">
<i class="fas fa-brain me-1"></i>
{{ task.model_cls }}
</span>
<span class="badge bg-warning">
<i class="fas fa-cog me-1"></i>
{{ task.stage }}
</span>
</div>
</div>
</div>
<div class="row">
<div class="col-md-6">
<p class="card-text mb-2" v-if="task.params.prompt">
<strong>
<i class="fas fa-comment me-1"></i>
提示词:
</strong>
<span class="text-truncate d-inline-block" style="max-width: 200px;" :title="task.params.prompt">
{{ task.params.prompt.length > 50 ? task.params.prompt.substring(0, 50) + '...' : task.params.prompt }}
</span>
</p>
<p class="card-text mb-1" v-if="task.params.seed">
<strong>
<i class="fas fa-seedling me-1"></i>
种子值:
</strong>
<span class="badge bg-secondary">{{ task.params.seed }}</span>
</p>
</div>
<div class="col-md-6">
<p class="card-text mb-2" v-if="task.create_t">
<strong>
<i class="fas fa-calendar me-1"></i>
创建时间:
</strong>
<small>{{ formatTime(task.create_t) }}</small>
</p>
<p v-if="task.status === 'SUCCEED' && task.outputs && Object.keys(task.outputs).length > 0" class="mb-2">
<strong>
<i class="fas fa-file-video me-1"></i>
输出:
</strong>
<span class="mt-1">
<span v-for="(output, key) in task.outputs" :key="key" class="me-2 mb-1 d-inline-flex align-items-center">
{{ key }}
<i class="fas fa-eye ms-1" style="cursor: pointer; color: #28a745;" @click="viewSingleResult(task.task_id, key, output)" title="查看此文件"></i>
<i class="fas fa-download ms-1" style="cursor: pointer; color: #007bff;" @click="downloadSingleResult(task.task_id, key, output)" title="下载此文件"></i>
</span>
</span>
</p>
</div>
</div>
</div>
<div class="col-md-4 text-end">
<div class="d-flex align-items-center justify-content-end">
<span :class="getStatusBadgeClass(task.status)" class="badge status-badge me-2">
<i class="fas fa-circle me-1"></i>
{{ task.status }}
</span>
<div class="btn-group" role="group">
<button class="btn btn-outline-info btn-sm" @click="toggleTaskDetail(task.task_id)">
<i class="fas fa-info-circle me-1"></i>
{{ expandedTasks.includes(task.task_id) ? '收起' : '详情' }}
</button>
<button v-if="['CREATED', 'PENDING', 'RUNNING'].includes(task.status)"
class="btn btn-outline-warning btn-sm"
@click="cancelTask(task.task_id)">
<i class="fas fa-stop me-1"></i>
取消
</button>
<button v-if="['SUCCEED', 'FAILED', 'CANCEL'].includes(task.status)"
class="btn btn-outline-success btn-sm"
@click="resumeTask(task.task_id)">
<i class="fas fa-redo me-1"></i>
重试
</button>
</div>
</div>
</div>
</div>
<!-- 详细展开信息 -->
<div v-if="expandedTasks.includes(task.task_id)" class="task-detail-section p-3">
<div class="row">
<div class="col-md-6">
<h6 class="text-muted mb-2">
<i class="fas fa-comment me-1"></i>
完整提示词
</h6>
<div class="bg-white p-3 rounded border">
<p class="mb-0 prompt-text">{{ task.params.prompt || '无提示词' }}</p>
</div>
</div>
<div class="col-md-6">
<h6 class="text-muted mb-2">
<i class="fas fa-cog me-1"></i>
任务参数
</h6>
<div class="bg-white p-3 rounded border">
<div class="mb-3 d-flex align-items-center">
<strong class="me-2">任务类型:</strong>
<span class="badge bg-primary">
<i class="fas fa-tasks me-1"></i>
{{ task.task_type }}
</span>
</div>
<div class="mb-3 d-flex align-items-center">
<strong class="me-2">模型名称:</strong>
<span class="badge bg-info">
<i class="fas fa-brain me-1"></i>
{{ task.model_cls }}
</span>
</div>
<div class="mb-3 d-flex align-items-center">
<strong>推理模式:</strong>
<span class="badge bg-warning ms-2">
<i class="fas fa-cog me-1"></i>
{{ task.stage }}
</span>
</div>
<div class="mb-3 d-flex align-items-center" v-if="task.params.seed">
<strong>种子值:</strong>
<span class="badge bg-secondary ms-2">{{ task.params.seed }}</span>
</div>
</div>
</div>
</div>
<div class="row mt-3">
<div class="col-md-6">
<h6 class="text-muted mb-2">
<i class="fas fa-clock me-1"></i>
时间信息
</h6>
<div class="bg-white p-3 rounded border">
<ul class="list-unstyled mb-0">
<li><strong>创建时间:</strong> {{ formatTime(task.create_t) }}</li>
<li><strong>更新时间:</strong> {{ formatTime(task.update_t) }}</li>
</ul>
</div>
</div>
<div class="col-md-6" v-if="task.outputs && task.inputs">
<h6 class="text-muted mb-2">
<i class="fas fa-file-video me-1"></i>
输入文件/输出结果
</h6>
<div class="bg-white p-3 rounded border">
<div v-for="(input, key) in task.inputs" :key="key" class="mb-2">
<strong>输入文件:</strong>
<span class="badge bg-secondary output-badge">{{ key }}</span>
</div>
<div v-for="(output, key) in task.outputs" :key="key" class="mb-2">
<strong>输出结果:</strong>
<span class="badge bg-success output-badge">{{ key }}</span>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
<!-- 分页组件 -->
<div v-if="pagination && pagination.total_pages > 0" class="d-flex justify-content-between align-items-center mt-4">
<div class="text-muted">
显示第 {{ (pagination.page - 1) * pagination.page_size + 1 }} -
{{ Math.min(pagination.page * pagination.page_size, pagination.total) }} 条,
共 {{ pagination.total }} 条记录
</div>
<nav aria-label="任务列表分页">
<ul class="pagination pagination-sm mb-0">
<!-- 上一页 -->
<li class="page-item" :class="{ disabled: pagination.page <= 1 }">
<a class="page-link" href="#" @click.prevent="changePage(pagination.page - 1)">
<i class="fas fa-chevron-left"></i>
</a>
</li>
<!-- 页码 -->
<li v-for="pageNum in visiblePages" :key="pageNum"
class="page-item" :class="{ active: pageNum === pagination.page }">
<a class="page-link" href="#" @click.prevent="changePage(pageNum)">
{{ pageNum }}
</a>
</li>
<!-- 下一页 -->
<li class="page-item" :class="{ disabled: pagination.page >= pagination.total_pages }">
<a class="page-link" href="#" @click.prevent="changePage(pagination.page + 1)">
<i class="fas fa-chevron-right"></i>
</a>
</li>
</ul>
</nav>
<!-- 每页显示数量选择器 -->
<div class="d-flex align-items-center">
<span class="text-muted me-2">每页显示:</span>
<select v-model="pageSize" @change="changePageSize" class="form-select form-select-sm" style="width: auto;">
<option value="5">5</option>
<option value="10">10</option>
<option value="20">20</option>
<option value="50">50</option>
</select>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
<!-- 加载指示器 -->
<div v-if="loading" class="loading position-fixed top-50 start-50 translate-middle show">
<div class="spinner-border text-primary" role="status">
<span class="visually-hidden">加载中...</span>
</div>
</div>
<!-- 提示消息 -->
<div v-if="alert.show" class="position-fixed top-0 start-50 translate-middle-x mt-3" style="z-index: 9999;">
<div :class="`alert alert-${alert.type} alert-dismissible fade show`" role="alert">
{{ alert.message }}
<button type="button" class="btn-close" @click="alert.show = false"></button>
</div>
</div>
</div>
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/js/bootstrap.bundle.min.js"></script>
<script src="https://unpkg.com/vue@3/dist/vue.global.js"></script>
<script>
const { createApp, ref, computed, onMounted, watch } = Vue;
createApp({
setup() {
// 响应式数据
const isLoggedIn = ref(false);
const loading = ref(false);
const submitting = ref(false);
const currentUser = ref({});
const models = ref([]);
const tasks = ref([]);
const alert = ref({ show: false, message: '', type: 'info' });
const expandedTasks = ref([]); // 新增:用于控制任务详情展开
// 分页相关数据
const pagination = ref(null);
const currentPage = ref(1);
const pageSize = ref(10);
const statusFilter = ref('ALL');
const taskForm = ref({
task: '',
model_cls: '',
stage: '',
imageFile: null,
audioFile: null,
prompt: '',
seed: 42
});
const imagePreview = ref(null);
const audioPreview = ref(null);
// 计算属性
const availableTaskTypes = computed(() => {
return [...new Set(models.value.map(m => m.task))];
});
const availableModelClasses = computed(() => {
if (!taskForm.value.task) return [];
return [...new Set(models.value
.filter(m => m.task === taskForm.value.task)
.map(m => m.model_cls))];
});
const availableStages = computed(() => {
if (!taskForm.value.task || !taskForm.value.model_cls) return [];
return [...new Set(models.value
.filter(m => m.task === taskForm.value.task && m.model_cls === taskForm.value.model_cls)
.map(m => m.stage))];
});
const isAudioModel = computed(() => {
return taskForm.value.model_cls && (taskForm.value.model_cls.includes('audio') || taskForm.value.model_cls.includes('Audio'));
});
// 分页相关计算属性
const visiblePages = computed(() => {
if (!pagination.value) return [];
const totalPages = pagination.value.total_pages;
const current = pagination.value.page;
const pages = [];
// 显示最多5个页码
let start = Math.max(1, current - 2);
let end = Math.min(totalPages, current + 2);
// 调整起始位置,确保显示5个页码(如果可能)
if (end - start < 4 && totalPages > 4) {
if (start === 1) {
end = Math.min(totalPages, start + 4);
} else {
start = Math.max(1, end - 4);
}
}
for (let i = start; i <= end; i++) {
pages.push(i);
}
return pages;
});
// 方法
const showAlert = (message, type = 'info') => {
alert.value = { show: true, message, type };
setTimeout(() => {
alert.value.show = false;
}, 5000);
};
const setLoading = (value) => {
loading.value = value;
};
const apiCall = async (endpoint, options = {}) => {
const url = `${endpoint}`;
const headers = {
'Content-Type': 'application/json',
...options.headers
};
if (localStorage.getItem('accessToken')) {
headers['Authorization'] = `Bearer ${localStorage.getItem('accessToken')}`;
}
const response = await fetch(url, {
...options,
headers
});
if (response.status === 401) {
logout();
throw new Error('认证失败,请重新登录'); }
if (response.status === 400) {
const error = await response.json();
showAlert(error.message, 'danger');
throw new Error(error.message);
}
// 添加50ms延迟,防止触发服务端频率限制
await new Promise(resolve => setTimeout(resolve, 50));
return response;
};
const loginWithGitHub = async () => {
try {
setLoading(true);
const response = await fetch('./auth/login/github');
const data = await response.json();
window.location.href = data.auth_url;
} catch (error) {
showAlert('获取GitHub认证URL失败', 'danger');
} finally {
setLoading(false);
}
};
const handleGitHubCallback = async (code) => {
try {
setLoading(true);
const response = await fetch(`./auth/callback/github?code=${code}`);
if (response.ok) {
const data = await response.json();
console.log(data);
localStorage.setItem('accessToken', data.access_token);
localStorage.setItem('currentUser', JSON.stringify(data.user_info));
currentUser.value = data.user_info;
isLoggedIn.value = true;
} else {
const error = await response.json();
showAlert(`登录失败: ${error.detail}`, 'danger');
}
window.location.href = '/';
} catch (error) {
showAlert('登录过程中发生错误', 'danger');
console.error(error);
} finally {
setLoading(false);
}
};
const logout = () => {
localStorage.removeItem('accessToken');
localStorage.removeItem('currentUser');
currentUser.value = {};
isLoggedIn.value = false;
models.value = [];
tasks.value = [];
};
const loadModels = async () => {
try {
const response = await apiCall('./api/v1/model/list');
if (response.ok) {
const data = await response.json();
console.log(data);
models.value = data.models || [];
// 如果有模型,自动选择第一个模型的信息
if (models.value.length > 0) {
const firstModel = models.value[0];
taskForm.value.task = firstModel.task;
taskForm.value.model_cls = firstModel.model_cls;
taskForm.value.stage = firstModel.stage;
}
} else {
showAlert('加载模型列表失败', 'danger');
}
} catch (error) {
showAlert(`加载模型失败: ${error.message}`, 'danger');
}
};
const handleImageUpload = (event) => {
const file = event.target.files[0];
taskForm.value.imageFile = file;
// 创建图片预览
if (file) {
const reader = new FileReader();
reader.onload = (e) => {
imagePreview.value = e.target.result;
};
reader.readAsDataURL(file);
} else {
imagePreview.value = null;
}
};
const handleAudioUpload = (event) => {
const file = event.target.files[0];
taskForm.value.audioFile = file;
// 创建音频预览
if (file) {
const reader = new FileReader();
reader.onload = (e) => {
audioPreview.value = e.target.result;
};
reader.readAsDataURL(file);
} else {
audioPreview.value = null;
}
};
const submitTask = async () => {
try {
setLoading(true);
submitting.value = true;
// 准备提交数据
var formData = {
task: taskForm.value.task,
model_cls: taskForm.value.model_cls,
stage: taskForm.value.stage,
prompt: taskForm.value.prompt,
seed: taskForm.value.seed
};
if (taskForm.value.model_cls.startsWith('wan2.1')) {
formData.negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
}
// 如果是i2v任务且有图片文件,将图片转换为base64
if (taskForm.value.task === 'i2v' && taskForm.value.imageFile) {
const base64 = await fileToBase64(taskForm.value.imageFile);
formData.input_image = {
type: 'base64',
data: base64
};
}
// 如果是音频模型且有音频文件,将音频转换为base64
if (isAudioModel.value && taskForm.value.audioFile) {
const base64 = await fileToBase64(taskForm.value.audioFile);
formData.input_audio = {
type: 'base64',
data: base64
};
formData.negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
}
const response = await apiCall('./api/v1/task/submit', {
method: 'POST',
body: JSON.stringify(formData)
});
if (response.ok) {
const result = await response.json();
showAlert(`任务提交成功!任务ID: ${result.task_id}`, 'success');
await refreshTasks();
} else {
const error = await response.json();
showAlert(`任务提交失败: ${error.message}`, 'danger');
}
} catch (error) {
showAlert(`提交任务失败: ${error.message}`, 'danger');
} finally {
submitting.value = false;
setLoading(false);
}
};
const fileToBase64 = (file) => {
return new Promise((resolve, reject) => {
const reader = new FileReader();
reader.readAsDataURL(file);
reader.onload = () => {
// 移除data:image/xxx;base64,前缀,只保留base64部分
const base64 = reader.result.split(',')[1];
resolve(base64);
};
reader.onerror = error => reject(error);
});
};
const formatTime = (timestamp) => {
if (!timestamp) return '';
// 将浮点时间戳转换为毫秒(JavaScript Date 构造函数需要毫秒)
const date = new Date(timestamp * 1000);
return date.toLocaleString('zh-CN');
};
const refreshTasks = async () => {
try {
const params = new URLSearchParams({
page: currentPage.value.toString(),
page_size: pageSize.value.toString()
});
if (statusFilter.value !== 'ALL') {
params.append('status', statusFilter.value);
}
const response = await apiCall(`./api/v1/task/list?${params.toString()}`);
if (response.ok) {
const data = await response.json();
console.log(data);
tasks.value = data.tasks || [];
pagination.value = data.pagination || null;
} else {
showAlert('刷新任务列表失败', 'danger');
}
} catch (error) {
showAlert(`刷新任务列表失败: ${error.message}`, 'danger');
}
};
const changePage = (page) => {
if (page < 1 || (pagination.value && page > pagination.value.total_pages)) {
return;
}
currentPage.value = page;
refreshTasks();
};
const changePageSize = () => {
currentPage.value = 1; // 重置到第一页
refreshTasks();
};
const getStatusBadgeClass = (status) => {
const statusMap = {
'SUCCEED': 'bg-success',
'FAILED': 'bg-danger',
'RUNNING': 'bg-warning',
'PENDING': 'bg-secondary',
'CREATED': 'bg-secondary'
};
return statusMap[status] || 'bg-secondary';
};
const toggleTaskDetail = (taskId) => {
if (expandedTasks.value.includes(taskId)) {
expandedTasks.value = expandedTasks.value.filter(id => id !== taskId);
} else {
expandedTasks.value.push(taskId);
}
};
const downloadSingleResult = async (taskId, key, outputPath) => {
try {
setLoading(true);
const response = await apiCall(`./api/v1/task/result?task_id=${taskId}&name=${key}`);
if (response.ok) {
const blob = await response.blob();
const url = window.URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = `${outputPath}`;
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
window.URL.revokeObjectURL(url);
} else {
showAlert('获取结果失败', 'danger');
}
} catch (error) {
showAlert(`下载结果失败: ${error.message}`, 'danger');
} finally {
setLoading(false);
}
};
const viewSingleResult = async (taskId, key, outputPath) => {
try {
setLoading(true);
const response = await apiCall(`./api/v1/task/result?task_id=${taskId}&name=${key}`);
if (response.ok) {
const blob = await response.blob();
const videoBlob = new Blob([blob], { type: 'video/mp4' });
const url = window.URL.createObjectURL(videoBlob);
window.open(url, '_blank');
} else {
showAlert('获取结果失败', 'danger');
}
} catch (error) {
showAlert(`查看结果失败: ${error.message}`, 'danger');
} finally {
setLoading(false);
}
};
const cancelTask = async (taskId) => {
try {
const response = await apiCall(`./api/v1/task/cancel?task_id=${taskId}`);
if (response.ok) {
showAlert('任务取消成功', 'success');
await refreshTasks();
} else {
const error = await response.json();
showAlert(`取消任务失败: ${error.message}`, 'danger');
}
} catch (error) {
showAlert(`取消任务失败: ${error.message}`, 'danger');
}
};
const resumeTask = async (taskId) => {
try {
const response = await apiCall(`./api/v1/task/resume?task_id=${taskId}`);
if (response.ok) {
showAlert('任务重试成功', 'success');
await refreshTasks();
} else {
const error = await response.json();
showAlert(`重试任务失败: ${error.message}`, 'danger');
}
} catch (error) {
showAlert(`重试任务失败: ${error.message}`, 'danger');
}
};
const initModelAndTasks = async () => {
await loadModels();
await refreshTasks();
};
// 监听器
watch(() => taskForm.value.task, () => {
// 当任务类型改变时,清空模型类和推理模式
taskForm.value.model_cls = '';
taskForm.value.stage = '';
// 如果当前任务类型有对应的模型,自动选择第一个
const availableModels = models.value.filter(m => m.task === taskForm.value.task);
if (availableModels.length > 0) {
const firstModel = availableModels[0];
taskForm.value.model_cls = firstModel.model_cls;
taskForm.value.stage = firstModel.stage;
}
// 清空图片预览
imagePreview.value = null;
taskForm.value.file = null;
});
watch(() => taskForm.value.model_cls, () => {
// 当模型类改变时,清空推理模式
taskForm.value.stage = '';
// 如果当前模型类有对应的推理模式,自动选择第一个
const availableStages = models.value
.filter(m => m.task === taskForm.value.task && m.model_cls === taskForm.value.model_cls)
.map(m => m.stage);
if (availableStages.length > 0) {
taskForm.value.stage = availableStages[0];
}
});
// 生命周期
onMounted(() => {
// 检查是否已登录
const savedToken = localStorage.getItem('accessToken');
const savedUser = localStorage.getItem('currentUser');
if (savedToken && savedUser) {
currentUser.value = JSON.parse(savedUser);
isLoggedIn.value = true;
initModelAndTasks();
} else {
// 检查是否是GitHub回调
const urlParams = new URLSearchParams(window.location.search);
const code = urlParams.get('code');
if (code) {
handleGitHubCallback(code);
}
}
});
return {
isLoggedIn,
loading,
submitting,
currentUser,
models,
tasks,
alert,
taskForm,
imagePreview,
availableTaskTypes,
availableModelClasses,
availableStages,
expandedTasks,
pagination,
currentPage,
pageSize,
statusFilter,
visiblePages,
isAudioModel,
showAlert,
loginWithGitHub,
logout,
loadModels,
handleImageUpload,
handleAudioUpload,
submitTask,
refreshTasks,
changePage,
changePageSize,
getStatusBadgeClass,
formatTime,
toggleTaskDetail,
downloadSingleResult,
viewSingleResult,
cancelTask,
resumeTask
};
}
}).mount('#app');
</script>
</body>
</html>
import uuid
from enum import Enum
from re import T
from loguru import logger
from lightx2v.deploy.common.utils import current_time, data_name
class TaskStatus(Enum):
CREATED = 1
PENDING = 2
RUNNING = 3
SUCCEED = 4
FAILED = 5
CANCEL = 6
ActiveStatus = [TaskStatus.CREATED, TaskStatus.PENDING, TaskStatus.RUNNING]
FinishedStatus = [TaskStatus.SUCCEED, TaskStatus.FAILED, TaskStatus.CANCEL]
class BaseTaskManager:
def __init__(self):
pass
async def init(self):
pass
async def close(self):
pass
async def insert_user_if_not_exists(self, user_info):
raise NotImplementedError
async def query_user(self, user_id):
raise NotImplementedError
async def insert_task(self, task, subtasks):
raise NotImplementedError
async def list_tasks(self, **kwargs):
raise NotImplementedError
async def query_task(self, task_id, user_id=None, only_task=True):
raise NotImplementedError
async def next_subtasks(self, task_id):
raise NotImplementedError
async def run_subtasks(self, subtasks, worker_identity):
raise NotImplementedError
async def ping_subtask(self, task_id, worker_name, worker_identity):
raise NotImplementedError
async def finish_subtasks(self, task_id, status, worker_identity=None, worker_name=None, fail_msg=None, should_running=False):
raise NotImplementedError
async def cancel_task(self, task_id, user_id=None):
raise NotImplementedError
async def resume_task(self, task_id, all_subtask=False, user_id=None):
raise NotImplementedError
def fmt_dict(self, data):
for k in ["status"]:
if k in data:
data[k] = data[k].name
def parse_dict(self, data):
for k in ["status"]:
if k in data:
data[k] = TaskStatus[data[k]]
async def create_user(self, user_info):
assert user_info["source"] == "github", f"do not support {user_info['source']} user!"
cur_t = current_time()
user_id = f"{user_info['source']}_{user_info['id']}"
data = {
"user_id": user_id,
"source": user_info["source"],
"id": user_info["id"],
"username": user_info["username"],
"email": user_info["email"],
"homepage": user_info["homepage"],
"avatar_url": user_info["avatar_url"],
"create_t": cur_t,
"update_t": cur_t,
"extra_info": "",
"tag": "",
}
assert await self.insert_user_if_not_exists(data), f"create user {data} failed"
return user_id
async def create_task(self, worker_keys, workers, params, inputs, outputs, user_id):
task_type, model_cls, stage = worker_keys
cur_t = current_time()
task_id = str(uuid.uuid4())
task = {
"task_id": task_id,
"task_type": task_type,
"model_cls": model_cls,
"stage": stage,
"params": params,
"create_t": cur_t,
"update_t": cur_t,
"status": TaskStatus.CREATED,
"extra_info": "",
"tag": "",
"inputs": {x: data_name(x, task_id) for x in inputs},
"outputs": {x: data_name(x, task_id) for x in outputs},
"user_id": user_id,
}
self.mark_task_start(task)
subtasks = []
for worker_name, worker_item in workers.items():
subtasks.append(
{
"task_id": task_id,
"worker_name": worker_name,
"inputs": {x: data_name(x, task_id) for x in worker_item["inputs"]},
"outputs": {x: data_name(x, task_id) for x in worker_item["outputs"]},
"queue": worker_item["queue"],
"previous": worker_item["previous"],
"status": TaskStatus.CREATED,
"worker_identity": "",
"result": "",
"fail_time": 0,
"extra_info": "",
"create_t": cur_t,
"update_t": cur_t,
"ping_t": 0.0,
"infer_cost": -1.0,
}
)
self.mark_subtask_change(subtasks[-1], None, TaskStatus.CREATED)
ret = await self.insert_task(task, subtasks)
# if insert error
if not ret:
self.mark_task_end(task, TaskStatus.FAILED)
for sub in subtasks:
self.mark_subtask_change(sub, sub["status"], TaskStatus.FAILED)
assert ret, f"create task {task_id} failed"
return task_id
async def mark_server_restart(self):
if self.metrics_monitor:
tasks = await self.list_tasks(status=ActiveStatus)
subtasks = await self.list_tasks(status=ActiveStatus, subtasks=True)
logger.warning(f"Mark system restart, {len(tasks)} tasks, {len(subtasks)} subtasks")
self.metrics_monitor.record_task_recover(tasks)
self.metrics_monitor.record_subtask_recover(subtasks)
def mark_task_start(self, task):
t = current_time()
if not isinstance(task["extra_info"], dict):
task["extra_info"] = {}
if "active_elapse" in task["extra_info"]:
del task["extra_info"]["active_elapse"]
task["extra_info"]["start_t"] = t
logger.info(f"Task {task['task_id']} active start")
if self.metrics_monitor:
self.metrics_monitor.record_task_start(task)
def mark_task_end(self, task, end_status):
if "start_t" not in task["extra_info"]:
logger.warning(f"Task {task} has no start time")
else:
elapse = current_time() - task["extra_info"]["start_t"]
task["extra_info"]["active_elapse"] = elapse
del task["extra_info"]["start_t"]
logger.info(f"Task {task['task_id']} active end with [{end_status}], elapse: {elapse}")
if self.metrics_monitor:
self.metrics_monitor.record_task_end(task, end_status, elapse)
def mark_subtask_change(self, subtask, old_status, new_status, fail_msg=None):
t = current_time()
if not isinstance(subtask["extra_info"], dict):
subtask["extra_info"] = {}
if isinstance(fail_msg, str) and len(fail_msg) > 0:
subtask["extra_info"]["fail_msg"] = fail_msg
elif "fail_msg" in subtask["extra_info"]:
del subtask["extra_info"]["fail_msg"]
if old_status == new_status:
logger.warning(f"Subtask {subtask} update same status: {old_status} vs {new_status}")
return
elapse, elapse_key = None, None
if old_status in ActiveStatus:
if "start_t" not in subtask["extra_info"]:
logger.warning(f"Subtask {subtask} has no start time, status: {old_status}")
else:
elapse = t - subtask["extra_info"]["start_t"]
elapse_key = f"{old_status.name}-{new_status.name}"
if "elapses" not in subtask["extra_info"]:
subtask["extra_info"]["elapses"] = {}
subtask["extra_info"]["elapses"][elapse_key] = elapse
del subtask["extra_info"]["start_t"]
if new_status in ActiveStatus:
subtask["extra_info"]["start_t"] = t
if new_status == TaskStatus.CREATED and "elapses" in subtask["extra_info"]:
del subtask["extra_info"]["elapses"]
logger.info(
f"Subtask {subtask['task_id']} {subtask['worker_name']} status changed: \
[{old_status}] -> [{new_status}], {elapse_key}: {elapse}, fail_msg: {fail_msg}"
)
if self.metrics_monitor:
self.metrics_monitor.record_subtask_change(subtask, old_status, new_status, elapse_key, elapse)
# Import task manager implementations
from .local_task_manager import LocalTaskManager # noqa
from .sql_task_manager import PostgresSQLTaskManager # noqa
__all__ = ["BaseTaskManager", "LocalTaskManager", "PostgresSQLTaskManager"]
import asyncio
import json
import os
from lightx2v.deploy.common.utils import class_try_catch_async, current_time, str2time, time2str
from lightx2v.deploy.task_manager import ActiveStatus, BaseTaskManager, FinishedStatus, TaskStatus
class LocalTaskManager(BaseTaskManager):
def __init__(self, local_dir, metrics_monitor=None):
self.local_dir = local_dir
if not os.path.exists(self.local_dir):
os.makedirs(self.local_dir)
self.metrics_monitor = metrics_monitor
def get_task_filename(self, task_id):
return os.path.join(self.local_dir, f"task_{task_id}.json")
def get_user_filename(self, user_id):
return os.path.join(self.local_dir, f"user_{user_id}.json")
def fmt_dict(self, data):
super().fmt_dict(data)
for k in ["create_t", "update_t", "ping_t"]:
if k in data:
data[k] = time2str(data[k])
def parse_dict(self, data):
super().parse_dict(data)
for k in ["create_t", "update_t", "ping_t"]:
if k in data:
data[k] = str2time(data[k])
def save(self, task, subtasks, with_fmt=True):
info = {"task": task, "subtasks": subtasks}
if with_fmt:
self.fmt_dict(info["task"])
[self.fmt_dict(x) for x in info["subtasks"]]
out_name = self.get_task_filename(task["task_id"])
with open(out_name, "w") as fout:
fout.write(json.dumps(info, indent=4, ensure_ascii=False))
def load(self, task_id, user_id=None, only_task=False):
fpath = self.get_task_filename(task_id)
info = json.load(open(fpath))
task, subtasks = info["task"], info["subtasks"]
if user_id is not None and task["user_id"] != user_id:
raise Exception(f"Task {task_id} is not belong to user {user_id}")
self.parse_dict(task)
if only_task:
return task
for sub in subtasks:
self.parse_dict(sub)
return task, subtasks
@class_try_catch_async
async def insert_task(self, task, subtasks):
self.save(task, subtasks)
return True
@class_try_catch_async
async def list_tasks(self, **kwargs):
tasks = []
fs = [os.path.join(self.local_dir, f) for f in os.listdir(self.local_dir)]
for f in os.listdir(self.local_dir):
if not f.startswith("task_"):
continue
fpath = os.path.join(self.local_dir, f)
info = json.load(open(fpath))
if kwargs.get("subtasks", False):
items = info["subtasks"]
assert "user_id" not in kwargs, "user_id is not allowed when subtasks is True"
else:
items = [info["task"]]
for task in items:
self.parse_dict(task)
if "user_id" in kwargs and task["user_id"] != kwargs["user_id"]:
continue
if "status" in kwargs:
if isinstance(kwargs["status"], list) and task["status"] not in kwargs["status"]:
continue
elif kwargs["status"] != task["status"]:
continue
if "start_created_t" in kwargs and kwargs["start_created_t"] > task["create_t"]:
continue
if "end_created_t" in kwargs and kwargs["end_created_t"] < task["create_t"]:
continue
if "start_updated_t" in kwargs and kwargs["start_updated_t"] > task["update_t"]:
continue
if "end_updated_t" in kwargs and kwargs["end_updated_t"] < task["update_t"]:
continue
if "start_ping_t" in kwargs and kwargs["start_ping_t"] > task["ping_t"]:
continue
if "end_ping_t" in kwargs and kwargs["end_ping_t"] < task["ping_t"]:
continue
tasks.append(task)
if "count" in kwargs:
return len(tasks)
tasks = sorted(tasks, key=lambda x: x["create_t"], reverse=True)
if "offset" in kwargs:
tasks = tasks[kwargs["offset"] :]
if "limit" in kwargs:
tasks = tasks[: kwargs["limit"]]
return tasks
@class_try_catch_async
async def query_task(self, task_id, user_id=None, only_task=True):
return self.load(task_id, user_id, only_task)
@class_try_catch_async
async def next_subtasks(self, task_id):
task, subtasks = self.load(task_id)
if task["status"] not in ActiveStatus:
return []
succeeds = set()
for sub in subtasks:
if sub["status"] == TaskStatus.SUCCEED:
succeeds.add(sub["worker_name"])
nexts = []
for sub in subtasks:
if sub["status"] == TaskStatus.CREATED:
dep_ok = True
for prev in sub["previous"]:
if prev not in succeeds:
dep_ok = False
break
if dep_ok:
self.mark_subtask_change(sub, sub["status"], TaskStatus.PENDING)
sub["params"] = task["params"]
sub["status"] = TaskStatus.PENDING
sub["update_t"] = current_time()
nexts.append(sub)
if len(nexts) > 0:
task["status"] = TaskStatus.PENDING
task["update_t"] = current_time()
self.save(task, subtasks)
return nexts
@class_try_catch_async
async def run_subtasks(self, cands, worker_identity):
valids = []
for cand in cands:
task_id = cand["task_id"]
worker_name = cand["worker_name"]
task, subtasks = self.load(task_id)
if task["status"] in [TaskStatus.SUCCEED, TaskStatus.FAILED, TaskStatus.CANCEL]:
continue
for sub in subtasks:
if sub["worker_name"] == worker_name:
self.mark_subtask_change(sub, sub["status"], TaskStatus.RUNNING)
sub["status"] = TaskStatus.RUNNING
sub["worker_identity"] = worker_identity
sub["update_t"] = current_time()
task["status"] = TaskStatus.RUNNING
task["update_t"] = current_time()
task["ping_t"] = current_time()
self.save(task, subtasks)
valids.append(cand)
break
return valids
@class_try_catch_async
async def ping_subtask(self, task_id, worker_name, worker_identity):
task, subtasks = self.load(task_id)
for sub in subtasks:
if sub["worker_name"] == worker_name:
pre = sub["worker_identity"]
assert pre == worker_identity, f"worker identity not matched: {pre} vs {worker_identity}"
sub["ping_t"] = current_time()
self.save(task, subtasks)
return True
return False
@class_try_catch_async
async def finish_subtasks(self, task_id, status, worker_identity=None, worker_name=None, fail_msg=None, should_running=False):
task, subtasks = self.load(task_id)
subs = subtasks
if worker_name:
subs = [sub for sub in subtasks if sub["worker_name"] == worker_name]
assert len(subs) >= 1, f"no worker task_id={task_id}, name={worker_name}"
if worker_identity:
pre = subs[0]["worker_identity"]
assert pre == worker_identity, f"worker identity not matched: {pre} vs {worker_identity}"
assert status in [TaskStatus.SUCCEED, TaskStatus.FAILED], f"invalid finish status: {status}"
for sub in subs:
if sub["status"] not in FinishedStatus:
if should_running and sub["status"] != TaskStatus.RUNNING:
print(f"task {task_id} is not running, skip finish subtask: {sub}")
continue
self.mark_subtask_change(sub, sub["status"], status, fail_msg=fail_msg)
sub["status"] = status
sub["update_t"] = current_time()
if task["status"] == TaskStatus.CANCEL:
self.save(task, subtasks)
return TaskStatus.CANCEL
running_subs = []
failed_sub = False
for sub in subtasks:
if sub["status"] not in FinishedStatus:
running_subs.append(sub)
if sub["status"] == TaskStatus.FAILED:
failed_sub = True
# some subtask failed, we should fail all other subtasks
if failed_sub:
if task["status"] != TaskStatus.FAILED:
self.mark_task_end(task, TaskStatus.FAILED)
task["status"] = TaskStatus.FAILED
task["update_t"] = current_time()
for sub in running_subs:
self.mark_subtask_change(sub, sub["status"], TaskStatus.FAILED, fail_msg="other subtask failed")
sub["status"] = TaskStatus.FAILED
sub["update_t"] = current_time()
self.save(task, subtasks)
return TaskStatus.FAILED
# all subtasks finished and all succeed
elif len(running_subs) == 0:
if task["status"] != TaskStatus.SUCCEED:
self.mark_task_end(task, TaskStatus.SUCCEED)
task["status"] = TaskStatus.SUCCEED
task["update_t"] = current_time()
self.save(task, subtasks)
return TaskStatus.SUCCEED
self.save(task, subtasks)
return None
@class_try_catch_async
async def cancel_task(self, task_id, user_id=None):
task, subtasks = self.load(task_id, user_id)
if task["status"] not in ActiveStatus:
return f"Task {task_id} is not in active status (current status: {task['status']}). Only tasks with status CREATED, PENDING, or RUNNING can be cancelled."
for sub in subtasks:
if sub["status"] not in FinishedStatus:
self.mark_subtask_change(sub, sub["status"], TaskStatus.CANCEL)
sub["status"] = TaskStatus.CANCEL
sub["update_t"] = current_time()
self.mark_task_end(task, TaskStatus.CANCEL)
task["status"] = TaskStatus.CANCEL
task["update_t"] = current_time()
self.save(task, subtasks)
return True
@class_try_catch_async
async def resume_task(self, task_id, all_subtask=False, user_id=None):
task, subtasks = self.load(task_id, user_id)
# the task is not finished
if task["status"] not in FinishedStatus:
return False
# the task is no need to resume
if not all_subtask and task["status"] == TaskStatus.SUCCEED:
return False
for sub in subtasks:
if all_subtask or sub["status"] != TaskStatus.SUCCEED:
self.mark_subtask_change(sub, None, TaskStatus.CREATED)
sub["status"] = TaskStatus.CREATED
sub["update_t"] = current_time()
sub["ping_t"] = 0.0
self.mark_task_start(task)
task["status"] = TaskStatus.CREATED
task["update_t"] = current_time()
self.save(task, subtasks)
return True
@class_try_catch_async
async def insert_user_if_not_exists(self, user_info):
fpath = self.get_user_filename(user_info["user_id"])
if os.path.exists(fpath):
return True
self.fmt_dict(user_info)
with open(fpath, "w") as fout:
fout.write(json.dumps(user_info, indent=4, ensure_ascii=False))
return True
@class_try_catch_async
async def query_user(self, user_id):
fpath = self.get_user_filename(user_id)
if not os.path.exists(fpath):
return None
data = json.load(open(fpath))
self.parse_dict(data)
return data
async def test():
from lightx2v.deploy.common.pipeline import Pipeline
p = Pipeline("/data/nvme1/liuliang1/lightx2v/configs/model_pipeline.json")
m = LocalTaskManager("/data/nvme1/liuliang1/lightx2v/local_task")
await m.init()
keys = ["t2v", "wan2.1", "multi_stage"]
workers = p.get_workers(keys)
inputs = p.get_inputs(keys)
outputs = p.get_outputs(keys)
params = {
"prompt": "fake input prompts",
"resolution": {
"height": 233,
"width": 456,
},
}
user_info = {
"source": "github",
"id": "test-id-233",
"username": "test-username-233",
"email": "test-email-233@test.com",
"homepage": "https://test.com",
"avatar_url": "https://test.com/avatar.png",
}
user_id = await m.create_user(user_info)
print(" - create_user:", user_id)
user = await m.query_user(user_id)
print(" - query_user:", user)
task_id = await m.create_task(keys, workers, params, inputs, outputs, user_id)
print(" - create_task:", task_id)
tasks = await m.list_tasks()
print(" - list_tasks:", tasks)
task = await m.query_task(task_id)
print(" - query_task:", task)
subtasks = await m.next_subtasks(task_id)
print(" - next_subtasks:", subtasks)
await m.run_subtasks(subtasks, "fake-worker")
await m.finish_subtasks(task_id, TaskStatus.FAILED)
await m.cancel_task(task_id)
await m.resume_task(task_id)
for sub in subtasks:
await m.finish_subtasks(sub["task_id"], TaskStatus.SUCCEED, worker_name=sub["worker_name"], worker_identity="fake-worker")
subtasks = await m.next_subtasks(task_id)
print(" - final next_subtasks:", subtasks)
task = await m.query_task(task_id)
print(" - final task:", task)
await m.close()
if __name__ == "__main__":
asyncio.run(test())
import asyncio
import json
import traceback
from datetime import datetime
import asyncpg
from loguru import logger
from lightx2v.deploy.common.utils import class_try_catch_async
from lightx2v.deploy.task_manager import ActiveStatus, BaseTaskManager, FinishedStatus, TaskStatus
ASYNC_LOCK = asyncio.Lock()
class PostgresSQLTaskManager(BaseTaskManager):
def __init__(self, db_url, metrics_monitor=None):
self.db_url = db_url
self.table_tasks = "tasks"
self.table_subtasks = "subtasks"
self.table_users = "users"
self.table_versions = "versions"
self.pool = None
self.metrics_monitor = metrics_monitor
async def init(self):
await self.upgrade_db()
async def close(self):
if self.pool:
await self.pool.close()
def fmt_dict(self, data):
super().fmt_dict(data)
for k in ["create_t", "update_t", "ping_t"]:
if k in data and isinstance(data[k], float):
data[k] = datetime.fromtimestamp(data[k])
for k in ["params", "extra_info", "inputs", "outputs", "previous"]:
if k in data:
data[k] = json.dumps(data[k], ensure_ascii=False)
def parse_dict(self, data):
super().parse_dict(data)
for k in ["params", "extra_info", "inputs", "outputs", "previous"]:
if k in data:
data[k] = json.loads(data[k])
for k in ["create_t", "update_t", "ping_t"]:
if k in data:
data[k] = data[k].timestamp()
async def get_conn(self):
if self.pool is None:
self.pool = await asyncpg.create_pool(self.db_url)
return await self.pool.acquire()
async def release_conn(self, conn):
await self.pool.release(conn)
async def query_version(self):
conn = await self.get_conn()
try:
row = await conn.fetchrow(f"SELECT version FROM {self.table_versions} ORDER BY create_t DESC LIMIT 1")
row = dict(row)
return row["version"] if row else 0
except: # noqa
logger.error(f"query_version error: {traceback.format_exc()}")
return 0
finally:
await self.release_conn(conn)
@class_try_catch_async
async def upgrade_db(self):
versions = [
(1, "Init tables", self.upgrade_v1),
# (2, "Add new fields or indexes", self.upgrade_v2),
]
logger.info(f"upgrade_db: {self.db_url}")
cur_ver = await self.query_version()
for ver, description, func in versions:
if cur_ver < ver:
logger.info(f"Upgrade to version {ver}: {description}")
if not await func(ver, description):
logger.error(f"Upgrade to version {ver}: {description} func failed")
return False
cur_ver = ver
logger.info(f"upgrade_db: {self.db_url} done")
return True
async def upgrade_v1(self, version, description):
conn = await self.get_conn()
try:
async with conn.transaction(isolation="read_uncommitted"):
# create users table
await conn.execute(f"""
CREATE TABLE IF NOT EXISTS {self.table_users} (
user_id VARCHAR(256) PRIMARY KEY,
source VARCHAR(32),
id VARCHAR(200),
username VARCHAR(256),
email VARCHAR(256),
homepage VARCHAR(256),
avatar_url VARCHAR(256),
create_t TIMESTAMPTZ,
update_t TIMESTAMPTZ,
extra_info JSONB,
tag VARCHAR(64)
)
""")
# create tasks table
await conn.execute(f"""
CREATE TABLE IF NOT EXISTS {self.table_tasks} (
task_id VARCHAR(128) PRIMARY KEY,
task_type VARCHAR(64),
model_cls VARCHAR(64),
stage VARCHAR(64),
params JSONB,
create_t TIMESTAMPTZ,
update_t TIMESTAMPTZ,
status VARCHAR(64),
extra_info JSONB,
tag VARCHAR(64),
inputs JSONB,
outputs JSONB,
user_id VARCHAR(256),
FOREIGN KEY (user_id) REFERENCES {self.table_users}(user_id) ON DELETE CASCADE
)
""")
# create subtasks table
await conn.execute(f"""
CREATE TABLE IF NOT EXISTS {self.table_subtasks} (
task_id VARCHAR(128),
worker_name VARCHAR(128),
inputs JSONB,
outputs JSONB,
queue VARCHAR(128),
previous JSONB,
status VARCHAR(64),
worker_identity VARCHAR(128),
result VARCHAR(128),
fail_time INTEGER,
extra_info JSONB,
create_t TIMESTAMPTZ,
update_t TIMESTAMPTZ,
ping_t TIMESTAMPTZ,
infer_cost FLOAT,
PRIMARY KEY (task_id, worker_name),
FOREIGN KEY (task_id) REFERENCES {self.table_tasks}(task_id) ON DELETE CASCADE
)
""")
# create versions table
await conn.execute(f"""
CREATE TABLE IF NOT EXISTS {self.table_versions} (
version INTEGER PRIMARY KEY,
description VARCHAR(255),
create_t TIMESTAMPTZ NOT NULL
)
""")
# create indexes
await conn.execute(f"CREATE INDEX IF NOT EXISTS idx_{self.table_users}_source ON {self.table_users}(source)")
await conn.execute(f"CREATE INDEX IF NOT EXISTS idx_{self.table_users}_id ON {self.table_users}(id)")
await conn.execute(f"CREATE INDEX IF NOT EXISTS idx_{self.table_tasks}_status ON {self.table_tasks}(status)")
await conn.execute(f"CREATE INDEX IF NOT EXISTS idx_{self.table_tasks}_create_t ON {self.table_tasks}(create_t)")
await conn.execute(f"CREATE INDEX IF NOT EXISTS idx_{self.table_tasks}_tag ON {self.table_tasks}(tag)")
await conn.execute(f"CREATE INDEX IF NOT EXISTS idx_{self.table_subtasks}_task_id ON {self.table_subtasks}(task_id)")
await conn.execute(f"CREATE INDEX IF NOT EXISTS idx_{self.table_subtasks}_worker_name ON {self.table_subtasks}(worker_name)")
await conn.execute(f"CREATE INDEX IF NOT EXISTS idx_{self.table_subtasks}_status ON {self.table_subtasks}(status)")
# update version
await conn.execute(f"INSERT INTO {self.table_versions} (version, description, create_t) VALUES ($1, $2, $3)", version, description, datetime.now())
return True
except: # noqa
logger.error(f"upgrade_v1 error: {traceback.format_exc()}")
return False
finally:
await self.release_conn(conn)
async def load(self, conn, task_id, user_id=None, only_task=False, worker_name=None):
query = f"SELECT * FROM {self.table_tasks} WHERE task_id = $1"
params = [task_id]
if user_id is not None:
query += " AND user_id = $2"
params.append(user_id)
row = await conn.fetchrow(query, *params)
task = dict(row)
assert task, f"query_task: task not found: {task_id} {user_id}"
self.parse_dict(task)
if only_task:
return task
query2 = f"SELECT * FROM {self.table_subtasks} WHERE task_id = $1"
params2 = [task_id]
if worker_name is not None:
query2 += " AND worker_name = $2"
params2.append(worker_name)
rows = await conn.fetch(query2, *params2)
subtasks = []
for row in rows:
sub = dict(row)
self.parse_dict(sub)
subtasks.append(sub)
return task, subtasks
async def update_task(self, conn, task_id, **kwargs):
query = f"UPDATE {self.table_tasks} SET "
conds = ["update_t = $1"]
params = [datetime.now()]
param_idx = 1
if "status" in kwargs:
param_idx += 1
conds.append(f"status = ${param_idx}")
params.append(kwargs["status"].name)
if "extra_info" in kwargs:
param_idx += 1
conds.append(f"extra_info = ${param_idx}")
params.append(json.dumps(kwargs["extra_info"], ensure_ascii=False))
query += " ,".join(conds)
query += f" WHERE task_id = ${param_idx + 1}"
params.append(task_id)
await conn.execute(query, *params)
async def update_subtask(self, conn, task_id, worker_name, **kwargs):
query = f"UPDATE {self.table_subtasks} SET "
conds = []
params = []
param_idx = 0
if kwargs.get("update_t", True):
param_idx += 1
conds.append(f"update_t = ${param_idx}")
params.append(datetime.now())
if kwargs.get("ping_t", False):
param_idx += 1
conds.append(f"ping_t = ${param_idx}")
params.append(datetime.now())
if kwargs.get("reset_ping_t", False):
param_idx += 1
conds.append(f"ping_t = ${param_idx}")
params.append(datetime.fromtimestamp(0))
if "status" in kwargs:
param_idx += 1
conds.append(f"status = ${param_idx}")
params.append(kwargs["status"].name)
if "worker_identity" in kwargs:
param_idx += 1
conds.append(f"worker_identity = ${param_idx}")
params.append(kwargs["worker_identity"])
if "infer_cost" in kwargs:
param_idx += 1
conds.append(f"infer_cost = ${param_idx}")
params.append(kwargs["infer_cost"])
if "extra_info" in kwargs:
param_idx += 1
conds.append(f"extra_info = ${param_idx}")
params.append(json.dumps(kwargs["extra_info"], ensure_ascii=False))
query += " ,".join(conds)
query += f" WHERE task_id = ${param_idx + 1} AND worker_name = ${param_idx + 2}"
params.extend([task_id, worker_name])
await conn.execute(query, *params)
@class_try_catch_async
async def insert_task(self, task, subtasks):
conn = await self.get_conn()
try:
async with conn.transaction(isolation="read_uncommitted"):
self.fmt_dict(task)
await conn.execute(
f"""
INSERT INTO {self.table_tasks}
(task_id, task_type, model_cls, stage, params, create_t,
update_t, status, extra_info, tag, inputs, outputs, user_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
""",
task["task_id"],
task["task_type"],
task["model_cls"],
task["stage"],
task["params"],
task["create_t"],
task["update_t"],
task["status"],
task["extra_info"],
task["tag"],
task["inputs"],
task["outputs"],
task["user_id"],
)
for sub in subtasks:
self.fmt_dict(sub)
await conn.execute(
f"""
INSERT INTO {self.table_subtasks}
(task_id, worker_name, inputs, outputs, queue, previous, status,
worker_identity, result, fail_time, extra_info, create_t, update_t,
ping_t, infer_cost)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
""",
sub["task_id"],
sub["worker_name"],
sub["inputs"],
sub["outputs"],
sub["queue"],
sub["previous"],
sub["status"],
sub["worker_identity"],
sub["result"],
sub["fail_time"],
sub["extra_info"],
sub["create_t"],
sub["update_t"],
sub["ping_t"],
sub["infer_cost"],
)
return True
except: # noqa
logger.error(f"insert_task error: {traceback.format_exc()}")
return False
finally:
await self.release_conn(conn)
@class_try_catch_async
async def list_tasks(self, **kwargs):
conn = await self.get_conn()
try:
count = kwargs.get("count", False)
query = f"SELECT * FROM "
if count:
query = f"SELECT COUNT(*) FROM "
assert "limit" not in kwargs, "limit is not allowed when count is True"
assert "offset" not in kwargs, "offset is not allowed when count is True"
params = []
conds = []
param_idx = 0
if kwargs.get("subtasks", False):
query += self.table_subtasks
assert "user_id" not in kwargs, "user_id is not allowed when subtasks is True"
else:
query += self.table_tasks
if "status" in kwargs:
param_idx += 1
if isinstance(kwargs["status"], list):
next_idx = param_idx + len(kwargs["status"])
placeholders = ",".join([f"${i}" for i in range(param_idx, next_idx)])
conds.append(f"status IN ({placeholders})")
params.extend([x.name for x in kwargs["status"]])
param_idx = next_idx - 1
else:
conds.append(f"status = ${param_idx}")
params.append(kwargs["status"].name)
if "start_created_t" in kwargs:
param_idx += 1
conds.append(f"create_t >= ${param_idx}")
params.append(datetime.fromtimestamp(kwargs["start_created_t"]))
if "end_created_t" in kwargs:
param_idx += 1
conds.append(f"create_t <= ${param_idx}")
params.append(datetime.fromtimestamp(kwargs["end_created_t"]))
if "start_updated_t" in kwargs:
param_idx += 1
conds.append(f"update_t >= ${param_idx}")
params.append(datetime.fromtimestamp(kwargs["start_updated_t"]))
if "end_updated_t" in kwargs:
param_idx += 1
conds.append(f"update_t <= ${param_idx}")
params.append(datetime.fromtimestamp(kwargs["end_updated_t"]))
if "start_ping_t" in kwargs:
param_idx += 1
conds.append(f"ping_t >= ${param_idx}")
params.append(datetime.fromtimestamp(kwargs["start_ping_t"]))
if "end_ping_t" in kwargs:
param_idx += 1
conds.append(f"ping_t <= ${param_idx}")
params.append(datetime.fromtimestamp(kwargs["end_ping_t"]))
if "user_id" in kwargs:
param_idx += 1
conds.append(f"user_id = ${param_idx}")
params.append(kwargs["user_id"])
if conds:
query += " WHERE " + " AND ".join(conds)
if not count:
query += " ORDER BY create_t DESC"
if "limit" in kwargs:
param_idx += 1
query += f" LIMIT ${param_idx}"
params.append(kwargs["limit"])
if "offset" in kwargs:
param_idx += 1
query += f" OFFSET ${param_idx}"
params.append(kwargs["offset"])
rows = await conn.fetch(query, *params)
if count:
return rows[0]["count"]
tasks = []
for row in rows:
task = dict(row)
self.parse_dict(task)
tasks.append(task)
return tasks
except: # noqa
logger.error(f"list_tasks error: {traceback.format_exc()}")
return []
finally:
await self.release_conn(conn)
@class_try_catch_async
async def query_task(self, task_id, user_id=None, only_task=True):
conn = await self.get_conn()
try:
return await self.load(conn, task_id, user_id, only_task=only_task)
except: # noqa
logger.error(f"query_task error: {traceback.format_exc()}")
return None
finally:
await self.release_conn(conn)
@class_try_catch_async
async def next_subtasks(self, task_id):
conn = await self.get_conn()
try:
await ASYNC_LOCK.acquire()
async with conn.transaction(isolation="read_uncommitted"):
task, subtasks = await self.load(conn, task_id)
if task["status"] not in ActiveStatus:
return []
succeeds = set()
for sub in subtasks:
if sub["status"] == TaskStatus.SUCCEED:
succeeds.add(sub["worker_name"])
nexts = []
for sub in subtasks:
if sub["status"] == TaskStatus.CREATED:
dep_ok = True
for prev in sub["previous"]:
if prev not in succeeds:
dep_ok = False
break
if dep_ok:
sub["params"] = task["params"]
self.mark_subtask_change(sub, sub["status"], TaskStatus.PENDING)
await self.update_subtask(conn, task_id, sub["worker_name"], status=TaskStatus.PENDING, extra_info=sub["extra_info"])
nexts.append(sub)
if len(nexts) > 0:
await self.update_task(conn, task_id, status=TaskStatus.PENDING)
return nexts
except: # noqa
logger.error(f"next_subtasks error: {traceback.format_exc()}")
return None
finally:
ASYNC_LOCK.release()
await self.release_conn(conn)
@class_try_catch_async
async def run_subtasks(self, cands, worker_identity):
conn = await self.get_conn()
try:
await ASYNC_LOCK.acquire()
async with conn.transaction(isolation="read_uncommitted"):
valids = []
for cand in cands:
task_id = cand["task_id"]
worker_name = cand["worker_name"]
task, subs = await self.load(conn, task_id, worker_name=worker_name)
assert len(subs) == 1, f"task {task_id} has multiple subtasks: {subs} with worker_name: {worker_name}"
if task["status"] in [TaskStatus.SUCCEED, TaskStatus.FAILED, TaskStatus.CANCEL]:
continue
self.mark_subtask_change(subs[0], subs[0]["status"], TaskStatus.RUNNING)
await self.update_subtask(conn, task_id, worker_name, status=TaskStatus.RUNNING, worker_identity=worker_identity, ping_t=True, extra_info=subs[0]["extra_info"])
await self.update_task(conn, task_id, status=TaskStatus.RUNNING)
valids.append(cand)
break
return valids
except: # noqa
logger.error(f"run_subtasks error: {traceback.format_exc()}")
return []
finally:
ASYNC_LOCK.release()
await self.release_conn(conn)
@class_try_catch_async
async def ping_subtask(self, task_id, worker_name, worker_identity):
conn = await self.get_conn()
try:
await ASYNC_LOCK.acquire()
async with conn.transaction(isolation="read_uncommitted"):
task, subtasks = await self.load(conn, task_id)
for sub in subtasks:
if sub["worker_name"] == worker_name:
pre = sub["worker_identity"]
assert pre == worker_identity, f"worker identity not matched: {pre} vs {worker_identity}"
await self.update_subtask(conn, task_id, worker_name, ping_t=True, update_t=False)
return True
return False
except: # noqa
logger.error(f"ping_subtask error: {traceback.format_exc()}")
return False
finally:
ASYNC_LOCK.release()
await self.release_conn(conn)
@class_try_catch_async
async def finish_subtasks(self, task_id, status, worker_identity=None, worker_name=None, fail_msg=None, should_running=False):
conn = await self.get_conn()
try:
await ASYNC_LOCK.acquire()
async with conn.transaction(isolation="read_uncommitted"):
task, subtasks = await self.load(conn, task_id)
subs = subtasks
if worker_name:
subs = [sub for sub in subtasks if sub["worker_name"] == worker_name]
assert len(subs) >= 1, f"no worker task_id={task_id}, name={worker_name}"
if worker_identity:
pre = subs[0]["worker_identity"]
assert pre == worker_identity, f"worker identity not matched: {pre} vs {worker_identity}"
assert status in [TaskStatus.SUCCEED, TaskStatus.FAILED], f"invalid finish status: {status}"
for sub in subs:
if sub["status"] not in FinishedStatus:
if should_running and sub["status"] != TaskStatus.RUNNING:
logger.warning(f"task {task_id} is not running, skip finish subtask: {sub}")
continue
self.mark_subtask_change(sub, sub["status"], status, fail_msg=fail_msg)
await self.update_subtask(conn, task_id, sub["worker_name"], status=status, extra_info=sub["extra_info"])
sub["status"] = status
if task["status"] == TaskStatus.CANCEL:
return TaskStatus.CANCEL
running_subs = []
failed_sub = False
for sub in subtasks:
if sub["status"] not in FinishedStatus:
running_subs.append(sub)
if sub["status"] == TaskStatus.FAILED:
failed_sub = True
# some subtask failed, we should fail all other subtasks
if failed_sub:
if task["status"] != TaskStatus.FAILED:
self.mark_task_end(task, TaskStatus.FAILED)
await self.update_task(conn, task_id, status=TaskStatus.FAILED, extra_info=task["extra_info"])
for sub in running_subs:
self.mark_subtask_change(sub, sub["status"], TaskStatus.FAILED, fail_msg="other subtask failed")
await self.update_subtask(conn, task_id, sub["worker_name"], status=TaskStatus.FAILED, extra_info=sub["extra_info"])
return TaskStatus.FAILED
# all subtasks finished and all succeed
elif len(running_subs) == 0:
if task["status"] != TaskStatus.SUCCEED:
self.mark_task_end(task, TaskStatus.SUCCEED)
await self.update_task(conn, task_id, status=TaskStatus.SUCCEED, extra_info=task["extra_info"])
return TaskStatus.SUCCEED
return None
except: # noqa
logger.error(f"finish_subtasks error: {traceback.format_exc()}")
return None
finally:
ASYNC_LOCK.release()
await self.release_conn(conn)
@class_try_catch_async
async def cancel_task(self, task_id, user_id=None):
conn = await self.get_conn()
try:
await ASYNC_LOCK.acquire()
async with conn.transaction(isolation="read_uncommitted"):
task, subtasks = await self.load(conn, task_id, user_id)
if task["status"] not in ActiveStatus:
return f"Task {task_id} is not in active status (current status: {task['status']}). Only tasks with status CREATED, PENDING, or RUNNING can be cancelled."
for sub in subtasks:
if sub["status"] not in FinishedStatus:
self.mark_subtask_change(sub, sub["status"], TaskStatus.CANCEL)
await self.update_subtask(conn, task_id, sub["worker_name"], status=TaskStatus.CANCEL, extra_info=sub["extra_info"])
self.mark_task_end(task, TaskStatus.CANCEL)
await self.update_task(conn, task_id, status=TaskStatus.CANCEL, extra_info=task["extra_info"])
return True
except: # noqa
logger.error(f"cancel_task error: {traceback.format_exc()}")
return "unknown cancel error"
finally:
ASYNC_LOCK.release()
await self.release_conn(conn)
@class_try_catch_async
async def resume_task(self, task_id, all_subtask=False, user_id=None):
conn = await self.get_conn()
try:
await ASYNC_LOCK.acquire()
async with conn.transaction(isolation="read_uncommitted"):
task, subtasks = await self.load(conn, task_id, user_id)
# the task is not finished
if task["status"] not in FinishedStatus:
return False
# the task is no need to resume
if not all_subtask and task["status"] == TaskStatus.SUCCEED:
return False
for sub in subtasks:
if all_subtask or sub["status"] != TaskStatus.SUCCEED:
self.mark_subtask_change(sub, None, TaskStatus.CREATED)
await self.update_subtask(conn, task_id, sub["worker_name"], status=TaskStatus.CREATED, reset_ping_t=True, extra_info=sub["extra_info"])
self.mark_task_start(task)
await self.update_task(conn, task_id, status=TaskStatus.CREATED, extra_info=task["extra_info"])
return True
except: # noqa
logger.error(f"resume_task error: {traceback.format_exc()}")
return False
finally:
ASYNC_LOCK.release()
await self.release_conn(conn)
@class_try_catch_async
async def insert_user_if_not_exists(self, user_info):
conn = await self.get_conn()
try:
async with conn.transaction(isolation="read_uncommitted"):
row = await conn.fetchrow(f"SELECT * FROM {self.table_users} WHERE user_id = $1", user_info["user_id"])
if row:
logger.info(f"user already exists: {user_info['user_id']}")
return True
self.fmt_dict(user_info)
await conn.execute(
f"""
INSERT INTO {self.table_users}
(user_id, source, id, username, email, homepage,
avatar_url, create_t, update_t, extra_info, tag)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
""",
user_info["user_id"],
user_info["source"],
user_info["id"],
user_info["username"],
user_info["email"],
user_info["homepage"],
user_info["avatar_url"],
user_info["create_t"],
user_info["update_t"],
user_info["extra_info"],
user_info["tag"],
)
return True
except: # noqa
logger.error(f"insert_user_if_not_exists error: {traceback.format_exc()}")
return False
finally:
await self.release_conn(conn)
@class_try_catch_async
async def query_user(self, user_id):
conn = await self.get_conn()
try:
row = await conn.fetchrow(f"SELECT * FROM {self.table_users} WHERE user_id = $1", user_id)
user = dict(row)
self.parse_dict(user)
return user
except: # noqa
logger.error(f"query_user error: {traceback.format_exc()}")
return None
finally:
await self.release_conn(conn)
async def test():
from lightx2v.deploy.common.pipeline import Pipeline
p = Pipeline("/data/nvme1/liuliang1/lightx2v/configs/model_pipeline.json")
m = PostgresSQLTaskManager("postgresql://test:test@127.0.0.1:5432/lightx2v_test")
await m.init()
keys = ["t2v", "wan2.1", "multi_stage"]
workers = p.get_workers(keys)
inputs = p.get_inputs(keys)
outputs = p.get_outputs(keys)
params = {
"prompt": "fake input prompts",
"resolution": {
"height": 233,
"width": 456,
},
}
user_info = {
"source": "github",
"id": "4566",
"username": "test-username-233",
"email": "test-email-233@test.com",
"homepage": "https://test.com",
"avatar_url": "https://test.com/avatar.png",
}
user_id = await m.create_user(user_info)
print(" - create_user:", user_id)
user = await m.query_user(user_id)
print(" - query_user:", user)
task_id = await m.create_task(keys, workers, params, inputs, outputs, user_id)
print(" - create_task:", task_id)
tasks = await m.list_tasks()
print(" - list_tasks:", tasks)
task = await m.query_task(task_id)
print(" - query_task:", task)
subtasks = await m.next_subtasks(task_id)
print(" - next_subtasks:", subtasks)
await m.run_subtasks(subtasks, "fake-worker")
await m.finish_subtasks(task_id, TaskStatus.FAILED)
await m.cancel_task(task_id)
await m.resume_task(task_id)
for sub in subtasks:
await m.finish_subtasks(sub["task_id"], TaskStatus.SUCCEED, worker_name=sub["worker_name"], worker_identity="fake-worker")
subtasks = await m.next_subtasks(task_id)
print(" - final next_subtasks:", subtasks)
task = await m.query_task(task_id)
print(" - final task:", task)
await m.close()
if __name__ == "__main__":
asyncio.run(test())
import argparse
import asyncio
import json
import os
import signal
import sys
import traceback
import uuid
import aiohttp
import torch
import torch.distributed as dist
from loguru import logger
from lightx2v.deploy.data_manager import LocalDataManager, S3DataManager
from lightx2v.deploy.task_manager import TaskStatus
from lightx2v.deploy.worker.hub import DiTWorker, ImageEncoderWorker, PipelineWorker, SegmentDiTWorker, TextEncoderWorker, VaeDecoderWorker, VaeEncoderWorker
RUNNER_MAP = {
"pipeline": PipelineWorker,
"text_encoder": TextEncoderWorker,
"image_encoder": ImageEncoderWorker,
"vae_encoder": VaeEncoderWorker,
"vae_decoder": VaeDecoderWorker,
"dit": DiTWorker,
"segment_dit": SegmentDiTWorker,
}
# {task_id: {"server": xx, "worker_name": xx, "identity": xx}}
RUNNING_SUBTASKS = {}
WORKER_SECRET_KEY = os.getenv("WORKER_SECRET_KEY", "worker-secret-key-change-in-production")
HEADERS = {"Authorization": f"Bearer {WORKER_SECRET_KEY}", "Content-Type": "application/json"}
STOPPED = False
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
RANK = int(os.environ.get("RANK", 0))
TARGET_RANK = WORLD_SIZE - 1
async def ping_life(server_url, worker_identity, keys):
url = server_url + "/api/v1/worker/ping/life"
params = {"worker_identity": worker_identity, "worker_keys": keys}
while True:
try:
logger.info(f"{worker_identity} pinging life ...")
async with aiohttp.ClientSession() as session:
async with session.post(url, data=json.dumps(params), headers=HEADERS) as ret:
if ret.status == 200:
ret = await ret.json()
logger.info(f"{worker_identity} ping life: {ret}")
if ret["msg"] == "delete":
logger.warning(f"{worker_identity} deleted")
# asyncio.create_task(shutdown(asyncio.get_event_loop()))
return
await asyncio.sleep(10)
else:
error_text = await ret.text()
raise Exception(f"{worker_identity} ping life fail: [{ret.status}], error: {error_text}")
except asyncio.CancelledError:
logger.warning("Ping life cancelled, shutting down...")
raise asyncio.CancelledError
except: # noqa
logger.warning(f"Ping life failed: {traceback.format_exc()}")
await asyncio.sleep(10)
async def ping_subtask(server_url, worker_identity, task_id, worker_name, queue, running_task, ping_interval):
url = server_url + "/api/v1/worker/ping/subtask"
params = {
"worker_identity": worker_identity,
"task_id": task_id,
"worker_name": worker_name,
"queue": queue,
}
while True:
try:
logger.info(f"{worker_identity} pinging subtask {task_id} {worker_name} ...")
async with aiohttp.ClientSession() as session:
async with session.post(url, data=json.dumps(params), headers=HEADERS) as ret:
if ret.status == 200:
ret = await ret.json()
logger.info(f"{worker_identity} ping subtask {task_id} {worker_name}: {ret}")
if ret["msg"] == "delete":
logger.warning(f"{worker_identity} subtask {task_id} {worker_name} deleted")
running_task.cancel()
return
await asyncio.sleep(ping_interval)
else:
error_text = await ret.text()
raise Exception(f"{worker_identity} ping subtask fail: [{ret.status}], error: {error_text}")
except asyncio.CancelledError:
logger.warning(f"Ping subtask {task_id} {worker_name} cancelled")
raise asyncio.CancelledError
except: # noqa
logger.warning(f"Ping subtask failed: {traceback.format_exc()}")
await asyncio.sleep(10)
async def fetch_subtasks(server_url, worker_keys, worker_identity, max_batch, timeout):
url = server_url + "/api/v1/worker/fetch"
params = {
"worker_keys": worker_keys,
"worker_identity": worker_identity,
"max_batch": max_batch,
"timeout": timeout,
}
try:
logger.info(f"{worker_identity} fetching {worker_keys} with timeout: {timeout}s ...")
async with aiohttp.ClientSession() as session:
async with session.post(url, data=json.dumps(params), headers=HEADERS, timeout=timeout + 1) as ret:
if ret.status == 200:
ret = await ret.json()
subtasks = ret["subtasks"]
for sub in subtasks:
sub["server_url"] = server_url
sub["worker_identity"] = worker_identity
RUNNING_SUBTASKS[sub["task_id"]] = sub
logger.info(f"{worker_identity} fetch {worker_keys} ok: {subtasks}")
return subtasks
else:
error_text = await ret.text()
logger.warning(f"{worker_identity} fetch {worker_keys} fail: [{ret.status}], error: {error_text}")
return None
except asyncio.CancelledError:
logger.warning("Fetch subtasks cancelled, shutting down...")
raise asyncio.CancelledError
except: # noqa
logger.warning(f"Fetch subtasks failed: {traceback.format_exc()}")
await asyncio.sleep(10)
async def report_task(server_url, task_id, worker_name, status, worker_identity, queue, **kwargs):
url = server_url + "/api/v1/worker/report"
params = {
"task_id": task_id,
"worker_name": worker_name,
"status": status,
"worker_identity": worker_identity,
"queue": queue,
"fail_msg": "" if status == TaskStatus.SUCCEED.name else "worker failed",
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(url, data=json.dumps(params), headers=HEADERS) as ret:
if ret.status == 200:
RUNNING_SUBTASKS.pop(task_id)
ret = await ret.json()
logger.info(f"{worker_identity} report {task_id} {worker_name} {status} ok")
return True
else:
error_text = await ret.text()
logger.warning(f"{worker_identity} report {task_id} {worker_name} {status} fail: [{ret.status}], error: {error_text}")
return False
except asyncio.CancelledError:
logger.warning("Report task cancelled, shutting down...")
raise asyncio.CancelledError
except: # noqa
logger.warning(f"Report task failed: {traceback.format_exc()}")
async def boradcast_subtasks(subtasks):
subtasks = [] if subtasks is None else subtasks
if WORLD_SIZE <= 1:
return subtasks
try:
if RANK == TARGET_RANK:
subtasks_data = json.dumps(subtasks, ensure_ascii=False).encode("utf-8")
subtasks_tensor = torch.frombuffer(bytearray(subtasks_data), dtype=torch.uint8).to(device="cuda")
data_size = subtasks_tensor.shape[0]
size_tensor = torch.tensor([data_size], dtype=torch.int32).to(device="cuda")
logger.info(f"rank {RANK} send subtasks: {subtasks_tensor.shape}, {size_tensor}")
else:
size_tensor = torch.zeros(1, dtype=torch.int32, device="cuda")
dist.broadcast(size_tensor, src=TARGET_RANK)
if RANK != TARGET_RANK:
subtasks_tensor = torch.zeros(size_tensor.item(), dtype=torch.uint8, device="cuda")
dist.broadcast(subtasks_tensor, src=TARGET_RANK)
if RANK != TARGET_RANK:
subtasks_data = subtasks_tensor.cpu().numpy().tobytes()
subtasks = json.loads(subtasks_data.decode("utf-8"))
logger.info(f"rank {RANK} recv subtasks: {subtasks}")
return subtasks
except: # noqa
logger.error(f"Broadcast subtasks failed: {traceback.format_exc()}")
return []
async def sync_subtask():
if WORLD_SIZE <= 1:
return
try:
logger.info(f"Sync subtask {RANK}/{WORLD_SIZE} wait barrier")
dist.barrier()
logger.info(f"Sync subtask {RANK}/{WORLD_SIZE} ok")
except: # noqa
logger.error(f"Sync subtask failed: {traceback.format_exc()}")
async def main(args):
if args.model_name == "":
args.model_name = args.model_cls
worker_keys = [args.task, args.model_name, args.stage, args.worker]
data_manager = None
if args.data_url.startswith("/"):
data_manager = LocalDataManager(args.data_url)
elif args.data_url.startswith("{"):
data_manager = S3DataManager(args.data_url)
else:
raise NotImplementedError
await data_manager.init()
if WORLD_SIZE > 1:
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
logger.info(f"Distributed initialized: rank={RANK}, world_size={WORLD_SIZE}")
runner = RUNNER_MAP[args.worker](args)
if WORLD_SIZE > 1:
dist.barrier()
# asyncio.create_task(ping_life(args.server, args.identity, worker_keys))
while True:
subtasks = None
if RANK == TARGET_RANK:
subtasks = await fetch_subtasks(args.server, worker_keys, args.identity, args.max_batch, args.timeout)
subtasks = await boradcast_subtasks(subtasks)
for sub in subtasks:
status = TaskStatus.FAILED.name
ping_task = None
try:
run_task = asyncio.create_task(runner.run(sub["inputs"], sub["outputs"], sub["params"], data_manager))
if RANK == TARGET_RANK:
ping_task = asyncio.create_task(ping_subtask(args.server, sub["worker_identity"], sub["task_id"], sub["worker_name"], sub["queue"], run_task, args.ping_interval))
ret = await run_task
if ret is True:
status = TaskStatus.SUCCEED.name
except asyncio.CancelledError:
if STOPPED:
logger.warning("Main loop cancelled, already stopped, should exit")
return
logger.warning("Main loop cancelled, do not shut down")
finally:
if RANK == TARGET_RANK and sub["task_id"] in RUNNING_SUBTASKS:
try:
await report_task(status=status, **sub)
except: # noqa
logger.warning(f"Report failed: {traceback.format_exc()}")
if ping_task:
ping_task.cancel()
await sync_subtask()
async def shutdown(loop):
logger.warning("Received kill signal")
global STOPPED
STOPPED = True
for t in asyncio.all_tasks():
if t is not asyncio.current_task():
logger.warning(f"Cancel async task {t} ...")
t.cancel()
# Report remaining running subtasks failed
if RANK == TARGET_RANK:
task_ids = list(RUNNING_SUBTASKS.keys())
for task_id in task_ids:
try:
s = RUNNING_SUBTASKS[task_id]
logger.warning(f"Report {task_id} {s['worker_name']} {TaskStatus.FAILED.name} ...")
await report_task(status=TaskStatus.FAILED.name, **s)
except: # noqa
logger.warning(f"Report task {task_id} failed: {traceback.format_exc()}")
if WORLD_SIZE > 1:
dist.destroy_process_group()
# Force exit after a short delay to ensure cleanup
def force_exit():
logger.warning("Force exiting process...")
sys.exit(0)
loop.call_later(2, force_exit)
# =========================
# Main Entry
# =========================
if __name__ == "__main__":
parser = argparse.ArgumentParser()
cur_dir = os.path.dirname(os.path.abspath(__file__))
base_dir = os.path.abspath(os.path.join(cur_dir, "../../.."))
dft_data_url = os.path.join(base_dir, "local_data")
parser.add_argument("--task", type=str, required=True)
parser.add_argument("--model_cls", type=str, required=True)
parser.add_argument("--model_name", type=str, default="")
parser.add_argument("--stage", type=str, required=True)
parser.add_argument("--worker", type=str, required=True)
parser.add_argument("--identity", type=str, default="")
parser.add_argument("--max_batch", type=int, default=1)
parser.add_argument("--timeout", type=int, default=300)
parser.add_argument("--ping_interval", type=int, default=10)
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--server", type=str, default="http://127.0.0.1:8080")
parser.add_argument("--data_url", type=str, default=dft_data_url)
args = parser.parse_args()
if args.identity == "":
# TODO: spec worker instance identity by k8s env
args.identity = "worker-" + str(uuid.uuid4())[:8]
logger.info(f"args: {args}")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
for s in [signal.SIGINT, signal.SIGTERM]:
loop.add_signal_handler(s, lambda: asyncio.create_task(shutdown(loop)))
try:
loop.create_task(main(args), name="main")
loop.run_forever()
finally:
loop.close()
logger.warning("Event loop closed")
import asyncio
import copy
import ctypes
import gc
import json
import os
import tempfile
import threading
import traceback
import torch
import torch.distributed as dist
from loguru import logger
from lightx2v.deploy.common.utils import class_try_catch_async
from lightx2v.infer import init_runner # noqa
from lightx2v.models.runners.graph_runner import GraphRunner
from lightx2v.utils.envs import CHECK_ENABLE_GRAPH_MODE
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.set_config import set_config, set_parallel_config
from lightx2v.utils.utils import seed_all
class BaseWorker:
@ProfilingContext("Init Worker Worker Cost:")
def __init__(self, args):
config = set_config(args)
config["mode"] = ""
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
seed_all(config.seed)
self.rank = 0
if config.parallel:
self.rank = dist.get_rank()
set_parallel_config(config)
self.runner = RUNNER_REGISTER[config.model_cls](config)
# fixed config
self.fixed_config = copy.deepcopy(self.runner.config)
def update_config(self, kwargs):
for k, v in kwargs.items():
setattr(self.runner.config, k, v)
def set_inputs(self, params):
self.runner.config["prompt"] = params["prompt"]
self.runner.config["negative_prompt"] = params.get("negative_prompt", "")
self.runner.config["image_path"] = params.get("image_path", "")
self.runner.config["save_video_path"] = params.get("save_video_path", "")
self.runner.config["seed"] = params.get("seed", self.fixed_config.get("seed", 42))
self.runner.config["audio_path"] = params.get("audio_path", "")
async def prepare_input_image(self, params, inputs, tmp_dir, data_manager):
input_image_path = inputs.get("input_image", "")
tmp_image_path = os.path.join(tmp_dir, input_image_path)
# prepare tmp image
if self.runner.config.task == "i2v":
img_data = await data_manager.load_bytes(input_image_path)
with open(tmp_image_path, "wb") as fout:
fout.write(img_data)
params["image_path"] = tmp_image_path
async def prepare_input_audio(self, params, inputs, tmp_dir, data_manager):
input_audio_path = inputs.get("input_audio", "")
tmp_audio_path = os.path.join(tmp_dir, input_audio_path)
# for stream audio input, value is dict
stream_audio_path = params.get("input_audio", None)
if stream_audio_path is not None:
tmp_audio_path = stream_audio_path
if input_audio_path and self.is_audio_model() and isinstance(tmp_audio_path, str):
audio_data = await data_manager.load_bytes(input_audio_path)
with open(tmp_audio_path, "wb") as fout:
fout.write(audio_data)
params["audio_path"] = tmp_audio_path
def prepare_output_video(self, params, outputs, tmp_dir, data_manager):
output_video_path = outputs.get("output_video", "")
tmp_video_path = os.path.join(tmp_dir, output_video_path)
if data_manager.name == "local":
tmp_video_path = os.path.join(data_manager.local_dir, output_video_path)
# for stream video output, value is dict
stream_video_path = params.get("output_video", None)
if stream_video_path is not None:
tmp_video_path = stream_video_path
params["save_video_path"] = tmp_video_path
return tmp_video_path, output_video_path
async def prepare_dit_inputs(self, inputs, data_manager):
device = torch.device("cuda", self.rank)
text_out = inputs["text_encoder_output"]
text_encoder_output = await data_manager.load_object(text_out, device)
image_encoder_output = None
if self.runner.config.task == "i2v":
clip_path = inputs["clip_encoder_output"]
vae_path = inputs["vae_encoder_output"]
clip_encoder_out = await data_manager.load_object(clip_path, device)
vae_encoder_out = await data_manager.load_object(vae_path, device)
image_encoder_output = {
"clip_encoder_out": clip_encoder_out,
"vae_encoder_out": vae_encoder_out["vals"],
}
# apploy the config changes by vae encoder
self.update_config(vae_encoder_out["kwargs"])
self.runner.inputs = {
"text_encoder_output": text_encoder_output,
"image_encoder_output": image_encoder_output,
}
if self.is_audio_model():
audio_segments, expected_frames = self.runner.read_audio_input()
self.runner.inputs["audio_segments"] = audio_segments
self.runner.inputs["expected_frames"] = expected_frames
async def save_output_video(self, tmp_video_path, output_video_path, data_manager):
# save output video
if data_manager.name != "local" and self.rank == 0 and isinstance(tmp_video_path, str):
video_data = open(tmp_video_path, "rb").read()
await data_manager.save_bytes(video_data, output_video_path)
def is_audio_model(self):
return "audio" in self.runner.config.model_cls or "seko_talk" in self.runner.config.model_cls
class RunnerThread(threading.Thread):
def __init__(self, loop, future, run_func, rank, *args, **kwargs):
super().__init__(daemon=True)
self.loop = loop
self.future = future
self.run_func = run_func
self.args = args
self.kwargs = kwargs
self.rank = rank
def run(self):
try:
# cuda device bind for each thread
torch.cuda.set_device(self.rank)
res = self.run_func(*self.args, **self.kwargs)
status = True
except: # noqa
logger.error(f"RunnerThread run failed: {traceback.format_exc()}")
res = None
status = False
finally:
async def set_future_result():
self.future.set_result((status, res))
# add the task of setting future to the loop queue
asyncio.run_coroutine_threadsafe(set_future_result(), self.loop)
def stop(self):
if self.is_alive():
try:
logger.warning(f"Force terminate thread {self.ident} ...")
ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(self.ident), ctypes.py_object(SystemExit))
except Exception as e:
logger.error(f"Force terminate thread failed: {e}")
def class_try_catch_async_with_thread(func):
async def wrapper(self, *args, **kwargs):
try:
return await func(self, *args, **kwargs)
except asyncio.CancelledError:
logger.warning(f"RunnerThread inside {func.__name__} cancelled")
if hasattr(self, "thread"):
# self.thread.stop()
self.runner.stop_signal = True
self.thread.join()
raise asyncio.CancelledError
except Exception:
logger.error(f"Error in {self.__class__.__name__}.{func.__name__}:")
traceback.print_exc()
return None
return wrapper
class PipelineWorker(BaseWorker):
def __init__(self, args):
super().__init__(args)
self.runner.init_modules()
if CHECK_ENABLE_GRAPH_MODE():
self.init_temp_params()
self.graph_runner = GraphRunner(self.runner)
self.run_func = self.graph_runner.run_pipeline
else:
self.run_func = self.runner.run_pipeline
def init_temp_params(self):
cur_dir = os.path.dirname(os.path.abspath(__file__))
base_dir = os.path.abspath(os.path.join(cur_dir, "../../.."))
self.runner.config["prompt"] = "The video features a old lady is saying something and knitting a sweater."
if self.runner.config.task == "i2v":
self.runner.config["image_path"] = os.path.join(base_dir, "assets", "inputs", "audio", "15.png")
if self.is_audio_model():
self.runner.config["audio_path"] = os.path.join(base_dir, "assets", "inputs", "audio", "15.wav")
@class_try_catch_async_with_thread
async def run(self, inputs, outputs, params, data_manager):
with tempfile.TemporaryDirectory() as tmp_dir:
await self.prepare_input_image(params, inputs, tmp_dir, data_manager)
await self.prepare_input_audio(params, inputs, tmp_dir, data_manager)
tmp_video_path, output_video_path = self.prepare_output_video(params, outputs, tmp_dir, data_manager)
logger.info(f"run params: {params}, {inputs}, {outputs}")
self.set_inputs(params)
self.runner.stop_signal = False
future = asyncio.Future()
self.thread = RunnerThread(asyncio.get_running_loop(), future, self.run_func, self.rank)
self.thread.start()
status, _ = await future
if not status:
return False
await self.save_output_video(tmp_video_path, output_video_path, data_manager)
return True
class TextEncoderWorker(BaseWorker):
def __init__(self, args):
super().__init__(args)
self.runner.text_encoders = self.runner.load_text_encoder()
@class_try_catch_async
async def run(self, inputs, outputs, params, data_manager):
logger.info(f"run params: {params}, {inputs}, {outputs}")
input_image_path = inputs.get("input_image", "")
self.set_inputs(params)
prompt = self.runner.config["prompt"]
img = None
if self.runner.config["use_prompt_enhancer"]:
prompt = self.runner.config["prompt_enhanced"]
if self.runner.config.task == "i2v" and not self.is_audio_model():
img = await data_manager.load_image(input_image_path)
img = self.runner.read_image_input(img)
if isinstance(img, tuple):
img = img[0]
out = self.runner.run_text_encoder(prompt, img)
if self.rank == 0:
await data_manager.save_object(out, outputs["text_encoder_output"])
del out
torch.cuda.empty_cache()
gc.collect()
return True
class ImageEncoderWorker(BaseWorker):
def __init__(self, args):
super().__init__(args)
self.runner.image_encoder = self.runner.load_image_encoder()
@class_try_catch_async
async def run(self, inputs, outputs, params, data_manager):
logger.info(f"run params: {params}, {inputs}, {outputs}")
self.set_inputs(params)
img = await data_manager.load_image(inputs["input_image"])
img = self.runner.read_image_input(img)
if isinstance(img, tuple):
img = img[0]
out = self.runner.run_image_encoder(img)
if self.rank == 0:
await data_manager.save_object(out, outputs["clip_encoder_output"])
del out
torch.cuda.empty_cache()
gc.collect()
return True
class VaeEncoderWorker(BaseWorker):
def __init__(self, args):
super().__init__(args)
self.runner.vae_encoder, vae_decoder = self.runner.load_vae()
del vae_decoder
@class_try_catch_async
async def run(self, inputs, outputs, params, data_manager):
logger.info(f"run params: {params}, {inputs}, {outputs}")
self.set_inputs(params)
img = await data_manager.load_image(inputs["input_image"])
# could change config.lat_h, lat_w, tgt_h, tgt_w
img = self.runner.read_image_input(img)
if isinstance(img, tuple):
img = img[1] if self.runner.vae_encoder_need_img_original else img[0]
# run vae encoder changed the config, we use kwargs pass changes
vals = self.runner.run_vae_encoder(img)
out = {"vals": vals, "kwargs": {}}
for key in ["lat_h", "lat_w", "tgt_h", "tgt_w"]:
if hasattr(self.runner.config, key):
out["kwargs"][key] = int(getattr(self.runner.config, key))
if self.rank == 0:
await data_manager.save_object(out, outputs["vae_encoder_output"])
del out, img, vals
torch.cuda.empty_cache()
gc.collect()
return True
class DiTWorker(BaseWorker):
def __init__(self, args):
super().__init__(args)
self.runner.model = self.runner.load_transformer()
@class_try_catch_async_with_thread
async def run(self, inputs, outputs, params, data_manager):
logger.info(f"run params: {params}, {inputs}, {outputs}")
self.set_inputs(params)
await self.prepare_dit_inputs(inputs, data_manager)
self.runner.stop_signal = False
future = asyncio.Future()
self.thread = RunnerThread(asyncio.get_running_loop(), future, self.run_dit, self.rank)
self.thread.start()
status, (out, _) = await future
if not status:
return False
if self.rank == 0:
await data_manager.save_tensor(out, outputs["latents"])
del out
torch.cuda.empty_cache()
gc.collect()
return True
def run_dit(self):
self.runner.init_run()
assert self.runner.video_segment_num == 1, "DiTWorker only support single segment"
latents, generator = self.runner.run_segment(total_steps=None)
self.runner.end_run()
return latents, generator
class VaeDecoderWorker(BaseWorker):
def __init__(self, args):
super().__init__(args)
vae_encoder, self.runner.vae_decoder = self.runner.load_vae()
self.runner.vfi_model = self.runner.load_vfi_model() if "video_frame_interpolation" in self.runner.config else None
del vae_encoder
@class_try_catch_async
async def run(self, inputs, outputs, params, data_manager):
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_video_path, output_video_path = self.prepare_output_video(params, outputs, tmp_dir, data_manager)
logger.info(f"run params: {params}, {inputs}, {outputs}")
self.set_inputs(params)
device = torch.device("cuda", self.rank)
latents = await data_manager.load_tensor(inputs["latents"], device)
self.runner.gen_video = self.runner.run_vae_decoder(latents)
self.runner.process_images_after_vae_decoder(save_video=True)
await self.save_output_video(tmp_video_path, output_video_path, data_manager)
del latents
torch.cuda.empty_cache()
gc.collect()
return True
class SegmentDiTWorker(BaseWorker):
def __init__(self, args):
super().__init__(args)
self.runner.model = self.runner.load_transformer()
self.runner.vae_encoder, self.runner.vae_decoder = self.runner.load_vae()
self.runner.vfi_model = self.runner.load_vfi_model() if "video_frame_interpolation" in self.runner.config else None
if self.is_audio_model():
self.runner.audio_encoder = self.runner.load_audio_encoder()
self.runner.audio_adapter = self.runner.load_audio_adapter()
self.runner.model.set_audio_adapter(self.runner.audio_adapter)
@class_try_catch_async_with_thread
async def run(self, inputs, outputs, params, data_manager):
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_video_path, output_video_path = self.prepare_output_video(params, outputs, tmp_dir, data_manager)
await self.prepare_input_audio(params, inputs, tmp_dir, data_manager)
logger.info(f"run params: {params}, {inputs}, {outputs}")
self.set_inputs(params)
await self.prepare_dit_inputs(inputs, data_manager)
self.runner.stop_signal = False
future = asyncio.Future()
self.thread = RunnerThread(asyncio.get_running_loop(), future, self.run_dit, self.rank)
self.thread.start()
status, _ = await future
if not status:
return False
await self.save_output_video(tmp_video_path, output_video_path, data_manager)
torch.cuda.empty_cache()
gc.collect()
return True
def run_dit(self):
self.runner.run_main()
self.runner.process_images_after_vae_decoder(save_video=True)
from abc import ABC
import torch
import torch.distributed as dist
from lightx2v.utils.utils import save_videos_grid
......@@ -147,3 +150,28 @@ class BaseRunner(ABC):
def end_run(self):
pass
def check_stop(self):
"""Check if the stop signal is received"""
rank, world_size = 0, 1
if dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
signal_rank = world_size - 1
stopped = 0
if rank == signal_rank and hasattr(self, "stop_signal") and self.stop_signal:
stopped = 1
if world_size > 1:
if rank == signal_rank:
t = torch.tensor([stopped], dtype=torch.int32).to(device="cuda")
else:
t = torch.zeros(1, dtype=torch.int32, device="cuda")
dist.broadcast(t, src=signal_rank)
stopped = t.item()
print(f"rank {rank} recv stopped: {stopped}")
if stopped == 1:
raise Exception(f"find rank: {rank} stop_signal, stop running, it's an expected behavior")
......@@ -111,6 +111,9 @@ class DefaultRunner(BaseRunner):
if total_steps is None:
total_steps = self.model.scheduler.infer_steps
for step_index in range(total_steps):
# only for single segment, check stop signal every step
if self.video_segment_num == 1:
self.check_stop()
logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
with ProfilingContext4Debug("step_pre"):
......@@ -145,7 +148,10 @@ class DefaultRunner(BaseRunner):
gc.collect()
def read_image_input(self, img_path):
img_ori = Image.open(img_path).convert("RGB")
if isinstance(img_path, Image.Image):
img_ori = img_path
else:
img_ori = Image.open(img_path).convert("RGB")
img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
return img, img_ori
......@@ -219,6 +225,7 @@ class DefaultRunner(BaseRunner):
for segment_idx in range(self.video_segment_num):
logger.info(f"🔄 segment_idx: {segment_idx + 1}/{self.video_segment_num}")
with ProfilingContext(f"segment end2end {segment_idx}"):
self.check_stop()
# 1. default do nothing
self.init_run_segment(segment_idx)
# 2. main inference loop
......
......@@ -15,6 +15,8 @@ from loguru import logger
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
from lightx2v.deploy.common.va_reader import VAReader
from lightx2v.deploy.common.va_recorder import VARecorder
from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import AudioAdapter
from lightx2v.models.input_encoders.hf.seko_audio.audio_encoder import SekoAudioEncoderModel
from lightx2v.models.networks.wan.audio_model import WanAudioModel
......@@ -221,7 +223,7 @@ class AudioProcessor:
def get_audio_range(self, start_frame: int, end_frame: int) -> Tuple[int, int]:
"""Calculate audio range for given frame range"""
audio_frame_rate = self.audio_sr / self.target_fps
return round(start_frame * audio_frame_rate), round((end_frame + 1) * audio_frame_rate)
return round(start_frame * audio_frame_rate), round(end_frame * audio_frame_rate)
def segment_audio(self, audio_array: np.ndarray, expected_frames: int, max_num_frames: int, prev_frame_length: int = 5) -> List[AudioSegment]:
"""Segment audio based on frame requirements"""
......@@ -299,6 +301,8 @@ class WanAudioRunner(WanRunner): # type:ignore
audio_sr = self.config.get("audio_sr", 16000)
target_fps = self.config.get("target_fps", 16)
self._audio_processor = AudioProcessor(audio_sr, target_fps)
if not isinstance(self.config["audio_path"], str):
return [], 0
audio_array = self._audio_processor.load_audio(self.config["audio_path"])
video_duration = self.config.get("video_duration", 5)
......@@ -312,7 +316,10 @@ class WanAudioRunner(WanRunner): # type:ignore
return audio_segments, expected_frames
def read_image_input(self, img_path):
ref_img = Image.open(img_path).convert("RGB")
if isinstance(img_path, Image.Image):
ref_img = img_path
else:
ref_img = Image.open(img_path).convert("RGB")
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
ref_img, h, w = resize_image(ref_img, resize_mode=self.config.get("resize_mode", "adaptive"), fixed_area=self.config.get("fixed_area", None), fixed_shape=self.config.get("fixed_shape", None))
......@@ -449,10 +456,12 @@ class WanAudioRunner(WanRunner): # type:ignore
self.prev_video = None
@ProfilingContext4Debug("Init run segment")
def init_run_segment(self, segment_idx):
def init_run_segment(self, segment_idx, audio_array=None):
self.segment_idx = segment_idx
self.segment = self.inputs["audio_segments"][segment_idx]
if audio_array is not None:
self.segment = AudioSegment(audio_array, 0, audio_array.shape[0], False)
else:
self.segment = self.inputs["audio_segments"][segment_idx]
self.config.seed = self.config.seed + segment_idx
torch.manual_seed(self.config.seed)
......@@ -477,7 +486,7 @@ class WanAudioRunner(WanRunner): # type:ignore
# Extract relevant frames
start_frame = 0 if self.segment_idx == 0 else self.prev_frame_length
start_audio_frame = 0 if self.segment_idx == 0 else int((self.prev_frame_length + 1) * self._audio_processor.audio_sr / self.config.get("target_fps", 16))
start_audio_frame = 0 if self.segment_idx == 0 else int(self.prev_frame_length * self._audio_processor.audio_sr / self.config.get("target_fps", 16))
if self.segment.is_last and self.segment.useful_length:
end_frame = self.segment.end_frame - self.segment.start_frame
......@@ -490,6 +499,14 @@ class WanAudioRunner(WanRunner): # type:ignore
self.gen_video_list.append(self.gen_video[:, :, start_frame:].cpu())
self.cut_audio_list.append(self.segment.audio_array[start_audio_frame:])
if self.va_recorder:
cur_video = vae_to_comfyui_image(self.gen_video_list[-1])
self.va_recorder.pub_livestream(cur_video, self.cut_audio_list[-1])
if self.va_reader:
self.gen_video_list.pop()
self.cut_audio_list.pop()
# Update prev_video for next iteration
self.prev_video = self.gen_video
......@@ -497,6 +514,102 @@ class WanAudioRunner(WanRunner): # type:ignore
del self.gen_video
torch.cuda.empty_cache()
def get_rank_and_world_size(self):
rank = 0
world_size = 1
if dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
return rank, world_size
def init_va_recorder(self):
output_video_path = self.config.get("save_video_path", None)
self.va_recorder = None
if isinstance(output_video_path, dict):
assert output_video_path["type"] == "stream", f"unexcept save_video_path: {output_video_path}"
rank, world_size = self.get_rank_and_world_size()
if rank == 2 % world_size:
record_fps = self.config.get("target_fps", 16)
audio_sr = self.config.get("audio_sr", 16000)
if "video_frame_interpolation" in self.config and self.vfi_model is not None:
record_fps = self.config["video_frame_interpolation"]["target_fps"]
self.va_recorder = VARecorder(
livestream_url=output_video_path["data"],
fps=record_fps,
sample_rate=audio_sr,
)
def init_va_reader(self):
audio_path = self.config.get("audio_path", None)
self.va_reader = None
if isinstance(audio_path, dict):
assert audio_path["type"] == "stream", f"unexcept audio_path: {audio_path}"
rank, world_size = self.get_rank_and_world_size()
target_fps = self.config.get("target_fps", 16)
max_num_frames = self.config.get("target_video_length", 81)
audio_sr = self.config.get("audio_sr", 16000)
prev_frames = self.config.get("prev_frame_length", 5)
self.va_reader = VAReader(
rank=rank,
world_size=world_size,
stream_url=audio_path["data"],
sample_rate=audio_sr,
segment_duration=max_num_frames / target_fps,
prev_duration=prev_frames / target_fps,
target_rank=1,
)
def run_main(self, total_steps=None):
try:
self.init_va_recorder()
self.init_va_reader()
logger.info(f"init va_recorder: {self.va_recorder} and va_reader: {self.va_reader}")
if self.va_reader is None:
return super().run_main(total_steps)
rank, world_size = self.get_rank_and_world_size()
if rank == 2 % world_size:
assert self.va_recorder is not None, "va_recorder is required for stream audio input for rank 0"
self.va_reader.start()
self.init_run()
self.video_segment_num = "unlimited"
fetch_timeout = self.va_reader.segment_duration + 1
segment_idx = 0
fail_count = 0
max_fail_count = 10
while True:
with ProfilingContext4Debug(f"stream segment get audio segment {segment_idx}"):
self.check_stop()
audio_array = self.va_reader.get_audio_segment(timeout=fetch_timeout)
if audio_array is None:
fail_count += 1
logger.warning(f"Failed to get audio chunk {fail_count} times")
if fail_count > max_fail_count:
raise Exception(f"Failed to get audio chunk {fail_count} times, stop reader")
continue
with ProfilingContext4Debug(f"stream segment end2end {segment_idx}"):
fail_count = 0
self.init_run_segment(segment_idx, audio_array)
latents, generator = self.run_segment(total_steps=None)
self.gen_video = self.run_vae_decoder(latents)
self.end_run_segment()
segment_idx += 1
finally:
if hasattr(self.model, "scheduler"):
self.end_run()
if self.va_reader:
self.va_reader.stop()
self.va_reader = None
if self.va_recorder:
self.va_recorder.stop(wait=False)
self.va_recorder = None
@ProfilingContext4Debug("Process after vae decoder")
def process_images_after_vae_decoder(self, save_video=True):
# Merge results
......@@ -515,7 +628,7 @@ class WanAudioRunner(WanRunner): # type:ignore
target_fps=target_fps,
)
if save_video:
if save_video and isinstance(self.config["save_video_path"], str):
if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps"):
fps = self.config["video_frame_interpolation"]["target_fps"]
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment