Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
3c778aee
Commit
3c778aee
authored
Sep 16, 2025
by
PengGao
Committed by
GitHub
Sep 16, 2025
Browse files
Gp/dev (#310)
Co-authored-by:
Yang Yong(雍洋)
<
yongyang1030@163.com
>
parent
32fd1c52
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
368 additions
and
713 deletions
+368
-713
lightx2v/api_multi_servers.py
lightx2v/api_multi_servers.py
+0
-172
lightx2v/server/README.md
lightx2v/server/README.md
+114
-86
lightx2v/server/__main__.py
lightx2v/server/__main__.py
+34
-0
lightx2v/server/api.py
lightx2v/server/api.py
+7
-6
lightx2v/server/config.py
lightx2v/server/config.py
+1
-22
lightx2v/server/distributed_utils.py
lightx2v/server/distributed_utils.py
+23
-46
lightx2v/server/gpu_manager.py
lightx2v/server/gpu_manager.py
+0
-116
lightx2v/server/main.py
lightx2v/server/main.py
+26
-11
lightx2v/server/service.py
lightx2v/server/service.py
+149
-239
scripts/server/start_multi_servers.sh
scripts/server/start_multi_servers.sh
+8
-8
scripts/server/start_server.sh
scripts/server/start_server.sh
+6
-7
No files found.
lightx2v/api_multi_servers.py
deleted
100644 → 0
View file @
32fd1c52
import
argparse
import
concurrent.futures
import
os
import
socket
import
subprocess
import
time
from
dataclasses
import
dataclass
from
typing
import
Optional
import
requests
from
loguru
import
logger
@
dataclass
class
ServerConfig
:
port
:
int
gpu_id
:
int
model_cls
:
str
task
:
str
model_path
:
str
config_json
:
str
def
get_node_ip
()
->
str
:
"""Get the IP address of the current node"""
try
:
# Create a UDP socket
s
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_DGRAM
)
# Connect to an external address (no actual connection needed)
s
.
connect
((
"8.8.8.8"
,
80
))
# Get local IP
ip
=
s
.
getsockname
()[
0
]
s
.
close
()
return
ip
except
Exception
as
e
:
logger
.
error
(
f
"Failed to get IP address:
{
e
}
"
)
return
"localhost"
def
is_port_in_use
(
port
:
int
)
->
bool
:
"""Check if a port is in use"""
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
return
s
.
connect_ex
((
"localhost"
,
port
))
==
0
def
find_available_port
(
start_port
:
int
)
->
Optional
[
int
]:
"""Find an available port starting from start_port"""
port
=
start_port
while
port
<
start_port
+
1000
:
# Try up to 1000 ports
if
not
is_port_in_use
(
port
):
return
port
port
+=
1
return
None
def
start_server
(
config
:
ServerConfig
)
->
Optional
[
tuple
[
subprocess
.
Popen
,
str
]]:
"""Start a single server instance"""
try
:
# Set GPU
env
=
os
.
environ
.
copy
()
env
[
"CUDA_VISIBLE_DEVICES"
]
=
str
(
config
.
gpu_id
)
# Start server
process
=
subprocess
.
Popen
(
[
"python"
,
"-m"
,
"lightx2v.api_server"
,
"--model_cls"
,
config
.
model_cls
,
"--task"
,
config
.
task
,
"--model_path"
,
config
.
model_path
,
"--config_json"
,
config
.
config_json
,
"--port"
,
str
(
config
.
port
),
],
env
=
env
,
)
# Wait for server to start, up to 600 seconds
node_ip
=
get_node_ip
()
service_url
=
f
"http://
{
node_ip
}
:
{
config
.
port
}
/v1/service/status"
# Check once per second, up to 600 times
for
_
in
range
(
600
):
try
:
response
=
requests
.
get
(
service_url
,
timeout
=
1
)
if
response
.
status_code
==
200
:
return
process
,
f
"http://
{
node_ip
}
:
{
config
.
port
}
"
except
(
requests
.
RequestException
,
ConnectionError
)
as
e
:
pass
time
.
sleep
(
1
)
# If timeout, terminate the process
logger
.
error
(
f
"Server startup timeout: port=
{
config
.
port
}
, gpu=
{
config
.
gpu_id
}
"
)
process
.
terminate
()
return
None
except
Exception
as
e
:
logger
.
error
(
f
"Failed to start server:
{
e
}
"
)
return
None
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--num_gpus"
,
type
=
int
,
required
=
True
,
help
=
"Number of GPUs to use"
)
parser
.
add_argument
(
"--start_port"
,
type
=
int
,
required
=
True
,
help
=
"Starting port number"
)
parser
.
add_argument
(
"--model_cls"
,
type
=
str
,
required
=
True
,
help
=
"Model class"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
required
=
True
,
help
=
"Task type"
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
,
help
=
"Model path"
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
,
help
=
"Config file path"
)
args
=
parser
.
parse_args
()
# Prepare configurations for all servers on this node
server_configs
=
[]
current_port
=
args
.
start_port
# Create configs for each GPU on this node
for
gpu
in
range
(
args
.
num_gpus
):
port
=
find_available_port
(
current_port
)
if
port
is
None
:
logger
.
error
(
f
"Cannot find available port starting from
{
current_port
}
"
)
continue
config
=
ServerConfig
(
port
=
port
,
gpu_id
=
gpu
,
model_cls
=
args
.
model_cls
,
task
=
args
.
task
,
model_path
=
args
.
model_path
,
config_json
=
args
.
config_json
)
server_configs
.
append
(
config
)
current_port
=
port
+
1
# Start all servers in parallel
processes
=
[]
urls
=
[]
with
concurrent
.
futures
.
ThreadPoolExecutor
(
max_workers
=
len
(
server_configs
))
as
executor
:
future_to_config
=
{
executor
.
submit
(
start_server
,
config
):
config
for
config
in
server_configs
}
for
future
in
concurrent
.
futures
.
as_completed
(
future_to_config
):
config
=
future_to_config
[
future
]
try
:
result
=
future
.
result
()
if
result
:
process
,
url
=
result
processes
.
append
(
process
)
urls
.
append
(
url
)
logger
.
info
(
f
"Server started successfully:
{
url
}
(GPU:
{
config
.
gpu_id
}
)"
)
else
:
logger
.
error
(
f
"Failed to start server: port=
{
config
.
port
}
, gpu=
{
config
.
gpu_id
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Error occurred while starting server:
{
e
}
"
)
# Print all server URLs
logger
.
info
(
"
\n
All server URLs:"
)
for
url
in
urls
:
logger
.
info
(
url
)
# Print node information
node_ip
=
get_node_ip
()
logger
.
info
(
f
"
\n
Current node IP:
{
node_ip
}
"
)
logger
.
info
(
f
"Number of servers started:
{
len
(
urls
)
}
"
)
try
:
# Wait for all processes
for
process
in
processes
:
process
.
wait
()
except
KeyboardInterrupt
:
logger
.
info
(
"Received interrupt signal, shutting down all servers..."
)
for
process
in
processes
:
process
.
terminate
()
if
__name__
==
"__main__"
:
main
()
lightx2v/server/README.md
View file @
3c778aee
...
...
@@ -9,85 +9,91 @@ The LightX2V server is a distributed video generation service built with FastAPI
### System Architecture
```
mermaid
graph TB
subgraph "Client Layer"
Client[HTTP Client]
end
flowchart TB
Client[Client] -->|Send API Request| Router[FastAPI Router]
subgraph "API Layer"
FastAPI[FastAPI Application]
ApiServer[ApiServer]
Router1[Tasks Router<br/>/v1/tasks]
Router2[Files Router<br/>/v1/files]
Router3[Service Router<br/>/v1/service]
end
subgraph API Layer
Router --> TaskRoutes[Task APIs]
Router --> FileRoutes[File APIs]
Router --> ServiceRoutes[Service Status APIs]
subgraph "Service Layer"
TaskManager[TaskManager<br/>Thread-safe Task Queue]
FileService[FileService<br/>File I/O & Downloads]
VideoService[VideoGenerationService]
end
subgraph "Processing Layer"
Thread[Processing Thread<br/>Sequential Task Loop]
end
TaskRoutes --> CreateTask["POST /v1/tasks/ - Create Task"]
TaskRoutes --> CreateTaskForm["POST /v1/tasks/form - Form Create"]
TaskRoutes --> ListTasks["GET /v1/tasks/ - List Tasks"]
TaskRoutes --> GetTaskStatus["GET /v1/tasks/id/status - Get Status"]
TaskRoutes --> GetTaskResult["GET /v1/tasks/id/result - Get Result"]
TaskRoutes --> StopTask["DELETE /v1/tasks/id - Stop Task"]
subgraph "Distributed Inference Layer"
DistService[DistributedInferenceService]
SharedData[(Shared Data<br/>mp.Manager.dict)]
TaskEvent[Task Event<br/>mp.Manager.Event]
ResultEvent[Result Event<br/>mp.Manager.Event]
FileRoutes --> DownloadFile["GET /v1/files/download/path - Download File"]
subgraph "Worker Processes"
W0[Worker 0<br/>Master/Rank 0]
W1[Worker 1<br/>Rank 1]
WN[Worker N<br/>Rank N]
end
ServiceRoutes --> GetServiceStatus["GET /v1/service/status - Service Status"]
ServiceRoutes --> GetServiceMetadata["GET /v1/service/metadata - Metadata"]
end
subgraph "Resource Management"
GPUManager[GPUManager<br/>GPU Detection & Allocation]
DistManager[DistributedManager<br/>PyTorch Distributed]
Config[ServerConfig<br/>Configuration]
subgraph Task Management
TaskManager[Task Manager]
TaskQueue[Task Queue]
TaskStatus[Task Status]
TaskResult[Task Result]
CreateTask --> TaskManager
CreateTaskForm --> TaskManager
TaskManager --> TaskQueue
TaskManager --> TaskStatus
TaskManager --> TaskResult
end
Client -->|HTTP Request| FastAPI
FastAPI --> ApiServer
ApiServer --> Router1
ApiServer --> Router2
ApiServer --> Router3
Router1 -->|Create/Manage Tasks| TaskManager
Router1 -->|Process Tasks| Thread
Router2 -->|File Operations| FileService
Router3 -->|Service Status| TaskManager
Thread -->|Get Pending Tasks| TaskManager
Thread -->|Generate Video| VideoService
subgraph File Service
FileService[File Service]
DownloadImage[Download Image]
DownloadAudio[Download Audio]
SaveFile[Save File]
GetOutputPath[Get Output Path]
FileService --> DownloadImage
FileService --> DownloadAudio
FileService --> SaveFile
FileService --> GetOutputPath
end
VideoService -->|Download Images| FileService
VideoService -->|Submit Task| DistService
subgraph Processing Thread
ProcessingThread[Processing Thread]
NextTask[Get Next Task]
ProcessTask[Process Single Task]
DistService -->|Update| SharedData
DistService -->|Signal| TaskEvent
TaskEvent -->|Notify| W0
W0 -->|Broadcast| W1
W0 -->|Broadcast| WN
ProcessingThread --> NextTask
ProcessingThread --> ProcessTask
end
W0 -->|Update Result| SharedData
W0 -->|Signal| ResultEvent
ResultEvent -->|Notify| DistService
subgraph Video Generation Service
VideoService[Video Service]
GenerateVideo[Generate Video]
W0 -.->|Uses| GPUManager
W1 -.->|Uses| GPUManager
WN -.->|Uses| GPUManager
VideoService --> GenerateVideo
end
W0 -.->|Setup| DistManager
W1 -.->|Setup| DistManager
WN -.->|Setup| DistManager
subgraph Distributed Inference Service
InferenceService[Distributed Inference Service]
SubmitTask[Submit Task]
Worker[Inference Worker Node]
ProcessRequest[Process Request]
RunPipeline[Run Inference Pipeline]
InferenceService --> SubmitTask
SubmitTask --> Worker
Worker --> ProcessRequest
ProcessRequest --> RunPipeline
end
DistService -.->|Reads| Config
ApiServer -.->|Reads| Config
%% ====== Connect Modules ======
TaskQueue --> ProcessingThread
ProcessTask --> VideoService
GenerateVideo --> InferenceService
GetTaskResult --> FileService
DownloadFile --> FileService
VideoService --> FileService
InferenceService --> TaskManager
TaskManager --> TaskStatus
```
## Task Processing Flow
...
...
@@ -100,9 +106,9 @@ sequenceDiagram
participant PT as Processing Thread
participant VS as VideoService
participant FS as FileService
participant DIS as Distributed
<br/>
Inference
Service
participant W0 as Worker
0
<br/>(
Master
)
participant W1 as
Worker
1..N
participant DIS as DistributedInferenceService
participant
TI
W0 as
TorchrunInference
Worker<br/>(
Rank 0
)
participant
TI
W1 as
TorchrunInferenceWorker<br/>(Rank
1..N
)
C->>API: POST /v1/tasks<br/>(Create Task)
API->>TM: create_task()
...
...
@@ -127,32 +133,54 @@ sequenceDiagram
else Image is Base64
VS->>FS: save_base64_image()
FS-->>VS: image_path
else Image is Upload
VS->>FS: validate_file()
FS-->>VS: image_path
else Image is local path
VS->>VS: use existing path
end
alt Audio is URL
VS->>FS: download_audio()
FS->>FS: HTTP download<br/>with retry
FS-->>VS: audio_path
else Audio is Base64
VS->>FS: save_base64_audio()
FS-->>VS: audio_path
else Audio is local path
VS->>VS: use existing path
end
VS->>DIS: submit_task(task_data)
DIS->>DIS: shared_data["current_task"] = task_data
DIS->>DIS: task_event.set()
VS->>DIS: submit_task_async(task_data)
DIS->>TIW0: process_request(task_data)
Note over TIW0,TIW1: Torchrun-based Distributed Processing
TIW0->>TIW0: Check if processing
TIW0->>TIW0: Set processing = True
Note over W0,W1: Distributed Processing
W0->>W0: task_event.wait()
W0->>W0: Get task from shared_data
W0->>W1: broadcast_task_data()
alt Multi-GPU Mode (world_size > 1)
TIW0->>TIW1: broadcast_task_data()<br/>(via DistributedManager)
Note over TIW1: worker_loop() listens for broadcasts
TIW1->>TIW1: Receive task_data
end
par Parallel Inference
W0->>W0: run_pipeline()
par Parallel Inference across all ranks
TIW0->>TIW0: runner.set_inputs(task_data)
TIW0->>TIW0: runner.run_pipeline()
and
W1->>W1: run_pipeline()
Note over TIW1: If world_size > 1
TIW1->>TIW1: runner.set_inputs(task_data)
TIW1->>TIW1: runner.run_pipeline()
end
Note over TIW0,TIW1: Synchronization
alt Multi-GPU Mode
TIW0->>TIW1: barrier() for sync
TIW1->>TIW0: barrier() response
end
W0->>W0:
barrier() for sync
W0->>
W0: shared_data["result"] = result
W0->>DIS: result_event.set(
)
TI
W0->>
TI
W0:
Set processing = False
TI
W0->>
DIS: Return result (only rank 0)
TIW1->>TIW1: Return None (non-rank 0
)
DIS->>DIS: result_event.wait()
DIS->>VS: return result
DIS-->>VS: TaskResponse
VS-->>PT: TaskResponse
PT->>TM: complete_task()<br/>(status: COMPLETED)
...
...
lightx2v/
api_
server.py
→
lightx2v/server
/__main__
.py
100755 → 100644
View file @
3c778aee
#!/usr/bin/env python
import
argparse
import
sys
from
pathlib
import
Path
sys
.
path
.
insert
(
0
,
str
(
Path
(
__file__
).
parent
.
parent
))
from
lightx2v.server.main
import
run_server
from
.main
import
run_server
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"
Run
LightX2V
inference s
erver"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"LightX2V
S
erver"
)
# Model arguments
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
,
help
=
"Path to model"
)
parser
.
add_argument
(
"--model_cls"
,
type
=
str
,
required
=
True
,
help
=
"Model class name"
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
help
=
"Path to model config JSON file"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"i2v"
,
help
=
"Task type (i2v, etc.)"
)
parser
.
add_argument
(
"--nproc_per_node"
,
type
=
int
,
default
=
1
,
help
=
"Number of processes per node (GPUs to use)"
)
# Server arguments
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"0.0.0.0"
,
help
=
"Server host"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
,
help
=
"Server port"
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"127.0.0.1"
,
help
=
"Server host"
)
args
=
parser
.
parse_args
()
# Parse any additional arguments that might be passed
args
,
unknown
=
parser
.
parse_known_args
()
# Add any unknown arguments as attributes to args
# This allows flexibility for model-specific arguments
for
i
in
range
(
0
,
len
(
unknown
),
2
):
if
unknown
[
i
].
startswith
(
"--"
):
key
=
unknown
[
i
][
2
:]
if
i
+
1
<
len
(
unknown
)
and
not
unknown
[
i
+
1
].
startswith
(
"--"
):
value
=
unknown
[
i
+
1
]
setattr
(
args
,
key
,
value
)
# Run the server
run_server
(
args
)
...
...
lightx2v/server/api.py
View file @
3c778aee
...
...
@@ -314,7 +314,6 @@ class ApiServer:
return
False
def
_ensure_processing_thread_running
(
self
):
"""Ensure the processing thread is running."""
if
self
.
processing_thread
is
None
or
not
self
.
processing_thread
.
is_alive
():
self
.
stop_processing
.
clear
()
self
.
processing_thread
=
threading
.
Thread
(
target
=
self
.
_task_processing_loop
,
daemon
=
True
)
...
...
@@ -322,9 +321,11 @@ class ApiServer:
logger
.
info
(
"Started task processing thread"
)
def
_task_processing_loop
(
self
):
"""Main loop that processes tasks from the queue one by one."""
logger
.
info
(
"Task processing loop started"
)
asyncio
.
set_event_loop
(
asyncio
.
new_event_loop
())
loop
=
asyncio
.
get_event_loop
()
while
not
self
.
stop_processing
.
is_set
():
task_id
=
task_manager
.
get_next_pending_task
()
...
...
@@ -335,12 +336,12 @@ class ApiServer:
task_info
=
task_manager
.
get_task
(
task_id
)
if
task_info
and
task_info
.
status
==
TaskStatus
.
PENDING
:
logger
.
info
(
f
"Processing task
{
task_id
}
"
)
self
.
_process_single_task
(
task_info
)
loop
.
run_until_complete
(
self
.
_process_single_task
(
task_info
)
)
loop
.
close
()
logger
.
info
(
"Task processing loop stopped"
)
def
_process_single_task
(
self
,
task_info
:
Any
):
"""Process a single task."""
async
def
_process_single_task
(
self
,
task_info
:
Any
):
assert
self
.
video_service
is
not
None
,
"Video service is not initialized"
task_id
=
task_info
.
task_id
...
...
@@ -360,7 +361,7 @@ class ApiServer:
task_manager
.
fail_task
(
task_id
,
"Task cancelled"
)
return
result
=
a
syncio
.
run
(
self
.
video_service
.
generate_video_with_stop_event
(
message
,
task_info
.
stop_event
)
)
result
=
a
wait
self
.
video_service
.
generate_video_with_stop_event
(
message
,
task_info
.
stop_event
)
if
result
:
task_manager
.
complete_task
(
task_id
,
result
.
save_video_path
)
...
...
lightx2v/server/config.py
View file @
3c778aee
...
...
@@ -11,9 +11,6 @@ class ServerConfig:
port
:
int
=
8000
max_queue_size
:
int
=
10
master_addr
:
str
=
"127.0.0.1"
master_port_range
:
tuple
=
(
29500
,
29600
)
task_timeout
:
int
=
300
task_history_limit
:
int
=
1000
...
...
@@ -42,31 +39,13 @@ class ServerConfig:
except
ValueError
:
logger
.
warning
(
f
"Invalid max queue size:
{
env_queue_size
}
"
)
if
env_master_addr
:
=
os
.
environ
.
get
(
"MASTER_ADDR"
):
config
.
master_addr
=
env_master_addr
# MASTER_ADDR is now managed by torchrun, no need to set manually
if
env_cache_dir
:
=
os
.
environ
.
get
(
"LIGHTX2V_CACHE_DIR"
):
config
.
cache_dir
=
env_cache_dir
return
config
def
find_free_master_port
(
self
)
->
str
:
import
socket
for
port
in
range
(
self
.
master_port_range
[
0
],
self
.
master_port_range
[
1
]):
try
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
self
.
master_addr
,
port
))
logger
.
info
(
f
"Found free port for master:
{
port
}
"
)
return
str
(
port
)
except
OSError
:
continue
raise
RuntimeError
(
f
"No free port found for master in range
{
self
.
master_port_range
[
0
]
}
-
{
self
.
master_port_range
[
1
]
-
1
}
"
f
"on address
{
self
.
master_addr
}
. Please adjust 'master_port_range' or free an occupied port."
)
def
validate
(
self
)
->
bool
:
valid
=
True
...
...
lightx2v/server/distributed_utils.py
View file @
3c778aee
...
...
@@ -6,8 +6,6 @@ import torch
import
torch.distributed
as
dist
from
loguru
import
logger
from
.gpu_manager
import
gpu_manager
class
DistributedManager
:
def
__init__
(
self
):
...
...
@@ -18,29 +16,35 @@ class DistributedManager:
CHUNK_SIZE
=
1024
*
1024
def
init_process_group
(
self
,
rank
:
int
,
world_size
:
int
,
master_addr
:
str
,
master_port
:
str
)
->
bool
:
def
init_process_group
(
self
)
->
bool
:
"""Initialize process group using torchrun environment variables"""
try
:
os
.
environ
[
"RANK"
]
=
str
(
rank
)
os
.
environ
[
"WORLD_SIZE"
]
=
str
(
world_size
)
os
.
environ
[
"MASTER_ADDR"
]
=
master_addr
os
.
environ
[
"MASTER_PORT"
]
=
master_port
# torchrun sets these environment variables automatically
self
.
rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
0
))
self
.
world_size
=
int
(
os
.
environ
.
get
(
"WORLD_SIZE"
,
1
))
if
self
.
world_size
>
1
:
# torchrun handles backend, init_method, rank, and world_size
# We just need to call init_process_group without parameters
backend
=
"nccl"
if
torch
.
cuda
.
is_available
()
else
"gloo"
dist
.
init_process_group
(
backend
=
backend
,
init_method
=
f
"tcp://
{
master_addr
}
:
{
master_port
}
"
,
rank
=
rank
,
world_size
=
world_size
)
dist
.
init_process_group
(
backend
=
backend
,
init_method
=
"env://"
)
logger
.
info
(
f
"Setup backend:
{
backend
}
"
)
self
.
device
=
gpu_manager
.
set_device_for_rank
(
rank
,
world_size
)
# Set CUDA device for this rank
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
set_device
(
self
.
rank
)
self
.
device
=
f
"cuda:
{
self
.
rank
}
"
else
:
self
.
device
=
"cpu"
else
:
self
.
device
=
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
self
.
is_initialized
=
True
self
.
rank
=
rank
self
.
world_size
=
world_size
logger
.
info
(
f
"Rank
{
rank
}
/
{
world_size
-
1
}
distributed environment initialized successfully"
)
logger
.
info
(
f
"Rank
{
self
.
rank
}
/
{
self
.
world_size
-
1
}
distributed environment initialized successfully"
)
return
True
except
Exception
as
e
:
logger
.
error
(
f
"Rank
{
rank
}
distributed environment initialization failed:
{
str
(
e
)
}
"
)
logger
.
error
(
f
"Rank
{
self
.
rank
}
distributed environment initialization failed:
{
str
(
e
)
}
"
)
return
False
def
cleanup
(
self
):
...
...
@@ -143,30 +147,3 @@ class DistributedManager:
task_bytes
=
self
.
_receive_byte_chunks
(
total_length
,
broadcast_device
)
task_data
=
pickle
.
loads
(
task_bytes
)
return
task_data
class
DistributedWorker
:
def
__init__
(
self
,
rank
:
int
,
world_size
:
int
,
master_addr
:
str
,
master_port
:
str
):
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
master_addr
=
master_addr
self
.
master_port
=
master_port
self
.
dist_manager
=
DistributedManager
()
def
init
(
self
)
->
bool
:
return
self
.
dist_manager
.
init_process_group
(
self
.
rank
,
self
.
world_size
,
self
.
master_addr
,
self
.
master_port
)
def
cleanup
(
self
):
self
.
dist_manager
.
cleanup
()
def
sync_and_report
(
self
,
task_id
:
str
,
status
:
str
,
result_queue
,
**
kwargs
):
self
.
dist_manager
.
barrier
()
if
self
.
dist_manager
.
is_rank_zero
():
result
=
{
"task_id"
:
task_id
,
"status"
:
status
,
**
kwargs
}
result_queue
.
put
(
result
)
logger
.
info
(
f
"Task
{
task_id
}
{
status
}
"
)
def
create_distributed_worker
(
rank
:
int
,
world_size
:
int
,
master_addr
:
str
,
master_port
:
str
)
->
DistributedWorker
:
return
DistributedWorker
(
rank
,
world_size
,
master_addr
,
master_port
)
lightx2v/server/gpu_manager.py
deleted
100644 → 0
View file @
32fd1c52
import
os
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
loguru
import
logger
class
GPUManager
:
def
__init__
(
self
):
self
.
available_gpus
=
self
.
_detect_gpus
()
self
.
gpu_count
=
len
(
self
.
available_gpus
)
def
_detect_gpus
(
self
)
->
List
[
int
]:
if
not
torch
.
cuda
.
is_available
():
logger
.
warning
(
"No CUDA devices available, will use CPU"
)
return
[]
gpu_count
=
torch
.
cuda
.
device_count
()
cuda_visible
=
os
.
environ
.
get
(
"CUDA_VISIBLE_DEVICES"
,
""
)
if
cuda_visible
:
try
:
visible_devices
=
[
int
(
d
.
strip
())
for
d
in
cuda_visible
.
split
(
","
)]
logger
.
info
(
f
"CUDA_VISIBLE_DEVICES set to:
{
visible_devices
}
"
)
return
list
(
range
(
len
(
visible_devices
)))
except
ValueError
:
logger
.
warning
(
f
"Invalid CUDA_VISIBLE_DEVICES:
{
cuda_visible
}
, using all devices"
)
available_gpus
=
list
(
range
(
gpu_count
))
logger
.
info
(
f
"Detected
{
gpu_count
}
GPU devices:
{
available_gpus
}
"
)
return
available_gpus
def
get_device_for_rank
(
self
,
rank
:
int
,
world_size
:
int
)
->
str
:
if
not
self
.
available_gpus
:
logger
.
info
(
f
"Rank
{
rank
}
: Using CPU (no GPUs available)"
)
return
"cpu"
if
self
.
gpu_count
==
1
:
device
=
f
"cuda:
{
self
.
available_gpus
[
0
]
}
"
logger
.
info
(
f
"Rank
{
rank
}
: Using single GPU
{
device
}
"
)
return
device
if
self
.
gpu_count
>=
world_size
:
gpu_id
=
self
.
available_gpus
[
rank
%
self
.
gpu_count
]
device
=
f
"cuda:
{
gpu_id
}
"
logger
.
info
(
f
"Rank
{
rank
}
: Assigned to dedicated GPU
{
device
}
"
)
return
device
else
:
gpu_id
=
self
.
available_gpus
[
rank
%
self
.
gpu_count
]
device
=
f
"cuda:
{
gpu_id
}
"
logger
.
info
(
f
"Rank
{
rank
}
: Sharing GPU
{
device
}
(world_size=
{
world_size
}
> gpu_count=
{
self
.
gpu_count
}
)"
)
return
device
def
set_device_for_rank
(
self
,
rank
:
int
,
world_size
:
int
)
->
str
:
device
=
self
.
get_device_for_rank
(
rank
,
world_size
)
if
device
.
startswith
(
"cuda:"
):
gpu_id
=
int
(
device
.
split
(
":"
)[
1
])
torch
.
cuda
.
set_device
(
gpu_id
)
logger
.
info
(
f
"Rank
{
rank
}
: CUDA device set to
{
gpu_id
}
"
)
return
device
def
get_memory_info
(
self
,
device
:
Optional
[
str
]
=
None
)
->
Tuple
[
int
,
int
]:
if
not
torch
.
cuda
.
is_available
():
return
(
0
,
0
)
if
device
and
device
.
startswith
(
"cuda:"
):
gpu_id
=
int
(
device
.
split
(
":"
)[
1
])
else
:
gpu_id
=
torch
.
cuda
.
current_device
()
try
:
used
=
torch
.
cuda
.
memory_allocated
(
gpu_id
)
total
=
torch
.
cuda
.
get_device_properties
(
gpu_id
).
total_memory
return
(
used
,
total
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to get memory info for device
{
gpu_id
}
:
{
e
}
"
)
return
(
0
,
0
)
def
clear_cache
(
self
,
device
:
Optional
[
str
]
=
None
):
if
not
torch
.
cuda
.
is_available
():
return
if
device
and
device
.
startswith
(
"cuda:"
):
gpu_id
=
int
(
device
.
split
(
":"
)[
1
])
with
torch
.
cuda
.
device
(
gpu_id
):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
synchronize
()
else
:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
synchronize
()
logger
.
info
(
f
"GPU cache cleared for device:
{
device
or
'current'
}
"
)
@
staticmethod
def
get_optimal_world_size
(
requested_world_size
:
int
)
->
int
:
if
not
torch
.
cuda
.
is_available
():
logger
.
warning
(
"No GPUs available, using single process"
)
return
1
gpu_count
=
torch
.
cuda
.
device_count
()
if
requested_world_size
<=
0
:
optimal_size
=
gpu_count
logger
.
info
(
f
"Auto-detected world_size:
{
optimal_size
}
(based on
{
gpu_count
}
GPUs)"
)
elif
requested_world_size
>
gpu_count
:
logger
.
warning
(
f
"Requested world_size (
{
requested_world_size
}
) exceeds GPU count (
{
gpu_count
}
). Processes will share GPUs."
)
optimal_size
=
requested_world_size
else
:
optimal_size
=
requested_world_size
return
optimal_size
gpu_manager
=
GPUManager
()
lightx2v/server/main.py
View file @
3c778aee
import
os
import
sys
from
pathlib
import
Path
...
...
@@ -10,9 +11,14 @@ from .service import DistributedInferenceService
def
run_server
(
args
):
"""Run server with torchrun support"""
inference_service
=
None
try
:
logger
.
info
(
"Starting LightX2V server..."
)
# Get rank from environment (set by torchrun)
rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
0
))
world_size
=
int
(
os
.
environ
.
get
(
"WORLD_SIZE"
,
1
))
logger
.
info
(
f
"Starting LightX2V server (Rank
{
rank
}
/
{
world_size
}
)..."
)
if
hasattr
(
args
,
"host"
)
and
args
.
host
:
server_config
.
host
=
args
.
host
...
...
@@ -22,11 +28,14 @@ def run_server(args):
if
not
server_config
.
validate
():
raise
RuntimeError
(
"Invalid server configuration"
)
# Initialize inference service
inference_service
=
DistributedInferenceService
()
if
not
inference_service
.
start_distributed_inference
(
args
):
raise
RuntimeError
(
"Failed to start distributed inference service"
)
logger
.
info
(
"
Inference service started successfully"
)
logger
.
info
(
f
"Rank
{
rank
}
:
Inference service started successfully"
)
if
rank
==
0
:
# Only rank 0 runs the FastAPI server
cache_dir
=
Path
(
server_config
.
cache_dir
)
cache_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
...
...
@@ -35,15 +44,21 @@ def run_server(args):
app
=
api_server
.
get_app
()
logger
.
info
(
f
"Starting server on
{
server_config
.
host
}
:
{
server_config
.
port
}
"
)
logger
.
info
(
f
"Starting
FastAPI
server on
{
server_config
.
host
}
:
{
server_config
.
port
}
"
)
uvicorn
.
run
(
app
,
host
=
server_config
.
host
,
port
=
server_config
.
port
,
log_level
=
"info"
)
else
:
# Non-rank-0 processes run the worker loop
logger
.
info
(
f
"Rank
{
rank
}
: Starting worker loop"
)
import
asyncio
asyncio
.
run
(
inference_service
.
run_worker_loop
())
except
KeyboardInterrupt
:
logger
.
info
(
"Server interrupted by user"
)
logger
.
info
(
f
"Server
rank
{
rank
}
interrupted by user"
)
if
inference_service
:
inference_service
.
stop_distributed_inference
()
except
Exception
as
e
:
logger
.
error
(
f
"Server failed:
{
e
}
"
)
logger
.
error
(
f
"Server
rank
{
rank
}
failed:
{
e
}
"
)
if
inference_service
:
inference_service
.
stop_distributed_inference
()
sys
.
exit
(
1
)
lightx2v/server/service.py
View file @
3c778aee
import
asyncio
import
threading
import
time
import
json
import
os
import
uuid
from
pathlib
import
Path
from
typing
import
Optional
from
typing
import
Any
,
Dict
,
Optional
from
urllib.parse
import
urlparse
import
httpx
import
torch
.multiprocessing
as
mp
import
torch
from
loguru
import
logger
from
..infer
import
init_runner
from
..utils.set_config
import
set_config
from
.audio_utils
import
is_base64_audio
,
save_base64_audio
from
.config
import
server_config
from
.distributed_utils
import
create_distributed_worker
from
.distributed_utils
import
DistributedManager
from
.image_utils
import
is_base64_image
,
save_base64_image
from
.schema
import
TaskRequest
,
TaskResponse
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
class
FileService
:
def
__init__
(
self
,
cache_dir
:
Path
):
...
...
@@ -196,280 +193,191 @@ class FileService:
self
.
_http_client
=
None
def
_distributed_inference_worker
(
rank
,
world_size
,
master_addr
,
master_port
,
args
,
shared_data
,
task_event
,
result_event
):
task_data
=
None
worker
=
None
class
TorchrunInferenceWorker
:
"""Worker class for torchrun-based distributed inference"""
def
__init__
(
self
):
self
.
rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
0
))
self
.
world_size
=
int
(
os
.
environ
.
get
(
"WORLD_SIZE"
,
1
))
self
.
runner
=
None
self
.
dist_manager
=
DistributedManager
()
self
.
request_queue
=
asyncio
.
Queue
()
if
self
.
rank
==
0
else
None
self
.
processing
=
False
# Track if currently processing a request
def
init
(
self
,
args
)
->
bool
:
"""Initialize the worker with model and distributed setup"""
try
:
logger
.
info
(
f
"Process
{
rank
}
/
{
world_size
-
1
}
initializing distributed inference service..."
)
worker
=
create_distributed_worker
(
rank
,
world_size
,
master_addr
,
master_port
)
if
not
worker
.
init
():
raise
RuntimeError
(
f
"Rank
{
rank
}
distributed environment initialization failed"
)
# Initialize distributed process group using torchrun env vars
if
self
.
world_size
>
1
:
if
not
self
.
dist_manager
.
init_process_group
():
raise
RuntimeError
(
"Failed to initialize distributed process group"
)
else
:
# Single GPU mode
self
.
dist_manager
.
rank
=
0
self
.
dist_manager
.
world_size
=
1
self
.
dist_manager
.
device
=
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
self
.
dist_manager
.
is_initialized
=
False
# Initialize model
config
=
set_config
(
args
)
logger
.
info
(
f
"Rank
{
rank
}
config:
{
config
}
"
)
runner
=
init_runner
(
config
)
logger
.
info
(
f
"Process
{
rank
}
/
{
world_size
-
1
}
distributed inference service initialization completed"
)
if
self
.
rank
==
0
:
logger
.
info
(
f
"Config:
\n
{
json
.
dumps
(
config
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
while
True
:
if
not
task_event
.
wait
(
timeout
=
1.0
):
continue
self
.
runner
=
init_runner
(
config
)
logger
.
info
(
f
"Rank
{
self
.
rank
}
/
{
self
.
world_size
-
1
}
initialization completed"
)
if
rank
==
0
:
if
shared_data
.
get
(
"stop"
,
False
):
logger
.
info
(
f
"Process
{
rank
}
received stop signal, exiting inference service"
)
worker
.
dist_manager
.
broadcast_task_data
(
None
)
break
return
True
task_data
=
shared_data
.
get
(
"current_task"
)
if
task_data
:
worker
.
dist_manager
.
broadcast_task_data
(
task_data
)
shared_data
[
"current_task"
]
=
None
try
:
task_event
.
clear
()
except
Exception
:
pass
else
:
continue
else
:
task_data
=
worker
.
dist_manager
.
broadcast_task_data
()
if
task_data
is
None
:
logger
.
info
(
f
"Process
{
rank
}
received stop signal, exiting inference service"
)
break
except
Exception
as
e
:
logger
.
error
(
f
"Rank
{
self
.
rank
}
initialization failed:
{
str
(
e
)
}
"
)
return
False
if
task_data
is
not
None
:
logger
.
info
(
f
"Process
{
rank
}
received inference task:
{
task_data
[
'task_id'
]
}
"
)
async
def
process_request
(
self
,
task_data
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]
:
"""Process a single inference request
Note: We keep the inference synchronous to maintain NCCL/CUDA context integrity.
The async wrapper allows FastAPI to handle other requests while this runs.
"""
try
:
runner
.
set_inputs
(
task_data
)
# type: ignore
runner
.
run_pipeline
()
worker
.
dist_manager
.
barrier
()
if
rank
==
0
:
# Only rank 0 updates the result
shared_data
[
"result"
]
=
{
# Only rank 0 broadcasts task data (worker processes already received it in worker_loop)
if
self
.
world_size
>
1
and
self
.
rank
==
0
:
task_data
=
self
.
dist_manager
.
broadcast_task_data
(
task_data
)
# Run inference directly - torchrun handles the parallelization
# Using asyncio.to_thread would be risky with NCCL operations
# Instead, we rely on FastAPI's async handling and queue management
self
.
runner
.
set_inputs
(
task_data
)
self
.
runner
.
run_pipeline
()
# Small yield to allow other async operations if needed
await
asyncio
.
sleep
(
0
)
# Synchronize all ranks
if
self
.
world_size
>
1
:
self
.
dist_manager
.
barrier
()
# Only rank 0 returns the result
if
self
.
rank
==
0
:
return
{
"task_id"
:
task_data
[
"task_id"
],
"status"
:
"success"
,
"save_video_path"
:
task_data
.
get
(
"video_path"
,
task_data
[
"save_video_path"
]),
# Return original path for API
"save_video_path"
:
task_data
.
get
(
"video_path"
,
task_data
[
"save_video_path"
]),
"message"
:
"Inference completed"
,
}
result_event
.
set
()
logger
.
info
(
f
"Task
{
task_data
[
'task_id'
]
}
success"
)
else
:
return
None
except
Exception
as
e
:
logger
.
exception
(
f
"Process
{
rank
}
error occurred while processing task
:
{
str
(
e
)
}
"
)
worker
.
dist_manager
.
barrier
()
logger
.
error
(
f
"Rank
{
self
.
rank
}
inference failed
:
{
str
(
e
)
}
"
)
if
self
.
world_size
>
1
:
self
.
dist_manager
.
barrier
()
if
rank
==
0
:
# Only rank 0 updates the result
shared_data
[
"result"
]
=
{
if
self
.
rank
==
0
:
return
{
"task_id"
:
task_data
.
get
(
"task_id"
,
"unknown"
),
"status"
:
"failed"
,
"error"
:
str
(
e
),
"message"
:
f
"Inference failed:
{
str
(
e
)
}
"
,
}
result_event
.
set
()
logger
.
info
(
f
"Task
{
task_data
.
get
(
'task_id'
,
'unknown'
)
}
failed"
)
else
:
return
None
except
KeyboardInterrupt
:
logger
.
info
(
f
"Process
{
rank
}
received KeyboardInterrupt, gracefully exiting"
)
except
Exception
as
e
:
logger
.
exception
(
f
"Distributed inference service process
{
rank
}
startup failed:
{
str
(
e
)
}
"
)
if
rank
==
0
:
shared_data
[
"result"
]
=
{
"task_id"
:
"startup"
,
"status"
:
"startup_failed"
,
"error"
:
str
(
e
),
"message"
:
f
"Inference service startup failed:
{
str
(
e
)
}
"
,
}
result_event
.
set
()
finally
:
async
def
worker_loop
(
self
):
"""Non-rank-0 workers: Listen for broadcast tasks"""
while
True
:
try
:
if
worker
:
worker
.
cleanup
()
task_data
=
self
.
dist_manager
.
broadcast_task_data
()
if
task_data
is
None
:
logger
.
info
(
f
"Rank
{
self
.
rank
}
received stop signal"
)
break
await
self
.
process_request
(
task_data
)
except
Exception
as
e
:
logger
.
debug
(
f
"Error cleaning up worker for rank
{
rank
}
:
{
e
}
"
)
logger
.
error
(
f
"Rank
{
self
.
rank
}
worker loop error:
{
str
(
e
)
}
"
)
continue
def
cleanup
(
self
):
self
.
dist_manager
.
cleanup
()
class
DistributedInferenceService
:
def
__init__
(
self
):
self
.
manager
=
None
self
.
shared_data
=
None
self
.
task_event
=
None
self
.
result_event
=
None
self
.
processes
=
[]
self
.
worker
=
None
self
.
is_running
=
False
self
.
args
=
None
def
start_distributed_inference
(
self
,
args
)
->
bool
:
if
hasattr
(
args
,
"lora_path"
)
and
args
.
lora_path
:
args
.
lora_configs
=
[{
"path"
:
args
.
lora_path
,
"strength"
:
getattr
(
args
,
"lora_strength"
,
1.0
)}]
delattr
(
args
,
"lora_path"
)
if
hasattr
(
args
,
"lora_strength"
):
delattr
(
args
,
"lora_strength"
)
self
.
args
=
args
if
self
.
is_running
:
logger
.
warning
(
"Distributed inference service is already running"
)
return
True
nproc_per_node
=
args
.
nproc_per_node
if
nproc_per_node
<=
0
:
logger
.
error
(
"nproc_per_node must be greater than 0"
)
return
False
try
:
master_addr
=
server_config
.
master_addr
master_port
=
server_config
.
find_free_master_port
()
logger
.
info
(
f
"Distributed inference service Master Addr:
{
master_addr
}
, Master Port:
{
master_port
}
"
)
# Create shared data structures
self
.
manager
=
mp
.
Manager
()
self
.
shared_data
=
self
.
manager
.
dict
()
self
.
task_event
=
self
.
manager
.
Event
()
self
.
result_event
=
self
.
manager
.
Event
()
# Initialize shared data
self
.
shared_data
[
"current_task"
]
=
None
self
.
shared_data
[
"result"
]
=
None
self
.
shared_data
[
"stop"
]
=
False
for
rank
in
range
(
nproc_per_node
):
p
=
mp
.
Process
(
target
=
_distributed_inference_worker
,
args
=
(
rank
,
nproc_per_node
,
master_addr
,
master_port
,
args
,
self
.
shared_data
,
self
.
task_event
,
self
.
result_event
,
),
daemon
=
False
,
# Changed to False for proper cleanup
)
p
.
start
()
self
.
processes
.
append
(
p
)
self
.
worker
=
TorchrunInferenceWorker
()
if
not
self
.
worker
.
init
(
args
):
raise
RuntimeError
(
"Worker initialization failed"
)
self
.
is_running
=
True
logger
.
info
(
f
"
Distributed
inference service started successfully
with
{
nproc_per_node
}
processes
"
)
logger
.
info
(
f
"
Rank
{
self
.
worker
.
rank
}
inference service started successfully"
)
return
True
except
Exception
as
e
:
logger
.
e
xception
(
f
"Error
occurred while starting distributed
inference service:
{
str
(
e
)
}
"
)
logger
.
e
rror
(
f
"Error
starting
inference service:
{
str
(
e
)
}
"
)
self
.
stop_distributed_inference
()
return
False
def
stop_distributed_inference
(
self
):
assert
self
.
task_event
,
"Task event is not initialized"
assert
self
.
result_event
,
"Result event is not initialized"
if
not
self
.
is_running
:
return
try
:
logger
.
info
(
f
"Stopping
{
len
(
self
.
processes
)
}
distributed inference service processes..."
)
if
self
.
shared_data
is
not
None
:
self
.
shared_data
[
"stop"
]
=
True
self
.
task_event
.
set
()
for
p
in
self
.
processes
:
try
:
p
.
join
(
timeout
=
10
)
if
p
.
is_alive
():
logger
.
warning
(
f
"Process
{
p
.
pid
}
did not end within the specified time, forcing termination..."
)
p
.
terminate
()
p
.
join
(
timeout
=
5
)
except
Exception
as
e
:
logger
.
warning
(
f
"Error terminating process
{
p
.
pid
}
:
{
e
}
"
)
logger
.
info
(
"All distributed inference service processes have stopped"
)
if
self
.
worker
:
self
.
worker
.
cleanup
()
logger
.
info
(
"Inference service stopped"
)
except
Exception
as
e
:
logger
.
error
(
f
"Error
occurred while stopping distributed
inference service:
{
str
(
e
)
}
"
)
logger
.
error
(
f
"Error
stopping
inference service:
{
str
(
e
)
}
"
)
finally
:
# Clean up resources
self
.
processes
=
[]
self
.
manager
=
None
self
.
shared_data
=
None
self
.
task_event
=
None
self
.
result_event
=
None
self
.
worker
=
None
self
.
is_running
=
False
def
submit_task
(
self
,
task_data
:
dict
)
->
bool
:
assert
self
.
task_event
,
"Task event is not initialized"
assert
self
.
result_event
,
"Result event is not initialized"
if
not
self
.
is_running
or
not
self
.
shared_data
:
logger
.
error
(
"Distributed inference service is not started"
)
return
False
try
:
self
.
result_event
.
clear
()
self
.
shared_data
[
"result"
]
=
None
self
.
shared_data
[
"current_task"
]
=
task_data
self
.
task_event
.
set
()
# Signal workers
return
True
except
Exception
as
e
:
logger
.
error
(
f
"Failed to submit task:
{
str
(
e
)
}
"
)
return
False
def
wait_for_result
(
self
,
task_id
:
str
,
timeout
:
Optional
[
int
]
=
None
)
->
Optional
[
dict
]:
assert
self
.
task_event
,
"Task event is not initialized"
assert
self
.
result_event
,
"Result event is not initialized"
if
timeout
is
None
:
timeout
=
server_config
.
task_timeout
if
not
self
.
is_running
or
not
self
.
shared_data
:
return
None
if
self
.
result_event
.
wait
(
timeout
=
timeout
):
result
=
self
.
shared_data
.
get
(
"result"
)
if
result
and
result
.
get
(
"task_id"
)
==
task_id
:
self
.
shared_data
[
"current_task"
]
=
None
self
.
task_event
.
clear
()
return
result
return
None
def
wait_for_result_with_stop
(
self
,
task_id
:
str
,
stop_event
:
threading
.
Event
,
timeout
:
Optional
[
int
]
=
None
)
->
Optional
[
dict
]:
if
timeout
is
None
:
timeout
=
server_config
.
task_timeout
if
not
self
.
is_running
or
not
self
.
shared_data
:
async
def
submit_task_async
(
self
,
task_data
:
dict
)
->
Optional
[
dict
]:
if
not
self
.
is_running
or
not
self
.
worker
:
logger
.
error
(
"Inference service is not started"
)
return
None
assert
self
.
task_event
,
"Task event is not initialized"
assert
self
.
result_event
,
"Result event is not initialized"
start_time
=
time
.
time
()
while
time
.
time
()
-
start_time
<
timeout
:
if
stop_event
.
is_set
():
logger
.
info
(
f
"Task
{
task_id
}
stop event triggered during wait"
)
self
.
shared_data
[
"current_task"
]
=
None
self
.
task_event
.
clear
()
if
self
.
worker
.
rank
!=
0
:
return
None
if
self
.
result_event
.
wait
(
timeout
=
0.5
):
result
=
self
.
shared_data
.
get
(
"result"
)
if
result
and
result
.
get
(
"task_id"
)
==
task_id
:
self
.
shared_data
[
"current_task"
]
=
None
self
.
task_event
.
clear
()
try
:
if
self
.
worker
.
processing
:
# If we want to support queueing, we can add the task to queue
# For now, we'll process sequentially
logger
.
info
(
f
"Waiting for previous task to complete before processing task
{
task_data
.
get
(
'task_id'
)
}
"
)
self
.
worker
.
processing
=
True
result
=
await
self
.
worker
.
process_request
(
task_data
)
self
.
worker
.
processing
=
False
return
result
return
None
except
Exception
as
e
:
self
.
worker
.
processing
=
False
logger
.
error
(
f
"Failed to process task:
{
str
(
e
)
}
"
)
return
{
"task_id"
:
task_data
.
get
(
"task_id"
,
"unknown"
),
"status"
:
"failed"
,
"error"
:
str
(
e
),
"message"
:
f
"Task processing failed:
{
str
(
e
)
}
"
,
}
def
server_metadata
(
self
):
assert
hasattr
(
self
,
"args"
),
"Distributed inference service has not been started. Call start_distributed_inference() first."
return
{
"nproc_per_node"
:
self
.
args
.
nproc_per_node
,
"model_cls"
:
self
.
args
.
model_cls
,
"model_path"
:
self
.
args
.
model_path
}
return
{
"nproc_per_node"
:
self
.
worker
.
world_size
,
"model_cls"
:
self
.
args
.
model_cls
,
"model_path"
:
self
.
args
.
model_path
}
async
def
run_worker_loop
(
self
):
"""Run the worker loop for non-rank-0 processes"""
if
self
.
worker
and
self
.
worker
.
rank
!=
0
:
await
self
.
worker
.
worker_loop
()
class
VideoGenerationService
:
...
...
@@ -478,6 +386,7 @@ class VideoGenerationService:
self
.
inference_service
=
inference_service
async
def
generate_video_with_stop_event
(
self
,
message
:
TaskRequest
,
stop_event
)
->
Optional
[
TaskResponse
]:
"""Generate video using torchrun-based inference"""
try
:
task_data
=
{
field
:
getattr
(
message
,
field
)
for
field
in
message
.
model_fields_set
if
field
!=
"task_id"
}
task_data
[
"task_id"
]
=
message
.
task_id
...
...
@@ -496,6 +405,8 @@ class VideoGenerationService:
else
:
task_data
[
"image_path"
]
=
message
.
image_path
logger
.
info
(
f
"Task
{
message
.
task_id
}
image path:
{
task_data
[
'image_path'
]
}
"
)
if
"audio_path"
in
message
.
model_fields_set
and
message
.
audio_path
:
if
message
.
audio_path
.
startswith
(
"http"
):
audio_path
=
await
self
.
file_service
.
download_audio
(
message
.
audio_path
)
...
...
@@ -506,20 +417,19 @@ class VideoGenerationService:
else
:
task_data
[
"audio_path"
]
=
message
.
audio_path
logger
.
info
(
f
"Task
{
message
.
task_id
}
audio path:
{
task_data
[
'audio_path'
]
}
"
)
actual_save_path
=
self
.
file_service
.
get_output_path
(
message
.
save_video_path
)
task_data
[
"save_video_path"
]
=
str
(
actual_save_path
)
task_data
[
"video_path"
]
=
message
.
save_video_path
if
not
self
.
inference_service
.
submit_task
(
task_data
):
raise
RuntimeError
(
"Distributed inference service is not started"
)
result
=
self
.
inference_service
.
wait_for_result_with_stop
(
message
.
task_id
,
stop_event
,
timeout
=
300
)
result
=
await
self
.
inference_service
.
submit_task_async
(
task_data
)
if
result
is
None
:
if
stop_event
.
is_set
():
logger
.
info
(
f
"Task
{
message
.
task_id
}
cancelled during processing"
)
return
None
raise
RuntimeError
(
"Task processing
timeout
"
)
raise
RuntimeError
(
"Task processing
failed
"
)
if
result
.
get
(
"status"
)
==
"success"
:
return
TaskResponse
(
...
...
scripts/server/start_multi_servers.sh
View file @
3c778aee
#!/bin/bash
# set path and first
lightx2v_path
=
model_path
=
lightx2v_path
=
/mnt/afs/users/lijiaqi2/deploy-comfyui-ljq-custom_nodes/ComfyUI-Lightx2vWrapper/lightx2v
model_path
=
/mnt/afs/users/lijiaqi2/wan_model/Wan2.1-R2V0909-Audio-14B-720P-fp8
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3
# set environment variables
source
${
lightx2v_path
}
/scripts/base/base.sh
# Start multiple servers
python
-m
lightx2v.api_multi_servers
\
--num_gpus
$num_gpus
\
--start_port
8000
\
--model_cls
wan2.1_distill
\
torchrun
--nproc_per_node
4
-m
lightx2v.server
\
--model_cls
seko_talk
\
--task
i2v
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/distill/wan_i2v_distill_4step_cfg.json
--config_json
${
lightx2v_path
}
/configs/seko_talk/xxx_dist.json
\
--port
8000
scripts/server/start_server.sh
View file @
3c778aee
#!/bin/bash
# set path and first
lightx2v_path
=
model_path
=
lightx2v_path
=
/path/to/Lightx2v
model_path
=
/path/to/Wan2.1-R2V0909-Audio-14B-720P-fp8
export
CUDA_VISIBLE_DEVICES
=
0
...
...
@@ -11,12 +11,11 @@ source ${lightx2v_path}/scripts/base/base.sh
# Start API server with distributed inference service
python
-m
lightx2v.
api_
server
\
--model_cls
wan2.1_distill
\
python
-m
lightx2v.server
\
--model_cls
seko_talk
\
--task
i2v
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/distill/wan_i2v_distill_4step_cfg.json
\
--port
8000
\
--nproc_per_node
1
--config_json
${
lightx2v_path
}
/configs/seko_talk/seko_talk_05_offload_fp8_4090.json
\
--port
8000
echo
"Service stopped"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment