Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
3c778aee
Commit
3c778aee
authored
Sep 16, 2025
by
PengGao
Committed by
GitHub
Sep 16, 2025
Browse files
Gp/dev (#310)
Co-authored-by:
Yang Yong(雍洋)
<
yongyang1030@163.com
>
parent
32fd1c52
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
368 additions
and
713 deletions
+368
-713
lightx2v/api_multi_servers.py
lightx2v/api_multi_servers.py
+0
-172
lightx2v/server/README.md
lightx2v/server/README.md
+114
-86
lightx2v/server/__main__.py
lightx2v/server/__main__.py
+34
-0
lightx2v/server/api.py
lightx2v/server/api.py
+7
-6
lightx2v/server/config.py
lightx2v/server/config.py
+1
-22
lightx2v/server/distributed_utils.py
lightx2v/server/distributed_utils.py
+23
-46
lightx2v/server/gpu_manager.py
lightx2v/server/gpu_manager.py
+0
-116
lightx2v/server/main.py
lightx2v/server/main.py
+26
-11
lightx2v/server/service.py
lightx2v/server/service.py
+149
-239
scripts/server/start_multi_servers.sh
scripts/server/start_multi_servers.sh
+8
-8
scripts/server/start_server.sh
scripts/server/start_server.sh
+6
-7
No files found.
lightx2v/api_multi_servers.py
deleted
100644 → 0
View file @
32fd1c52
import
argparse
import
concurrent.futures
import
os
import
socket
import
subprocess
import
time
from
dataclasses
import
dataclass
from
typing
import
Optional
import
requests
from
loguru
import
logger
@
dataclass
class
ServerConfig
:
port
:
int
gpu_id
:
int
model_cls
:
str
task
:
str
model_path
:
str
config_json
:
str
def
get_node_ip
()
->
str
:
"""Get the IP address of the current node"""
try
:
# Create a UDP socket
s
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_DGRAM
)
# Connect to an external address (no actual connection needed)
s
.
connect
((
"8.8.8.8"
,
80
))
# Get local IP
ip
=
s
.
getsockname
()[
0
]
s
.
close
()
return
ip
except
Exception
as
e
:
logger
.
error
(
f
"Failed to get IP address:
{
e
}
"
)
return
"localhost"
def
is_port_in_use
(
port
:
int
)
->
bool
:
"""Check if a port is in use"""
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
return
s
.
connect_ex
((
"localhost"
,
port
))
==
0
def
find_available_port
(
start_port
:
int
)
->
Optional
[
int
]:
"""Find an available port starting from start_port"""
port
=
start_port
while
port
<
start_port
+
1000
:
# Try up to 1000 ports
if
not
is_port_in_use
(
port
):
return
port
port
+=
1
return
None
def
start_server
(
config
:
ServerConfig
)
->
Optional
[
tuple
[
subprocess
.
Popen
,
str
]]:
"""Start a single server instance"""
try
:
# Set GPU
env
=
os
.
environ
.
copy
()
env
[
"CUDA_VISIBLE_DEVICES"
]
=
str
(
config
.
gpu_id
)
# Start server
process
=
subprocess
.
Popen
(
[
"python"
,
"-m"
,
"lightx2v.api_server"
,
"--model_cls"
,
config
.
model_cls
,
"--task"
,
config
.
task
,
"--model_path"
,
config
.
model_path
,
"--config_json"
,
config
.
config_json
,
"--port"
,
str
(
config
.
port
),
],
env
=
env
,
)
# Wait for server to start, up to 600 seconds
node_ip
=
get_node_ip
()
service_url
=
f
"http://
{
node_ip
}
:
{
config
.
port
}
/v1/service/status"
# Check once per second, up to 600 times
for
_
in
range
(
600
):
try
:
response
=
requests
.
get
(
service_url
,
timeout
=
1
)
if
response
.
status_code
==
200
:
return
process
,
f
"http://
{
node_ip
}
:
{
config
.
port
}
"
except
(
requests
.
RequestException
,
ConnectionError
)
as
e
:
pass
time
.
sleep
(
1
)
# If timeout, terminate the process
logger
.
error
(
f
"Server startup timeout: port=
{
config
.
port
}
, gpu=
{
config
.
gpu_id
}
"
)
process
.
terminate
()
return
None
except
Exception
as
e
:
logger
.
error
(
f
"Failed to start server:
{
e
}
"
)
return
None
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--num_gpus"
,
type
=
int
,
required
=
True
,
help
=
"Number of GPUs to use"
)
parser
.
add_argument
(
"--start_port"
,
type
=
int
,
required
=
True
,
help
=
"Starting port number"
)
parser
.
add_argument
(
"--model_cls"
,
type
=
str
,
required
=
True
,
help
=
"Model class"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
required
=
True
,
help
=
"Task type"
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
,
help
=
"Model path"
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
,
help
=
"Config file path"
)
args
=
parser
.
parse_args
()
# Prepare configurations for all servers on this node
server_configs
=
[]
current_port
=
args
.
start_port
# Create configs for each GPU on this node
for
gpu
in
range
(
args
.
num_gpus
):
port
=
find_available_port
(
current_port
)
if
port
is
None
:
logger
.
error
(
f
"Cannot find available port starting from
{
current_port
}
"
)
continue
config
=
ServerConfig
(
port
=
port
,
gpu_id
=
gpu
,
model_cls
=
args
.
model_cls
,
task
=
args
.
task
,
model_path
=
args
.
model_path
,
config_json
=
args
.
config_json
)
server_configs
.
append
(
config
)
current_port
=
port
+
1
# Start all servers in parallel
processes
=
[]
urls
=
[]
with
concurrent
.
futures
.
ThreadPoolExecutor
(
max_workers
=
len
(
server_configs
))
as
executor
:
future_to_config
=
{
executor
.
submit
(
start_server
,
config
):
config
for
config
in
server_configs
}
for
future
in
concurrent
.
futures
.
as_completed
(
future_to_config
):
config
=
future_to_config
[
future
]
try
:
result
=
future
.
result
()
if
result
:
process
,
url
=
result
processes
.
append
(
process
)
urls
.
append
(
url
)
logger
.
info
(
f
"Server started successfully:
{
url
}
(GPU:
{
config
.
gpu_id
}
)"
)
else
:
logger
.
error
(
f
"Failed to start server: port=
{
config
.
port
}
, gpu=
{
config
.
gpu_id
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Error occurred while starting server:
{
e
}
"
)
# Print all server URLs
logger
.
info
(
"
\n
All server URLs:"
)
for
url
in
urls
:
logger
.
info
(
url
)
# Print node information
node_ip
=
get_node_ip
()
logger
.
info
(
f
"
\n
Current node IP:
{
node_ip
}
"
)
logger
.
info
(
f
"Number of servers started:
{
len
(
urls
)
}
"
)
try
:
# Wait for all processes
for
process
in
processes
:
process
.
wait
()
except
KeyboardInterrupt
:
logger
.
info
(
"Received interrupt signal, shutting down all servers..."
)
for
process
in
processes
:
process
.
terminate
()
if
__name__
==
"__main__"
:
main
()
lightx2v/server/README.md
View file @
3c778aee
...
@@ -9,85 +9,91 @@ The LightX2V server is a distributed video generation service built with FastAPI
...
@@ -9,85 +9,91 @@ The LightX2V server is a distributed video generation service built with FastAPI
### System Architecture
### System Architecture
```
mermaid
```
mermaid
graph TB
flowchart TB
subgraph "Client Layer"
Client[Client] -->|Send API Request| Router[FastAPI Router]
Client[HTTP Client]
end
subgraph "API Layer"
subgraph API Layer
FastAPI[FastAPI Application]
Router --> TaskRoutes[Task APIs]
ApiServer[ApiServer]
Router --> FileRoutes[File APIs]
Router1[Tasks Router<br/>/v1/tasks]
Router --> ServiceRoutes[Service Status APIs]
Router2[Files Router<br/>/v1/files]
Router3[Service Router<br/>/v1/service]
end
subgraph "Service Layer"
TaskRoutes --> CreateTask["POST /v1/tasks/ - Create Task"]
TaskManager[TaskManager<br/>Thread-safe Task Queue]
TaskRoutes --> CreateTaskForm["POST /v1/tasks/form - Form Create"]
FileService[FileService<br/>File I/O & Downloads]
TaskRoutes --> ListTasks["GET /v1/tasks/ - List Tasks"]
VideoService[VideoGenerationService]
TaskRoutes --> GetTaskStatus["GET /v1/tasks/id/status - Get Status"]
end
TaskRoutes --> GetTaskResult["GET /v1/tasks/id/result - Get Result"]
TaskRoutes --> StopTask["DELETE /v1/tasks/id - Stop Task"]
subgraph "Processing Layer"
Thread[Processing Thread<br/>Sequential Task Loop]
end
subgraph "Distributed Inference Layer"
FileRoutes --> DownloadFile["GET /v1/files/download/path - Download File"]
DistService[DistributedInferenceService]
SharedData[(Shared Data<br/>mp.Manager.dict)]
TaskEvent[Task Event<br/>mp.Manager.Event]
ResultEvent[Result Event<br/>mp.Manager.Event]
subgraph "Worker Processes"
ServiceRoutes --> GetServiceStatus["GET /v1/service/status - Service Status"]
W0[Worker 0<br/>Master/Rank 0]
ServiceRoutes --> GetServiceMetadata["GET /v1/service/metadata - Metadata"]
W1[Worker 1<br/>Rank 1]
WN[Worker N<br/>Rank N]
end
end
end
subgraph "Resource Management"
subgraph Task Management
GPUManager[GPUManager<br/>GPU Detection & Allocation]
TaskManager[Task Manager]
DistManager[DistributedManager<br/>PyTorch Distributed]
TaskQueue[Task Queue]
Config[ServerConfig<br/>Configuration]
TaskStatus[Task Status]
TaskResult[Task Result]
CreateTask --> TaskManager
CreateTaskForm --> TaskManager
TaskManager --> TaskQueue
TaskManager --> TaskStatus
TaskManager --> TaskResult
end
end
Client -->|HTTP Request| FastAPI
subgraph File Service
FastAPI --> ApiServer
FileService[File Service]
ApiServer --> Router1
DownloadImage[Download Image]
ApiServer --> Router2
DownloadAudio[Download Audio]
ApiServer --> Router3
SaveFile[Save File]
GetOutputPath[Get Output Path]
Router1 -->|Create/Manage Tasks| TaskManager
Router1 -->|Process Tasks| Thread
FileService --> DownloadImage
Router2 -->|File Operations| FileService
FileService --> DownloadAudio
Router3 -->|Service Status| TaskManager
FileService --> SaveFile
FileService --> GetOutputPath
Thread -->|Get Pending Tasks| TaskManager
end
Thread -->|Generate Video| VideoService
VideoService -->|Download Images| FileService
subgraph Processing Thread
VideoService -->|Submit Task| DistService
ProcessingThread[Processing Thread]
NextTask[Get Next Task]
ProcessTask[Process Single Task]
DistService -->|Update| SharedData
ProcessingThread --> NextTask
DistService -->|Signal| TaskEvent
ProcessingThread --> ProcessTask
TaskEvent -->|Notify| W0
end
W0 -->|Broadcast| W1
W0 -->|Broadcast| WN
W0 -->|Update Result| SharedData
subgraph Video Generation Service
W0 -->|Signal| ResultEvent
VideoService[Video Service]
ResultEvent -->|Notify| DistService
GenerateVideo[Generate Video]
W0 -.->|Uses| GPUManager
VideoService --> GenerateVideo
W1 -.->|Uses| GPUManager
end
WN -.->|Uses| GPUManager
W0 -.->|Setup| DistManager
subgraph Distributed Inference Service
W1 -.->|Setup| DistManager
InferenceService[Distributed Inference Service]
WN -.->|Setup| DistManager
SubmitTask[Submit Task]
Worker[Inference Worker Node]
ProcessRequest[Process Request]
RunPipeline[Run Inference Pipeline]
InferenceService --> SubmitTask
SubmitTask --> Worker
Worker --> ProcessRequest
ProcessRequest --> RunPipeline
end
DistService -.->|Reads| Config
%% ====== Connect Modules ======
ApiServer -.->|Reads| Config
TaskQueue --> ProcessingThread
ProcessTask --> VideoService
GenerateVideo --> InferenceService
GetTaskResult --> FileService
DownloadFile --> FileService
VideoService --> FileService
InferenceService --> TaskManager
TaskManager --> TaskStatus
```
```
## Task Processing Flow
## Task Processing Flow
...
@@ -100,9 +106,9 @@ sequenceDiagram
...
@@ -100,9 +106,9 @@ sequenceDiagram
participant PT as Processing Thread
participant PT as Processing Thread
participant VS as VideoService
participant VS as VideoService
participant FS as FileService
participant FS as FileService
participant DIS as Distributed
<br/>
Inference
Service
participant DIS as DistributedInferenceService
participant W0 as Worker
0
<br/>(
Master
)
participant
TI
W0 as
TorchrunInference
Worker<br/>(
Rank 0
)
participant W1 as
Worker
1..N
participant
TI
W1 as
TorchrunInferenceWorker<br/>(Rank
1..N
)
C->>API: POST /v1/tasks<br/>(Create Task)
C->>API: POST /v1/tasks<br/>(Create Task)
API->>TM: create_task()
API->>TM: create_task()
...
@@ -127,32 +133,54 @@ sequenceDiagram
...
@@ -127,32 +133,54 @@ sequenceDiagram
else Image is Base64
else Image is Base64
VS->>FS: save_base64_image()
VS->>FS: save_base64_image()
FS-->>VS: image_path
FS-->>VS: image_path
else Image is Upload
else Image is local path
VS->>FS: validate_file()
VS->>VS: use existing path
FS-->>VS: image_path
end
alt Audio is URL
VS->>FS: download_audio()
FS->>FS: HTTP download<br/>with retry
FS-->>VS: audio_path
else Audio is Base64
VS->>FS: save_base64_audio()
FS-->>VS: audio_path
else Audio is local path
VS->>VS: use existing path
end
end
VS->>DIS: submit_task(task_data)
VS->>DIS: submit_task_async(task_data)
DIS->>DIS: shared_data["current_task"] = task_data
DIS->>TIW0: process_request(task_data)
DIS->>DIS: task_event.set()
Note over TIW0,TIW1: Torchrun-based Distributed Processing
TIW0->>TIW0: Check if processing
TIW0->>TIW0: Set processing = True
Note over W0,W1: Distributed Processing
alt Multi-GPU Mode (world_size > 1)
W0->>W0: task_event.wait()
TIW0->>TIW1: broadcast_task_data()<br/>(via DistributedManager)
W0->>W0: Get task from shared_data
Note over TIW1: worker_loop() listens for broadcasts
W0->>W1: broadcast_task_data()
TIW1->>TIW1: Receive task_data
end
par Parallel Inference
par Parallel Inference across all ranks
W0->>W0: run_pipeline()
TIW0->>TIW0: runner.set_inputs(task_data)
TIW0->>TIW0: runner.run_pipeline()
and
and
W1->>W1: run_pipeline()
Note over TIW1: If world_size > 1
TIW1->>TIW1: runner.set_inputs(task_data)
TIW1->>TIW1: runner.run_pipeline()
end
Note over TIW0,TIW1: Synchronization
alt Multi-GPU Mode
TIW0->>TIW1: barrier() for sync
TIW1->>TIW0: barrier() response
end
end
W0->>W0:
barrier() for sync
TI
W0->>
TI
W0:
Set processing = False
W0->>
W0: shared_data["result"] = result
TI
W0->>
DIS: Return result (only rank 0)
W0->>DIS: result_event.set(
)
TIW1->>TIW1: Return None (non-rank 0
)
DIS->>DIS: result_event.wait()
DIS-->>VS: TaskResponse
DIS->>VS: return result
VS-->>PT: TaskResponse
VS-->>PT: TaskResponse
PT->>TM: complete_task()<br/>(status: COMPLETED)
PT->>TM: complete_task()<br/>(status: COMPLETED)
...
...
lightx2v/
api_
server.py
→
lightx2v/server
/__main__
.py
100755 → 100644
View file @
3c778aee
#!/usr/bin/env python
import
argparse
import
argparse
import
sys
from
pathlib
import
Path
sys
.
path
.
insert
(
0
,
str
(
Path
(
__file__
).
parent
.
parent
))
from
.main
import
run_server
from
lightx2v.server.main
import
run_server
def
main
():
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"
Run
LightX2V
inference s
erver"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"LightX2V
S
erver"
)
# Model arguments
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
,
help
=
"Path to model"
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
,
help
=
"Path to model"
)
parser
.
add_argument
(
"--model_cls"
,
type
=
str
,
required
=
True
,
help
=
"Model class name"
)
parser
.
add_argument
(
"--model_cls"
,
type
=
str
,
required
=
True
,
help
=
"Model class name"
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
help
=
"Path to model config JSON file"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"i2v"
,
help
=
"Task type (i2v, etc.)"
)
parser
.
add_argument
(
"--nproc_per_node"
,
type
=
int
,
default
=
1
,
help
=
"Number of processes per node (GPUs to use)"
)
# Server arguments
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"0.0.0.0"
,
help
=
"Server host"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
,
help
=
"Server port"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
,
help
=
"Server port"
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"127.0.0.1"
,
help
=
"Server host"
)
args
=
parser
.
parse_args
()
# Parse any additional arguments that might be passed
args
,
unknown
=
parser
.
parse_known_args
()
# Add any unknown arguments as attributes to args
# This allows flexibility for model-specific arguments
for
i
in
range
(
0
,
len
(
unknown
),
2
):
if
unknown
[
i
].
startswith
(
"--"
):
key
=
unknown
[
i
][
2
:]
if
i
+
1
<
len
(
unknown
)
and
not
unknown
[
i
+
1
].
startswith
(
"--"
):
value
=
unknown
[
i
+
1
]
setattr
(
args
,
key
,
value
)
# Run the server
run_server
(
args
)
run_server
(
args
)
...
...
lightx2v/server/api.py
View file @
3c778aee
...
@@ -314,7 +314,6 @@ class ApiServer:
...
@@ -314,7 +314,6 @@ class ApiServer:
return
False
return
False
def
_ensure_processing_thread_running
(
self
):
def
_ensure_processing_thread_running
(
self
):
"""Ensure the processing thread is running."""
if
self
.
processing_thread
is
None
or
not
self
.
processing_thread
.
is_alive
():
if
self
.
processing_thread
is
None
or
not
self
.
processing_thread
.
is_alive
():
self
.
stop_processing
.
clear
()
self
.
stop_processing
.
clear
()
self
.
processing_thread
=
threading
.
Thread
(
target
=
self
.
_task_processing_loop
,
daemon
=
True
)
self
.
processing_thread
=
threading
.
Thread
(
target
=
self
.
_task_processing_loop
,
daemon
=
True
)
...
@@ -322,9 +321,11 @@ class ApiServer:
...
@@ -322,9 +321,11 @@ class ApiServer:
logger
.
info
(
"Started task processing thread"
)
logger
.
info
(
"Started task processing thread"
)
def
_task_processing_loop
(
self
):
def
_task_processing_loop
(
self
):
"""Main loop that processes tasks from the queue one by one."""
logger
.
info
(
"Task processing loop started"
)
logger
.
info
(
"Task processing loop started"
)
asyncio
.
set_event_loop
(
asyncio
.
new_event_loop
())
loop
=
asyncio
.
get_event_loop
()
while
not
self
.
stop_processing
.
is_set
():
while
not
self
.
stop_processing
.
is_set
():
task_id
=
task_manager
.
get_next_pending_task
()
task_id
=
task_manager
.
get_next_pending_task
()
...
@@ -335,12 +336,12 @@ class ApiServer:
...
@@ -335,12 +336,12 @@ class ApiServer:
task_info
=
task_manager
.
get_task
(
task_id
)
task_info
=
task_manager
.
get_task
(
task_id
)
if
task_info
and
task_info
.
status
==
TaskStatus
.
PENDING
:
if
task_info
and
task_info
.
status
==
TaskStatus
.
PENDING
:
logger
.
info
(
f
"Processing task
{
task_id
}
"
)
logger
.
info
(
f
"Processing task
{
task_id
}
"
)
self
.
_process_single_task
(
task_info
)
loop
.
run_until_complete
(
self
.
_process_single_task
(
task_info
)
)
loop
.
close
()
logger
.
info
(
"Task processing loop stopped"
)
logger
.
info
(
"Task processing loop stopped"
)
def
_process_single_task
(
self
,
task_info
:
Any
):
async
def
_process_single_task
(
self
,
task_info
:
Any
):
"""Process a single task."""
assert
self
.
video_service
is
not
None
,
"Video service is not initialized"
assert
self
.
video_service
is
not
None
,
"Video service is not initialized"
task_id
=
task_info
.
task_id
task_id
=
task_info
.
task_id
...
@@ -360,7 +361,7 @@ class ApiServer:
...
@@ -360,7 +361,7 @@ class ApiServer:
task_manager
.
fail_task
(
task_id
,
"Task cancelled"
)
task_manager
.
fail_task
(
task_id
,
"Task cancelled"
)
return
return
result
=
a
syncio
.
run
(
self
.
video_service
.
generate_video_with_stop_event
(
message
,
task_info
.
stop_event
)
)
result
=
a
wait
self
.
video_service
.
generate_video_with_stop_event
(
message
,
task_info
.
stop_event
)
if
result
:
if
result
:
task_manager
.
complete_task
(
task_id
,
result
.
save_video_path
)
task_manager
.
complete_task
(
task_id
,
result
.
save_video_path
)
...
...
lightx2v/server/config.py
View file @
3c778aee
...
@@ -11,9 +11,6 @@ class ServerConfig:
...
@@ -11,9 +11,6 @@ class ServerConfig:
port
:
int
=
8000
port
:
int
=
8000
max_queue_size
:
int
=
10
max_queue_size
:
int
=
10
master_addr
:
str
=
"127.0.0.1"
master_port_range
:
tuple
=
(
29500
,
29600
)
task_timeout
:
int
=
300
task_timeout
:
int
=
300
task_history_limit
:
int
=
1000
task_history_limit
:
int
=
1000
...
@@ -42,31 +39,13 @@ class ServerConfig:
...
@@ -42,31 +39,13 @@ class ServerConfig:
except
ValueError
:
except
ValueError
:
logger
.
warning
(
f
"Invalid max queue size:
{
env_queue_size
}
"
)
logger
.
warning
(
f
"Invalid max queue size:
{
env_queue_size
}
"
)
if
env_master_addr
:
=
os
.
environ
.
get
(
"MASTER_ADDR"
):
# MASTER_ADDR is now managed by torchrun, no need to set manually
config
.
master_addr
=
env_master_addr
if
env_cache_dir
:
=
os
.
environ
.
get
(
"LIGHTX2V_CACHE_DIR"
):
if
env_cache_dir
:
=
os
.
environ
.
get
(
"LIGHTX2V_CACHE_DIR"
):
config
.
cache_dir
=
env_cache_dir
config
.
cache_dir
=
env_cache_dir
return
config
return
config
def
find_free_master_port
(
self
)
->
str
:
import
socket
for
port
in
range
(
self
.
master_port_range
[
0
],
self
.
master_port_range
[
1
]):
try
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
self
.
master_addr
,
port
))
logger
.
info
(
f
"Found free port for master:
{
port
}
"
)
return
str
(
port
)
except
OSError
:
continue
raise
RuntimeError
(
f
"No free port found for master in range
{
self
.
master_port_range
[
0
]
}
-
{
self
.
master_port_range
[
1
]
-
1
}
"
f
"on address
{
self
.
master_addr
}
. Please adjust 'master_port_range' or free an occupied port."
)
def
validate
(
self
)
->
bool
:
def
validate
(
self
)
->
bool
:
valid
=
True
valid
=
True
...
...
lightx2v/server/distributed_utils.py
View file @
3c778aee
...
@@ -6,8 +6,6 @@ import torch
...
@@ -6,8 +6,6 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
loguru
import
logger
from
loguru
import
logger
from
.gpu_manager
import
gpu_manager
class
DistributedManager
:
class
DistributedManager
:
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -18,29 +16,35 @@ class DistributedManager:
...
@@ -18,29 +16,35 @@ class DistributedManager:
CHUNK_SIZE
=
1024
*
1024
CHUNK_SIZE
=
1024
*
1024
def
init_process_group
(
self
,
rank
:
int
,
world_size
:
int
,
master_addr
:
str
,
master_port
:
str
)
->
bool
:
def
init_process_group
(
self
)
->
bool
:
"""Initialize process group using torchrun environment variables"""
try
:
try
:
os
.
environ
[
"RANK"
]
=
str
(
rank
)
# torchrun sets these environment variables automatically
os
.
environ
[
"WORLD_SIZE"
]
=
str
(
world_size
)
self
.
rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
0
))
os
.
environ
[
"MASTER_ADDR"
]
=
master_addr
self
.
world_size
=
int
(
os
.
environ
.
get
(
"WORLD_SIZE"
,
1
))
os
.
environ
[
"MASTER_PORT"
]
=
master_port
if
self
.
world_size
>
1
:
# torchrun handles backend, init_method, rank, and world_size
# We just need to call init_process_group without parameters
backend
=
"nccl"
if
torch
.
cuda
.
is_available
()
else
"gloo"
backend
=
"nccl"
if
torch
.
cuda
.
is_available
()
else
"gloo"
dist
.
init_process_group
(
backend
=
backend
,
init_method
=
"env://"
)
dist
.
init_process_group
(
backend
=
backend
,
init_method
=
f
"tcp://
{
master_addr
}
:
{
master_port
}
"
,
rank
=
rank
,
world_size
=
world_size
)
logger
.
info
(
f
"Setup backend:
{
backend
}
"
)
logger
.
info
(
f
"Setup backend:
{
backend
}
"
)
self
.
device
=
gpu_manager
.
set_device_for_rank
(
rank
,
world_size
)
# Set CUDA device for this rank
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
set_device
(
self
.
rank
)
self
.
device
=
f
"cuda:
{
self
.
rank
}
"
else
:
self
.
device
=
"cpu"
else
:
self
.
device
=
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
self
.
is_initialized
=
True
self
.
is_initialized
=
True
self
.
rank
=
rank
logger
.
info
(
f
"Rank
{
self
.
rank
}
/
{
self
.
world_size
-
1
}
distributed environment initialized successfully"
)
self
.
world_size
=
world_size
logger
.
info
(
f
"Rank
{
rank
}
/
{
world_size
-
1
}
distributed environment initialized successfully"
)
return
True
return
True
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"Rank
{
rank
}
distributed environment initialization failed:
{
str
(
e
)
}
"
)
logger
.
error
(
f
"Rank
{
self
.
rank
}
distributed environment initialization failed:
{
str
(
e
)
}
"
)
return
False
return
False
def
cleanup
(
self
):
def
cleanup
(
self
):
...
@@ -143,30 +147,3 @@ class DistributedManager:
...
@@ -143,30 +147,3 @@ class DistributedManager:
task_bytes
=
self
.
_receive_byte_chunks
(
total_length
,
broadcast_device
)
task_bytes
=
self
.
_receive_byte_chunks
(
total_length
,
broadcast_device
)
task_data
=
pickle
.
loads
(
task_bytes
)
task_data
=
pickle
.
loads
(
task_bytes
)
return
task_data
return
task_data
class
DistributedWorker
:
def
__init__
(
self
,
rank
:
int
,
world_size
:
int
,
master_addr
:
str
,
master_port
:
str
):
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
master_addr
=
master_addr
self
.
master_port
=
master_port
self
.
dist_manager
=
DistributedManager
()
def
init
(
self
)
->
bool
:
return
self
.
dist_manager
.
init_process_group
(
self
.
rank
,
self
.
world_size
,
self
.
master_addr
,
self
.
master_port
)
def
cleanup
(
self
):
self
.
dist_manager
.
cleanup
()
def
sync_and_report
(
self
,
task_id
:
str
,
status
:
str
,
result_queue
,
**
kwargs
):
self
.
dist_manager
.
barrier
()
if
self
.
dist_manager
.
is_rank_zero
():
result
=
{
"task_id"
:
task_id
,
"status"
:
status
,
**
kwargs
}
result_queue
.
put
(
result
)
logger
.
info
(
f
"Task
{
task_id
}
{
status
}
"
)
def
create_distributed_worker
(
rank
:
int
,
world_size
:
int
,
master_addr
:
str
,
master_port
:
str
)
->
DistributedWorker
:
return
DistributedWorker
(
rank
,
world_size
,
master_addr
,
master_port
)
lightx2v/server/gpu_manager.py
deleted
100644 → 0
View file @
32fd1c52
import
os
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
loguru
import
logger
class
GPUManager
:
def
__init__
(
self
):
self
.
available_gpus
=
self
.
_detect_gpus
()
self
.
gpu_count
=
len
(
self
.
available_gpus
)
def
_detect_gpus
(
self
)
->
List
[
int
]:
if
not
torch
.
cuda
.
is_available
():
logger
.
warning
(
"No CUDA devices available, will use CPU"
)
return
[]
gpu_count
=
torch
.
cuda
.
device_count
()
cuda_visible
=
os
.
environ
.
get
(
"CUDA_VISIBLE_DEVICES"
,
""
)
if
cuda_visible
:
try
:
visible_devices
=
[
int
(
d
.
strip
())
for
d
in
cuda_visible
.
split
(
","
)]
logger
.
info
(
f
"CUDA_VISIBLE_DEVICES set to:
{
visible_devices
}
"
)
return
list
(
range
(
len
(
visible_devices
)))
except
ValueError
:
logger
.
warning
(
f
"Invalid CUDA_VISIBLE_DEVICES:
{
cuda_visible
}
, using all devices"
)
available_gpus
=
list
(
range
(
gpu_count
))
logger
.
info
(
f
"Detected
{
gpu_count
}
GPU devices:
{
available_gpus
}
"
)
return
available_gpus
def
get_device_for_rank
(
self
,
rank
:
int
,
world_size
:
int
)
->
str
:
if
not
self
.
available_gpus
:
logger
.
info
(
f
"Rank
{
rank
}
: Using CPU (no GPUs available)"
)
return
"cpu"
if
self
.
gpu_count
==
1
:
device
=
f
"cuda:
{
self
.
available_gpus
[
0
]
}
"
logger
.
info
(
f
"Rank
{
rank
}
: Using single GPU
{
device
}
"
)
return
device
if
self
.
gpu_count
>=
world_size
:
gpu_id
=
self
.
available_gpus
[
rank
%
self
.
gpu_count
]
device
=
f
"cuda:
{
gpu_id
}
"
logger
.
info
(
f
"Rank
{
rank
}
: Assigned to dedicated GPU
{
device
}
"
)
return
device
else
:
gpu_id
=
self
.
available_gpus
[
rank
%
self
.
gpu_count
]
device
=
f
"cuda:
{
gpu_id
}
"
logger
.
info
(
f
"Rank
{
rank
}
: Sharing GPU
{
device
}
(world_size=
{
world_size
}
> gpu_count=
{
self
.
gpu_count
}
)"
)
return
device
def
set_device_for_rank
(
self
,
rank
:
int
,
world_size
:
int
)
->
str
:
device
=
self
.
get_device_for_rank
(
rank
,
world_size
)
if
device
.
startswith
(
"cuda:"
):
gpu_id
=
int
(
device
.
split
(
":"
)[
1
])
torch
.
cuda
.
set_device
(
gpu_id
)
logger
.
info
(
f
"Rank
{
rank
}
: CUDA device set to
{
gpu_id
}
"
)
return
device
def
get_memory_info
(
self
,
device
:
Optional
[
str
]
=
None
)
->
Tuple
[
int
,
int
]:
if
not
torch
.
cuda
.
is_available
():
return
(
0
,
0
)
if
device
and
device
.
startswith
(
"cuda:"
):
gpu_id
=
int
(
device
.
split
(
":"
)[
1
])
else
:
gpu_id
=
torch
.
cuda
.
current_device
()
try
:
used
=
torch
.
cuda
.
memory_allocated
(
gpu_id
)
total
=
torch
.
cuda
.
get_device_properties
(
gpu_id
).
total_memory
return
(
used
,
total
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to get memory info for device
{
gpu_id
}
:
{
e
}
"
)
return
(
0
,
0
)
def
clear_cache
(
self
,
device
:
Optional
[
str
]
=
None
):
if
not
torch
.
cuda
.
is_available
():
return
if
device
and
device
.
startswith
(
"cuda:"
):
gpu_id
=
int
(
device
.
split
(
":"
)[
1
])
with
torch
.
cuda
.
device
(
gpu_id
):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
synchronize
()
else
:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
synchronize
()
logger
.
info
(
f
"GPU cache cleared for device:
{
device
or
'current'
}
"
)
@
staticmethod
def
get_optimal_world_size
(
requested_world_size
:
int
)
->
int
:
if
not
torch
.
cuda
.
is_available
():
logger
.
warning
(
"No GPUs available, using single process"
)
return
1
gpu_count
=
torch
.
cuda
.
device_count
()
if
requested_world_size
<=
0
:
optimal_size
=
gpu_count
logger
.
info
(
f
"Auto-detected world_size:
{
optimal_size
}
(based on
{
gpu_count
}
GPUs)"
)
elif
requested_world_size
>
gpu_count
:
logger
.
warning
(
f
"Requested world_size (
{
requested_world_size
}
) exceeds GPU count (
{
gpu_count
}
). Processes will share GPUs."
)
optimal_size
=
requested_world_size
else
:
optimal_size
=
requested_world_size
return
optimal_size
gpu_manager
=
GPUManager
()
lightx2v/server/main.py
View file @
3c778aee
import
os
import
sys
import
sys
from
pathlib
import
Path
from
pathlib
import
Path
...
@@ -10,9 +11,14 @@ from .service import DistributedInferenceService
...
@@ -10,9 +11,14 @@ from .service import DistributedInferenceService
def
run_server
(
args
):
def
run_server
(
args
):
"""Run server with torchrun support"""
inference_service
=
None
inference_service
=
None
try
:
try
:
logger
.
info
(
"Starting LightX2V server..."
)
# Get rank from environment (set by torchrun)
rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
0
))
world_size
=
int
(
os
.
environ
.
get
(
"WORLD_SIZE"
,
1
))
logger
.
info
(
f
"Starting LightX2V server (Rank
{
rank
}
/
{
world_size
}
)..."
)
if
hasattr
(
args
,
"host"
)
and
args
.
host
:
if
hasattr
(
args
,
"host"
)
and
args
.
host
:
server_config
.
host
=
args
.
host
server_config
.
host
=
args
.
host
...
@@ -22,11 +28,14 @@ def run_server(args):
...
@@ -22,11 +28,14 @@ def run_server(args):
if
not
server_config
.
validate
():
if
not
server_config
.
validate
():
raise
RuntimeError
(
"Invalid server configuration"
)
raise
RuntimeError
(
"Invalid server configuration"
)
# Initialize inference service
inference_service
=
DistributedInferenceService
()
inference_service
=
DistributedInferenceService
()
if
not
inference_service
.
start_distributed_inference
(
args
):
if
not
inference_service
.
start_distributed_inference
(
args
):
raise
RuntimeError
(
"Failed to start distributed inference service"
)
raise
RuntimeError
(
"Failed to start distributed inference service"
)
logger
.
info
(
"
Inference service started successfully"
)
logger
.
info
(
f
"Rank
{
rank
}
:
Inference service started successfully"
)
if
rank
==
0
:
# Only rank 0 runs the FastAPI server
cache_dir
=
Path
(
server_config
.
cache_dir
)
cache_dir
=
Path
(
server_config
.
cache_dir
)
cache_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
cache_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
...
@@ -35,15 +44,21 @@ def run_server(args):
...
@@ -35,15 +44,21 @@ def run_server(args):
app
=
api_server
.
get_app
()
app
=
api_server
.
get_app
()
logger
.
info
(
f
"Starting server on
{
server_config
.
host
}
:
{
server_config
.
port
}
"
)
logger
.
info
(
f
"Starting
FastAPI
server on
{
server_config
.
host
}
:
{
server_config
.
port
}
"
)
uvicorn
.
run
(
app
,
host
=
server_config
.
host
,
port
=
server_config
.
port
,
log_level
=
"info"
)
uvicorn
.
run
(
app
,
host
=
server_config
.
host
,
port
=
server_config
.
port
,
log_level
=
"info"
)
else
:
# Non-rank-0 processes run the worker loop
logger
.
info
(
f
"Rank
{
rank
}
: Starting worker loop"
)
import
asyncio
asyncio
.
run
(
inference_service
.
run_worker_loop
())
except
KeyboardInterrupt
:
except
KeyboardInterrupt
:
logger
.
info
(
"Server interrupted by user"
)
logger
.
info
(
f
"Server
rank
{
rank
}
interrupted by user"
)
if
inference_service
:
if
inference_service
:
inference_service
.
stop_distributed_inference
()
inference_service
.
stop_distributed_inference
()
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"Server failed:
{
e
}
"
)
logger
.
error
(
f
"Server
rank
{
rank
}
failed:
{
e
}
"
)
if
inference_service
:
if
inference_service
:
inference_service
.
stop_distributed_inference
()
inference_service
.
stop_distributed_inference
()
sys
.
exit
(
1
)
sys
.
exit
(
1
)
lightx2v/server/service.py
View file @
3c778aee
import
asyncio
import
asyncio
import
threading
import
json
import
time
import
os
import
uuid
import
uuid
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Optional
from
typing
import
Any
,
Dict
,
Optional
from
urllib.parse
import
urlparse
from
urllib.parse
import
urlparse
import
httpx
import
httpx
import
torch
.multiprocessing
as
mp
import
torch
from
loguru
import
logger
from
loguru
import
logger
from
..infer
import
init_runner
from
..infer
import
init_runner
from
..utils.set_config
import
set_config
from
..utils.set_config
import
set_config
from
.audio_utils
import
is_base64_audio
,
save_base64_audio
from
.audio_utils
import
is_base64_audio
,
save_base64_audio
from
.config
import
server_config
from
.distributed_utils
import
DistributedManager
from
.distributed_utils
import
create_distributed_worker
from
.image_utils
import
is_base64_image
,
save_base64_image
from
.image_utils
import
is_base64_image
,
save_base64_image
from
.schema
import
TaskRequest
,
TaskResponse
from
.schema
import
TaskRequest
,
TaskResponse
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
class
FileService
:
class
FileService
:
def
__init__
(
self
,
cache_dir
:
Path
):
def
__init__
(
self
,
cache_dir
:
Path
):
...
@@ -196,280 +193,191 @@ class FileService:
...
@@ -196,280 +193,191 @@ class FileService:
self
.
_http_client
=
None
self
.
_http_client
=
None
def
_distributed_inference_worker
(
rank
,
world_size
,
master_addr
,
master_port
,
args
,
shared_data
,
task_event
,
result_event
):
class
TorchrunInferenceWorker
:
task_data
=
None
"""Worker class for torchrun-based distributed inference"""
worker
=
None
def
__init__
(
self
):
self
.
rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
0
))
self
.
world_size
=
int
(
os
.
environ
.
get
(
"WORLD_SIZE"
,
1
))
self
.
runner
=
None
self
.
dist_manager
=
DistributedManager
()
self
.
request_queue
=
asyncio
.
Queue
()
if
self
.
rank
==
0
else
None
self
.
processing
=
False
# Track if currently processing a request
def
init
(
self
,
args
)
->
bool
:
"""Initialize the worker with model and distributed setup"""
try
:
try
:
logger
.
info
(
f
"Process
{
rank
}
/
{
world_size
-
1
}
initializing distributed inference service..."
)
# Initialize distributed process group using torchrun env vars
if
self
.
world_size
>
1
:
worker
=
create_distributed_worker
(
rank
,
world_size
,
master_addr
,
master_port
)
if
not
self
.
dist_manager
.
init_process_group
():
if
not
worker
.
init
():
raise
RuntimeError
(
"Failed to initialize distributed process group"
)
raise
RuntimeError
(
f
"Rank
{
rank
}
distributed environment initialization failed"
)
else
:
# Single GPU mode
self
.
dist_manager
.
rank
=
0
self
.
dist_manager
.
world_size
=
1
self
.
dist_manager
.
device
=
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
self
.
dist_manager
.
is_initialized
=
False
# Initialize model
config
=
set_config
(
args
)
config
=
set_config
(
args
)
logger
.
info
(
f
"Rank
{
rank
}
config:
{
config
}
"
)
if
self
.
rank
==
0
:
logger
.
info
(
f
"Config:
\n
{
json
.
dumps
(
config
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
runner
=
init_runner
(
config
)
logger
.
info
(
f
"Process
{
rank
}
/
{
world_size
-
1
}
distributed inference service initialization completed"
)
while
True
:
self
.
runner
=
init_runner
(
config
)
if
not
task_event
.
wait
(
timeout
=
1.0
):
logger
.
info
(
f
"Rank
{
self
.
rank
}
/
{
self
.
world_size
-
1
}
initialization completed"
)
continue
if
rank
==
0
:
return
True
if
shared_data
.
get
(
"stop"
,
False
):
logger
.
info
(
f
"Process
{
rank
}
received stop signal, exiting inference service"
)
worker
.
dist_manager
.
broadcast_task_data
(
None
)
break
task_data
=
shared_data
.
get
(
"current_task"
)
except
Exception
as
e
:
if
task_data
:
logger
.
error
(
f
"Rank
{
self
.
rank
}
initialization failed:
{
str
(
e
)
}
"
)
worker
.
dist_manager
.
broadcast_task_data
(
task_data
)
return
False
shared_data
[
"current_task"
]
=
None
try
:
task_event
.
clear
()
except
Exception
:
pass
else
:
continue
else
:
task_data
=
worker
.
dist_manager
.
broadcast_task_data
()
if
task_data
is
None
:
logger
.
info
(
f
"Process
{
rank
}
received stop signal, exiting inference service"
)
break
if
task_data
is
not
None
:
async
def
process_request
(
self
,
task_data
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]
:
logger
.
info
(
f
"Process
{
rank
}
received inference task:
{
task_data
[
'task_id'
]
}
"
)
"""Process a single inference request
Note: We keep the inference synchronous to maintain NCCL/CUDA context integrity.
The async wrapper allows FastAPI to handle other requests while this runs.
"""
try
:
try
:
runner
.
set_inputs
(
task_data
)
# type: ignore
# Only rank 0 broadcasts task data (worker processes already received it in worker_loop)
runner
.
run_pipeline
()
if
self
.
world_size
>
1
and
self
.
rank
==
0
:
task_data
=
self
.
dist_manager
.
broadcast_task_data
(
task_data
)
worker
.
dist_manager
.
barrier
()
# Run inference directly - torchrun handles the parallelization
if
rank
==
0
:
# Using asyncio.to_thread would be risky with NCCL operations
# Only rank 0 updates the result
# Instead, we rely on FastAPI's async handling and queue management
shared_data
[
"result"
]
=
{
self
.
runner
.
set_inputs
(
task_data
)
self
.
runner
.
run_pipeline
()
# Small yield to allow other async operations if needed
await
asyncio
.
sleep
(
0
)
# Synchronize all ranks
if
self
.
world_size
>
1
:
self
.
dist_manager
.
barrier
()
# Only rank 0 returns the result
if
self
.
rank
==
0
:
return
{
"task_id"
:
task_data
[
"task_id"
],
"task_id"
:
task_data
[
"task_id"
],
"status"
:
"success"
,
"status"
:
"success"
,
"save_video_path"
:
task_data
.
get
(
"video_path"
,
task_data
[
"save_video_path"
]),
# Return original path for API
"save_video_path"
:
task_data
.
get
(
"video_path"
,
task_data
[
"save_video_path"
]),
"message"
:
"Inference completed"
,
"message"
:
"Inference completed"
,
}
}
result_event
.
set
()
else
:
logger
.
info
(
f
"Task
{
task_data
[
'task_id'
]
}
success"
)
return
None
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
exception
(
f
"Process
{
rank
}
error occurred while processing task
:
{
str
(
e
)
}
"
)
logger
.
error
(
f
"Rank
{
self
.
rank
}
inference failed
:
{
str
(
e
)
}
"
)
if
self
.
world_size
>
1
:
worker
.
dist_manager
.
barrier
()
self
.
dist_manager
.
barrier
()
if
rank
==
0
:
if
self
.
rank
==
0
:
# Only rank 0 updates the result
return
{
shared_data
[
"result"
]
=
{
"task_id"
:
task_data
.
get
(
"task_id"
,
"unknown"
),
"task_id"
:
task_data
.
get
(
"task_id"
,
"unknown"
),
"status"
:
"failed"
,
"status"
:
"failed"
,
"error"
:
str
(
e
),
"error"
:
str
(
e
),
"message"
:
f
"Inference failed:
{
str
(
e
)
}
"
,
"message"
:
f
"Inference failed:
{
str
(
e
)
}
"
,
}
}
result_event
.
set
()
else
:
logger
.
info
(
f
"Task
{
task_data
.
get
(
'task_id'
,
'unknown'
)
}
failed"
)
return
None
except
KeyboardInterrupt
:
async
def
worker_loop
(
self
):
logger
.
info
(
f
"Process
{
rank
}
received KeyboardInterrupt, gracefully exiting"
)
"""Non-rank-0 workers: Listen for broadcast tasks"""
except
Exception
as
e
:
while
True
:
logger
.
exception
(
f
"Distributed inference service process
{
rank
}
startup failed:
{
str
(
e
)
}
"
)
if
rank
==
0
:
shared_data
[
"result"
]
=
{
"task_id"
:
"startup"
,
"status"
:
"startup_failed"
,
"error"
:
str
(
e
),
"message"
:
f
"Inference service startup failed:
{
str
(
e
)
}
"
,
}
result_event
.
set
()
finally
:
try
:
try
:
if
worker
:
task_data
=
self
.
dist_manager
.
broadcast_task_data
()
worker
.
cleanup
()
if
task_data
is
None
:
logger
.
info
(
f
"Rank
{
self
.
rank
}
received stop signal"
)
break
await
self
.
process_request
(
task_data
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
debug
(
f
"Error cleaning up worker for rank
{
rank
}
:
{
e
}
"
)
logger
.
error
(
f
"Rank
{
self
.
rank
}
worker loop error:
{
str
(
e
)
}
"
)
continue
def
cleanup
(
self
):
self
.
dist_manager
.
cleanup
()
class
DistributedInferenceService
:
class
DistributedInferenceService
:
def
__init__
(
self
):
def
__init__
(
self
):
self
.
manager
=
None
self
.
worker
=
None
self
.
shared_data
=
None
self
.
task_event
=
None
self
.
result_event
=
None
self
.
processes
=
[]
self
.
is_running
=
False
self
.
is_running
=
False
self
.
args
=
None
def
start_distributed_inference
(
self
,
args
)
->
bool
:
def
start_distributed_inference
(
self
,
args
)
->
bool
:
if
hasattr
(
args
,
"lora_path"
)
and
args
.
lora_path
:
args
.
lora_configs
=
[{
"path"
:
args
.
lora_path
,
"strength"
:
getattr
(
args
,
"lora_strength"
,
1.0
)}]
delattr
(
args
,
"lora_path"
)
if
hasattr
(
args
,
"lora_strength"
):
delattr
(
args
,
"lora_strength"
)
self
.
args
=
args
self
.
args
=
args
if
self
.
is_running
:
if
self
.
is_running
:
logger
.
warning
(
"Distributed inference service is already running"
)
logger
.
warning
(
"Distributed inference service is already running"
)
return
True
return
True
nproc_per_node
=
args
.
nproc_per_node
if
nproc_per_node
<=
0
:
logger
.
error
(
"nproc_per_node must be greater than 0"
)
return
False
try
:
try
:
master_addr
=
server_config
.
master_addr
self
.
worker
=
TorchrunInferenceWorker
()
master_port
=
server_config
.
find_free_master_port
()
logger
.
info
(
f
"Distributed inference service Master Addr:
{
master_addr
}
, Master Port:
{
master_port
}
"
)
if
not
self
.
worker
.
init
(
args
):
raise
RuntimeError
(
"Worker initialization failed"
)
# Create shared data structures
self
.
manager
=
mp
.
Manager
()
self
.
shared_data
=
self
.
manager
.
dict
()
self
.
task_event
=
self
.
manager
.
Event
()
self
.
result_event
=
self
.
manager
.
Event
()
# Initialize shared data
self
.
shared_data
[
"current_task"
]
=
None
self
.
shared_data
[
"result"
]
=
None
self
.
shared_data
[
"stop"
]
=
False
for
rank
in
range
(
nproc_per_node
):
p
=
mp
.
Process
(
target
=
_distributed_inference_worker
,
args
=
(
rank
,
nproc_per_node
,
master_addr
,
master_port
,
args
,
self
.
shared_data
,
self
.
task_event
,
self
.
result_event
,
),
daemon
=
False
,
# Changed to False for proper cleanup
)
p
.
start
()
self
.
processes
.
append
(
p
)
self
.
is_running
=
True
self
.
is_running
=
True
logger
.
info
(
f
"
Distributed
inference service started successfully
with
{
nproc_per_node
}
processes
"
)
logger
.
info
(
f
"
Rank
{
self
.
worker
.
rank
}
inference service started successfully"
)
return
True
return
True
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
e
xception
(
f
"Error
occurred while starting distributed
inference service:
{
str
(
e
)
}
"
)
logger
.
e
rror
(
f
"Error
starting
inference service:
{
str
(
e
)
}
"
)
self
.
stop_distributed_inference
()
self
.
stop_distributed_inference
()
return
False
return
False
def
stop_distributed_inference
(
self
):
def
stop_distributed_inference
(
self
):
assert
self
.
task_event
,
"Task event is not initialized"
assert
self
.
result_event
,
"Result event is not initialized"
if
not
self
.
is_running
:
if
not
self
.
is_running
:
return
return
try
:
try
:
logger
.
info
(
f
"Stopping
{
len
(
self
.
processes
)
}
distributed inference service processes..."
)
if
self
.
worker
:
self
.
worker
.
cleanup
()
if
self
.
shared_data
is
not
None
:
logger
.
info
(
"Inference service stopped"
)
self
.
shared_data
[
"stop"
]
=
True
self
.
task_event
.
set
()
for
p
in
self
.
processes
:
try
:
p
.
join
(
timeout
=
10
)
if
p
.
is_alive
():
logger
.
warning
(
f
"Process
{
p
.
pid
}
did not end within the specified time, forcing termination..."
)
p
.
terminate
()
p
.
join
(
timeout
=
5
)
except
Exception
as
e
:
logger
.
warning
(
f
"Error terminating process
{
p
.
pid
}
:
{
e
}
"
)
logger
.
info
(
"All distributed inference service processes have stopped"
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"Error
occurred while stopping distributed
inference service:
{
str
(
e
)
}
"
)
logger
.
error
(
f
"Error
stopping
inference service:
{
str
(
e
)
}
"
)
finally
:
finally
:
# Clean up resources
self
.
worker
=
None
self
.
processes
=
[]
self
.
manager
=
None
self
.
shared_data
=
None
self
.
task_event
=
None
self
.
result_event
=
None
self
.
is_running
=
False
self
.
is_running
=
False
def
submit_task
(
self
,
task_data
:
dict
)
->
bool
:
async
def
submit_task_async
(
self
,
task_data
:
dict
)
->
Optional
[
dict
]:
assert
self
.
task_event
,
"Task event is not initialized"
if
not
self
.
is_running
or
not
self
.
worker
:
assert
self
.
result_event
,
"Result event is not initialized"
logger
.
error
(
"Inference service is not started"
)
if
not
self
.
is_running
or
not
self
.
shared_data
:
logger
.
error
(
"Distributed inference service is not started"
)
return
False
try
:
self
.
result_event
.
clear
()
self
.
shared_data
[
"result"
]
=
None
self
.
shared_data
[
"current_task"
]
=
task_data
self
.
task_event
.
set
()
# Signal workers
return
True
except
Exception
as
e
:
logger
.
error
(
f
"Failed to submit task:
{
str
(
e
)
}
"
)
return
False
def
wait_for_result
(
self
,
task_id
:
str
,
timeout
:
Optional
[
int
]
=
None
)
->
Optional
[
dict
]:
assert
self
.
task_event
,
"Task event is not initialized"
assert
self
.
result_event
,
"Result event is not initialized"
if
timeout
is
None
:
timeout
=
server_config
.
task_timeout
if
not
self
.
is_running
or
not
self
.
shared_data
:
return
None
if
self
.
result_event
.
wait
(
timeout
=
timeout
):
result
=
self
.
shared_data
.
get
(
"result"
)
if
result
and
result
.
get
(
"task_id"
)
==
task_id
:
self
.
shared_data
[
"current_task"
]
=
None
self
.
task_event
.
clear
()
return
result
return
None
def
wait_for_result_with_stop
(
self
,
task_id
:
str
,
stop_event
:
threading
.
Event
,
timeout
:
Optional
[
int
]
=
None
)
->
Optional
[
dict
]:
if
timeout
is
None
:
timeout
=
server_config
.
task_timeout
if
not
self
.
is_running
or
not
self
.
shared_data
:
return
None
return
None
assert
self
.
task_event
,
"Task event is not initialized"
if
self
.
worker
.
rank
!=
0
:
assert
self
.
result_event
,
"Result event is not initialized"
start_time
=
time
.
time
()
while
time
.
time
()
-
start_time
<
timeout
:
if
stop_event
.
is_set
():
logger
.
info
(
f
"Task
{
task_id
}
stop event triggered during wait"
)
self
.
shared_data
[
"current_task"
]
=
None
self
.
task_event
.
clear
()
return
None
return
None
if
self
.
result_event
.
wait
(
timeout
=
0.5
):
try
:
result
=
self
.
shared_data
.
get
(
"result"
)
if
self
.
worker
.
processing
:
if
result
and
result
.
get
(
"task_id"
)
==
task_id
:
# If we want to support queueing, we can add the task to queue
self
.
shared_data
[
"current_task"
]
=
None
# For now, we'll process sequentially
self
.
task_event
.
clear
()
logger
.
info
(
f
"Waiting for previous task to complete before processing task
{
task_data
.
get
(
'task_id'
)
}
"
)
self
.
worker
.
processing
=
True
result
=
await
self
.
worker
.
process_request
(
task_data
)
self
.
worker
.
processing
=
False
return
result
return
result
except
Exception
as
e
:
return
None
self
.
worker
.
processing
=
False
logger
.
error
(
f
"Failed to process task:
{
str
(
e
)
}
"
)
return
{
"task_id"
:
task_data
.
get
(
"task_id"
,
"unknown"
),
"status"
:
"failed"
,
"error"
:
str
(
e
),
"message"
:
f
"Task processing failed:
{
str
(
e
)
}
"
,
}
def
server_metadata
(
self
):
def
server_metadata
(
self
):
assert
hasattr
(
self
,
"args"
),
"Distributed inference service has not been started. Call start_distributed_inference() first."
assert
hasattr
(
self
,
"args"
),
"Distributed inference service has not been started. Call start_distributed_inference() first."
return
{
"nproc_per_node"
:
self
.
args
.
nproc_per_node
,
"model_cls"
:
self
.
args
.
model_cls
,
"model_path"
:
self
.
args
.
model_path
}
return
{
"nproc_per_node"
:
self
.
worker
.
world_size
,
"model_cls"
:
self
.
args
.
model_cls
,
"model_path"
:
self
.
args
.
model_path
}
async
def
run_worker_loop
(
self
):
"""Run the worker loop for non-rank-0 processes"""
if
self
.
worker
and
self
.
worker
.
rank
!=
0
:
await
self
.
worker
.
worker_loop
()
class
VideoGenerationService
:
class
VideoGenerationService
:
...
@@ -478,6 +386,7 @@ class VideoGenerationService:
...
@@ -478,6 +386,7 @@ class VideoGenerationService:
self
.
inference_service
=
inference_service
self
.
inference_service
=
inference_service
async
def
generate_video_with_stop_event
(
self
,
message
:
TaskRequest
,
stop_event
)
->
Optional
[
TaskResponse
]:
async
def
generate_video_with_stop_event
(
self
,
message
:
TaskRequest
,
stop_event
)
->
Optional
[
TaskResponse
]:
"""Generate video using torchrun-based inference"""
try
:
try
:
task_data
=
{
field
:
getattr
(
message
,
field
)
for
field
in
message
.
model_fields_set
if
field
!=
"task_id"
}
task_data
=
{
field
:
getattr
(
message
,
field
)
for
field
in
message
.
model_fields_set
if
field
!=
"task_id"
}
task_data
[
"task_id"
]
=
message
.
task_id
task_data
[
"task_id"
]
=
message
.
task_id
...
@@ -496,6 +405,8 @@ class VideoGenerationService:
...
@@ -496,6 +405,8 @@ class VideoGenerationService:
else
:
else
:
task_data
[
"image_path"
]
=
message
.
image_path
task_data
[
"image_path"
]
=
message
.
image_path
logger
.
info
(
f
"Task
{
message
.
task_id
}
image path:
{
task_data
[
'image_path'
]
}
"
)
if
"audio_path"
in
message
.
model_fields_set
and
message
.
audio_path
:
if
"audio_path"
in
message
.
model_fields_set
and
message
.
audio_path
:
if
message
.
audio_path
.
startswith
(
"http"
):
if
message
.
audio_path
.
startswith
(
"http"
):
audio_path
=
await
self
.
file_service
.
download_audio
(
message
.
audio_path
)
audio_path
=
await
self
.
file_service
.
download_audio
(
message
.
audio_path
)
...
@@ -506,20 +417,19 @@ class VideoGenerationService:
...
@@ -506,20 +417,19 @@ class VideoGenerationService:
else
:
else
:
task_data
[
"audio_path"
]
=
message
.
audio_path
task_data
[
"audio_path"
]
=
message
.
audio_path
logger
.
info
(
f
"Task
{
message
.
task_id
}
audio path:
{
task_data
[
'audio_path'
]
}
"
)
actual_save_path
=
self
.
file_service
.
get_output_path
(
message
.
save_video_path
)
actual_save_path
=
self
.
file_service
.
get_output_path
(
message
.
save_video_path
)
task_data
[
"save_video_path"
]
=
str
(
actual_save_path
)
task_data
[
"save_video_path"
]
=
str
(
actual_save_path
)
task_data
[
"video_path"
]
=
message
.
save_video_path
task_data
[
"video_path"
]
=
message
.
save_video_path
if
not
self
.
inference_service
.
submit_task
(
task_data
):
result
=
await
self
.
inference_service
.
submit_task_async
(
task_data
)
raise
RuntimeError
(
"Distributed inference service is not started"
)
result
=
self
.
inference_service
.
wait_for_result_with_stop
(
message
.
task_id
,
stop_event
,
timeout
=
300
)
if
result
is
None
:
if
result
is
None
:
if
stop_event
.
is_set
():
if
stop_event
.
is_set
():
logger
.
info
(
f
"Task
{
message
.
task_id
}
cancelled during processing"
)
logger
.
info
(
f
"Task
{
message
.
task_id
}
cancelled during processing"
)
return
None
return
None
raise
RuntimeError
(
"Task processing
timeout
"
)
raise
RuntimeError
(
"Task processing
failed
"
)
if
result
.
get
(
"status"
)
==
"success"
:
if
result
.
get
(
"status"
)
==
"success"
:
return
TaskResponse
(
return
TaskResponse
(
...
...
scripts/server/start_multi_servers.sh
View file @
3c778aee
#!/bin/bash
#!/bin/bash
# set path and first
# set path and first
lightx2v_path
=
lightx2v_path
=
/mnt/afs/users/lijiaqi2/deploy-comfyui-ljq-custom_nodes/ComfyUI-Lightx2vWrapper/lightx2v
model_path
=
model_path
=
/mnt/afs/users/lijiaqi2/wan_model/Wan2.1-R2V0909-Audio-14B-720P-fp8
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3
# set environment variables
# set environment variables
source
${
lightx2v_path
}
/scripts/base/base.sh
source
${
lightx2v_path
}
/scripts/base/base.sh
# Start multiple servers
# Start multiple servers
python
-m
lightx2v.api_multi_servers
\
torchrun
--nproc_per_node
4
-m
lightx2v.server
\
--num_gpus
$num_gpus
\
--model_cls
seko_talk
\
--start_port
8000
\
--model_cls
wan2.1_distill
\
--task
i2v
\
--task
i2v
\
--model_path
$model_path
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/distill/wan_i2v_distill_4step_cfg.json
--config_json
${
lightx2v_path
}
/configs/seko_talk/xxx_dist.json
\
--port
8000
scripts/server/start_server.sh
View file @
3c778aee
#!/bin/bash
#!/bin/bash
# set path and first
# set path and first
lightx2v_path
=
lightx2v_path
=
/path/to/Lightx2v
model_path
=
model_path
=
/path/to/Wan2.1-R2V0909-Audio-14B-720P-fp8
export
CUDA_VISIBLE_DEVICES
=
0
export
CUDA_VISIBLE_DEVICES
=
0
...
@@ -11,12 +11,11 @@ source ${lightx2v_path}/scripts/base/base.sh
...
@@ -11,12 +11,11 @@ source ${lightx2v_path}/scripts/base/base.sh
# Start API server with distributed inference service
# Start API server with distributed inference service
python
-m
lightx2v.
api_
server
\
python
-m
lightx2v.server
\
--model_cls
wan2.1_distill
\
--model_cls
seko_talk
\
--task
i2v
\
--task
i2v
\
--model_path
$model_path
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/distill/wan_i2v_distill_4step_cfg.json
\
--config_json
${
lightx2v_path
}
/configs/seko_talk/seko_talk_05_offload_fp8_4090.json
\
--port
8000
\
--port
8000
--nproc_per_node
1
echo
"Service stopped"
echo
"Service stopped"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment