Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
3c778aee
Commit
3c778aee
authored
Sep 16, 2025
by
PengGao
Committed by
GitHub
Sep 16, 2025
Browse files
Gp/dev (#310)
Co-authored-by:
Yang Yong(雍洋)
<
yongyang1030@163.com
>
parent
32fd1c52
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
368 additions
and
713 deletions
+368
-713
lightx2v/api_multi_servers.py
lightx2v/api_multi_servers.py
+0
-172
lightx2v/server/README.md
lightx2v/server/README.md
+114
-86
lightx2v/server/__main__.py
lightx2v/server/__main__.py
+34
-0
lightx2v/server/api.py
lightx2v/server/api.py
+7
-6
lightx2v/server/config.py
lightx2v/server/config.py
+1
-22
lightx2v/server/distributed_utils.py
lightx2v/server/distributed_utils.py
+23
-46
lightx2v/server/gpu_manager.py
lightx2v/server/gpu_manager.py
+0
-116
lightx2v/server/main.py
lightx2v/server/main.py
+26
-11
lightx2v/server/service.py
lightx2v/server/service.py
+149
-239
scripts/server/start_multi_servers.sh
scripts/server/start_multi_servers.sh
+8
-8
scripts/server/start_server.sh
scripts/server/start_server.sh
+6
-7
No files found.
lightx2v/api_multi_servers.py
deleted
100644 → 0
View file @
32fd1c52
import
argparse
import
concurrent.futures
import
os
import
socket
import
subprocess
import
time
from
dataclasses
import
dataclass
from
typing
import
Optional
import
requests
from
loguru
import
logger
@
dataclass
class
ServerConfig
:
port
:
int
gpu_id
:
int
model_cls
:
str
task
:
str
model_path
:
str
config_json
:
str
def
get_node_ip
()
->
str
:
"""Get the IP address of the current node"""
try
:
# Create a UDP socket
s
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_DGRAM
)
# Connect to an external address (no actual connection needed)
s
.
connect
((
"8.8.8.8"
,
80
))
# Get local IP
ip
=
s
.
getsockname
()[
0
]
s
.
close
()
return
ip
except
Exception
as
e
:
logger
.
error
(
f
"Failed to get IP address:
{
e
}
"
)
return
"localhost"
def
is_port_in_use
(
port
:
int
)
->
bool
:
"""Check if a port is in use"""
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
return
s
.
connect_ex
((
"localhost"
,
port
))
==
0
def
find_available_port
(
start_port
:
int
)
->
Optional
[
int
]:
"""Find an available port starting from start_port"""
port
=
start_port
while
port
<
start_port
+
1000
:
# Try up to 1000 ports
if
not
is_port_in_use
(
port
):
return
port
port
+=
1
return
None
def
start_server
(
config
:
ServerConfig
)
->
Optional
[
tuple
[
subprocess
.
Popen
,
str
]]:
"""Start a single server instance"""
try
:
# Set GPU
env
=
os
.
environ
.
copy
()
env
[
"CUDA_VISIBLE_DEVICES"
]
=
str
(
config
.
gpu_id
)
# Start server
process
=
subprocess
.
Popen
(
[
"python"
,
"-m"
,
"lightx2v.api_server"
,
"--model_cls"
,
config
.
model_cls
,
"--task"
,
config
.
task
,
"--model_path"
,
config
.
model_path
,
"--config_json"
,
config
.
config_json
,
"--port"
,
str
(
config
.
port
),
],
env
=
env
,
)
# Wait for server to start, up to 600 seconds
node_ip
=
get_node_ip
()
service_url
=
f
"http://
{
node_ip
}
:
{
config
.
port
}
/v1/service/status"
# Check once per second, up to 600 times
for
_
in
range
(
600
):
try
:
response
=
requests
.
get
(
service_url
,
timeout
=
1
)
if
response
.
status_code
==
200
:
return
process
,
f
"http://
{
node_ip
}
:
{
config
.
port
}
"
except
(
requests
.
RequestException
,
ConnectionError
)
as
e
:
pass
time
.
sleep
(
1
)
# If timeout, terminate the process
logger
.
error
(
f
"Server startup timeout: port=
{
config
.
port
}
, gpu=
{
config
.
gpu_id
}
"
)
process
.
terminate
()
return
None
except
Exception
as
e
:
logger
.
error
(
f
"Failed to start server:
{
e
}
"
)
return
None
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--num_gpus"
,
type
=
int
,
required
=
True
,
help
=
"Number of GPUs to use"
)
parser
.
add_argument
(
"--start_port"
,
type
=
int
,
required
=
True
,
help
=
"Starting port number"
)
parser
.
add_argument
(
"--model_cls"
,
type
=
str
,
required
=
True
,
help
=
"Model class"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
required
=
True
,
help
=
"Task type"
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
,
help
=
"Model path"
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
,
help
=
"Config file path"
)
args
=
parser
.
parse_args
()
# Prepare configurations for all servers on this node
server_configs
=
[]
current_port
=
args
.
start_port
# Create configs for each GPU on this node
for
gpu
in
range
(
args
.
num_gpus
):
port
=
find_available_port
(
current_port
)
if
port
is
None
:
logger
.
error
(
f
"Cannot find available port starting from
{
current_port
}
"
)
continue
config
=
ServerConfig
(
port
=
port
,
gpu_id
=
gpu
,
model_cls
=
args
.
model_cls
,
task
=
args
.
task
,
model_path
=
args
.
model_path
,
config_json
=
args
.
config_json
)
server_configs
.
append
(
config
)
current_port
=
port
+
1
# Start all servers in parallel
processes
=
[]
urls
=
[]
with
concurrent
.
futures
.
ThreadPoolExecutor
(
max_workers
=
len
(
server_configs
))
as
executor
:
future_to_config
=
{
executor
.
submit
(
start_server
,
config
):
config
for
config
in
server_configs
}
for
future
in
concurrent
.
futures
.
as_completed
(
future_to_config
):
config
=
future_to_config
[
future
]
try
:
result
=
future
.
result
()
if
result
:
process
,
url
=
result
processes
.
append
(
process
)
urls
.
append
(
url
)
logger
.
info
(
f
"Server started successfully:
{
url
}
(GPU:
{
config
.
gpu_id
}
)"
)
else
:
logger
.
error
(
f
"Failed to start server: port=
{
config
.
port
}
, gpu=
{
config
.
gpu_id
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Error occurred while starting server:
{
e
}
"
)
# Print all server URLs
logger
.
info
(
"
\n
All server URLs:"
)
for
url
in
urls
:
logger
.
info
(
url
)
# Print node information
node_ip
=
get_node_ip
()
logger
.
info
(
f
"
\n
Current node IP:
{
node_ip
}
"
)
logger
.
info
(
f
"Number of servers started:
{
len
(
urls
)
}
"
)
try
:
# Wait for all processes
for
process
in
processes
:
process
.
wait
()
except
KeyboardInterrupt
:
logger
.
info
(
"Received interrupt signal, shutting down all servers..."
)
for
process
in
processes
:
process
.
terminate
()
if
__name__
==
"__main__"
:
main
()
lightx2v/server/README.md
View file @
3c778aee
...
...
@@ -9,85 +9,91 @@ The LightX2V server is a distributed video generation service built with FastAPI
### System Architecture
```
mermaid
graph TB
subgraph "Client Layer"
Client[HTTP Client]
end
flowchart TB
Client[Client] -->|Send API Request| Router[FastAPI Router]
subgraph "API Layer"
FastAPI[FastAPI Application]
ApiServer[ApiServer]
Router1[Tasks Router<br/>/v1/tasks]
Router2[Files Router<br/>/v1/files]
Router3[Service Router<br/>/v1/service]
end
subgraph API Layer
Router --> TaskRoutes[Task APIs]
Router --> FileRoutes[File APIs]
Router --> ServiceRoutes[Service Status APIs]
subgraph "Service Layer"
TaskManager[TaskManager<br/>Thread-safe Task Queue]
FileService[FileService<br/>File I/O & Downloads]
VideoService[VideoGenerationService]
end
TaskRoutes --> CreateTask["POST /v1/tasks/ - Create Task"]
TaskRoutes --> CreateTaskForm["POST /v1/tasks/form - Form Create"]
TaskRoutes --> ListTasks["GET /v1/tasks/ - List Tasks"]
TaskRoutes --> GetTaskStatus["GET /v1/tasks/id/status - Get Status"]
TaskRoutes --> GetTaskResult["GET /v1/tasks/id/result - Get Result"]
TaskRoutes --> StopTask["DELETE /v1/tasks/id - Stop Task"]
subgraph "Processing Layer"
Thread[Processing Thread<br/>Sequential Task Loop]
end
FileRoutes --> DownloadFile["GET /v1/files/download/path - Download File"]
subgraph "Distributed Inference Layer"
DistService[DistributedInferenceService]
SharedData[(Shared Data<br/>mp.Manager.dict)]
TaskEvent[Task Event<br/>mp.Manager.Event]
ResultEvent[Result Event<br/>mp.Manager.Event]
subgraph "Worker Processes"
W0[Worker 0<br/>Master/Rank 0]
W1[Worker 1<br/>Rank 1]
WN[Worker N<br/>Rank N]
end
ServiceRoutes --> GetServiceStatus["GET /v1/service/status - Service Status"]
ServiceRoutes --> GetServiceMetadata["GET /v1/service/metadata - Metadata"]
end
subgraph "Resource Management"
GPUManager[GPUManager<br/>GPU Detection & Allocation]
DistManager[DistributedManager<br/>PyTorch Distributed]
Config[ServerConfig<br/>Configuration]
subgraph Task Management
TaskManager[Task Manager]
TaskQueue[Task Queue]
TaskStatus[Task Status]
TaskResult[Task Result]
CreateTask --> TaskManager
CreateTaskForm --> TaskManager
TaskManager --> TaskQueue
TaskManager --> TaskStatus
TaskManager --> TaskResult
end
Client -->|HTTP Request| FastAPI
FastAPI --> ApiServer
ApiServer --> Router1
ApiServer --> Router2
ApiServer --> Router3
Router1 -->|Create/Manage Tasks| TaskManager
Router1 -->|Process Tasks| Thread
Router2 -->|File Operations| FileService
Router3 -->|Service Status| TaskManager
Thread -->|Get Pending Tasks| TaskManager
Thread -->|Generate Video| VideoService
subgraph File Service
FileService[File Service]
DownloadImage[Download Image]
DownloadAudio[Download Audio]
SaveFile[Save File]
GetOutputPath[Get Output Path]
FileService --> DownloadImage
FileService --> DownloadAudio
FileService --> SaveFile
FileService --> GetOutputPath
end
VideoService -->|Download Images| FileService
VideoService -->|Submit Task| DistService
subgraph Processing Thread
ProcessingThread[Processing Thread]
NextTask[Get Next Task]
ProcessTask[Process Single Task]
DistService -->|Update| SharedData
DistService -->|Signal| TaskEvent
TaskEvent -->|Notify| W0
W0 -->|Broadcast| W1
W0 -->|Broadcast| WN
ProcessingThread --> NextTask
ProcessingThread --> ProcessTask
end
W0 -->|Update Result| SharedData
W0 -->|Signal| ResultEvent
ResultEvent -->|Notify| DistService
subgraph Video Generation Service
VideoService[Video Service]
GenerateVideo[Generate Video]
W0 -.->|Uses| GPUManager
W1 -.->|Uses| GPUManager
WN -.->|Uses| GPUManager
VideoService --> GenerateVideo
end
W0 -.->|Setup| DistManager
W1 -.->|Setup| DistManager
WN -.->|Setup| DistManager
subgraph Distributed Inference Service
InferenceService[Distributed Inference Service]
SubmitTask[Submit Task]
Worker[Inference Worker Node]
ProcessRequest[Process Request]
RunPipeline[Run Inference Pipeline]
InferenceService --> SubmitTask
SubmitTask --> Worker
Worker --> ProcessRequest
ProcessRequest --> RunPipeline
end
DistService -.->|Reads| Config
ApiServer -.->|Reads| Config
%% ====== Connect Modules ======
TaskQueue --> ProcessingThread
ProcessTask --> VideoService
GenerateVideo --> InferenceService
GetTaskResult --> FileService
DownloadFile --> FileService
VideoService --> FileService
InferenceService --> TaskManager
TaskManager --> TaskStatus
```
## Task Processing Flow
...
...
@@ -100,9 +106,9 @@ sequenceDiagram
participant PT as Processing Thread
participant VS as VideoService
participant FS as FileService
participant DIS as Distributed
<br/>
Inference
Service
participant W0 as Worker
0
<br/>(
Master
)
participant W1 as
Worker
1..N
participant DIS as DistributedInferenceService
participant
TI
W0 as
TorchrunInference
Worker<br/>(
Rank 0
)
participant
TI
W1 as
TorchrunInferenceWorker<br/>(Rank
1..N
)
C->>API: POST /v1/tasks<br/>(Create Task)
API->>TM: create_task()
...
...
@@ -127,32 +133,54 @@ sequenceDiagram
else Image is Base64
VS->>FS: save_base64_image()
FS-->>VS: image_path
else Image is Upload
VS->>FS: validate_file()
FS-->>VS: image_path
else Image is local path
VS->>VS: use existing path
end
alt Audio is URL
VS->>FS: download_audio()
FS->>FS: HTTP download<br/>with retry
FS-->>VS: audio_path
else Audio is Base64
VS->>FS: save_base64_audio()
FS-->>VS: audio_path
else Audio is local path
VS->>VS: use existing path
end
VS->>DIS: submit_task(task_data)
DIS->>DIS: shared_data["current_task"] = task_data
DIS->>DIS: task_event.set()
VS->>DIS: submit_task_async(task_data)
DIS->>TIW0: process_request(task_data)
Note over W0,W1: Distributed Processing
W0->>W0: task_event.wait()
W0->>W0: Get task from shared_data
W0->>W1: broadcast_task_data()
Note over TIW0,TIW1: Torchrun-based Distributed Processing
TIW0->>TIW0: Check if processing
TIW0->>TIW0: Set processing = True
par Parallel Inference
W0->>W0: run_pipeline()
alt Multi-GPU Mode (world_size > 1)
TIW0->>TIW1: broadcast_task_data()<br/>(via DistributedManager)
Note over TIW1: worker_loop() listens for broadcasts
TIW1->>TIW1: Receive task_data
end
par Parallel Inference across all ranks
TIW0->>TIW0: runner.set_inputs(task_data)
TIW0->>TIW0: runner.run_pipeline()
and
W1->>W1: run_pipeline()
Note over TIW1: If world_size > 1
TIW1->>TIW1: runner.set_inputs(task_data)
TIW1->>TIW1: runner.run_pipeline()
end
Note over TIW0,TIW1: Synchronization
alt Multi-GPU Mode
TIW0->>TIW1: barrier() for sync
TIW1->>TIW0: barrier() response
end
W0->>W0:
barrier() for sync
W0->>
W0: shared_data["result"] = result
W0->>DIS: result_event.set(
)
TI
W0->>
TI
W0:
Set processing = False
TI
W0->>
DIS: Return result (only rank 0)
TIW1->>TIW1: Return None (non-rank 0
)
DIS->>DIS: result_event.wait()
DIS->>VS: return result
DIS-->>VS: TaskResponse
VS-->>PT: TaskResponse
PT->>TM: complete_task()<br/>(status: COMPLETED)
...
...
lightx2v/
api_
server.py
→
lightx2v/server
/__main__
.py
100755 → 100644
View file @
3c778aee
#!/usr/bin/env python
import
argparse
import
sys
from
pathlib
import
Path
sys
.
path
.
insert
(
0
,
str
(
Path
(
__file__
).
parent
.
parent
))
from
lightx2v.server.main
import
run_server
from
.main
import
run_server
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"
Run
LightX2V
inference s
erver"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"LightX2V
S
erver"
)
# Model arguments
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
,
help
=
"Path to model"
)
parser
.
add_argument
(
"--model_cls"
,
type
=
str
,
required
=
True
,
help
=
"Model class name"
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
help
=
"Path to model config JSON file"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"i2v"
,
help
=
"Task type (i2v, etc.)"
)
parser
.
add_argument
(
"--nproc_per_node"
,
type
=
int
,
default
=
1
,
help
=
"Number of processes per node (GPUs to use)"
)
# Server arguments
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"0.0.0.0"
,
help
=
"Server host"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
,
help
=
"Server port"
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"127.0.0.1"
,
help
=
"Server host"
)
args
=
parser
.
parse_args
()
# Parse any additional arguments that might be passed
args
,
unknown
=
parser
.
parse_known_args
()
# Add any unknown arguments as attributes to args
# This allows flexibility for model-specific arguments
for
i
in
range
(
0
,
len
(
unknown
),
2
):
if
unknown
[
i
].
startswith
(
"--"
):
key
=
unknown
[
i
][
2
:]
if
i
+
1
<
len
(
unknown
)
and
not
unknown
[
i
+
1
].
startswith
(
"--"
):
value
=
unknown
[
i
+
1
]
setattr
(
args
,
key
,
value
)
# Run the server
run_server
(
args
)
...
...
lightx2v/server/api.py
View file @
3c778aee
...
...
@@ -314,7 +314,6 @@ class ApiServer:
return
False
def
_ensure_processing_thread_running
(
self
):
"""Ensure the processing thread is running."""
if
self
.
processing_thread
is
None
or
not
self
.
processing_thread
.
is_alive
():
self
.
stop_processing
.
clear
()
self
.
processing_thread
=
threading
.
Thread
(
target
=
self
.
_task_processing_loop
,
daemon
=
True
)
...
...
@@ -322,9 +321,11 @@ class ApiServer:
logger
.
info
(
"Started task processing thread"
)
def
_task_processing_loop
(
self
):
"""Main loop that processes tasks from the queue one by one."""
logger
.
info
(
"Task processing loop started"
)
asyncio
.
set_event_loop
(
asyncio
.
new_event_loop
())
loop
=
asyncio
.
get_event_loop
()
while
not
self
.
stop_processing
.
is_set
():
task_id
=
task_manager
.
get_next_pending_task
()
...
...
@@ -335,12 +336,12 @@ class ApiServer:
task_info
=
task_manager
.
get_task
(
task_id
)
if
task_info
and
task_info
.
status
==
TaskStatus
.
PENDING
:
logger
.
info
(
f
"Processing task
{
task_id
}
"
)
self
.
_process_single_task
(
task_info
)
loop
.
run_until_complete
(
self
.
_process_single_task
(
task_info
)
)
loop
.
close
()
logger
.
info
(
"Task processing loop stopped"
)
def
_process_single_task
(
self
,
task_info
:
Any
):
"""Process a single task."""
async
def
_process_single_task
(
self
,
task_info
:
Any
):
assert
self
.
video_service
is
not
None
,
"Video service is not initialized"
task_id
=
task_info
.
task_id
...
...
@@ -360,7 +361,7 @@ class ApiServer:
task_manager
.
fail_task
(
task_id
,
"Task cancelled"
)
return
result
=
a
syncio
.
run
(
self
.
video_service
.
generate_video_with_stop_event
(
message
,
task_info
.
stop_event
)
)
result
=
a
wait
self
.
video_service
.
generate_video_with_stop_event
(
message
,
task_info
.
stop_event
)
if
result
:
task_manager
.
complete_task
(
task_id
,
result
.
save_video_path
)
...
...
lightx2v/server/config.py
View file @
3c778aee
...
...
@@ -11,9 +11,6 @@ class ServerConfig:
port
:
int
=
8000
max_queue_size
:
int
=
10
master_addr
:
str
=
"127.0.0.1"
master_port_range
:
tuple
=
(
29500
,
29600
)
task_timeout
:
int
=
300
task_history_limit
:
int
=
1000
...
...
@@ -42,31 +39,13 @@ class ServerConfig:
except
ValueError
:
logger
.
warning
(
f
"Invalid max queue size:
{
env_queue_size
}
"
)
if
env_master_addr
:
=
os
.
environ
.
get
(
"MASTER_ADDR"
):
config
.
master_addr
=
env_master_addr
# MASTER_ADDR is now managed by torchrun, no need to set manually
if
env_cache_dir
:
=
os
.
environ
.
get
(
"LIGHTX2V_CACHE_DIR"
):
config
.
cache_dir
=
env_cache_dir
return
config
def
find_free_master_port
(
self
)
->
str
:
import
socket
for
port
in
range
(
self
.
master_port_range
[
0
],
self
.
master_port_range
[
1
]):
try
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
self
.
master_addr
,
port
))
logger
.
info
(
f
"Found free port for master:
{
port
}
"
)
return
str
(
port
)
except
OSError
:
continue
raise
RuntimeError
(
f
"No free port found for master in range
{
self
.
master_port_range
[
0
]
}
-
{
self
.
master_port_range
[
1
]
-
1
}
"
f
"on address
{
self
.
master_addr
}
. Please adjust 'master_port_range' or free an occupied port."
)
def
validate
(
self
)
->
bool
:
valid
=
True
...
...
lightx2v/server/distributed_utils.py
View file @
3c778aee
...
...
@@ -6,8 +6,6 @@ import torch
import
torch.distributed
as
dist
from
loguru
import
logger
from
.gpu_manager
import
gpu_manager
class
DistributedManager
:
def
__init__
(
self
):
...
...
@@ -18,29 +16,35 @@ class DistributedManager:
CHUNK_SIZE
=
1024
*
1024
def
init_process_group
(
self
,
rank
:
int
,
world_size
:
int
,
master_addr
:
str
,
master_port
:
str
)
->
bool
:
def
init_process_group
(
self
)
->
bool
:
"""Initialize process group using torchrun environment variables"""
try
:
os
.
environ
[
"RANK"
]
=
str
(
rank
)
os
.
environ
[
"WORLD_SIZE"
]
=
str
(
world_size
)
os
.
environ
[
"MASTER_ADDR"
]
=
master_addr
os
.
environ
[
"MASTER_PORT"
]
=
master_port
backend
=
"nccl"
if
torch
.
cuda
.
is_available
()
else
"gloo"
dist
.
init_process_group
(
backend
=
backend
,
init_method
=
f
"tcp://
{
master_addr
}
:
{
master_port
}
"
,
rank
=
rank
,
world_size
=
world_size
)
logger
.
info
(
f
"Setup backend:
{
backend
}
"
)
self
.
device
=
gpu_manager
.
set_device_for_rank
(
rank
,
world_size
)
# torchrun sets these environment variables automatically
self
.
rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
0
))
self
.
world_size
=
int
(
os
.
environ
.
get
(
"WORLD_SIZE"
,
1
))
if
self
.
world_size
>
1
:
# torchrun handles backend, init_method, rank, and world_size
# We just need to call init_process_group without parameters
backend
=
"nccl"
if
torch
.
cuda
.
is_available
()
else
"gloo"
dist
.
init_process_group
(
backend
=
backend
,
init_method
=
"env://"
)
logger
.
info
(
f
"Setup backend:
{
backend
}
"
)
# Set CUDA device for this rank
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
set_device
(
self
.
rank
)
self
.
device
=
f
"cuda:
{
self
.
rank
}
"
else
:
self
.
device
=
"cpu"
else
:
self
.
device
=
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
self
.
is_initialized
=
True
self
.
rank
=
rank
self
.
world_size
=
world_size
logger
.
info
(
f
"Rank
{
rank
}
/
{
world_size
-
1
}
distributed environment initialized successfully"
)
logger
.
info
(
f
"Rank
{
self
.
rank
}
/
{
self
.
world_size
-
1
}
distributed environment initialized successfully"
)
return
True
except
Exception
as
e
:
logger
.
error
(
f
"Rank
{
rank
}
distributed environment initialization failed:
{
str
(
e
)
}
"
)
logger
.
error
(
f
"Rank
{
self
.
rank
}
distributed environment initialization failed:
{
str
(
e
)
}
"
)
return
False
def
cleanup
(
self
):
...
...
@@ -143,30 +147,3 @@ class DistributedManager:
task_bytes
=
self
.
_receive_byte_chunks
(
total_length
,
broadcast_device
)
task_data
=
pickle
.
loads
(
task_bytes
)
return
task_data
class
DistributedWorker
:
def
__init__
(
self
,
rank
:
int
,
world_size
:
int
,
master_addr
:
str
,
master_port
:
str
):
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
master_addr
=
master_addr
self
.
master_port
=
master_port
self
.
dist_manager
=
DistributedManager
()
def
init
(
self
)
->
bool
:
return
self
.
dist_manager
.
init_process_group
(
self
.
rank
,
self
.
world_size
,
self
.
master_addr
,
self
.
master_port
)
def
cleanup
(
self
):
self
.
dist_manager
.
cleanup
()
def
sync_and_report
(
self
,
task_id
:
str
,
status
:
str
,
result_queue
,
**
kwargs
):
self
.
dist_manager
.
barrier
()
if
self
.
dist_manager
.
is_rank_zero
():
result
=
{
"task_id"
:
task_id
,
"status"
:
status
,
**
kwargs
}
result_queue
.
put
(
result
)
logger
.
info
(
f
"Task
{
task_id
}
{
status
}
"
)
def
create_distributed_worker
(
rank
:
int
,
world_size
:
int
,
master_addr
:
str
,
master_port
:
str
)
->
DistributedWorker
:
return
DistributedWorker
(
rank
,
world_size
,
master_addr
,
master_port
)
lightx2v/server/gpu_manager.py
deleted
100644 → 0
View file @
32fd1c52
import
os
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
loguru
import
logger
class
GPUManager
:
def
__init__
(
self
):
self
.
available_gpus
=
self
.
_detect_gpus
()
self
.
gpu_count
=
len
(
self
.
available_gpus
)
def
_detect_gpus
(
self
)
->
List
[
int
]:
if
not
torch
.
cuda
.
is_available
():
logger
.
warning
(
"No CUDA devices available, will use CPU"
)
return
[]
gpu_count
=
torch
.
cuda
.
device_count
()
cuda_visible
=
os
.
environ
.
get
(
"CUDA_VISIBLE_DEVICES"
,
""
)
if
cuda_visible
:
try
:
visible_devices
=
[
int
(
d
.
strip
())
for
d
in
cuda_visible
.
split
(
","
)]
logger
.
info
(
f
"CUDA_VISIBLE_DEVICES set to:
{
visible_devices
}
"
)
return
list
(
range
(
len
(
visible_devices
)))
except
ValueError
:
logger
.
warning
(
f
"Invalid CUDA_VISIBLE_DEVICES:
{
cuda_visible
}
, using all devices"
)
available_gpus
=
list
(
range
(
gpu_count
))
logger
.
info
(
f
"Detected
{
gpu_count
}
GPU devices:
{
available_gpus
}
"
)
return
available_gpus
def
get_device_for_rank
(
self
,
rank
:
int
,
world_size
:
int
)
->
str
:
if
not
self
.
available_gpus
:
logger
.
info
(
f
"Rank
{
rank
}
: Using CPU (no GPUs available)"
)
return
"cpu"
if
self
.
gpu_count
==
1
:
device
=
f
"cuda:
{
self
.
available_gpus
[
0
]
}
"
logger
.
info
(
f
"Rank
{
rank
}
: Using single GPU
{
device
}
"
)
return
device
if
self
.
gpu_count
>=
world_size
:
gpu_id
=
self
.
available_gpus
[
rank
%
self
.
gpu_count
]
device
=
f
"cuda:
{
gpu_id
}
"
logger
.
info
(
f
"Rank
{
rank
}
: Assigned to dedicated GPU
{
device
}
"
)
return
device
else
:
gpu_id
=
self
.
available_gpus
[
rank
%
self
.
gpu_count
]
device
=
f
"cuda:
{
gpu_id
}
"
logger
.
info
(
f
"Rank
{
rank
}
: Sharing GPU
{
device
}
(world_size=
{
world_size
}
> gpu_count=
{
self
.
gpu_count
}
)"
)
return
device
def
set_device_for_rank
(
self
,
rank
:
int
,
world_size
:
int
)
->
str
:
device
=
self
.
get_device_for_rank
(
rank
,
world_size
)
if
device
.
startswith
(
"cuda:"
):
gpu_id
=
int
(
device
.
split
(
":"
)[
1
])
torch
.
cuda
.
set_device
(
gpu_id
)
logger
.
info
(
f
"Rank
{
rank
}
: CUDA device set to
{
gpu_id
}
"
)
return
device
def
get_memory_info
(
self
,
device
:
Optional
[
str
]
=
None
)
->
Tuple
[
int
,
int
]:
if
not
torch
.
cuda
.
is_available
():
return
(
0
,
0
)
if
device
and
device
.
startswith
(
"cuda:"
):
gpu_id
=
int
(
device
.
split
(
":"
)[
1
])
else
:
gpu_id
=
torch
.
cuda
.
current_device
()
try
:
used
=
torch
.
cuda
.
memory_allocated
(
gpu_id
)
total
=
torch
.
cuda
.
get_device_properties
(
gpu_id
).
total_memory
return
(
used
,
total
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to get memory info for device
{
gpu_id
}
:
{
e
}
"
)
return
(
0
,
0
)
def
clear_cache
(
self
,
device
:
Optional
[
str
]
=
None
):
if
not
torch
.
cuda
.
is_available
():
return
if
device
and
device
.
startswith
(
"cuda:"
):
gpu_id
=
int
(
device
.
split
(
":"
)[
1
])
with
torch
.
cuda
.
device
(
gpu_id
):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
synchronize
()
else
:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
synchronize
()
logger
.
info
(
f
"GPU cache cleared for device:
{
device
or
'current'
}
"
)
@
staticmethod
def
get_optimal_world_size
(
requested_world_size
:
int
)
->
int
:
if
not
torch
.
cuda
.
is_available
():
logger
.
warning
(
"No GPUs available, using single process"
)
return
1
gpu_count
=
torch
.
cuda
.
device_count
()
if
requested_world_size
<=
0
:
optimal_size
=
gpu_count
logger
.
info
(
f
"Auto-detected world_size:
{
optimal_size
}
(based on
{
gpu_count
}
GPUs)"
)
elif
requested_world_size
>
gpu_count
:
logger
.
warning
(
f
"Requested world_size (
{
requested_world_size
}
) exceeds GPU count (
{
gpu_count
}
). Processes will share GPUs."
)
optimal_size
=
requested_world_size
else
:
optimal_size
=
requested_world_size
return
optimal_size
gpu_manager
=
GPUManager
()
lightx2v/server/main.py
View file @
3c778aee
import
os
import
sys
from
pathlib
import
Path
...
...
@@ -10,9 +11,14 @@ from .service import DistributedInferenceService
def
run_server
(
args
):
"""Run server with torchrun support"""
inference_service
=
None
try
:
logger
.
info
(
"Starting LightX2V server..."
)
# Get rank from environment (set by torchrun)
rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
0
))
world_size
=
int
(
os
.
environ
.
get
(
"WORLD_SIZE"
,
1
))
logger
.
info
(
f
"Starting LightX2V server (Rank
{
rank
}
/
{
world_size
}
)..."
)
if
hasattr
(
args
,
"host"
)
and
args
.
host
:
server_config
.
host
=
args
.
host
...
...
@@ -22,28 +28,37 @@ def run_server(args):
if
not
server_config
.
validate
():
raise
RuntimeError
(
"Invalid server configuration"
)
# Initialize inference service
inference_service
=
DistributedInferenceService
()
if
not
inference_service
.
start_distributed_inference
(
args
):
raise
RuntimeError
(
"Failed to start distributed inference service"
)
logger
.
info
(
"Inference service started successfully"
)
logger
.
info
(
f
"Rank
{
rank
}
: Inference service started successfully"
)
if
rank
==
0
:
# Only rank 0 runs the FastAPI server
cache_dir
=
Path
(
server_config
.
cache_dir
)
cache_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
cache_dir
=
Path
(
server_config
.
cache_dir
)
cache_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
Tru
e
)
api_server
=
ApiServer
(
max_queue_size
=
server_config
.
max_queue_size
)
api_server
.
initialize_services
(
cache_dir
,
inference_servic
e
)
api_server
=
ApiServer
(
max_queue_size
=
server_config
.
max_queue_size
)
api_server
.
initialize_services
(
cache_dir
,
inference_service
)
app
=
api_server
.
get_app
()
app
=
api_server
.
get_app
()
logger
.
info
(
f
"Starting FastAPI server on
{
server_config
.
host
}
:
{
server_config
.
port
}
"
)
uvicorn
.
run
(
app
,
host
=
server_config
.
host
,
port
=
server_config
.
port
,
log_level
=
"info"
)
else
:
# Non-rank-0 processes run the worker loop
logger
.
info
(
f
"Rank
{
rank
}
: Starting worker loop"
)
import
asyncio
logger
.
info
(
f
"Starting server on
{
server_config
.
host
}
:
{
server_config
.
port
}
"
)
uvicorn
.
run
(
app
,
host
=
server_config
.
host
,
port
=
server_config
.
port
,
log_level
=
"info"
)
asyncio
.
run
(
inference_service
.
run_worker_loop
())
except
KeyboardInterrupt
:
logger
.
info
(
"Server interrupted by user"
)
logger
.
info
(
f
"Server
rank
{
rank
}
interrupted by user"
)
if
inference_service
:
inference_service
.
stop_distributed_inference
()
except
Exception
as
e
:
logger
.
error
(
f
"Server failed:
{
e
}
"
)
logger
.
error
(
f
"Server
rank
{
rank
}
failed:
{
e
}
"
)
if
inference_service
:
inference_service
.
stop_distributed_inference
()
sys
.
exit
(
1
)
lightx2v/server/service.py
View file @
3c778aee
This diff is collapsed.
Click to expand it.
scripts/server/start_multi_servers.sh
View file @
3c778aee
#!/bin/bash
# set path and first
lightx2v_path
=
model_path
=
lightx2v_path
=
/mnt/afs/users/lijiaqi2/deploy-comfyui-ljq-custom_nodes/ComfyUI-Lightx2vWrapper/lightx2v
model_path
=
/mnt/afs/users/lijiaqi2/wan_model/Wan2.1-R2V0909-Audio-14B-720P-fp8
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3
# set environment variables
source
${
lightx2v_path
}
/scripts/base/base.sh
# Start multiple servers
python
-m
lightx2v.api_multi_servers
\
--num_gpus
$num_gpus
\
--start_port
8000
\
--model_cls
wan2.1_distill
\
torchrun
--nproc_per_node
4
-m
lightx2v.server
\
--model_cls
seko_talk
\
--task
i2v
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/distill/wan_i2v_distill_4step_cfg.json
--config_json
${
lightx2v_path
}
/configs/seko_talk/xxx_dist.json
\
--port
8000
scripts/server/start_server.sh
View file @
3c778aee
#!/bin/bash
# set path and first
lightx2v_path
=
model_path
=
lightx2v_path
=
/path/to/Lightx2v
model_path
=
/path/to/Wan2.1-R2V0909-Audio-14B-720P-fp8
export
CUDA_VISIBLE_DEVICES
=
0
...
...
@@ -11,12 +11,11 @@ source ${lightx2v_path}/scripts/base/base.sh
# Start API server with distributed inference service
python
-m
lightx2v.
api_
server
\
--model_cls
wan2.1_distill
\
python
-m
lightx2v.server
\
--model_cls
seko_talk
\
--task
i2v
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/distill/wan_i2v_distill_4step_cfg.json
\
--port
8000
\
--nproc_per_node
1
--config_json
${
lightx2v_path
}
/configs/seko_talk/seko_talk_05_offload_fp8_4090.json
\
--port
8000
echo
"Service stopped"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment