Unverified Commit 5ffdbeb6 authored by LiangLiu's avatar LiangLiu Committed by GitHub
Browse files

deploy update (#355)

1、frontend vue+vite
2、share task & template
3、x264 rtc stream push
parent 39683e24
<script setup>
import FloatingParticles from '../components/FloatingParticles.vue'
import TopBar from '../components/TopBar.vue'
import LeftBar from '../components/LeftBar.vue'
import { useI18n } from 'vue-i18n'
const { t, locale } = useI18n()
import { loadLanguageAsync, switchLang } from '../utils/i18n'
import {
submitting,
// 任务类型下拉菜单
showTaskTypeMenu,
showModelMenu,
isLoggedIn,
loading,
loginLoading,
initLoading,
downloadLoading,
isLoading,
// 录音相关
isRecording,
recordingDuration,
startRecording,
stopRecording,
formatRecordingDuration,
taskSearchQuery,
currentUser,
models,
tasks,
alert,
showErrorDetails,
showFailureDetails,
confirmDialog,
showConfirmDialog,
showTaskDetailModal,
modalTask,
t2vForm,
i2vForm,
s2vForm,
getCurrentForm,
i2vImagePreview,
s2vImagePreview,
s2vAudioPreview,
getCurrentImagePreview,
getCurrentAudioPreview,
setCurrentImagePreview,
setCurrentAudioPreview,
updateUploadedContentStatus,
availableTaskTypes,
availableModelClasses,
currentTaskHints,
currentHintIndex,
startHintRotation,
stopHintRotation,
filteredTasks,
selectedTaskId,
selectedTask,
selectedTaskFiles,
loadingTaskFiles,
statusFilter,
pagination,
paginationInfo,
currentTaskPage,
taskPageSize,
taskPageInput,
paginationKey,
taskMenuVisible,
toggleTaskMenu,
closeAllTaskMenus,
handleClickOutside,
showAlert,
setLoading,
apiCall,
logout,
loadModels,
sidebarCollapsed,
sidebarWidth,
showExpandHint,
showGlow,
isDefaultStateHidden,
hideDefaultState,
showDefaultState,
isCreationAreaExpanded,
hasUploadedContent,
isContracting,
expandCreationArea,
contractCreationArea,
taskFileCache,
taskFileCacheLoaded,
templateFileCache,
templateFileCacheLoaded,
loadTaskFiles,
downloadFile,
viewFile,
handleImageUpload,
selectTask,
selectModel,
resetForm,
triggerImageUpload,
triggerAudioUpload,
removeImage,
removeAudio,
handleAudioUpload,
loadImageAudioTemplates,
selectImageTemplate,
selectAudioTemplate,
previewAudioTemplate,
getTemplateFile,
imageTemplates,
audioTemplates,
showImageTemplates,
showAudioTemplates,
mediaModalTab,
templatePagination,
templatePaginationInfo,
templateCurrentPage,
templatePageSize,
templatePageInput,
templatePaginationKey,
imageHistory,
audioHistory,
showTemplates,
showHistory,
showPromptModal,
promptModalTab,
submitTask,
fileToBase64,
formatTime,
refreshTasks,
goToPage,
jumpToPage,
getVisiblePages,
goToTemplatePage,
jumpToTemplatePage,
getVisibleTemplatePages,
goToInspirationPage,
jumpToInspirationPage,
getVisibleInspirationPages,
preloadTaskFilesUrl,
preloadTemplateFilesUrl,
loadTaskFilesFromCache,
saveTaskFilesToCache,
getTaskFileFromCache,
setTaskFileToCache,
getTaskFileUrlFromApi,
getTaskFileUrl,
getTaskFileUrlSync,
getTemplateFileUrlFromApi,
getTemplateFileUrl,
getTemplateFileUrlAsync,
loadTemplateFilesFromCache,
saveTemplateFilesToCache,
loadFromCache,
saveToCache,
clearAllCache,
getStatusBadgeClass,
viewSingleResult,
cancelTask,
resumeTask,
deleteTask,
startPollingTask,
stopPollingTask,
reuseTask,
showTaskCreator,
toggleSidebar,
clearPrompt,
getTaskItemClass,
getStatusIndicatorClass,
getTaskTypeBtnClass,
getModelBtnClass,
getTaskTypeIcon,
getTaskTypeName,
getPromptPlaceholder,
getStatusTextClass,
getImagePreview,
getTaskInputUrl,
getTaskInputImage,
getTaskInputAudio,
getHistoryImageUrl,
getUserAvatarUrl,
getCurrentImagePreviewUrl,
getCurrentAudioPreviewUrl,
handleThumbnailError,
handleImageError,
handleImageLoad,
handleAudioError,
handleAudioLoad,
getTaskStatusDisplay,
getTaskStatusColor,
getTaskStatusIcon,
getTaskDuration,
getRelativeTime,
getTaskHistory,
getActiveTasks,
getOverallProgress,
getProgressTitle,
getProgressInfo,
getSubtaskProgress,
getSubtaskStatusText,
formatEstimatedTime,
formatDuration,
searchTasks,
filterTasksByStatus,
filterTasksByType,
getAlertClass,
getAlertBorderClass,
getAlertTextClass,
getAlertIcon,
getAlertIconBgClass,
getPromptTemplates,
selectPromptTemplate,
promptHistory,
getPromptHistory,
addTaskToHistory,
getLocalTaskHistory,
selectPromptHistory,
clearPromptHistory,
getImageHistory,
getAudioHistory,
selectImageHistory,
selectAudioHistory,
previewAudioHistory,
clearImageHistory,
clearAudioHistory,
getAudioMimeType,
getAuthHeaders,
startResize,
sidebar,
switchToCreateView,
switchToProjectsView,
switchToInspirationView,
switchToLoginView,
openTaskDetailModal,
closeTaskDetailModal,
// 灵感广场相关
inspirationSearchQuery,
selectedInspirationCategory,
inspirationItems,
InspirationCategories,
loadInspirationData,
selectInspirationCategory,
handleInspirationSearch,
loadMoreInspiration,
inspirationPagination,
inspirationPaginationInfo,
inspirationCurrentPage,
inspirationPageSize,
inspirationPageInput,
inspirationPaginationKey,
// 工具函数
formatDate,
// 模板详情弹窗相关
showTemplateDetailModal,
selectedTemplate,
previewTemplateDetail,
closeTemplateDetailModal,
useTemplate,
// 图片放大弹窗相关
showImageZoomModal,
zoomedImageUrl,
showImageZoom,
closeImageZoomModal,
// 模板素材应用相关
applyTemplateImage,
applyTemplateAudio,
applyTemplatePrompt,
copyPrompt,
// 视频播放控制
playVideo,
pauseVideo,
toggleVideoPlay,
pauseAllVideos,
updateVideoIcon,
onVideoLoaded,
onVideoError,
onVideoEnded
} from '../utils/other'
import { computed, onMounted} from 'vue';
import {useRouter} from 'vue-router';
import Alert from '../components/Alert.vue'
import Confirm from '../components/Confirm.vue'
import TaskDetails from '../components/TaskDetails.vue'
import TemplateDetails from '../components/TemplateDetails.vue'
import PromptTemplate from '../components/PromptTemplate.vue'
import ImageEdit from '../components/ImageEdit.vue'
import Loading from '../components/Loading.vue'
const router= useRouter();
</script>
<template>
<div class="main-container">
<FloatingParticles />
<div class="main-content-area flex flex-col relative h-full">
<TopBar />
<!-- 下方区域 -->
<div class="flex flex-row relative flex-1 min-h-0">
<LeftBar />
<router-view></router-view>
</div>
</div>
</div>
<Alert />
<Confirm />
<TaskDetails />
<TemplateDetails />
<promptTemplate />
<ImageEdit />
<!-- 全局路由跳转Loading覆盖层 -->
<div v-show="isLoading" class="bg-gradient-main flex items-center justify-center">
<Loading />
</div>
</template>
<script setup>
import FloatingParticles from '../components/FloatingParticles.vue'
import LoginCard from '../components/LoginCard.vue'
import Alert from '../components/Alert.vue'
import Loading from '../components/Loading.vue'
import TemplateDisplay from '../components/TemplateDisplay.vue'
import { isLoading, featuredTemplates, loadFeaturedTemplates, getRandomFeaturedTemplates } from '../utils/other'
import { ref, onMounted } from 'vue'
import { useI18n } from 'vue-i18n'
const { t, locale } = useI18n()
import { loadLanguageAsync, switchLang } from '../utils/i18n'
// 当前显示的精选模版
const currentFeaturedTemplates = ref([])
// 获取随机精选模版
const refreshRandomTemplates = async () => {
try {
const randomTemplates = await getRandomFeaturedTemplates(5) // 获取5个模版
currentFeaturedTemplates.value = randomTemplates
} catch (error) {
console.error('刷新随机模版失败:', error)
}
}
// 组件挂载时初始化
onMounted(async () => {
// 加载精选模版数据
isLoading.value = true
await loadFeaturedTemplates(true)
// 获取随机精选模版
const randomTemplates = await getRandomFeaturedTemplates(5) // 获取5个模版
currentFeaturedTemplates.value = randomTemplates
isLoading.value = false
})
</script>
<template>
<div class="login-container w-full min-h-screen flex items-center justify-center p-4">
<FloatingParticles />
<!-- 主卡片容器 -->
<div class="w-full max-w-7xl mx-auto">
<div class="bg-dark-light/80 backdrop-blur-sm rounded-2xl border border-gray-700/50 shadow-2xl overflow-hidden h-auto lg:h-[100vh]" style="box-shadow: 0 25px 50px -12px rgba(0, 0, 0, 0.8), 0 0 30px rgba(154, 114, 255, 0.3), inset 0 1px 0 rgba(255, 255, 255, 0.1);">
<div class="grid grid-cols-1 lg:grid-cols-2 h-auto lg:h-full">
<!-- 左侧:登录区域 -->
<div class="flex flex-col items-center justify-center">
<LoginCard />
</div>
<!-- 右侧:模版展示区域 -->
<div v-if="currentFeaturedTemplates.length > 0"
class="flex flex-col lg:border-t-0 lg:border-l border-gray-700/50">
<!-- 区域头部 -->
<div class="pt-4 border-gray-700/50">
<div class="flex items-center justify-center">
<p class="text-gray-400 text-sm">{{ t('templatesGeneratedByLightX2V') }}</p>
<button @click="refreshRandomTemplates"
class="w-8 h-8 ml-2 flex items-center justify-center hover:bg-laser-purple/40 text-laser-purple rounded-full transition-all duration-300 hover:scale-110"
:title="t('refreshRandomTemplates')">
<i class="fas fa-random text-sm"></i>
</button>
</div>
</div>
<!-- 可滚动的模版展示区域 -->
<div class="flex-1 overflow-y-auto main-scrollbar">
<TemplateDisplay
:templates="currentFeaturedTemplates"
:show-actions="false"
layout="grid"
:max-templates="4"
/>
</div>
</div>
<!-- 如果没有模版数据,显示占位区域 -->
<div v-else class="flex items-center justify-center border-t lg:border-t-0 lg:border-l border-gray-700/50">
<div class="text-center">
<div class="animate-spin rounded-full h-12 w-12 border-b-2 border-laser-purple mx-auto mb-4"></div>
<p class="text-gray-400">{{ t('loading') }}</p>
</div>
</div>
</div>
</div>
</div>
</div>
<Alert />
<!-- 全局路由跳转Loading覆盖层 -->
<div v-show="isLoading" class="bg-gradient-main flex items-center justify-center">
<Loading />
</div>
</template>
<script setup>
import { ref, onMounted, computed } from 'vue'
import { useRoute, useRouter } from 'vue-router'
import { useI18n } from 'vue-i18n'
import topMenu from '../components/TopBar.vue'
import Loading from '../components/Loading.vue'
import {
isLoading,
selectedTaskId,
getCurrentForm,
setCurrentImagePreview,
setCurrentAudioPreview,
getTemplateFileUrl,
isCreationAreaExpanded,
switchToCreateView,
showAlert,
login
} from '../utils/other'
const { t } = useI18n()
const route = useRoute()
const router = useRouter()
const shareId = computed(() => route.params.shareId)
const shareData = ref(null)
const error = ref(null)
const videoUrl = ref('')
const inputUrls = ref({})
const showDetails = ref(false)
const videoLoading = ref(false)
const videoError = ref(false)
// 获取分享数据
const fetchShareData = async () => {
try {
const response = await fetch(`/api/v1/share/${shareId.value}`)
if (!response.ok) {
throw new Error('分享不存在或已过期')
}
const data = await response.json()
shareData.value = data
// 设置视频URL
if (data.output_video_url) {
videoUrl.value = data.output_video_url
}
// 设置输入素材URL
if (data.input_urls) {
inputUrls.value = data.input_urls
console.log('设置输入素材URL:', data.input_urls)
}
} catch (err) {
error.value = err.message
console.error('获取分享数据失败:', err)
}
}
// 获取分享标题
const getShareTitle = () => {
if (shareData.value?.share_type === 'task') {
// 获取用户名,如果没有则显示默认文本
const username = shareData.value?.username || '用户'
return `${username}${t('userGeneratedVideo')}`
}
return t('templateVideo')
}
// 获取分享描述
const getShareDescription = () => {
return t('description')
}
// 获取分享按钮文本
const getShareButtonText = () => {
switch (shareData.value?.share_type) {
case 'template':
return t('useTemplate')
default:
return t('createSimilar')
}
}
// 视频事件处理
const onVideoLoadStart = () => {
videoLoading.value = true
videoError.value = false
}
const onVideoCanPlay = () => {
videoLoading.value = false
videoError.value = false
}
const onVideoError = () => {
videoLoading.value = false
videoError.value = true
}
// 获取图片素材
const getImageMaterials = () => {
if (!inputUrls.value) return []
const imageMaterials = Object.entries(inputUrls.value).filter(([name, url]) =>
name.includes('image') && url
)
console.log('图片素材:', imageMaterials)
return imageMaterials
}
// 获取音频素材
const getAudioMaterials = () => {
if (!inputUrls.value) return []
const audioMaterials = Object.entries(inputUrls.value).filter(([name, url]) =>
name.includes('audio') && url
)
console.log('音频素材:', audioMaterials)
return audioMaterials
}
// 处理图片加载错误
const handleImageError = (event, inputName, url) => {
console.log('图片加载失败:', inputName, url)
console.log('错误详情:', event)
console.log('图片元素:', event.target)
// 尝试移除crossorigin属性重新加载
const img = event.target
if (img.crossOrigin) {
console.log('尝试移除crossorigin属性重新加载')
img.crossOrigin = null
img.src = url + '?retry=' + Date.now()
}
}
// 做同款功能
const createSimilar = async () => {
const token = localStorage.getItem('accessToken')
if (!token) {
// 未登录,跳转到登录页面,并保存分享ID
localStorage.setItem('shareData', JSON.stringify({ shareId: shareId.value }))
login()
return
}
if (!shareData.value) {
showAlert('分享数据不完整', 'danger')
return
}
console.log('使用分享数据:', shareData.value)
try {
// 先设置任务类型
selectedTaskId.value = shareData.value.task_type
// 获取当前表单
const currentForm = getCurrentForm()
// 设置表单数据
currentForm.prompt = shareData.value.prompt || ''
currentForm.negative_prompt = shareData.value.negative_prompt || ''
currentForm.seed = 42 // 默认种子
currentForm.model_cls = shareData.value.model_cls || ''
currentForm.stage = shareData.value.stage || 'single_stage'
// 如果有输入图片,先设置URL,延迟加载文件
if (shareData.value.inputs && shareData.value.inputs.input_image) {
let imageUrl
if (shareData.value.share_type === 'template') {
// 对于模板,使用模板文件URL
imageUrl = getTemplateFileUrl(shareData.value.inputs.input_image, 'images')
} else {
// 对于任务,使用分享数据中的URL
imageUrl = shareData.value.input_urls?.input_image || shareData.value.input_urls?.[Object.keys(shareData.value.input_urls).find(key => key.includes('image'))]
}
if (imageUrl) {
currentForm.imageUrl = imageUrl
setCurrentImagePreview(imageUrl) // 直接使用URL作为预览
console.log('分享输入图片:', imageUrl)
// 异步加载图片文件(不阻塞UI)
setTimeout(async () => {
try {
const imageResponse = await fetch(imageUrl)
if (imageResponse.ok) {
const blob = await imageResponse.blob()
const filename = shareData.value.inputs.input_image
const file = new File([blob], filename, { type: blob.type })
currentForm.imageFile = file
console.log('分享图片文件已加载')
}
} catch (error) {
console.warn('Failed to load share image file:', error)
}
}, 100)
}
}
// 如果有输入音频,先设置URL,延迟加载文件
if (shareData.value.inputs && shareData.value.inputs.input_audio) {
let audioUrl
if (shareData.value.share_type === 'template') {
// 对于模板,使用模板文件URL
audioUrl = getTemplateFileUrl(shareData.value.inputs.input_audio, 'audios')
} else {
// 对于任务,使用分享数据中的URL
audioUrl = shareData.value.input_urls?.input_audio || shareData.value.input_urls?.[Object.keys(shareData.value.input_urls).find(key => key.includes('audio'))]
}
if (audioUrl) {
currentForm.audioUrl = audioUrl
setCurrentAudioPreview(audioUrl) // 直接使用URL作为预览
console.log('分享输入音频:', audioUrl)
// 异步加载音频文件(不阻塞UI)
setTimeout(async () => {
try {
const audioResponse = await fetch(audioUrl)
if (audioResponse.ok) {
const blob = await audioResponse.blob()
const filename = shareData.value.inputs.input_audio
// 根据文件扩展名确定正确的MIME类型
let mimeType = blob.type
if (!mimeType || mimeType === 'application/octet-stream') {
const ext = filename.toLowerCase().split('.').pop()
const mimeTypes = {
'mp3': 'audio/mpeg',
'wav': 'audio/wav',
'mp4': 'audio/mp4',
'aac': 'audio/aac',
'ogg': 'audio/ogg',
'm4a': 'audio/mp4'
}
mimeType = mimeTypes[ext] || 'audio/mpeg'
}
const file = new File([blob], filename, { type: mimeType })
currentForm.audioFile = file
console.log('分享音频文件已加载')
// 使用FileReader生成data URL,与正常上传保持一致
const reader = new FileReader()
reader.onload = (e) => {
setCurrentAudioPreview(e.target.result)
console.log('分享音频预览已设置:', e.target.result.substring(0, 50) + '...')
}
reader.readAsDataURL(file)
}
} catch (error) {
console.warn('Failed to load share audio file:', error)
}
}, 100)
}
}
// 切换到创建视图
isCreationAreaExpanded.value = true
switchToCreateView()
showAlert(`已应用分享数据`, 'success')
} catch (error) {
console.error('应用分享数据失败:', error)
showAlert(`应用分享数据失败: ${error.message}`, 'danger')
}
}
onMounted(async () => {
await fetchShareData()
isLoading.value = false
})
</script>
<template>
<div class="landing-page">
<!-- TopBar -->
<topMenu />
<!-- 主要内容区域 -->
<div class="main-content main-scrollbar overflow-y-auto">
<!-- 错误状态 -->
<div v-if="error" class="error-container">
<div class="error-content">
<div class="error-icon">
<i class="fas fa-exclamation-triangle"></i>
</div>
<h2 class="error-title">{{ t('shareNotFound') }}</h2>
<p class="error-message">{{ error }}</p>
<button @click="router.push('/')" class="error-button">
<i class="fas fa-home mr-2"></i>
{{ t('backToHome') }}
</button>
</div>
</div>
<!-- 分享内容 -->
<div v-else-if="shareData" class="content-grid">
<!-- 左侧视频区域 -->
<div class="video-section">
<div class="video-container">
<!-- 视频加载占位符 -->
<div v-if="!videoUrl" class="video-placeholder">
<div class="loading-spinner">
<i class="fas fa-spinner fa-spin"></i>
</div>
<p class="loading-text">{{ t('loadingVideo') }}...</p>
</div>
<!-- 视频播放器 -->
<video
v-if="videoUrl"
:src="videoUrl"
class="video-player"
controls
autoplay
loop
preload="metadata"
@loadstart="onVideoLoadStart"
@canplay="onVideoCanPlay"
@error="onVideoError">
{{ t('browserNotSupported') }}
</video>
<!-- 视频错误状态 -->
<div v-if="videoError" class="video-error">
<i class="fas fa-exclamation-triangle"></i>
<p>{{ t('videoNotAvailable') }}</p>
</div>
</div>
</div>
<!-- 右侧信息区域 -->
<div class="info-section">
<div class="info-content">
<!-- 标题 -->
<h1 class="main-title">
{{ getShareTitle() }}
</h1>
<!-- 描述 -->
<p class="main-description">
{{ getShareDescription() }}
</p>
<!-- 特性列表 -->
<div class="features-list">
<div class="feature-item">
<i class="fas fa-rocket feature-icon"></i>
<span class="feature-text">{{ t('latestAIModel') }}</span>
</div>
<div class="feature-item">
<i class="fas fa-bolt feature-icon"></i>
<span class="feature-text">{{ t('oneClickReplication') }}</span>
</div>
<div class="feature-item">
<i class="fas fa-user-cog feature-icon"></i>
<span class="feature-text">{{ t('customizableCharacter') }}</span>
</div>
</div>
<!-- 操作按钮 -->
<div class="action-buttons">
<button @click="createSimilar" class="primary-button">
<i class="fas fa-magic mr-2"></i>
{{ getShareButtonText() }}
</button>
<!-- 详细信息按钮 -->
<button @click="showDetails = !showDetails" class="secondary-button">
<i :class="showDetails ? 'fas fa-chevron-up' : 'fas fa-info-circle'" class="mr-2"></i>
{{ showDetails ? t('hideDetails') : t('showDetails') }}
</button>
</div>
<!-- 技术信息 -->
<div class="tech-info">
<p class="tech-text">
<a href="https://github.com/ModelTC/LightX2V" target="_blank" rel="noopener noreferrer" class="tech-link">
{{ t('poweredByLightX2V') }}
</a>
</p>
</div>
</div>
</div>
</div>
</div>
<!-- 详细信息面板 -->
<div v-if="showDetails && shareData" class="details-panel">
<div class="details-content">
<!-- 输入素材标题 -->
<div class="materials-header">
<h2 class="materials-title">
<i class="fas fa-upload mr-2"></i>
{{ t('inputMaterials') }}
</h2>
</div>
<!-- 三个并列的分块卡片 -->
<div class="materials-cards">
<!-- 图片卡片 -->
<div class="material-card">
<div class="card-header">
<i class="fas fa-image card-icon"></i>
<h3 class="card-title">{{ t('image') }}</h3>
</div>
<div class="card-content">
<div v-if="getImageMaterials().length > 0" class="image-grid">
<div v-for="[inputName, url] in getImageMaterials()" :key="inputName" class="image-item">
<div class="image-container">
<img :src="url" :alt="inputName" class="image-preview"
@load="console.log('图片加载成功:', inputName, url)"
@error="handleImageError($event, inputName, url)">
<div class="image-placeholder" v-if="!url">
<i class="fas fa-image"></i>
</div>
<div class="image-error-placeholder" v-if="false">
<i class="fas fa-exclamation-triangle"></i>
<p>图片加载失败</p>
</div>
</div>
\\
</div>
</div>
<div v-else class="empty-state">
<i class="fas fa-image empty-icon"></i>
<p class="empty-text">{{ t('noImage') }}</p>
</div>
</div>
</div>
<!-- 音频卡片 -->
<div class="material-card">
<div class="card-header">
<i class="fas fa-music card-icon"></i>
<h3 class="card-title">{{ t('audio') }}</h3>
</div>
<div class="card-content">
<div v-if="getAudioMaterials().length > 0" class="audio-list">
<div v-for="[inputName, url] in getAudioMaterials()" :key="inputName" class="audio-item">
<audio :src="url" controls class="audio-player"></audio>
</div>
</div>
<div v-else class="empty-state">
<i class="fas fa-music empty-icon"></i>
<p class="empty-text">{{ t('noAudio') }}</p>
</div>
</div>
</div>
<!-- 提示词卡片 -->
<div class="material-card">
<div class="card-header">
<i class="fas fa-file-alt card-icon"></i>
<h3 class="card-title">{{ t('prompt') }}</h3>
</div>
<div class="card-content">
<div v-if="shareData.prompt" class="prompt-content">
<p class="prompt-text">{{ shareData.prompt }}</p>
</div>
<div v-else class="empty-state">
<i class="fas fa-file-alt empty-icon"></i>
<p class="empty-text">{{ t('noPrompt') }}</p>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
<!-- 全局路由跳转Loading覆盖层 -->
<div v-show="isLoading" class="loading-overlay">
<Loading />
</div>
</template>
<style scoped>
/* Landing Page 样式 */
.landing-page {
min-height: 100vh;
width: 100%;
background: linear-gradient(135deg, #0f0f23 0%, #1a1a2e 50%, #16213e 100%);
color: white;
}
.main-content {
width: 100%;
padding: 2rem 0;
min-height: calc(100vh - 80px);
background: linear-gradient(135deg, #0f0f23 0%, #1a1a2e 50%, #16213e 100%);
}
/* 错误状态 */
.error-container {
display: flex;
align-items: center;
justify-content: center;
min-height: 60vh;
}
.error-content {
text-align: center;
max-width: 500px;
padding: 2rem;
}
.error-icon {
width: 80px;
height: 80px;
margin: 0 auto 1.5rem;
background: rgba(239, 68, 68, 0.1);
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
font-size: 2rem;
color: #ef4444;
}
.error-title {
font-size: 1.5rem;
font-weight: 600;
margin-bottom: 1rem;
color: white;
}
.error-message {
color: #9ca3af;
margin-bottom: 2rem;
line-height: 1.6;
}
.error-button {
padding: 0.75rem 1.5rem;
background: rgba(139, 92, 246, 0.2);
border: 1px solid rgba(139, 92, 246, 0.4);
border-radius: 0.75rem;
color: white;
font-weight: 500;
transition: all 0.2s;
cursor: pointer;
}
.error-button:hover {
background: rgba(139, 92, 246, 0.3);
border-color: rgba(139, 92, 246, 0.6);
transform: translateY(-1px);
}
/* 内容网格布局 */
.content-grid {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 4rem;
width: 100%;
margin: 0 auto;
padding: 0 2rem;
align-items: center;
min-height: 60vh;
}
/* 视频区域 */
.video-section {
display: flex;
justify-content: center;
align-items: center;
}
.video-container {
width: 100%;
max-width: 500px;
aspect-ratio: 9/16;
background: #000;
border-radius: 1rem;
overflow: hidden;
box-shadow: 0 25px 50px -12px rgba(0, 0, 0, 0.4);
position: relative;
}
.video-placeholder {
width: 100%;
height: 100%;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
background: #1f2937;
}
.loading-spinner {
font-size: 2rem;
color: #8b5cf6;
margin-bottom: 1rem;
}
.loading-text {
color: #9ca3af;
font-size: 0.875rem;
}
.video-player {
width: 100%;
height: 100%;
object-fit: contain;
}
.video-error {
width: 100%;
height: 100%;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
background: #1f2937;
color: #ef4444;
}
.video-error i {
font-size: 2rem;
margin-bottom: 1rem;
}
/* 信息区域 */
.info-section {
display: flex;
align-items: center;
justify-content: center;
}
.info-content {
max-width: 500px;
padding: 2rem 0;
}
.main-title {
font-size: 3rem;
font-weight: 700;
margin-bottom: 1.5rem;
background: linear-gradient(135deg, #8b5cf6, #a855f7);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
line-height: 1.2;
}
.main-description {
font-size: 1.25rem;
color: #d1d5db;
margin-bottom: 2.5rem;
line-height: 1.6;
}
/* 特性列表 */
.features-list {
margin-bottom: 2.5rem;
}
.feature-item {
display: flex;
align-items: center;
margin-bottom: 1rem;
padding: 0.75rem 0;
}
.feature-icon {
width: 40px;
height: 40px;
background: rgba(139, 92, 246, 0.1);
border-radius: 0.5rem;
display: flex;
align-items: center;
justify-content: center;
margin-right: 1rem;
color: #8b5cf6;
font-size: 1.125rem;
}
.feature-text {
font-size: 1rem;
color: #e5e7eb;
font-weight: 500;
}
/* 操作按钮 */
.action-buttons {
margin-bottom: 2rem;
}
.primary-button {
width: 100%;
padding: 1rem 2rem;
background: linear-gradient(135deg, #8b5cf6, #a855f7);
border: none;
border-radius: 0.75rem;
color: white;
font-size: 1.125rem;
font-weight: 600;
cursor: pointer;
transition: all 0.2s;
margin-bottom: 1rem;
display: flex;
align-items: center;
justify-content: center;
}
.primary-button:hover {
background: linear-gradient(135deg, #7c3aed, #9333ea);
transform: translateY(-2px);
box-shadow: 0 10px 25px -5px rgba(139, 92, 246, 0.4);
}
.secondary-button {
width: 100%;
padding: 0.75rem 1.5rem;
background: rgba(255, 255, 255, 0.1);
border: 1px solid rgba(255, 255, 255, 0.2);
border-radius: 0.5rem;
color: white;
font-size: 0.875rem;
cursor: pointer;
transition: all 0.2s;
display: flex;
align-items: center;
justify-content: center;
}
.secondary-button:hover {
background: rgba(255, 255, 255, 0.15);
border-color: rgba(255, 255, 255, 0.3);
}
/* 技术信息 */
.tech-info {
text-align: center;
padding-top: 1rem;
border-top: 1px solid rgba(255, 255, 255, 0.1);
}
.tech-text {
color: #9ca3af;
font-size: 0.875rem;
font-weight: 500;
}
.tech-link {
color: #9ca3af;
text-decoration: underline;
transition: color 0.3s ease;
}
.tech-link:hover {
color: #8b5cf6;
}
/* 详细信息面板 */
.details-panel {
background: linear-gradient(135deg, #0f0f23 0%, #1a1a2e 50%, #16213e 100%);
padding: 5rem 0;
}
.details-content {
width: 100%;
margin: 0 auto;
padding: 0 2rem;
}
/* 输入素材标题 */
.materials-header {
text-align: center;
margin-bottom: 2rem;
}
.materials-title {
font-size: 1.5rem;
font-weight: 600;
color: white;
display: flex;
align-items: center;
justify-content: center;
margin: 0;
}
/* 三个并列的卡片 */
.materials-cards {
display: grid;
grid-template-columns: repeat(3, 1fr);
gap: 2rem;
}
.material-card {
background: rgba(255, 255, 255, 0.08);
border-radius: 1rem;
border: 1px solid rgba(255, 255, 255, 0.1);
overflow: hidden;
transition: all 0.3s ease;
}
.material-card:hover {
background: rgba(255, 255, 255, 0.12);
border-color: rgba(139, 92, 246, 0.3);
transform: translateY(-2px);
}
.card-header {
background: rgba(139, 92, 246, 0.1);
padding: 1rem;
border-bottom: 1px solid rgba(255, 255, 255, 0.1);
display: flex;
align-items: center;
gap: 0.75rem;
}
.card-icon {
font-size: 1.25rem;
color: #8b5cf6;
}
.card-title {
font-size: 1.125rem;
font-weight: 600;
color: white;
margin: 0;
}
.card-content {
padding: 1.5rem;
min-height: 200px;
}
/* 图片网格 */
.image-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(150px, 1fr));
gap: 1rem;
}
.image-item {
text-align: center;
}
.image-container {
position: relative;
width: 100%;
min-height: 120px;
margin-bottom: 0.5rem;
border-radius: 0.5rem;
border: 1px solid rgba(255, 255, 255, 0.1);
overflow: hidden;
background: rgba(255, 255, 255, 0.05);
display: flex;
align-items: center;
justify-content: center;
}
.image-preview {
max-width: 100%;
min-height: 80px;
height: auto;
width: auto;
object-fit: contain;
display: block;
position: relative !important;
}
.image-placeholder {
width: 100%;
height: 80px;
display: flex;
align-items: center;
justify-content: center;
background: rgba(255, 255, 255, 0.1);
color: #9ca3af;
font-size: 1.5rem;
}
.image-error-placeholder {
width: 100%;
height: 80px;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
background: rgba(239, 68, 68, 0.1);
color: #ef4444;
font-size: 0.875rem;
text-align: center;
}
.image-error-placeholder i {
font-size: 1.5rem;
margin-bottom: 0.25rem;
}
.image-label {
font-size: 0.75rem;
color: #9ca3af;
margin: 0;
word-break: break-all;
}
.debug-url {
font-size: 0.6rem;
color: #6b7280;
margin: 0.25rem 0 0 0;
word-break: break-all;
opacity: 0.7;
}
/* 音频列表 */
.audio-list {
display: flex;
flex-direction: column;
gap: 1rem;
}
.audio-item {
text-align: center;
}
.audio-player {
width: 100%;
height: 40px;
margin-bottom: 0.5rem;
}
.audio-label {
font-size: 0.75rem;
color: #9ca3af;
margin: 0;
word-break: break-all;
}
/* 提示词内容 */
.prompt-content {
background: rgba(255, 255, 255, 0.05);
border-radius: 0.5rem;
padding: 1rem;
border: 1px solid rgba(255, 255, 255, 0.1);
}
.prompt-text {
color: #d1d5db;
line-height: 1.6;
margin: 0;
word-break: break-word;
}
/* 空状态 */
.empty-state {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
height: 120px;
color: #6b7280;
}
.empty-icon {
font-size: 2rem;
margin-bottom: 0.5rem;
opacity: 0.5;
}
.empty-text {
font-size: 0.875rem;
margin: 0;
opacity: 0.7;
}
.loading-overlay {
position: fixed;
top: 0;
left: 0;
right: 0;
bottom: 0;
background: linear-gradient(135deg, #0f0f23 0%, #1a1a2e 50%, #16213e 100%);
display: flex;
align-items: center;
justify-content: center;
z-index: 9999;
}
/* 响应式设计 */
@media (max-width: 1024px) {
.content-grid {
gap: 3rem;
padding: 0 1.5rem;
}
.main-title {
font-size: 2.5rem;
}
.video-container {
max-width: 400px;
}
/* 卡片响应式 */
.materials-cards {
grid-template-columns: 1fr;
gap: 1.5rem;
}
}
@media (max-width: 768px) {
.main-content {
padding: 1rem 0;
}
.content-grid {
grid-template-columns: 1fr;
gap: 2rem;
padding: 0 1rem;
}
.main-title {
font-size: 2rem;
}
.main-description {
font-size: 1.125rem;
}
.video-container {
max-width: 300px;
}
.info-content {
padding: 1rem 0;
}
.details-content {
padding: 0 1rem;
}
/* 移动端卡片调整 */
.materials-cards {
gap: 1rem;
}
.card-content {
padding: 1rem;
min-height: 150px;
}
.image-grid {
grid-template-columns: repeat(auto-fit, minmax(120px, 1fr));
}
.materials-title {
font-size: 1.25rem;
}
}
</style>
import { defineConfig } from 'vite'
import vue from '@vitejs/plugin-vue'
import tailwindcss from '@tailwindcss/vite'
// https://vite.dev/config/
export default defineConfig({
plugins: [vue(), tailwindcss()],
})
......@@ -252,7 +252,7 @@ class ServerMonitor:
# check if user has too many daily tasks
daily_statuses = active_statuses + [TaskStatus.SUCCEED, TaskStatus.CANCEL, TaskStatus.FAILED]
daily_tasks = await self.task_manager.list_tasks(status=daily_statuses, user_id=user_id, start_created_t=cur_t - 86400)
daily_tasks = await self.task_manager.list_tasks(status=daily_statuses, user_id=user_id, start_created_t=cur_t - 86400, include_delete=True)
if len(daily_tasks) >= self.user_max_daily_tasks:
return f"User {user_id} has too many daily tasks, {len(daily_tasks)} vs {self.user_max_daily_tasks}"
......
../frontend/dist/assets
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
../frontend/dist/index.html
\ No newline at end of file
......@@ -66,6 +66,12 @@ class BaseTaskManager:
async def delete_task(self, task_id, user_id=None):
raise NotImplementedError
async def insert_share(self, share_info):
raise NotImplementedError
async def query_share(self, share_id):
raise NotImplementedError
def fmt_dict(self, data):
for k in ["status"]:
if k in data:
......@@ -76,6 +82,29 @@ class BaseTaskManager:
if k in data:
data[k] = TaskStatus[data[k]]
async def create_share(self, task_id, user_id, share_type, valid_days, auth_type, auth_value):
assert share_type in ["task", "template"], f"do not support {share_type} share type!"
assert auth_type in ["public", "login", "user_id"], f"do not support {auth_type} auth type!"
assert valid_days > 0, f"valid_days must be greater than 0!"
share_id = str(uuid.uuid4())
cur_t = current_time()
share_info = {
"share_id": share_id,
"task_id": task_id,
"user_id": user_id,
"share_type": share_type,
"create_t": cur_t,
"update_t": cur_t,
"valid_days": valid_days,
"valid_t": cur_t + valid_days * 24 * 3600,
"auth_type": auth_type,
"auth_value": auth_value,
"extra_info": "",
"tag": "",
}
assert await self.insert_share(share_info), f"create share {share_info} failed"
return share_id
async def create_user(self, user_info):
assert user_info["source"] in ["github", "google", "phone"], f"do not support {user_info['source']} user!"
cur_t = current_time()
......
......@@ -21,13 +21,13 @@ class LocalTaskManager(BaseTaskManager):
def fmt_dict(self, data):
super().fmt_dict(data)
for k in ["create_t", "update_t", "ping_t"]:
for k in ["create_t", "update_t", "ping_t", "valid_t"]:
if k in data:
data[k] = time2str(data[k])
def parse_dict(self, data):
super().parse_dict(data)
for k in ["create_t", "update_t", "ping_t"]:
for k in ["create_t", "update_t", "ping_t", "valid_t"]:
if k in data:
data[k] = str2time(data[k])
......@@ -46,6 +46,8 @@ class LocalTaskManager(BaseTaskManager):
task, subtasks = info["task"], info["subtasks"]
if user_id is not None and task["user_id"] != user_id:
raise Exception(f"Task {task_id} is not belong to user {user_id}")
if task["tag"] == "delete":
raise Exception(f"Task {task_id} is deleted")
self.parse_dict(task)
if only_task:
return task
......@@ -93,6 +95,8 @@ class LocalTaskManager(BaseTaskManager):
continue
if "end_ping_t" in kwargs and kwargs["end_ping_t"] < task["ping_t"]:
continue
if not kwargs.get("include_delete", False) and task.get("tag", "") == "delete":
continue
# 如果不是查询子任务,则添加子任务信息到任务中
if not kwargs.get("subtasks", False):
......@@ -292,16 +296,40 @@ class LocalTaskManager(BaseTaskManager):
@class_try_catch_async
async def delete_task(self, task_id, user_id=None):
task = self.load(task_id, user_id, only_task=True)
task, subtasks = self.load(task_id, user_id)
# only allow to delete finished tasks
if task["status"] not in FinishedStatus:
return False
# delete task file
task_file = self.get_task_filename(task_id)
if os.path.exists(task_file):
os.remove(task_file)
task["tag"] = "delete"
task["update_t"] = current_time()
self.save(task, subtasks)
return True
def get_share_filename(self, share_id):
return os.path.join(self.local_dir, f"share_{share_id}.json")
@class_try_catch_async
async def insert_share(self, share_info):
fpath = self.get_share_filename(share_info["share_id"])
self.fmt_dict(share_info)
with open(fpath, "w") as fout:
fout.write(json.dumps(share_info, indent=4, ensure_ascii=False))
return True
@class_try_catch_async
async def query_share(self, share_id):
fpath = self.get_share_filename(share_id)
if not os.path.exists(fpath):
return None
data = json.load(open(fpath))
self.parse_dict(data)
if data["tag"] == "delete":
raise Exception(f"Share {share_id} is deleted")
if data["valid_t"] < current_time():
raise Exception(f"Share {share_id} has expired")
return data
@class_try_catch_async
async def insert_user_if_not_exists(self, user_info):
fpath = self.get_user_filename(user_info["user_id"])
......
......@@ -17,6 +17,7 @@ class PostgresSQLTaskManager(BaseTaskManager):
self.table_subtasks = "subtasks"
self.table_users = "users"
self.table_versions = "versions"
self.table_shares = "shares"
self.pool = None
self.metrics_monitor = metrics_monitor
......@@ -29,7 +30,7 @@ class PostgresSQLTaskManager(BaseTaskManager):
def fmt_dict(self, data):
super().fmt_dict(data)
for k in ["create_t", "update_t", "ping_t"]:
for k in ["create_t", "update_t", "ping_t", "valid_t"]:
if k in data and isinstance(data[k], float):
data[k] = datetime.fromtimestamp(data[k])
for k in ["params", "extra_info", "inputs", "outputs", "previous"]:
......@@ -41,7 +42,7 @@ class PostgresSQLTaskManager(BaseTaskManager):
for k in ["params", "extra_info", "inputs", "outputs", "previous"]:
if k in data:
data[k] = json.loads(data[k])
for k in ["create_t", "update_t", "ping_t"]:
for k in ["create_t", "update_t", "ping_t", "valid_t"]:
if k in data:
data[k] = data[k].timestamp()
......@@ -69,7 +70,7 @@ class PostgresSQLTaskManager(BaseTaskManager):
async def upgrade_db(self):
versions = [
(1, "Init tables", self.upgrade_v1),
# (2, "Add new fields or indexes", self.upgrade_v2),
(2, "Add shares table", self.upgrade_v2),
]
logger.info(f"upgrade_db: {self.db_url}")
cur_ver = await self.query_version()
......@@ -171,8 +172,40 @@ class PostgresSQLTaskManager(BaseTaskManager):
finally:
await self.release_conn(conn)
async def upgrade_v2(self, version, description):
conn = await self.get_conn()
try:
async with conn.transaction(isolation="read_uncommitted"):
# create shares table
await conn.execute(f"""
CREATE TABLE IF NOT EXISTS {self.table_shares} (
share_id VARCHAR(128) PRIMARY KEY,
task_id VARCHAR(128),
user_id VARCHAR(256),
share_type VARCHAR(32),
create_t TIMESTAMPTZ,
update_t TIMESTAMPTZ,
valid_days INTEGER,
valid_t TIMESTAMPTZ,
auth_type VARCHAR(32),
auth_value VARCHAR(256),
extra_info JSONB,
tag VARCHAR(64),
FOREIGN KEY (user_id) REFERENCES {self.table_users}(user_id) ON DELETE CASCADE
)
""")
# update version
await conn.execute(f"INSERT INTO {self.table_versions} (version, description, create_t) VALUES ($1, $2, $3)", version, description, datetime.now())
return True
except: # noqa
logger.error(f"upgrade_v1 error: {traceback.format_exc()}")
return False
finally:
await self.release_conn(conn)
async def load(self, conn, task_id, user_id=None, only_task=False, worker_name=None):
query = f"SELECT * FROM {self.table_tasks} WHERE task_id = $1"
query = f"SELECT * FROM {self.table_tasks} WHERE task_id = $1 AND tag != 'delete'"
params = [task_id]
if user_id is not None:
query += " AND user_id = $2"
......@@ -357,6 +390,10 @@ class PostgresSQLTaskManager(BaseTaskManager):
assert "user_id" not in kwargs, "user_id is not allowed when subtasks is True"
else:
query += self.table_tasks
if not kwargs.get("include_delete", False):
param_idx += 1
conds.append(f"tag != ${param_idx}")
params.append("delete")
if "status" in kwargs:
param_idx += 1
......@@ -748,10 +785,9 @@ class PostgresSQLTaskManager(BaseTaskManager):
logger.warning(f"Cannot delete task {task_id} with status {task['status']}, only finished tasks can be deleted")
return False
# delete subtasks & task record
await conn.execute(f"DELETE FROM {self.table_subtasks} WHERE task_id = $1", task_id)
await conn.execute(f"DELETE FROM {self.table_tasks} WHERE task_id = $1", task_id)
logger.info(f"Task {task_id} and its subtasks deleted successfully")
# delete task record
await conn.execute(f"UPDATE {self.table_tasks} SET tag = 'delete', update_t = $1 WHERE task_id = $2", datetime.now(), task_id)
logger.info(f"Task {task_id} deleted successfully")
return True
except: # noqa
......@@ -760,6 +796,53 @@ class PostgresSQLTaskManager(BaseTaskManager):
finally:
await self.release_conn(conn)
@class_try_catch_async
async def insert_share(self, share_info):
conn = await self.get_conn()
try:
async with conn.transaction(isolation="read_uncommitted"):
self.fmt_dict(share_info)
await conn.execute(
f"""INSERT INTO {self.table_shares}
(share_id, task_id, user_id, share_type, create_t, update_t,
valid_days, valid_t, auth_type, auth_value, extra_info, tag)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
""",
share_info["share_id"],
share_info["task_id"],
share_info["user_id"],
share_info["share_type"],
share_info["create_t"],
share_info["update_t"],
share_info["valid_days"],
share_info["valid_t"],
share_info["auth_type"],
share_info["auth_value"],
share_info["extra_info"],
share_info["tag"],
)
return True
except: # noqa
logger.error(f"create_share_link error: {traceback.format_exc()}")
return False
finally:
await self.release_conn(conn)
@class_try_catch_async
async def query_share(self, share_id):
conn = await self.get_conn()
try:
async with conn.transaction(isolation="read_uncommitted"):
row = await conn.fetchrow(f"SELECT * FROM {self.table_shares} WHERE share_id = $1 AND tag != 'delete' AND valid_t >= $2", share_id, datetime.now())
share = dict(row)
self.parse_dict(share)
return share
except: # noqa
logger.error(f"query_share error: {traceback.format_exc()}")
return None
finally:
await self.release_conn(conn)
@class_try_catch_async
async def insert_user_if_not_exists(self, user_info):
conn = await self.get_conn()
......
......@@ -290,6 +290,23 @@ async def shutdown(loop):
loop.call_later(2, force_exit)
# align args like infer.py
def align_args(args):
args.seed = 42
args.sf_model_path = ""
args.use_prompt_enhancer = False
args.prompt = ""
args.negative_prompt = ""
args.image_path = ""
args.last_frame_path = ""
args.audio_path = ""
args.src_ref_images = None
args.src_video = None
args.src_mask = None
args.save_result_path = ""
args.return_result_tensor = False
# =========================
# Main Entry
# =========================
......@@ -319,6 +336,7 @@ if __name__ == "__main__":
parser.add_argument("--data_url", type=str, default=dft_data_url)
args = parser.parse_args()
align_args(args)
if args.identity == "":
# TODO: spec worker instance identity by k8s env
args.identity = "worker-" + str(uuid.uuid4())[:8]
......
import asyncio
import copy
import ctypes
import gc
import json
......@@ -14,6 +13,7 @@ from loguru import logger
from lightx2v.deploy.common.utils import class_try_catch_async
from lightx2v.infer import init_runner # noqa
from lightx2v.utils.input_info import set_input_info
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.set_config import set_config, set_parallel_config
......@@ -23,42 +23,39 @@ from lightx2v.utils.utils import seed_all
class BaseWorker:
@ProfilingContext4DebugL1("Init Worker Worker Cost:")
def __init__(self, args):
args.save_result_path = ""
config = set_config(args)
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
seed_all(config.seed)
seed_all(args.seed)
self.rank = 0
self.world_size = 1
if config.parallel:
if config["parallel"]:
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
set_parallel_config(config)
seed_all(config.seed)
# same as va_recorder rank and worker main ping rank
self.out_video_rank = self.world_size - 1
torch.set_grad_enabled(False)
self.runner = RUNNER_REGISTER[config.model_cls](config)
# fixed config
self.fixed_config = copy.deepcopy(self.runner.config)
self.runner = RUNNER_REGISTER[config["model_cls"]](config)
self.input_info = set_input_info(args)
def update_config(self, kwargs):
def update_input_info(self, kwargs):
for k, v in kwargs.items():
setattr(self.runner.config, k, v)
setattr(self.input_info, k, v)
def set_inputs(self, params):
self.runner.config["prompt"] = params["prompt"]
self.runner.config["negative_prompt"] = params.get("negative_prompt", "")
self.runner.config["image_path"] = params.get("image_path", "")
self.runner.config["save_result_path"] = params.get("save_result_path", "")
self.runner.config["seed"] = params.get("seed", self.fixed_config.get("seed", 42))
self.runner.config["audio_path"] = params.get("audio_path", "")
self.input_info.prompt = params["prompt"]
self.input_info.negative_prompt = params.get("negative_prompt", "")
self.input_info.image_path = params.get("image_path", "")
self.input_info.save_result_path = params.get("save_result_path", "")
self.input_info.seed = params.get("seed", self.input_info.seed)
self.input_info.audio_path = params.get("audio_path", "")
async def prepare_input_image(self, params, inputs, tmp_dir, data_manager):
input_image_path = inputs.get("input_image", "")
tmp_image_path = os.path.join(tmp_dir, input_image_path)
# prepare tmp image
if self.runner.config.task == "i2v":
if "image_path" in self.input_info.__dataclass_fields__:
img_data = await data_manager.load_bytes(input_image_path)
with open(tmp_image_path, "wb") as fout:
fout.write(img_data)
......@@ -101,7 +98,7 @@ class BaseWorker:
text_encoder_output = await data_manager.load_object(text_out, device)
image_encoder_output = None
if self.runner.config.task == "i2v":
if "image_path" in self.input_info.__dataclass_fields__:
clip_path = inputs["clip_encoder_output"]
vae_path = inputs["vae_encoder_output"]
clip_encoder_out = await data_manager.load_object(clip_path, device)
......@@ -111,7 +108,7 @@ class BaseWorker:
"vae_encoder_out": vae_encoder_out["vals"],
}
# apploy the config changes by vae encoder
self.update_config(vae_encoder_out["kwargs"])
self.update_input_info(vae_encoder_out["kwargs"])
self.runner.inputs = {
"text_encoder_output": text_encoder_output,
......@@ -130,7 +127,7 @@ class BaseWorker:
await data_manager.save_bytes(video_data, output_video_path)
def is_audio_model(self):
return "audio" in self.runner.config.model_cls or "seko_talk" in self.runner.config.model_cls
return "audio" in self.runner.config["model_cls"] or "seko_talk" in self.runner.config["model_cls"]
class RunnerThread(threading.Thread):
......@@ -207,7 +204,7 @@ class PipelineWorker(BaseWorker):
self.runner.stop_signal = False
future = asyncio.Future()
self.thread = RunnerThread(asyncio.get_running_loop(), future, self.run_func, self.rank)
self.thread = RunnerThread(asyncio.get_running_loop(), future, self.run_func, self.rank, input_info=self.input_info)
self.thread.start()
status, _ = await future
if not status:
......@@ -233,7 +230,7 @@ class TextEncoderWorker(BaseWorker):
if self.runner.config["use_prompt_enhancer"]:
prompt = self.runner.config["prompt_enhanced"]
if self.runner.config.task == "i2v" and not self.is_audio_model():
if self.runner.config["task"] == "i2v" and not self.is_audio_model():
img = await data_manager.load_image(input_image_path)
img = self.runner.read_image_input(img)
if isinstance(img, tuple):
......@@ -292,9 +289,9 @@ class VaeEncoderWorker(BaseWorker):
vals = self.runner.run_vae_encoder(img)
out = {"vals": vals, "kwargs": {}}
for key in ["lat_h", "lat_w", "tgt_h", "tgt_w"]:
if hasattr(self.runner.config, key):
out["kwargs"][key] = int(getattr(self.runner.config, key))
for key in ["original_shape", "resized_shape", "latent_shape", "target_shape"]:
if hasattr(self.input_info, key):
out["kwargs"][key] = getattr(self.input_info, key)
if self.rank == 0:
await data_manager.save_object(out, outputs["vae_encoder_output"])
......
......@@ -19,6 +19,7 @@ from torchvision.transforms.functional import resize
from lightx2v.deploy.common.va_reader import VAReader
from lightx2v.deploy.common.va_recorder import VARecorder
from lightx2v.deploy.common.va_x64_recorder import VAX64Recorder
from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import AudioAdapter
from lightx2v.models.input_encoders.hf.seko_audio.audio_encoder import SekoAudioEncoderModel
from lightx2v.models.networks.wan.audio_model import WanAudioModel
......@@ -343,6 +344,9 @@ class WanAudioRunner(WanRunner): # type:ignore
target_fps = self.config.get("target_fps", 16)
self._audio_processor = AudioProcessor(audio_sr, target_fps)
if not isinstance(audio_path, str):
return [], 0, None, 0
# Get audio files from person objects or legacy format
audio_files, mask_files = self.get_audio_files_from_audio_path(audio_path)
......@@ -571,8 +575,9 @@ class WanAudioRunner(WanRunner): # type:ignore
def init_run_segment(self, segment_idx, audio_array=None):
self.segment_idx = segment_idx
if audio_array is not None:
end_idx = audio_array.shape[1] // self._audio_processor.audio_frame_rate - self.prev_frame_length
self.segment = AudioSegment(audio_array, 0, end_idx)
end_idx = audio_array.shape[0] // self._audio_processor.audio_frame_rate - self.prev_frame_length
audio_tensor = torch.Tensor(audio_array).float().unsqueeze(0)
self.segment = AudioSegment(audio_tensor, 0, end_idx)
else:
self.segment = self.inputs["audio_segments"][segment_idx]
......@@ -648,6 +653,16 @@ class WanAudioRunner(WanRunner): # type:ignore
audio_sr = self.config.get("audio_sr", 16000)
if "video_frame_interpolation" in self.config and self.vfi_model is not None:
record_fps = self.config["video_frame_interpolation"]["target_fps"]
whip_shared_path = os.getenv("WHIP_SHARED_LIB", None)
if whip_shared_path and output_video_path.startswith("http"):
self.va_recorder = VAX64Recorder(
whip_shared_path=whip_shared_path,
livestream_url=output_video_path,
fps=record_fps,
sample_rate=audio_sr,
)
else:
self.va_recorder = VARecorder(
livestream_url=output_video_path,
fps=record_fps,
......@@ -655,7 +670,7 @@ class WanAudioRunner(WanRunner): # type:ignore
)
def init_va_reader(self):
audio_path = self.config.get("audio_path", None)
audio_path = self.input_info.audio_path
self.va_reader = None
if isinstance(audio_path, dict):
assert audio_path["type"] == "stream", f"unexcept audio_path: {audio_path}"
......@@ -683,10 +698,13 @@ class WanAudioRunner(WanRunner): # type:ignore
if self.va_reader is None:
return super().run_main(total_steps)
self.va_reader.start()
rank, world_size = self.get_rank_and_world_size()
if rank == world_size - 1:
assert self.va_recorder is not None, "va_recorder is required for stream audio input for rank 2"
self.va_reader.start()
self.va_recorder.start(self.input_info.target_shape[1], self.input_info.target_shape[0])
if world_size > 1:
dist.barrier()
self.init_run()
if self.config.get("compile", False):
......@@ -714,7 +732,7 @@ class WanAudioRunner(WanRunner): # type:ignore
self.init_run_segment(segment_idx, audio_array)
latents = self.run_segment(total_steps=None)
self.gen_video = self.run_vae_decoder(latents)
self.end_run_segment()
self.end_run_segment(segment_idx)
segment_idx += 1
finally:
......@@ -724,7 +742,7 @@ class WanAudioRunner(WanRunner): # type:ignore
self.va_reader.stop()
self.va_reader = None
if self.va_recorder:
self.va_recorder.stop(wait=False)
self.va_recorder.stop()
self.va_recorder = None
@ProfilingContext4DebugL1("Process after vae decoder")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment