Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
dea872a2
Unverified
Commit
dea872a2
authored
Nov 26, 2025
by
PengGao
Committed by
GitHub
Nov 26, 2025
Browse files
Api image (#515)
parent
1892a3db
Changes
36
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
755 additions
and
537 deletions
+755
-537
lightx2v/server/media/image.py
lightx2v/server/media/image.py
+68
-0
lightx2v/server/run_server.py
lightx2v/server/run_server.py
+0
-32
lightx2v/server/schema.py
lightx2v/server/schema.py
+30
-9
lightx2v/server/service.py
lightx2v/server/service.py
+0
-491
lightx2v/server/services/__init__.py
lightx2v/server/services/__init__.py
+11
-0
lightx2v/server/services/distributed_utils.py
lightx2v/server/services/distributed_utils.py
+0
-5
lightx2v/server/services/file_service.py
lightx2v/server/services/file_service.py
+153
-0
lightx2v/server/services/generation/__init__.py
lightx2v/server/services/generation/__init__.py
+9
-0
lightx2v/server/services/generation/base.py
lightx2v/server/services/generation/base.py
+145
-0
lightx2v/server/services/generation/image.py
lightx2v/server/services/generation/image.py
+66
-0
lightx2v/server/services/generation/video.py
lightx2v/server/services/generation/video.py
+22
-0
lightx2v/server/services/inference/__init__.py
lightx2v/server/services/inference/__init__.py
+7
-0
lightx2v/server/services/inference/service.py
lightx2v/server/services/inference/service.py
+81
-0
lightx2v/server/services/inference/worker.py
lightx2v/server/services/inference/worker.py
+109
-0
scripts/server/start_server_i2i.sh
scripts/server/start_server_i2i.sh
+26
-0
scripts/server/start_server_t2i.sh
scripts/server/start_server_t2i.sh
+28
-0
No files found.
lightx2v/server/media/image.py
0 → 100644
View file @
dea872a2
from
typing
import
Dict
from
.base
import
MediaHandler
class
ImageHandler
(
MediaHandler
):
_instance
=
None
def
__new__
(
cls
):
if
cls
.
_instance
is
None
:
cls
.
_instance
=
super
().
__new__
(
cls
)
return
cls
.
_instance
def
get_media_signatures
(
self
)
->
Dict
[
bytes
,
str
]:
return
{
b
"
\x89
PNG
\r\n\x1a\n
"
:
"png"
,
b
"
\xff\xd8\xff
"
:
"jpg"
,
b
"GIF87a"
:
"gif"
,
b
"GIF89a"
:
"gif"
,
}
def
get_data_url_prefix
(
self
)
->
str
:
return
"data:image/"
def
get_data_url_pattern
(
self
)
->
str
:
return
r
"data:image/(\w+);base64,(.+)"
def
get_default_extension
(
self
)
->
str
:
return
"png"
def
is_base64
(
self
,
data
:
str
)
->
bool
:
if
data
.
startswith
(
self
.
get_data_url_prefix
()):
return
True
try
:
import
base64
if
len
(
data
)
%
4
==
0
:
base64
.
b64decode
(
data
,
validate
=
True
)
decoded
=
base64
.
b64decode
(
data
[:
100
])
for
signature
in
self
.
get_media_signatures
().
keys
():
if
decoded
.
startswith
(
signature
):
return
True
if
len
(
decoded
)
>
12
and
decoded
[
8
:
12
]
==
b
"WEBP"
:
return
True
except
Exception
:
return
False
return
False
def
detect_extension
(
self
,
data
:
bytes
)
->
str
:
for
signature
,
ext
in
self
.
get_media_signatures
().
items
():
if
data
.
startswith
(
signature
):
return
ext
if
len
(
data
)
>
12
and
data
[
8
:
12
]
==
b
"WEBP"
:
return
"webp"
return
self
.
get_default_extension
()
_handler
=
ImageHandler
()
def
is_base64_image
(
data
:
str
)
->
bool
:
return
_handler
.
is_base64
(
data
)
def
save_base64_image
(
base64_data
:
str
,
output_dir
:
str
)
->
str
:
return
_handler
.
save_base64
(
base64_data
,
output_dir
)
lightx2v/server/run_server.py
deleted
100644 → 0
View file @
1892a3db
#!/usr/bin/env python
"""Example script to run the LightX2V server."""
import
argparse
import
sys
from
pathlib
import
Path
sys
.
path
.
insert
(
0
,
str
(
Path
(
__file__
).
parent
.
parent
.
parent
))
from
lightx2v.server.main
import
run_server
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Run LightX2V inference server"
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
,
help
=
"Path to model"
)
parser
.
add_argument
(
"--model_cls"
,
type
=
str
,
required
=
True
,
help
=
"Model class name"
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
help
=
"Path to model config JSON file"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"i2v"
,
help
=
"Task type (i2v, etc.)"
)
parser
.
add_argument
(
"--nproc_per_node"
,
type
=
int
,
default
=
1
,
help
=
"Number of processes per node (GPUs to use)"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
,
help
=
"Server port"
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"127.0.0.1"
,
help
=
"Server host"
)
args
=
parser
.
parse_args
()
run_server
(
args
)
if
__name__
==
"__main__"
:
main
()
lightx2v/server/schema.py
View file @
dea872a2
import
random
from
typing
import
Optional
from
pydantic
import
BaseModel
,
Field
...
...
@@ -5,35 +6,55 @@ from pydantic import BaseModel, Field
from
..utils.generate_task_id
import
generate_task_id
def
generate_random_seed
()
->
int
:
return
random
.
randint
(
0
,
2
**
32
-
1
)
class
TalkObject
(
BaseModel
):
audio
:
str
=
Field
(...,
description
=
"Audio path"
)
mask
:
str
=
Field
(...,
description
=
"Mask path"
)
class
TaskRequest
(
BaseModel
):
class
Base
TaskRequest
(
BaseModel
):
task_id
:
str
=
Field
(
default_factory
=
generate_task_id
,
description
=
"Task ID (auto-generated)"
)
prompt
:
str
=
Field
(
""
,
description
=
"Generation prompt"
)
use_prompt_enhancer
:
bool
=
Field
(
False
,
description
=
"Whether to use prompt enhancer"
)
negative_prompt
:
str
=
Field
(
""
,
description
=
"Negative prompt"
)
image_path
:
str
=
Field
(
""
,
description
=
"Base64 encoded image or URL"
)
num_fragments
:
int
=
Field
(
1
,
description
=
"Number of fragments"
)
save_result_path
:
str
=
Field
(
""
,
description
=
"Save video path (optional, defaults to task_id.mp4)"
)
save_result_path
:
str
=
Field
(
""
,
description
=
"Save result path (optional, defaults to task_id, suffix auto-detected)"
)
infer_steps
:
int
=
Field
(
5
,
description
=
"Inference steps"
)
target_video_length
:
int
=
Field
(
81
,
description
=
"Target video length"
)
seed
:
int
=
Field
(
42
,
description
=
"Random seed"
)
audio_path
:
str
=
Field
(
""
,
description
=
"Input audio path (Wan-Audio)"
)
video_duration
:
int
=
Field
(
5
,
description
=
"Video duration (Wan-Audio)"
)
talk_objects
:
Optional
[
list
[
TalkObject
]]
=
Field
(
None
,
description
=
"Talk objects (Wan-Audio)"
)
seed
:
int
=
Field
(
default_factory
=
generate_random_seed
,
description
=
"Random seed (auto-generated if not set)"
)
def
__init__
(
self
,
**
data
):
super
().
__init__
(
**
data
)
if
not
self
.
save_result_path
:
self
.
save_result_path
=
f
"
{
self
.
task_id
}
.mp4
"
self
.
save_result_path
=
f
"
{
self
.
task_id
}
"
def
get
(
self
,
key
,
default
=
None
):
return
getattr
(
self
,
key
,
default
)
class
VideoTaskRequest
(
BaseTaskRequest
):
num_fragments
:
int
=
Field
(
1
,
description
=
"Number of fragments"
)
target_video_length
:
int
=
Field
(
81
,
description
=
"Target video length"
)
audio_path
:
str
=
Field
(
""
,
description
=
"Input audio path (Wan-Audio)"
)
video_duration
:
int
=
Field
(
5
,
description
=
"Video duration (Wan-Audio)"
)
talk_objects
:
Optional
[
list
[
TalkObject
]]
=
Field
(
None
,
description
=
"Talk objects (Wan-Audio)"
)
class
ImageTaskRequest
(
BaseTaskRequest
):
aspect_ratio
:
str
=
Field
(
"16:9"
,
description
=
"Output aspect ratio"
)
class
TaskRequest
(
BaseTaskRequest
):
num_fragments
:
int
=
Field
(
1
,
description
=
"Number of fragments"
)
target_video_length
:
int
=
Field
(
81
,
description
=
"Target video length (video only)"
)
audio_path
:
str
=
Field
(
""
,
description
=
"Input audio path (Wan-Audio)"
)
video_duration
:
int
=
Field
(
5
,
description
=
"Video duration (Wan-Audio)"
)
talk_objects
:
Optional
[
list
[
TalkObject
]]
=
Field
(
None
,
description
=
"Talk objects (Wan-Audio)"
)
aspect_ratio
:
str
=
Field
(
"16:9"
,
description
=
"Output aspect ratio (T2I only)"
)
class
TaskStatusMessage
(
BaseModel
):
task_id
:
str
=
Field
(...,
description
=
"Task ID"
)
...
...
lightx2v/server/service.py
deleted
100644 → 0
View file @
1892a3db
This diff is collapsed.
Click to expand it.
lightx2v/server/services/__init__.py
0 → 100644
View file @
dea872a2
from
.file_service
import
FileService
from
.generation
import
ImageGenerationService
,
VideoGenerationService
from
.inference
import
DistributedInferenceService
,
TorchrunInferenceWorker
__all__
=
[
"FileService"
,
"DistributedInferenceService"
,
"TorchrunInferenceWorker"
,
"VideoGenerationService"
,
"ImageGenerationService"
,
]
lightx2v/server/distributed_utils.py
→
lightx2v/server/
services/
distributed_utils.py
View file @
dea872a2
...
...
@@ -17,20 +17,15 @@ class DistributedManager:
CHUNK_SIZE
=
1024
*
1024
def
init_process_group
(
self
)
->
bool
:
"""Initialize process group using torchrun environment variables"""
try
:
# torchrun sets these environment variables automatically
self
.
rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
0
))
self
.
world_size
=
int
(
os
.
environ
.
get
(
"WORLD_SIZE"
,
1
))
if
self
.
world_size
>
1
:
# torchrun handles backend, init_method, rank, and world_size
# We just need to call init_process_group without parameters
backend
=
"nccl"
if
torch
.
cuda
.
is_available
()
else
"gloo"
dist
.
init_process_group
(
backend
=
backend
,
init_method
=
"env://"
)
logger
.
info
(
f
"Setup backend:
{
backend
}
"
)
# Set CUDA device for this rank
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
set_device
(
self
.
rank
)
self
.
device
=
f
"cuda:
{
self
.
rank
}
"
...
...
lightx2v/server/services/file_service.py
0 → 100644
View file @
dea872a2
import
asyncio
import
uuid
from
pathlib
import
Path
from
typing
import
Optional
from
urllib.parse
import
urlparse
import
httpx
from
loguru
import
logger
class
FileService
:
def
__init__
(
self
,
cache_dir
:
Path
):
self
.
cache_dir
=
cache_dir
self
.
input_image_dir
=
cache_dir
/
"inputs"
/
"imgs"
self
.
input_audio_dir
=
cache_dir
/
"inputs"
/
"audios"
self
.
output_video_dir
=
cache_dir
/
"outputs"
self
.
_http_client
=
None
self
.
_client_lock
=
asyncio
.
Lock
()
self
.
max_retries
=
3
self
.
retry_delay
=
1.0
self
.
max_retry_delay
=
10.0
for
directory
in
[
self
.
input_image_dir
,
self
.
output_video_dir
,
self
.
input_audio_dir
,
]:
directory
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
async
def
_get_http_client
(
self
)
->
httpx
.
AsyncClient
:
async
with
self
.
_client_lock
:
if
self
.
_http_client
is
None
or
self
.
_http_client
.
is_closed
:
timeout
=
httpx
.
Timeout
(
connect
=
10.0
,
read
=
30.0
,
write
=
10.0
,
pool
=
5.0
,
)
limits
=
httpx
.
Limits
(
max_keepalive_connections
=
5
,
max_connections
=
10
,
keepalive_expiry
=
30.0
)
self
.
_http_client
=
httpx
.
AsyncClient
(
verify
=
False
,
timeout
=
timeout
,
limits
=
limits
,
follow_redirects
=
True
)
return
self
.
_http_client
async
def
_download_with_retry
(
self
,
url
:
str
,
max_retries
:
Optional
[
int
]
=
None
)
->
httpx
.
Response
:
if
max_retries
is
None
:
max_retries
=
self
.
max_retries
last_exception
=
None
retry_delay
=
self
.
retry_delay
for
attempt
in
range
(
max_retries
):
try
:
client
=
await
self
.
_get_http_client
()
response
=
await
client
.
get
(
url
)
if
response
.
status_code
==
200
:
return
response
elif
response
.
status_code
>=
500
:
logger
.
warning
(
f
"Server error
{
response
.
status_code
}
for
{
url
}
, attempt
{
attempt
+
1
}
/
{
max_retries
}
"
)
last_exception
=
httpx
.
HTTPStatusError
(
f
"Server returned
{
response
.
status_code
}
"
,
request
=
response
.
request
,
response
=
response
)
else
:
raise
httpx
.
HTTPStatusError
(
f
"Client error
{
response
.
status_code
}
"
,
request
=
response
.
request
,
response
=
response
)
except
(
httpx
.
ConnectError
,
httpx
.
TimeoutException
,
httpx
.
NetworkError
)
as
e
:
logger
.
warning
(
f
"Connection error for
{
url
}
, attempt
{
attempt
+
1
}
/
{
max_retries
}
:
{
str
(
e
)
}
"
)
last_exception
=
e
except
httpx
.
HTTPStatusError
as
e
:
if
e
.
response
and
e
.
response
.
status_code
<
500
:
raise
last_exception
=
e
except
Exception
as
e
:
logger
.
error
(
f
"Unexpected error downloading
{
url
}
:
{
str
(
e
)
}
"
)
last_exception
=
e
if
attempt
<
max_retries
-
1
:
await
asyncio
.
sleep
(
retry_delay
)
retry_delay
=
min
(
retry_delay
*
2
,
self
.
max_retry_delay
)
error_msg
=
f
"All
{
max_retries
}
connection attempts failed for
{
url
}
"
if
last_exception
:
error_msg
+=
f
":
{
str
(
last_exception
)
}
"
raise
httpx
.
ConnectError
(
error_msg
)
async
def
download_media
(
self
,
url
:
str
,
media_type
:
str
=
"image"
)
->
Path
:
try
:
parsed_url
=
urlparse
(
url
)
if
not
parsed_url
.
scheme
or
not
parsed_url
.
netloc
:
raise
ValueError
(
f
"Invalid URL format:
{
url
}
"
)
response
=
await
self
.
_download_with_retry
(
url
)
media_name
=
Path
(
parsed_url
.
path
).
name
if
not
media_name
:
default_ext
=
"jpg"
if
media_type
==
"image"
else
"mp3"
media_name
=
f
"
{
uuid
.
uuid4
()
}
.
{
default_ext
}
"
if
media_type
==
"image"
:
target_dir
=
self
.
input_image_dir
else
:
target_dir
=
self
.
input_audio_dir
media_path
=
target_dir
/
media_name
media_path
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
with
open
(
media_path
,
"wb"
)
as
f
:
f
.
write
(
response
.
content
)
logger
.
info
(
f
"Successfully downloaded
{
media_type
}
from
{
url
}
to
{
media_path
}
"
)
return
media_path
except
httpx
.
ConnectError
as
e
:
logger
.
error
(
f
"Connection error downloading
{
media_type
}
from
{
url
}
:
{
str
(
e
)
}
"
)
raise
ValueError
(
f
"Failed to connect to
{
url
}
:
{
str
(
e
)
}
"
)
except
httpx
.
TimeoutException
as
e
:
logger
.
error
(
f
"Timeout downloading
{
media_type
}
from
{
url
}
:
{
str
(
e
)
}
"
)
raise
ValueError
(
f
"Download timeout for
{
url
}
:
{
str
(
e
)
}
"
)
except
httpx
.
HTTPStatusError
as
e
:
logger
.
error
(
f
"HTTP error downloading
{
media_type
}
from
{
url
}
:
{
str
(
e
)
}
"
)
raise
ValueError
(
f
"HTTP error for
{
url
}
:
{
str
(
e
)
}
"
)
except
ValueError
:
raise
except
Exception
as
e
:
logger
.
error
(
f
"Unexpected error downloading
{
media_type
}
from
{
url
}
:
{
str
(
e
)
}
"
)
raise
ValueError
(
f
"Failed to download
{
media_type
}
from
{
url
}
:
{
str
(
e
)
}
"
)
async
def
download_image
(
self
,
image_url
:
str
)
->
Path
:
return
await
self
.
download_media
(
image_url
,
"image"
)
async
def
download_audio
(
self
,
audio_url
:
str
)
->
Path
:
return
await
self
.
download_media
(
audio_url
,
"audio"
)
def
save_uploaded_file
(
self
,
file_content
:
bytes
,
filename
:
str
)
->
Path
:
file_extension
=
Path
(
filename
).
suffix
unique_filename
=
f
"
{
uuid
.
uuid4
()
}{
file_extension
}
"
file_path
=
self
.
input_image_dir
/
unique_filename
with
open
(
file_path
,
"wb"
)
as
f
:
f
.
write
(
file_content
)
return
file_path
def
get_output_path
(
self
,
save_result_path
:
str
)
->
Path
:
video_path
=
Path
(
save_result_path
)
if
not
video_path
.
is_absolute
():
return
self
.
output_video_dir
/
save_result_path
return
video_path
async
def
cleanup
(
self
):
async
with
self
.
_client_lock
:
if
self
.
_http_client
and
not
self
.
_http_client
.
is_closed
:
await
self
.
_http_client
.
aclose
()
self
.
_http_client
=
None
lightx2v/server/services/generation/__init__.py
0 → 100644
View file @
dea872a2
from
.base
import
BaseGenerationService
from
.image
import
ImageGenerationService
from
.video
import
VideoGenerationService
__all__
=
[
"BaseGenerationService"
,
"VideoGenerationService"
,
"ImageGenerationService"
,
]
lightx2v/server/services/generation/base.py
0 → 100644
View file @
dea872a2
import
json
import
uuid
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Dict
,
Optional
from
loguru
import
logger
from
...media
import
is_base64_audio
,
is_base64_image
,
save_base64_audio
,
save_base64_image
from
...schema
import
TaskResponse
from
..file_service
import
FileService
from
..inference
import
DistributedInferenceService
class
BaseGenerationService
(
ABC
):
def
__init__
(
self
,
file_service
:
FileService
,
inference_service
:
DistributedInferenceService
):
self
.
file_service
=
file_service
self
.
inference_service
=
inference_service
@
abstractmethod
def
get_output_extension
(
self
)
->
str
:
pass
@
abstractmethod
def
get_task_type
(
self
)
->
str
:
pass
def
_is_target_task_type
(
self
)
->
bool
:
if
self
.
inference_service
.
worker
and
self
.
inference_service
.
worker
.
runner
:
task_type
=
self
.
inference_service
.
worker
.
runner
.
config
.
get
(
"task"
,
"t2v"
)
return
task_type
in
self
.
get_task_type
().
split
(
","
)
return
False
async
def
_process_image_path
(
self
,
image_path
:
str
,
task_data
:
Dict
[
str
,
Any
])
->
None
:
if
not
image_path
:
return
if
image_path
.
startswith
(
"http"
):
downloaded_path
=
await
self
.
file_service
.
download_image
(
image_path
)
task_data
[
"image_path"
]
=
str
(
downloaded_path
)
elif
is_base64_image
(
image_path
):
saved_path
=
save_base64_image
(
image_path
,
str
(
self
.
file_service
.
input_image_dir
))
task_data
[
"image_path"
]
=
str
(
saved_path
)
else
:
task_data
[
"image_path"
]
=
image_path
async
def
_process_audio_path
(
self
,
audio_path
:
str
,
task_data
:
Dict
[
str
,
Any
])
->
None
:
if
not
audio_path
:
return
if
audio_path
.
startswith
(
"http"
):
downloaded_path
=
await
self
.
file_service
.
download_audio
(
audio_path
)
task_data
[
"audio_path"
]
=
str
(
downloaded_path
)
elif
is_base64_audio
(
audio_path
):
saved_path
=
save_base64_audio
(
audio_path
,
str
(
self
.
file_service
.
input_audio_dir
))
task_data
[
"audio_path"
]
=
str
(
saved_path
)
else
:
task_data
[
"audio_path"
]
=
audio_path
async
def
_process_talk_objects
(
self
,
talk_objects
:
list
,
task_data
:
Dict
[
str
,
Any
])
->
None
:
if
not
talk_objects
:
return
task_data
[
"talk_objects"
]
=
[{}
for
_
in
range
(
len
(
talk_objects
))]
for
index
,
talk_object
in
enumerate
(
talk_objects
):
if
talk_object
.
audio
.
startswith
(
"http"
):
audio_path
=
await
self
.
file_service
.
download_audio
(
talk_object
.
audio
)
task_data
[
"talk_objects"
][
index
][
"audio"
]
=
str
(
audio_path
)
elif
is_base64_audio
(
talk_object
.
audio
):
audio_path
=
save_base64_audio
(
talk_object
.
audio
,
str
(
self
.
file_service
.
input_audio_dir
))
task_data
[
"talk_objects"
][
index
][
"audio"
]
=
str
(
audio_path
)
else
:
task_data
[
"talk_objects"
][
index
][
"audio"
]
=
talk_object
.
audio
if
talk_object
.
mask
.
startswith
(
"http"
):
mask_path
=
await
self
.
file_service
.
download_image
(
talk_object
.
mask
)
task_data
[
"talk_objects"
][
index
][
"mask"
]
=
str
(
mask_path
)
elif
is_base64_image
(
talk_object
.
mask
):
mask_path
=
save_base64_image
(
talk_object
.
mask
,
str
(
self
.
file_service
.
input_image_dir
))
task_data
[
"talk_objects"
][
index
][
"mask"
]
=
str
(
mask_path
)
else
:
task_data
[
"talk_objects"
][
index
][
"mask"
]
=
talk_object
.
mask
temp_path
=
self
.
file_service
.
cache_dir
/
uuid
.
uuid4
().
hex
[:
8
]
temp_path
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
task_data
[
"audio_path"
]
=
str
(
temp_path
)
config_path
=
temp_path
/
"config.json"
with
open
(
config_path
,
"w"
)
as
f
:
json
.
dump
({
"talk_objects"
:
task_data
[
"talk_objects"
]},
f
)
def
_prepare_output_path
(
self
,
save_result_path
:
str
,
task_data
:
Dict
[
str
,
Any
])
->
None
:
actual_save_path
=
self
.
file_service
.
get_output_path
(
save_result_path
)
if
not
actual_save_path
.
suffix
:
actual_save_path
=
actual_save_path
.
with_suffix
(
self
.
get_output_extension
())
task_data
[
"save_result_path"
]
=
str
(
actual_save_path
)
task_data
[
"video_path"
]
=
actual_save_path
.
name
async
def
generate_with_stop_event
(
self
,
message
:
Any
,
stop_event
)
->
Optional
[
Any
]:
try
:
task_data
=
{
field
:
getattr
(
message
,
field
)
for
field
in
message
.
model_fields_set
if
field
!=
"task_id"
}
task_data
[
"task_id"
]
=
message
.
task_id
if
stop_event
.
is_set
():
logger
.
info
(
f
"Task
{
message
.
task_id
}
cancelled before processing"
)
return
None
if
hasattr
(
message
,
"image_path"
)
and
message
.
image_path
:
await
self
.
_process_image_path
(
message
.
image_path
,
task_data
)
logger
.
info
(
f
"Task
{
message
.
task_id
}
image path:
{
task_data
.
get
(
'image_path'
)
}
"
)
if
hasattr
(
message
,
"audio_path"
)
and
message
.
audio_path
:
await
self
.
_process_audio_path
(
message
.
audio_path
,
task_data
)
logger
.
info
(
f
"Task
{
message
.
task_id
}
audio path:
{
task_data
.
get
(
'audio_path'
)
}
"
)
if
hasattr
(
message
,
"talk_objects"
)
and
message
.
talk_objects
:
await
self
.
_process_talk_objects
(
message
.
talk_objects
,
task_data
)
self
.
_prepare_output_path
(
message
.
save_result_path
,
task_data
)
task_data
[
"seed"
]
=
message
.
seed
result
=
await
self
.
inference_service
.
submit_task_async
(
task_data
)
if
result
is
None
:
if
stop_event
.
is_set
():
logger
.
info
(
f
"Task
{
message
.
task_id
}
cancelled during processing"
)
return
None
raise
RuntimeError
(
"Task processing failed"
)
if
result
.
get
(
"status"
)
==
"success"
:
actual_save_path
=
self
.
file_service
.
get_output_path
(
message
.
save_result_path
)
if
not
actual_save_path
.
suffix
:
actual_save_path
=
actual_save_path
.
with_suffix
(
self
.
get_output_extension
())
return
TaskResponse
(
task_id
=
message
.
task_id
,
task_status
=
"completed"
,
save_result_path
=
actual_save_path
.
name
,
)
else
:
error_msg
=
result
.
get
(
"error"
,
"Inference failed"
)
raise
RuntimeError
(
error_msg
)
except
Exception
as
e
:
logger
.
exception
(
f
"Task
{
message
.
task_id
}
processing failed:
{
str
(
e
)
}
"
)
raise
lightx2v/server/services/generation/image.py
0 → 100644
View file @
dea872a2
from
typing
import
Any
,
Optional
from
loguru
import
logger
from
...schema
import
TaskResponse
from
..file_service
import
FileService
from
..inference
import
DistributedInferenceService
from
.base
import
BaseGenerationService
class
ImageGenerationService
(
BaseGenerationService
):
def
__init__
(
self
,
file_service
:
FileService
,
inference_service
:
DistributedInferenceService
):
super
().
__init__
(
file_service
,
inference_service
)
def
get_output_extension
(
self
)
->
str
:
return
".png"
def
get_task_type
(
self
)
->
str
:
return
"t2i,i2i"
async
def
generate_with_stop_event
(
self
,
message
:
Any
,
stop_event
)
->
Optional
[
Any
]:
try
:
task_data
=
{
field
:
getattr
(
message
,
field
)
for
field
in
message
.
model_fields_set
if
field
!=
"task_id"
}
task_data
[
"task_id"
]
=
message
.
task_id
if
hasattr
(
message
,
"aspect_ratio"
):
task_data
[
"aspect_ratio"
]
=
message
.
aspect_ratio
if
stop_event
.
is_set
():
logger
.
info
(
f
"Task
{
message
.
task_id
}
cancelled before processing"
)
return
None
if
hasattr
(
message
,
"image_path"
)
and
message
.
image_path
:
await
self
.
_process_image_path
(
message
.
image_path
,
task_data
)
logger
.
info
(
f
"Task
{
message
.
task_id
}
image path:
{
task_data
.
get
(
'image_path'
)
}
"
)
self
.
_prepare_output_path
(
message
.
save_result_path
,
task_data
)
task_data
[
"seed"
]
=
message
.
seed
result
=
await
self
.
inference_service
.
submit_task_async
(
task_data
)
if
result
is
None
:
if
stop_event
.
is_set
():
logger
.
info
(
f
"Task
{
message
.
task_id
}
cancelled during processing"
)
return
None
raise
RuntimeError
(
"Task processing failed"
)
if
result
.
get
(
"status"
)
==
"success"
:
actual_save_path
=
self
.
file_service
.
get_output_path
(
message
.
save_result_path
)
if
not
actual_save_path
.
suffix
:
actual_save_path
=
actual_save_path
.
with_suffix
(
self
.
get_output_extension
())
return
TaskResponse
(
task_id
=
message
.
task_id
,
task_status
=
"completed"
,
save_result_path
=
actual_save_path
.
name
,
)
else
:
error_msg
=
result
.
get
(
"error"
,
"Inference failed"
)
raise
RuntimeError
(
error_msg
)
except
Exception
as
e
:
logger
.
exception
(
f
"Task
{
message
.
task_id
}
processing failed:
{
str
(
e
)
}
"
)
raise
async
def
generate_image_with_stop_event
(
self
,
message
:
Any
,
stop_event
)
->
Optional
[
Any
]:
return
await
self
.
generate_with_stop_event
(
message
,
stop_event
)
lightx2v/server/services/generation/video.py
0 → 100644
View file @
dea872a2
from
typing
import
Any
,
Optional
from
..file_service
import
FileService
from
..inference
import
DistributedInferenceService
from
.base
import
BaseGenerationService
class
VideoGenerationService
(
BaseGenerationService
):
def
__init__
(
self
,
file_service
:
FileService
,
inference_service
:
DistributedInferenceService
):
super
().
__init__
(
file_service
,
inference_service
)
def
get_output_extension
(
self
)
->
str
:
return
".mp4"
def
get_task_type
(
self
)
->
str
:
return
"t2v,i2v,s2v"
async
def
generate_with_stop_event
(
self
,
message
:
Any
,
stop_event
)
->
Optional
[
Any
]:
return
await
super
().
generate_with_stop_event
(
message
,
stop_event
)
async
def
generate_video_with_stop_event
(
self
,
message
:
Any
,
stop_event
)
->
Optional
[
Any
]:
return
await
self
.
generate_with_stop_event
(
message
,
stop_event
)
lightx2v/server/services/inference/__init__.py
0 → 100644
View file @
dea872a2
from
.service
import
DistributedInferenceService
from
.worker
import
TorchrunInferenceWorker
__all__
=
[
"TorchrunInferenceWorker"
,
"DistributedInferenceService"
,
]
lightx2v/server/services/inference/service.py
0 → 100644
View file @
dea872a2
from
typing
import
Optional
from
loguru
import
logger
from
.worker
import
TorchrunInferenceWorker
class
DistributedInferenceService
:
def
__init__
(
self
):
self
.
worker
=
None
self
.
is_running
=
False
self
.
args
=
None
def
start_distributed_inference
(
self
,
args
)
->
bool
:
self
.
args
=
args
if
self
.
is_running
:
logger
.
warning
(
"Distributed inference service is already running"
)
return
True
try
:
self
.
worker
=
TorchrunInferenceWorker
()
if
not
self
.
worker
.
init
(
args
):
raise
RuntimeError
(
"Worker initialization failed"
)
self
.
is_running
=
True
logger
.
info
(
f
"Rank
{
self
.
worker
.
rank
}
inference service started successfully"
)
return
True
except
Exception
as
e
:
logger
.
error
(
f
"Error starting inference service:
{
str
(
e
)
}
"
)
self
.
stop_distributed_inference
()
return
False
def
stop_distributed_inference
(
self
):
if
not
self
.
is_running
:
return
try
:
if
self
.
worker
:
self
.
worker
.
cleanup
()
logger
.
info
(
"Inference service stopped"
)
except
Exception
as
e
:
logger
.
error
(
f
"Error stopping inference service:
{
str
(
e
)
}
"
)
finally
:
self
.
worker
=
None
self
.
is_running
=
False
async
def
submit_task_async
(
self
,
task_data
:
dict
)
->
Optional
[
dict
]:
if
not
self
.
is_running
or
not
self
.
worker
:
logger
.
error
(
"Inference service is not started"
)
return
None
if
self
.
worker
.
rank
!=
0
:
return
None
try
:
if
self
.
worker
.
processing
:
logger
.
info
(
f
"Waiting for previous task to complete before processing task
{
task_data
.
get
(
'task_id'
)
}
"
)
self
.
worker
.
processing
=
True
result
=
await
self
.
worker
.
process_request
(
task_data
)
self
.
worker
.
processing
=
False
return
result
except
Exception
as
e
:
self
.
worker
.
processing
=
False
logger
.
error
(
f
"Failed to process task:
{
str
(
e
)
}
"
)
return
{
"task_id"
:
task_data
.
get
(
"task_id"
,
"unknown"
),
"status"
:
"failed"
,
"error"
:
str
(
e
),
"message"
:
f
"Task processing failed:
{
str
(
e
)
}
"
,
}
def
server_metadata
(
self
):
assert
hasattr
(
self
,
"args"
),
"Distributed inference service has not been started. Call start_distributed_inference() first."
return
{
"nproc_per_node"
:
self
.
worker
.
world_size
,
"model_cls"
:
self
.
args
.
model_cls
,
"model_path"
:
self
.
args
.
model_path
}
async
def
run_worker_loop
(
self
):
if
self
.
worker
and
self
.
worker
.
rank
!=
0
:
await
self
.
worker
.
worker_loop
()
lightx2v/server/services/inference/worker.py
0 → 100644
View file @
dea872a2
import
asyncio
import
json
import
os
from
typing
import
Any
,
Dict
import
torch
from
easydict
import
EasyDict
from
loguru
import
logger
from
lightx2v.infer
import
init_runner
from
lightx2v.utils.input_info
import
set_input_info
from
lightx2v.utils.set_config
import
set_config
from
..distributed_utils
import
DistributedManager
class
TorchrunInferenceWorker
:
def
__init__
(
self
):
self
.
rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
0
))
self
.
world_size
=
int
(
os
.
environ
.
get
(
"WORLD_SIZE"
,
1
))
self
.
runner
=
None
self
.
dist_manager
=
DistributedManager
()
self
.
processing
=
False
def
init
(
self
,
args
)
->
bool
:
try
:
if
self
.
world_size
>
1
:
if
not
self
.
dist_manager
.
init_process_group
():
raise
RuntimeError
(
"Failed to initialize distributed process group"
)
else
:
self
.
dist_manager
.
rank
=
0
self
.
dist_manager
.
world_size
=
1
self
.
dist_manager
.
device
=
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
self
.
dist_manager
.
is_initialized
=
False
config
=
set_config
(
args
)
if
self
.
rank
==
0
:
logger
.
info
(
f
"Config:
\n
{
json
.
dumps
(
config
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
self
.
runner
=
init_runner
(
config
)
logger
.
info
(
f
"Rank
{
self
.
rank
}
/
{
self
.
world_size
-
1
}
initialization completed"
)
return
True
except
Exception
as
e
:
logger
.
exception
(
f
"Rank
{
self
.
rank
}
initialization failed:
{
str
(
e
)
}
"
)
return
False
async
def
process_request
(
self
,
task_data
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
try
:
if
self
.
world_size
>
1
and
self
.
rank
==
0
:
task_data
=
self
.
dist_manager
.
broadcast_task_data
(
task_data
)
task_data
[
"task"
]
=
self
.
runner
.
config
[
"task"
]
task_data
[
"return_result_tensor"
]
=
False
task_data
[
"negative_prompt"
]
=
task_data
.
get
(
"negative_prompt"
,
""
)
task_data
=
EasyDict
(
task_data
)
input_info
=
set_input_info
(
task_data
)
self
.
runner
.
set_config
(
task_data
)
self
.
runner
.
run_pipeline
(
input_info
)
await
asyncio
.
sleep
(
0
)
if
self
.
world_size
>
1
:
self
.
dist_manager
.
barrier
()
if
self
.
rank
==
0
:
return
{
"task_id"
:
task_data
[
"task_id"
],
"status"
:
"success"
,
"save_result_path"
:
task_data
.
get
(
"video_path"
,
task_data
[
"save_result_path"
]),
"message"
:
"Inference completed"
,
}
else
:
return
None
except
Exception
as
e
:
logger
.
exception
(
f
"Rank
{
self
.
rank
}
inference failed:
{
str
(
e
)
}
"
)
if
self
.
world_size
>
1
:
self
.
dist_manager
.
barrier
()
if
self
.
rank
==
0
:
return
{
"task_id"
:
task_data
.
get
(
"task_id"
,
"unknown"
),
"status"
:
"failed"
,
"error"
:
str
(
e
),
"message"
:
f
"Inference failed:
{
str
(
e
)
}
"
,
}
else
:
return
None
async
def
worker_loop
(
self
):
while
True
:
try
:
task_data
=
self
.
dist_manager
.
broadcast_task_data
()
if
task_data
is
None
:
logger
.
info
(
f
"Rank
{
self
.
rank
}
received stop signal"
)
break
await
self
.
process_request
(
task_data
)
except
Exception
as
e
:
logger
.
error
(
f
"Rank
{
self
.
rank
}
worker loop error:
{
str
(
e
)
}
"
)
continue
def
cleanup
(
self
):
self
.
dist_manager
.
cleanup
()
scripts/server/start_server_i2i.sh
0 → 100755
View file @
dea872a2
#!/bin/bash
# set path and first
lightx2v_path
=
model_path
=
export
CUDA_VISIBLE_DEVICES
=
0
# set environment variables
source
${
lightx2v_path
}
/scripts/base/base.sh
# Start API server with distributed inference service
python
-m
lightx2v.server
\
--model_cls
qwen_image
\
--task
i2i
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/qwen_image/qwen_image_i2i.json
\
--port
8000
echo
"Service stopped"
# {
# "prompt": "turn the style of the photo to vintage comic book",
# "image_path": "assets/inputs/imgs/snake.png",
# "infer_steps": 50
# }
scripts/server/start_server_t2i.sh
0 → 100755
View file @
dea872a2
#!/bin/bash
# set path and first
lightx2v_path
=
model_path
=
export
CUDA_VISIBLE_DEVICES
=
0
# set environment variables
source
${
lightx2v_path
}
/scripts/base/base.sh
# Start API server with distributed inference service
python
-m
lightx2v.server
\
--model_cls
qwen_image
\
--task
t2i
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/qwen_image/qwen_image_t2i.json
\
--port
8000
echo
"Service stopped"
# {
# "prompt": "a beautiful sunset over the ocean",
# "aspect_ratio": "16:9",
# "infer_steps": 50
# }
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment