Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
3c778aee
Commit
3c778aee
authored
Sep 16, 2025
by
PengGao
Committed by
GitHub
Sep 16, 2025
Browse files
Gp/dev (#310)
Co-authored-by:
Yang Yong(雍洋)
<
yongyang1030@163.com
>
parent
32fd1c52
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
368 additions
and
713 deletions
+368
-713
lightx2v/api_multi_servers.py
lightx2v/api_multi_servers.py
+0
-172
lightx2v/server/README.md
lightx2v/server/README.md
+114
-86
lightx2v/server/__main__.py
lightx2v/server/__main__.py
+34
-0
lightx2v/server/api.py
lightx2v/server/api.py
+7
-6
lightx2v/server/config.py
lightx2v/server/config.py
+1
-22
lightx2v/server/distributed_utils.py
lightx2v/server/distributed_utils.py
+23
-46
lightx2v/server/gpu_manager.py
lightx2v/server/gpu_manager.py
+0
-116
lightx2v/server/main.py
lightx2v/server/main.py
+26
-11
lightx2v/server/service.py
lightx2v/server/service.py
+149
-239
scripts/server/start_multi_servers.sh
scripts/server/start_multi_servers.sh
+8
-8
scripts/server/start_server.sh
scripts/server/start_server.sh
+6
-7
No files found.
lightx2v/api_multi_servers.py
deleted
100644 → 0
View file @
32fd1c52
import
argparse
import
concurrent.futures
import
os
import
socket
import
subprocess
import
time
from
dataclasses
import
dataclass
from
typing
import
Optional
import
requests
from
loguru
import
logger
@
dataclass
class
ServerConfig
:
port
:
int
gpu_id
:
int
model_cls
:
str
task
:
str
model_path
:
str
config_json
:
str
def
get_node_ip
()
->
str
:
"""Get the IP address of the current node"""
try
:
# Create a UDP socket
s
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_DGRAM
)
# Connect to an external address (no actual connection needed)
s
.
connect
((
"8.8.8.8"
,
80
))
# Get local IP
ip
=
s
.
getsockname
()[
0
]
s
.
close
()
return
ip
except
Exception
as
e
:
logger
.
error
(
f
"Failed to get IP address:
{
e
}
"
)
return
"localhost"
def
is_port_in_use
(
port
:
int
)
->
bool
:
"""Check if a port is in use"""
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
return
s
.
connect_ex
((
"localhost"
,
port
))
==
0
def
find_available_port
(
start_port
:
int
)
->
Optional
[
int
]:
"""Find an available port starting from start_port"""
port
=
start_port
while
port
<
start_port
+
1000
:
# Try up to 1000 ports
if
not
is_port_in_use
(
port
):
return
port
port
+=
1
return
None
def
start_server
(
config
:
ServerConfig
)
->
Optional
[
tuple
[
subprocess
.
Popen
,
str
]]:
"""Start a single server instance"""
try
:
# Set GPU
env
=
os
.
environ
.
copy
()
env
[
"CUDA_VISIBLE_DEVICES"
]
=
str
(
config
.
gpu_id
)
# Start server
process
=
subprocess
.
Popen
(
[
"python"
,
"-m"
,
"lightx2v.api_server"
,
"--model_cls"
,
config
.
model_cls
,
"--task"
,
config
.
task
,
"--model_path"
,
config
.
model_path
,
"--config_json"
,
config
.
config_json
,
"--port"
,
str
(
config
.
port
),
],
env
=
env
,
)
# Wait for server to start, up to 600 seconds
node_ip
=
get_node_ip
()
service_url
=
f
"http://
{
node_ip
}
:
{
config
.
port
}
/v1/service/status"
# Check once per second, up to 600 times
for
_
in
range
(
600
):
try
:
response
=
requests
.
get
(
service_url
,
timeout
=
1
)
if
response
.
status_code
==
200
:
return
process
,
f
"http://
{
node_ip
}
:
{
config
.
port
}
"
except
(
requests
.
RequestException
,
ConnectionError
)
as
e
:
pass
time
.
sleep
(
1
)
# If timeout, terminate the process
logger
.
error
(
f
"Server startup timeout: port=
{
config
.
port
}
, gpu=
{
config
.
gpu_id
}
"
)
process
.
terminate
()
return
None
except
Exception
as
e
:
logger
.
error
(
f
"Failed to start server:
{
e
}
"
)
return
None
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--num_gpus"
,
type
=
int
,
required
=
True
,
help
=
"Number of GPUs to use"
)
parser
.
add_argument
(
"--start_port"
,
type
=
int
,
required
=
True
,
help
=
"Starting port number"
)
parser
.
add_argument
(
"--model_cls"
,
type
=
str
,
required
=
True
,
help
=
"Model class"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
required
=
True
,
help
=
"Task type"
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
,
help
=
"Model path"
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
,
help
=
"Config file path"
)
args
=
parser
.
parse_args
()
# Prepare configurations for all servers on this node
server_configs
=
[]
current_port
=
args
.
start_port
# Create configs for each GPU on this node
for
gpu
in
range
(
args
.
num_gpus
):
port
=
find_available_port
(
current_port
)
if
port
is
None
:
logger
.
error
(
f
"Cannot find available port starting from
{
current_port
}
"
)
continue
config
=
ServerConfig
(
port
=
port
,
gpu_id
=
gpu
,
model_cls
=
args
.
model_cls
,
task
=
args
.
task
,
model_path
=
args
.
model_path
,
config_json
=
args
.
config_json
)
server_configs
.
append
(
config
)
current_port
=
port
+
1
# Start all servers in parallel
processes
=
[]
urls
=
[]
with
concurrent
.
futures
.
ThreadPoolExecutor
(
max_workers
=
len
(
server_configs
))
as
executor
:
future_to_config
=
{
executor
.
submit
(
start_server
,
config
):
config
for
config
in
server_configs
}
for
future
in
concurrent
.
futures
.
as_completed
(
future_to_config
):
config
=
future_to_config
[
future
]
try
:
result
=
future
.
result
()
if
result
:
process
,
url
=
result
processes
.
append
(
process
)
urls
.
append
(
url
)
logger
.
info
(
f
"Server started successfully:
{
url
}
(GPU:
{
config
.
gpu_id
}
)"
)
else
:
logger
.
error
(
f
"Failed to start server: port=
{
config
.
port
}
, gpu=
{
config
.
gpu_id
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Error occurred while starting server:
{
e
}
"
)
# Print all server URLs
logger
.
info
(
"
\n
All server URLs:"
)
for
url
in
urls
:
logger
.
info
(
url
)
# Print node information
node_ip
=
get_node_ip
()
logger
.
info
(
f
"
\n
Current node IP:
{
node_ip
}
"
)
logger
.
info
(
f
"Number of servers started:
{
len
(
urls
)
}
"
)
try
:
# Wait for all processes
for
process
in
processes
:
process
.
wait
()
except
KeyboardInterrupt
:
logger
.
info
(
"Received interrupt signal, shutting down all servers..."
)
for
process
in
processes
:
process
.
terminate
()
if
__name__
==
"__main__"
:
main
()
lightx2v/server/README.md
View file @
3c778aee
...
@@ -9,85 +9,91 @@ The LightX2V server is a distributed video generation service built with FastAPI
...
@@ -9,85 +9,91 @@ The LightX2V server is a distributed video generation service built with FastAPI
### System Architecture
### System Architecture
```
mermaid
```
mermaid
graph TB
flowchart TB
subgraph "Client Layer"
Client[Client] -->|Send API Request| Router[FastAPI Router]
Client[HTTP Client]
end
subgraph "API Layer"
subgraph API Layer
FastAPI[FastAPI Application]
Router --> TaskRoutes[Task APIs]
ApiServer[ApiServer]
Router --> FileRoutes[File APIs]
Router1[Tasks Router<br/>/v1/tasks]
Router --> ServiceRoutes[Service Status APIs]
Router2[Files Router<br/>/v1/files]
Router3[Service Router<br/>/v1/service]
end
subgraph "Service Layer"
TaskRoutes --> CreateTask["POST /v1/tasks/ - Create Task"]
TaskManager[TaskManager<br/>Thread-safe Task Queue]
TaskRoutes --> CreateTaskForm["POST /v1/tasks/form - Form Create"]
FileService[FileService<br/>File I/O & Downloads]
TaskRoutes --> ListTasks["GET /v1/tasks/ - List Tasks"]
VideoService[VideoGenerationService]
TaskRoutes --> GetTaskStatus["GET /v1/tasks/id/status - Get Status"]
end
TaskRoutes --> GetTaskResult["GET /v1/tasks/id/result - Get Result"]
TaskRoutes --> StopTask["DELETE /v1/tasks/id - Stop Task"]
subgraph "Processing Layer"
FileRoutes --> DownloadFile["GET /v1/files/download/path - Download File"]
Thread[Processing Thread<br/>Sequential Task Loop]
end
subgraph "Distributed Inference Layer"
ServiceRoutes --> GetServiceStatus["GET /v1/service/status - Service Status"]
DistService[DistributedInferenceService]
ServiceRoutes --> GetServiceMetadata["GET /v1/service/metadata - Metadata"]
SharedData[(Shared Data<br/>mp.Manager.dict)]
TaskEvent[Task Event<br/>mp.Manager.Event]
ResultEvent[Result Event<br/>mp.Manager.Event]
subgraph "Worker Processes"
W0[Worker 0<br/>Master/Rank 0]
W1[Worker 1<br/>Rank 1]
WN[Worker N<br/>Rank N]
end
end
end
subgraph "Resource Management"
subgraph Task Management
GPUManager[GPUManager<br/>GPU Detection & Allocation]
TaskManager[Task Manager]
DistManager[DistributedManager<br/>PyTorch Distributed]
TaskQueue[Task Queue]
Config[ServerConfig<br/>Configuration]
TaskStatus[Task Status]
TaskResult[Task Result]
CreateTask --> TaskManager
CreateTaskForm --> TaskManager
TaskManager --> TaskQueue
TaskManager --> TaskStatus
TaskManager --> TaskResult
end
end
Client -->|HTTP Request| FastAPI
subgraph File Service
FastAPI --> ApiServer
FileService[File Service]
ApiServer --> Router1
DownloadImage[Download Image]
ApiServer --> Router2
DownloadAudio[Download Audio]
ApiServer --> Router3
SaveFile[Save File]
GetOutputPath[Get Output Path]
Router1 -->|Create/Manage Tasks| TaskManager
Router1 -->|Process Tasks| Thread
FileService --> DownloadImage
Router2 -->|File Operations| FileService
FileService --> DownloadAudio
Router3 -->|Service Status| TaskManager
FileService --> SaveFile
FileService --> GetOutputPath
Thread -->|Get Pending Tasks| TaskManager
end
Thread -->|Generate Video| VideoService
VideoService -->|Download Images| FileService
subgraph Processing Thread
VideoService -->|Submit Task| DistService
ProcessingThread[Processing Thread]
NextTask[Get Next Task]
ProcessTask[Process Single Task]
DistService -->|Update| SharedData
ProcessingThread --> NextTask
DistService -->|Signal| TaskEvent
ProcessingThread --> ProcessTask
TaskEvent -->|Notify| W0
end
W0 -->|Broadcast| W1
W0 -->|Broadcast| WN
W0 -->|Update Result| SharedData
subgraph Video Generation Service
W0 -->|Signal| ResultEvent
VideoService[Video Service]
ResultEvent -->|Notify| DistService
GenerateVideo[Generate Video]
W0 -.->|Uses| GPUManager
VideoService --> GenerateVideo
W1 -.->|Uses| GPUManager
end
WN -.->|Uses| GPUManager
W0 -.->|Setup| DistManager
subgraph Distributed Inference Service
W1 -.->|Setup| DistManager
InferenceService[Distributed Inference Service]
WN -.->|Setup| DistManager
SubmitTask[Submit Task]
Worker[Inference Worker Node]
ProcessRequest[Process Request]
RunPipeline[Run Inference Pipeline]
InferenceService --> SubmitTask
SubmitTask --> Worker
Worker --> ProcessRequest
ProcessRequest --> RunPipeline
end
DistService -.->|Reads| Config
%% ====== Connect Modules ======
ApiServer -.->|Reads| Config
TaskQueue --> ProcessingThread
ProcessTask --> VideoService
GenerateVideo --> InferenceService
GetTaskResult --> FileService
DownloadFile --> FileService
VideoService --> FileService
InferenceService --> TaskManager
TaskManager --> TaskStatus
```
```
## Task Processing Flow
## Task Processing Flow
...
@@ -100,9 +106,9 @@ sequenceDiagram
...
@@ -100,9 +106,9 @@ sequenceDiagram
participant PT as Processing Thread
participant PT as Processing Thread
participant VS as VideoService
participant VS as VideoService
participant FS as FileService
participant FS as FileService
participant DIS as Distributed
<br/>
Inference
Service
participant DIS as DistributedInferenceService
participant W0 as Worker
0
<br/>(
Master
)
participant
TI
W0 as
TorchrunInference
Worker<br/>(
Rank 0
)
participant W1 as
Worker
1..N
participant
TI
W1 as
TorchrunInferenceWorker<br/>(Rank
1..N
)
C->>API: POST /v1/tasks<br/>(Create Task)
C->>API: POST /v1/tasks<br/>(Create Task)
API->>TM: create_task()
API->>TM: create_task()
...
@@ -127,32 +133,54 @@ sequenceDiagram
...
@@ -127,32 +133,54 @@ sequenceDiagram
else Image is Base64
else Image is Base64
VS->>FS: save_base64_image()
VS->>FS: save_base64_image()
FS-->>VS: image_path
FS-->>VS: image_path
else Image is Upload
else Image is local path
VS->>FS: validate_file()
VS->>VS: use existing path
FS-->>VS: image_path
end
alt Audio is URL
VS->>FS: download_audio()
FS->>FS: HTTP download<br/>with retry
FS-->>VS: audio_path
else Audio is Base64
VS->>FS: save_base64_audio()
FS-->>VS: audio_path
else Audio is local path
VS->>VS: use existing path
end
end
VS->>DIS: submit_task(task_data)
VS->>DIS: submit_task_async(task_data)
DIS->>DIS: shared_data["current_task"] = task_data
DIS->>TIW0: process_request(task_data)
DIS->>DIS: task_event.set()
Note over W0,W1: Distributed Processing
Note over TIW0,TIW1: Torchrun-based Distributed Processing
W0->>W0: task_event.wait()
TIW0->>TIW0: Check if processing
W0->>W0: Get task from shared_data
TIW0->>TIW0: Set processing = True
W0->>W1: broadcast_task_data()
par Parallel Inference
alt Multi-GPU Mode (world_size > 1)
W0->>W0: run_pipeline()
TIW0->>TIW1: broadcast_task_data()<br/>(via DistributedManager)
Note over TIW1: worker_loop() listens for broadcasts
TIW1->>TIW1: Receive task_data
end
par Parallel Inference across all ranks
TIW0->>TIW0: runner.set_inputs(task_data)
TIW0->>TIW0: runner.run_pipeline()
and
and
W1->>W1: run_pipeline()
Note over TIW1: If world_size > 1
TIW1->>TIW1: runner.set_inputs(task_data)
TIW1->>TIW1: runner.run_pipeline()
end
Note over TIW0,TIW1: Synchronization
alt Multi-GPU Mode
TIW0->>TIW1: barrier() for sync
TIW1->>TIW0: barrier() response
end
end
W0->>W0:
barrier() for sync
TI
W0->>
TI
W0:
Set processing = False
W0->>
W0: shared_data["result"] = result
TI
W0->>
DIS: Return result (only rank 0)
W0->>DIS: result_event.set(
)
TIW1->>TIW1: Return None (non-rank 0
)
DIS->>DIS: result_event.wait()
DIS-->>VS: TaskResponse
DIS->>VS: return result
VS-->>PT: TaskResponse
VS-->>PT: TaskResponse
PT->>TM: complete_task()<br/>(status: COMPLETED)
PT->>TM: complete_task()<br/>(status: COMPLETED)
...
...
lightx2v/
api_
server.py
→
lightx2v/server
/__main__
.py
100755 → 100644
View file @
3c778aee
#!/usr/bin/env python
import
argparse
import
argparse
import
sys
from
pathlib
import
Path
sys
.
path
.
insert
(
0
,
str
(
Path
(
__file__
).
parent
.
parent
))
from
.main
import
run_server
from
lightx2v.server.main
import
run_server
def
main
():
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"
Run
LightX2V
inference s
erver"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"LightX2V
S
erver"
)
# Model arguments
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
,
help
=
"Path to model"
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
,
help
=
"Path to model"
)
parser
.
add_argument
(
"--model_cls"
,
type
=
str
,
required
=
True
,
help
=
"Model class name"
)
parser
.
add_argument
(
"--model_cls"
,
type
=
str
,
required
=
True
,
help
=
"Model class name"
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
help
=
"Path to model config JSON file"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"i2v"
,
help
=
"Task type (i2v, etc.)"
)
parser
.
add_argument
(
"--nproc_per_node"
,
type
=
int
,
default
=
1
,
help
=
"Number of processes per node (GPUs to use)"
)
# Server arguments
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"0.0.0.0"
,
help
=
"Server host"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
,
help
=
"Server port"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
,
help
=
"Server port"
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"127.0.0.1"
,
help
=
"Server host"
)
args
=
parser
.
parse_args
()
# Parse any additional arguments that might be passed
args
,
unknown
=
parser
.
parse_known_args
()
# Add any unknown arguments as attributes to args
# This allows flexibility for model-specific arguments
for
i
in
range
(
0
,
len
(
unknown
),
2
):
if
unknown
[
i
].
startswith
(
"--"
):
key
=
unknown
[
i
][
2
:]
if
i
+
1
<
len
(
unknown
)
and
not
unknown
[
i
+
1
].
startswith
(
"--"
):
value
=
unknown
[
i
+
1
]
setattr
(
args
,
key
,
value
)
# Run the server
run_server
(
args
)
run_server
(
args
)
...
...
lightx2v/server/api.py
View file @
3c778aee
...
@@ -314,7 +314,6 @@ class ApiServer:
...
@@ -314,7 +314,6 @@ class ApiServer:
return
False
return
False
def
_ensure_processing_thread_running
(
self
):
def
_ensure_processing_thread_running
(
self
):
"""Ensure the processing thread is running."""
if
self
.
processing_thread
is
None
or
not
self
.
processing_thread
.
is_alive
():
if
self
.
processing_thread
is
None
or
not
self
.
processing_thread
.
is_alive
():
self
.
stop_processing
.
clear
()
self
.
stop_processing
.
clear
()
self
.
processing_thread
=
threading
.
Thread
(
target
=
self
.
_task_processing_loop
,
daemon
=
True
)
self
.
processing_thread
=
threading
.
Thread
(
target
=
self
.
_task_processing_loop
,
daemon
=
True
)
...
@@ -322,9 +321,11 @@ class ApiServer:
...
@@ -322,9 +321,11 @@ class ApiServer:
logger
.
info
(
"Started task processing thread"
)
logger
.
info
(
"Started task processing thread"
)
def
_task_processing_loop
(
self
):
def
_task_processing_loop
(
self
):
"""Main loop that processes tasks from the queue one by one."""
logger
.
info
(
"Task processing loop started"
)
logger
.
info
(
"Task processing loop started"
)
asyncio
.
set_event_loop
(
asyncio
.
new_event_loop
())
loop
=
asyncio
.
get_event_loop
()
while
not
self
.
stop_processing
.
is_set
():
while
not
self
.
stop_processing
.
is_set
():
task_id
=
task_manager
.
get_next_pending_task
()
task_id
=
task_manager
.
get_next_pending_task
()
...
@@ -335,12 +336,12 @@ class ApiServer:
...
@@ -335,12 +336,12 @@ class ApiServer:
task_info
=
task_manager
.
get_task
(
task_id
)
task_info
=
task_manager
.
get_task
(
task_id
)
if
task_info
and
task_info
.
status
==
TaskStatus
.
PENDING
:
if
task_info
and
task_info
.
status
==
TaskStatus
.
PENDING
:
logger
.
info
(
f
"Processing task
{
task_id
}
"
)
logger
.
info
(
f
"Processing task
{
task_id
}
"
)
self
.
_process_single_task
(
task_info
)
loop
.
run_until_complete
(
self
.
_process_single_task
(
task_info
)
)
loop
.
close
()
logger
.
info
(
"Task processing loop stopped"
)
logger
.
info
(
"Task processing loop stopped"
)
def
_process_single_task
(
self
,
task_info
:
Any
):
async
def
_process_single_task
(
self
,
task_info
:
Any
):
"""Process a single task."""
assert
self
.
video_service
is
not
None
,
"Video service is not initialized"
assert
self
.
video_service
is
not
None
,
"Video service is not initialized"
task_id
=
task_info
.
task_id
task_id
=
task_info
.
task_id
...
@@ -360,7 +361,7 @@ class ApiServer:
...
@@ -360,7 +361,7 @@ class ApiServer:
task_manager
.
fail_task
(
task_id
,
"Task cancelled"
)
task_manager
.
fail_task
(
task_id
,
"Task cancelled"
)
return
return
result
=
a
syncio
.
run
(
self
.
video_service
.
generate_video_with_stop_event
(
message
,
task_info
.
stop_event
)
)
result
=
a
wait
self
.
video_service
.
generate_video_with_stop_event
(
message
,
task_info
.
stop_event
)
if
result
:
if
result
:
task_manager
.
complete_task
(
task_id
,
result
.
save_video_path
)
task_manager
.
complete_task
(
task_id
,
result
.
save_video_path
)
...
...
lightx2v/server/config.py
View file @
3c778aee
...
@@ -11,9 +11,6 @@ class ServerConfig:
...
@@ -11,9 +11,6 @@ class ServerConfig:
port
:
int
=
8000
port
:
int
=
8000
max_queue_size
:
int
=
10
max_queue_size
:
int
=
10
master_addr
:
str
=
"127.0.0.1"
master_port_range
:
tuple
=
(
29500
,
29600
)
task_timeout
:
int
=
300
task_timeout
:
int
=
300
task_history_limit
:
int
=
1000
task_history_limit
:
int
=
1000
...
@@ -42,31 +39,13 @@ class ServerConfig:
...
@@ -42,31 +39,13 @@ class ServerConfig:
except
ValueError
:
except
ValueError
:
logger
.
warning
(
f
"Invalid max queue size:
{
env_queue_size
}
"
)
logger
.
warning
(
f
"Invalid max queue size:
{
env_queue_size
}
"
)
if
env_master_addr
:
=
os
.
environ
.
get
(
"MASTER_ADDR"
):
# MASTER_ADDR is now managed by torchrun, no need to set manually
config
.
master_addr
=
env_master_addr
if
env_cache_dir
:
=
os
.
environ
.
get
(
"LIGHTX2V_CACHE_DIR"
):
if
env_cache_dir
:
=
os
.
environ
.
get
(
"LIGHTX2V_CACHE_DIR"
):
config
.
cache_dir
=
env_cache_dir
config
.
cache_dir
=
env_cache_dir
return
config
return
config
def
find_free_master_port
(
self
)
->
str
:
import
socket
for
port
in
range
(
self
.
master_port_range
[
0
],
self
.
master_port_range
[
1
]):
try
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
self
.
master_addr
,
port
))
logger
.
info
(
f
"Found free port for master:
{
port
}
"
)
return
str
(
port
)
except
OSError
:
continue
raise
RuntimeError
(
f
"No free port found for master in range
{
self
.
master_port_range
[
0
]
}
-
{
self
.
master_port_range
[
1
]
-
1
}
"
f
"on address
{
self
.
master_addr
}
. Please adjust 'master_port_range' or free an occupied port."
)
def
validate
(
self
)
->
bool
:
def
validate
(
self
)
->
bool
:
valid
=
True
valid
=
True
...
...
lightx2v/server/distributed_utils.py
View file @
3c778aee
...
@@ -6,8 +6,6 @@ import torch
...
@@ -6,8 +6,6 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
loguru
import
logger
from
loguru
import
logger
from
.gpu_manager
import
gpu_manager
class
DistributedManager
:
class
DistributedManager
:
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -18,29 +16,35 @@ class DistributedManager:
...
@@ -18,29 +16,35 @@ class DistributedManager:
CHUNK_SIZE
=
1024
*
1024
CHUNK_SIZE
=
1024
*
1024
def
init_process_group
(
self
,
rank
:
int
,
world_size
:
int
,
master_addr
:
str
,
master_port
:
str
)
->
bool
:
def
init_process_group
(
self
)
->
bool
:
"""Initialize process group using torchrun environment variables"""
try
:
try
:
os
.
environ
[
"RANK"
]
=
str
(
rank
)
# torchrun sets these environment variables automatically
os
.
environ
[
"WORLD_SIZE"
]
=
str
(
world_size
)
self
.
rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
0
))
os
.
environ
[
"MASTER_ADDR"
]
=
master_addr
self
.
world_size
=
int
(
os
.
environ
.
get
(
"WORLD_SIZE"
,
1
))
os
.
environ
[
"MASTER_PORT"
]
=
master_port
if
self
.
world_size
>
1
:
backend
=
"nccl"
if
torch
.
cuda
.
is_available
()
else
"gloo"
# torchrun handles backend, init_method, rank, and world_size
# We just need to call init_process_group without parameters
dist
.
init_process_group
(
backend
=
backend
,
init_method
=
f
"tcp://
{
master_addr
}
:
{
master_port
}
"
,
rank
=
rank
,
world_size
=
world_size
)
backend
=
"nccl"
if
torch
.
cuda
.
is_available
()
else
"gloo"
logger
.
info
(
f
"Setup backend:
{
backend
}
"
)
dist
.
init_process_group
(
backend
=
backend
,
init_method
=
"env://"
)
logger
.
info
(
f
"Setup backend:
{
backend
}
"
)
self
.
device
=
gpu_manager
.
set_device_for_rank
(
rank
,
world_size
)
# Set CUDA device for this rank
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
set_device
(
self
.
rank
)
self
.
device
=
f
"cuda:
{
self
.
rank
}
"
else
:
self
.
device
=
"cpu"
else
:
self
.
device
=
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
self
.
is_initialized
=
True
self
.
is_initialized
=
True
self
.
rank
=
rank
logger
.
info
(
f
"Rank
{
self
.
rank
}
/
{
self
.
world_size
-
1
}
distributed environment initialized successfully"
)
self
.
world_size
=
world_size
logger
.
info
(
f
"Rank
{
rank
}
/
{
world_size
-
1
}
distributed environment initialized successfully"
)
return
True
return
True
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"Rank
{
rank
}
distributed environment initialization failed:
{
str
(
e
)
}
"
)
logger
.
error
(
f
"Rank
{
self
.
rank
}
distributed environment initialization failed:
{
str
(
e
)
}
"
)
return
False
return
False
def
cleanup
(
self
):
def
cleanup
(
self
):
...
@@ -143,30 +147,3 @@ class DistributedManager:
...
@@ -143,30 +147,3 @@ class DistributedManager:
task_bytes
=
self
.
_receive_byte_chunks
(
total_length
,
broadcast_device
)
task_bytes
=
self
.
_receive_byte_chunks
(
total_length
,
broadcast_device
)
task_data
=
pickle
.
loads
(
task_bytes
)
task_data
=
pickle
.
loads
(
task_bytes
)
return
task_data
return
task_data
class
DistributedWorker
:
def
__init__
(
self
,
rank
:
int
,
world_size
:
int
,
master_addr
:
str
,
master_port
:
str
):
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
master_addr
=
master_addr
self
.
master_port
=
master_port
self
.
dist_manager
=
DistributedManager
()
def
init
(
self
)
->
bool
:
return
self
.
dist_manager
.
init_process_group
(
self
.
rank
,
self
.
world_size
,
self
.
master_addr
,
self
.
master_port
)
def
cleanup
(
self
):
self
.
dist_manager
.
cleanup
()
def
sync_and_report
(
self
,
task_id
:
str
,
status
:
str
,
result_queue
,
**
kwargs
):
self
.
dist_manager
.
barrier
()
if
self
.
dist_manager
.
is_rank_zero
():
result
=
{
"task_id"
:
task_id
,
"status"
:
status
,
**
kwargs
}
result_queue
.
put
(
result
)
logger
.
info
(
f
"Task
{
task_id
}
{
status
}
"
)
def
create_distributed_worker
(
rank
:
int
,
world_size
:
int
,
master_addr
:
str
,
master_port
:
str
)
->
DistributedWorker
:
return
DistributedWorker
(
rank
,
world_size
,
master_addr
,
master_port
)
lightx2v/server/gpu_manager.py
deleted
100644 → 0
View file @
32fd1c52
import
os
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
loguru
import
logger
class
GPUManager
:
def
__init__
(
self
):
self
.
available_gpus
=
self
.
_detect_gpus
()
self
.
gpu_count
=
len
(
self
.
available_gpus
)
def
_detect_gpus
(
self
)
->
List
[
int
]:
if
not
torch
.
cuda
.
is_available
():
logger
.
warning
(
"No CUDA devices available, will use CPU"
)
return
[]
gpu_count
=
torch
.
cuda
.
device_count
()
cuda_visible
=
os
.
environ
.
get
(
"CUDA_VISIBLE_DEVICES"
,
""
)
if
cuda_visible
:
try
:
visible_devices
=
[
int
(
d
.
strip
())
for
d
in
cuda_visible
.
split
(
","
)]
logger
.
info
(
f
"CUDA_VISIBLE_DEVICES set to:
{
visible_devices
}
"
)
return
list
(
range
(
len
(
visible_devices
)))
except
ValueError
:
logger
.
warning
(
f
"Invalid CUDA_VISIBLE_DEVICES:
{
cuda_visible
}
, using all devices"
)
available_gpus
=
list
(
range
(
gpu_count
))
logger
.
info
(
f
"Detected
{
gpu_count
}
GPU devices:
{
available_gpus
}
"
)
return
available_gpus
def
get_device_for_rank
(
self
,
rank
:
int
,
world_size
:
int
)
->
str
:
if
not
self
.
available_gpus
:
logger
.
info
(
f
"Rank
{
rank
}
: Using CPU (no GPUs available)"
)
return
"cpu"
if
self
.
gpu_count
==
1
:
device
=
f
"cuda:
{
self
.
available_gpus
[
0
]
}
"
logger
.
info
(
f
"Rank
{
rank
}
: Using single GPU
{
device
}
"
)
return
device
if
self
.
gpu_count
>=
world_size
:
gpu_id
=
self
.
available_gpus
[
rank
%
self
.
gpu_count
]
device
=
f
"cuda:
{
gpu_id
}
"
logger
.
info
(
f
"Rank
{
rank
}
: Assigned to dedicated GPU
{
device
}
"
)
return
device
else
:
gpu_id
=
self
.
available_gpus
[
rank
%
self
.
gpu_count
]
device
=
f
"cuda:
{
gpu_id
}
"
logger
.
info
(
f
"Rank
{
rank
}
: Sharing GPU
{
device
}
(world_size=
{
world_size
}
> gpu_count=
{
self
.
gpu_count
}
)"
)
return
device
def
set_device_for_rank
(
self
,
rank
:
int
,
world_size
:
int
)
->
str
:
device
=
self
.
get_device_for_rank
(
rank
,
world_size
)
if
device
.
startswith
(
"cuda:"
):
gpu_id
=
int
(
device
.
split
(
":"
)[
1
])
torch
.
cuda
.
set_device
(
gpu_id
)
logger
.
info
(
f
"Rank
{
rank
}
: CUDA device set to
{
gpu_id
}
"
)
return
device
def
get_memory_info
(
self
,
device
:
Optional
[
str
]
=
None
)
->
Tuple
[
int
,
int
]:
if
not
torch
.
cuda
.
is_available
():
return
(
0
,
0
)
if
device
and
device
.
startswith
(
"cuda:"
):
gpu_id
=
int
(
device
.
split
(
":"
)[
1
])
else
:
gpu_id
=
torch
.
cuda
.
current_device
()
try
:
used
=
torch
.
cuda
.
memory_allocated
(
gpu_id
)
total
=
torch
.
cuda
.
get_device_properties
(
gpu_id
).
total_memory
return
(
used
,
total
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to get memory info for device
{
gpu_id
}
:
{
e
}
"
)
return
(
0
,
0
)
def
clear_cache
(
self
,
device
:
Optional
[
str
]
=
None
):
if
not
torch
.
cuda
.
is_available
():
return
if
device
and
device
.
startswith
(
"cuda:"
):
gpu_id
=
int
(
device
.
split
(
":"
)[
1
])
with
torch
.
cuda
.
device
(
gpu_id
):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
synchronize
()
else
:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
synchronize
()
logger
.
info
(
f
"GPU cache cleared for device:
{
device
or
'current'
}
"
)
@
staticmethod
def
get_optimal_world_size
(
requested_world_size
:
int
)
->
int
:
if
not
torch
.
cuda
.
is_available
():
logger
.
warning
(
"No GPUs available, using single process"
)
return
1
gpu_count
=
torch
.
cuda
.
device_count
()
if
requested_world_size
<=
0
:
optimal_size
=
gpu_count
logger
.
info
(
f
"Auto-detected world_size:
{
optimal_size
}
(based on
{
gpu_count
}
GPUs)"
)
elif
requested_world_size
>
gpu_count
:
logger
.
warning
(
f
"Requested world_size (
{
requested_world_size
}
) exceeds GPU count (
{
gpu_count
}
). Processes will share GPUs."
)
optimal_size
=
requested_world_size
else
:
optimal_size
=
requested_world_size
return
optimal_size
gpu_manager
=
GPUManager
()
lightx2v/server/main.py
View file @
3c778aee
import
os
import
sys
import
sys
from
pathlib
import
Path
from
pathlib
import
Path
...
@@ -10,9 +11,14 @@ from .service import DistributedInferenceService
...
@@ -10,9 +11,14 @@ from .service import DistributedInferenceService
def
run_server
(
args
):
def
run_server
(
args
):
"""Run server with torchrun support"""
inference_service
=
None
inference_service
=
None
try
:
try
:
logger
.
info
(
"Starting LightX2V server..."
)
# Get rank from environment (set by torchrun)
rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
0
))
world_size
=
int
(
os
.
environ
.
get
(
"WORLD_SIZE"
,
1
))
logger
.
info
(
f
"Starting LightX2V server (Rank
{
rank
}
/
{
world_size
}
)..."
)
if
hasattr
(
args
,
"host"
)
and
args
.
host
:
if
hasattr
(
args
,
"host"
)
and
args
.
host
:
server_config
.
host
=
args
.
host
server_config
.
host
=
args
.
host
...
@@ -22,28 +28,37 @@ def run_server(args):
...
@@ -22,28 +28,37 @@ def run_server(args):
if
not
server_config
.
validate
():
if
not
server_config
.
validate
():
raise
RuntimeError
(
"Invalid server configuration"
)
raise
RuntimeError
(
"Invalid server configuration"
)
# Initialize inference service
inference_service
=
DistributedInferenceService
()
inference_service
=
DistributedInferenceService
()
if
not
inference_service
.
start_distributed_inference
(
args
):
if
not
inference_service
.
start_distributed_inference
(
args
):
raise
RuntimeError
(
"Failed to start distributed inference service"
)
raise
RuntimeError
(
"Failed to start distributed inference service"
)
logger
.
info
(
"Inference service started successfully"
)
logger
.
info
(
f
"Rank
{
rank
}
: Inference service started successfully"
)
if
rank
==
0
:
# Only rank 0 runs the FastAPI server
cache_dir
=
Path
(
server_config
.
cache_dir
)
cache_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
cache_dir
=
Path
(
server_config
.
cache_dir
)
api_server
=
ApiServer
(
max_queue_size
=
server_config
.
max_queue_size
)
cache_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
Tru
e
)
api_server
.
initialize_services
(
cache_dir
,
inference_servic
e
)
api_server
=
ApiServer
(
max_queue_size
=
server_config
.
max_queue_size
)
app
=
api_server
.
get_app
()
api_server
.
initialize_services
(
cache_dir
,
inference_service
)
app
=
api_server
.
get_app
()
logger
.
info
(
f
"Starting FastAPI server on
{
server_config
.
host
}
:
{
server_config
.
port
}
"
)
uvicorn
.
run
(
app
,
host
=
server_config
.
host
,
port
=
server_config
.
port
,
log_level
=
"info"
)
else
:
# Non-rank-0 processes run the worker loop
logger
.
info
(
f
"Rank
{
rank
}
: Starting worker loop"
)
import
asyncio
logger
.
info
(
f
"Starting server on
{
server_config
.
host
}
:
{
server_config
.
port
}
"
)
asyncio
.
run
(
inference_service
.
run_worker_loop
())
uvicorn
.
run
(
app
,
host
=
server_config
.
host
,
port
=
server_config
.
port
,
log_level
=
"info"
)
except
KeyboardInterrupt
:
except
KeyboardInterrupt
:
logger
.
info
(
"Server interrupted by user"
)
logger
.
info
(
f
"Server
rank
{
rank
}
interrupted by user"
)
if
inference_service
:
if
inference_service
:
inference_service
.
stop_distributed_inference
()
inference_service
.
stop_distributed_inference
()
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"Server failed:
{
e
}
"
)
logger
.
error
(
f
"Server
rank
{
rank
}
failed:
{
e
}
"
)
if
inference_service
:
if
inference_service
:
inference_service
.
stop_distributed_inference
()
inference_service
.
stop_distributed_inference
()
sys
.
exit
(
1
)
sys
.
exit
(
1
)
lightx2v/server/service.py
View file @
3c778aee
This diff is collapsed.
Click to expand it.
scripts/server/start_multi_servers.sh
View file @
3c778aee
#!/bin/bash
#!/bin/bash
# set path and first
# set path and first
lightx2v_path
=
lightx2v_path
=
/mnt/afs/users/lijiaqi2/deploy-comfyui-ljq-custom_nodes/ComfyUI-Lightx2vWrapper/lightx2v
model_path
=
model_path
=
/mnt/afs/users/lijiaqi2/wan_model/Wan2.1-R2V0909-Audio-14B-720P-fp8
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3
# set environment variables
# set environment variables
source
${
lightx2v_path
}
/scripts/base/base.sh
source
${
lightx2v_path
}
/scripts/base/base.sh
# Start multiple servers
# Start multiple servers
python
-m
lightx2v.api_multi_servers
\
torchrun
--nproc_per_node
4
-m
lightx2v.server
\
--num_gpus
$num_gpus
\
--model_cls
seko_talk
\
--start_port
8000
\
--model_cls
wan2.1_distill
\
--task
i2v
\
--task
i2v
\
--model_path
$model_path
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/distill/wan_i2v_distill_4step_cfg.json
--config_json
${
lightx2v_path
}
/configs/seko_talk/xxx_dist.json
\
--port
8000
scripts/server/start_server.sh
View file @
3c778aee
#!/bin/bash
#!/bin/bash
# set path and first
# set path and first
lightx2v_path
=
lightx2v_path
=
/path/to/Lightx2v
model_path
=
model_path
=
/path/to/Wan2.1-R2V0909-Audio-14B-720P-fp8
export
CUDA_VISIBLE_DEVICES
=
0
export
CUDA_VISIBLE_DEVICES
=
0
...
@@ -11,12 +11,11 @@ source ${lightx2v_path}/scripts/base/base.sh
...
@@ -11,12 +11,11 @@ source ${lightx2v_path}/scripts/base/base.sh
# Start API server with distributed inference service
# Start API server with distributed inference service
python
-m
lightx2v.
api_
server
\
python
-m
lightx2v.server
\
--model_cls
wan2.1_distill
\
--model_cls
seko_talk
\
--task
i2v
\
--task
i2v
\
--model_path
$model_path
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/distill/wan_i2v_distill_4step_cfg.json
\
--config_json
${
lightx2v_path
}
/configs/seko_talk/seko_talk_05_offload_fp8_4090.json
\
--port
8000
\
--port
8000
--nproc_per_node
1
echo
"Service stopped"
echo
"Service stopped"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment