Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
514ea716
Commit
514ea716
authored
Jul 16, 2025
by
helloyongyang
Browse files
remove split server & fix some bugs
parent
a23bef13
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
251 deletions
+27
-251
lightx2v/models/runners/async_wrapper.py
lightx2v/models/runners/async_wrapper.py
+0
-52
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+27
-199
No files found.
lightx2v/models/runners/async_wrapper.py
deleted
100644 → 0
View file @
a23bef13
import
asyncio
from
typing
import
Callable
,
Any
,
Optional
from
concurrent.futures
import
ThreadPoolExecutor
class
AsyncWrapper
:
def
__init__
(
self
,
runner
,
max_workers
:
Optional
[
int
]
=
None
):
self
.
runner
=
runner
self
.
executor
=
ThreadPoolExecutor
(
max_workers
=
max_workers
)
async
def
__aenter__
(
self
):
return
self
async
def
__aexit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
if
self
.
executor
:
self
.
executor
.
shutdown
(
wait
=
True
)
async
def
run_in_executor
(
self
,
func
:
Callable
,
*
args
,
**
kwargs
)
->
Any
:
loop
=
asyncio
.
get_event_loop
()
return
await
loop
.
run_in_executor
(
self
.
executor
,
func
,
*
args
,
**
kwargs
)
async
def
run_input_encoder
(
self
):
if
self
.
runner
.
config
[
"mode"
]
==
"split_server"
:
if
self
.
runner
.
config
[
"task"
]
==
"i2v"
:
return
await
self
.
runner
.
_run_input_encoder_server_i2v
()
else
:
return
await
self
.
runner
.
_run_input_encoder_server_t2v
()
else
:
if
self
.
runner
.
config
[
"task"
]
==
"i2v"
:
return
await
self
.
run_in_executor
(
self
.
runner
.
_run_input_encoder_local_i2v
)
else
:
return
await
self
.
run_in_executor
(
self
.
runner
.
_run_input_encoder_local_t2v
)
async
def
run_dit
(
self
,
kwargs
):
if
self
.
runner
.
config
[
"mode"
]
==
"split_server"
:
return
await
self
.
runner
.
_run_dit_server
(
kwargs
)
else
:
return
await
self
.
run_in_executor
(
self
.
runner
.
_run_dit_local
,
kwargs
)
async
def
run_vae_decoder
(
self
,
latents
,
generator
):
if
self
.
runner
.
config
[
"mode"
]
==
"split_server"
:
return
await
self
.
runner
.
_run_vae_decoder_server
(
latents
,
generator
)
else
:
return
await
self
.
run_in_executor
(
self
.
runner
.
_run_vae_decoder_local
,
latents
,
generator
)
async
def
run_prompt_enhancer
(
self
):
if
self
.
runner
.
config
[
"use_prompt_enhancer"
]:
return
await
self
.
run_in_executor
(
self
.
runner
.
post_prompt_enhancer
)
return
None
async
def
save_video
(
self
,
images
):
return
await
self
.
run_in_executor
(
self
.
runner
.
save_video
,
images
)
lightx2v/models/runners/default_runner.py
View file @
514ea716
import
asyncio
import
gc
import
gc
import
aiohttp
import
requests
import
requests
from
requests.exceptions
import
RequestException
from
requests.exceptions
import
RequestException
import
torch
import
torch
...
@@ -13,7 +11,6 @@ from lightx2v.utils.generate_task_id import generate_task_id
...
@@ -13,7 +11,6 @@ from lightx2v.utils.generate_task_id import generate_task_id
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.service_utils
import
TensorTransporter
,
ImageTransporter
from
lightx2v.utils.service_utils
import
TensorTransporter
,
ImageTransporter
from
loguru
import
logger
from
loguru
import
logger
from
.async_wrapper
import
AsyncWrapper
from
.base_runner
import
BaseRunner
from
.base_runner
import
BaseRunner
...
@@ -33,21 +30,14 @@ class DefaultRunner(BaseRunner):
...
@@ -33,21 +30,14 @@ class DefaultRunner(BaseRunner):
def
init_modules
(
self
):
def
init_modules
(
self
):
logger
.
info
(
"Initializing runner modules..."
)
logger
.
info
(
"Initializing runner modules..."
)
if
self
.
config
[
"mode"
]
==
"split_server"
:
if
not
self
.
config
.
get
(
"lazy_load"
,
False
)
and
not
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
tensor_transporter
=
TensorTransporter
()
self
.
load_model
()
self
.
image_transporter
=
ImageTransporter
()
self
.
run_dit
=
self
.
_run_dit_local
if
not
self
.
check_sub_servers
(
"dit"
):
self
.
run_vae_decoder
=
self
.
_run_vae_decoder_local
raise
ValueError
(
"No dit server available"
)
if
self
.
config
[
"task"
]
==
"i2v"
:
if
not
self
.
check_sub_servers
(
"text_encoders"
):
self
.
run_input_encoder
=
self
.
_run_input_encoder_local_i2v
raise
ValueError
(
"No text encoder server available"
)
if
self
.
config
[
"task"
]
==
"i2v"
:
if
not
self
.
check_sub_servers
(
"image_encoder"
):
raise
ValueError
(
"No image encoder server available"
)
if
not
self
.
check_sub_servers
(
"vae_model"
):
raise
ValueError
(
"No vae server available"
)
else
:
else
:
if
not
self
.
config
.
get
(
"lazy_load"
,
False
)
and
not
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
run_input_encoder
=
self
.
_run_input_encoder_local_t2v
self
.
load_model
()
def
set_init_device
(
self
):
def
set_init_device
(
self
):
if
self
.
config
[
"parallel_attn_type"
]:
if
self
.
config
[
"parallel_attn_type"
]:
...
@@ -123,14 +113,13 @@ class DefaultRunner(BaseRunner):
...
@@ -123,14 +113,13 @@ class DefaultRunner(BaseRunner):
return
self
.
model
.
scheduler
.
latents
,
self
.
model
.
scheduler
.
generator
return
self
.
model
.
scheduler
.
latents
,
self
.
model
.
scheduler
.
generator
async
def
run_step
(
self
,
step_index
=
0
):
def
run_step
(
self
,
step_index
=
0
):
async
with
AsyncWrapper
(
self
)
as
wrapper
:
self
.
init_scheduler
()
self
.
init_scheduler
()
self
.
inputs
=
self
.
run_input_encoder
()
self
.
inputs
=
await
wrapper
.
run_input_encoder
()
self
.
model
.
scheduler
.
prepare
(
self
.
inputs
[
"image_encoder_output"
])
self
.
model
.
scheduler
.
prepare
(
self
.
inputs
[
"image_encoder_output"
])
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
self
.
model
.
infer
(
self
.
inputs
)
self
.
model
.
infer
(
self
.
inputs
)
self
.
model
.
scheduler
.
step_post
()
self
.
model
.
scheduler
.
step_post
()
def
end_run
(
self
):
def
end_run
(
self
):
self
.
model
.
scheduler
.
clear
()
self
.
model
.
scheduler
.
clear
()
...
@@ -194,52 +183,6 @@ class DefaultRunner(BaseRunner):
...
@@ -194,52 +183,6 @@ class DefaultRunner(BaseRunner):
if
not
self
.
config
.
parallel_attn_type
or
(
self
.
config
.
parallel_attn_type
and
dist
.
get_rank
()
==
0
):
if
not
self
.
config
.
parallel_attn_type
or
(
self
.
config
.
parallel_attn_type
and
dist
.
get_rank
()
==
0
):
self
.
save_video_func
(
images
)
self
.
save_video_func
(
images
)
async
def
post_task
(
self
,
task_type
,
urls
,
message
,
device
=
"cuda"
,
max_retries
=
3
,
timeout
=
30
):
for
attempt
in
range
(
max_retries
):
for
url
in
urls
:
try
:
async
with
aiohttp
.
ClientSession
(
timeout
=
aiohttp
.
ClientTimeout
(
total
=
timeout
))
as
session
:
try
:
async
with
session
.
get
(
f
"
{
url
}
/v1/local/
{
task_type
}
/generate/service_status"
)
as
response
:
if
response
.
status
!=
200
:
logger
.
warning
(
f
"Service
{
url
}
returned status
{
response
.
status
}
"
)
continue
status
=
await
response
.
json
()
except
asyncio
.
TimeoutError
:
logger
.
warning
(
f
"Timeout checking status for
{
url
}
"
)
continue
except
Exception
as
e
:
logger
.
warning
(
f
"Error checking status for
{
url
}
:
{
e
}
"
)
continue
if
status
.
get
(
"service_status"
)
==
"idle"
:
try
:
async
with
session
.
post
(
f
"
{
url
}
/v1/local/
{
task_type
}
/generate"
,
json
=
message
)
as
response
:
if
response
.
status
==
200
:
result
=
await
response
.
json
()
if
result
.
get
(
"kwargs"
)
is
not
None
:
for
k
,
v
in
result
[
"kwargs"
].
items
():
setattr
(
self
.
config
,
k
,
v
)
return
self
.
tensor_transporter
.
load_tensor
(
result
[
"output"
],
device
)
else
:
logger
.
warning
(
f
"Task failed with status
{
response
.
status
}
for
{
url
}
"
)
except
asyncio
.
TimeoutError
:
logger
.
warning
(
f
"Timeout posting task to
{
url
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Error posting task to
{
url
}
:
{
e
}
"
)
except
aiohttp
.
ClientError
as
e
:
logger
.
warning
(
f
"Client error for
{
url
}
:
{
e
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Unexpected error for
{
url
}
:
{
e
}
"
)
if
attempt
<
max_retries
-
1
:
wait_time
=
min
(
2
**
attempt
,
10
)
logger
.
info
(
f
"Retrying in
{
wait_time
}
seconds... (attempt
{
attempt
+
1
}
/
{
max_retries
}
)"
)
await
asyncio
.
sleep
(
wait_time
)
raise
RuntimeError
(
f
"Failed to complete task
{
task_type
}
after
{
max_retries
}
attempts"
)
def
post_prompt_enhancer
(
self
):
def
post_prompt_enhancer
(
self
):
while
True
:
while
True
:
for
url
in
self
.
config
[
"sub_servers"
][
"prompt_enhancer"
]:
for
url
in
self
.
config
[
"sub_servers"
][
"prompt_enhancer"
]:
...
@@ -256,138 +199,23 @@ class DefaultRunner(BaseRunner):
...
@@ -256,138 +199,23 @@ class DefaultRunner(BaseRunner):
logger
.
info
(
f
"Enhanced prompt:
{
enhanced_prompt
}
"
)
logger
.
info
(
f
"Enhanced prompt:
{
enhanced_prompt
}
"
)
return
enhanced_prompt
return
enhanced_prompt
async
def
post_encoders_i2v
(
self
,
prompt
,
img
=
None
,
n_prompt
=
None
,
i2v
=
False
):
def
run_pipeline
(
self
,
save_video
=
True
):
tasks
=
[]
if
self
.
config
[
"use_prompt_enhancer"
]:
img_byte
=
self
.
image_transporter
.
prepare_image
(
img
)
self
.
config
[
"prompt_enhanced"
]
=
self
.
post_prompt_enhancer
()
tasks
.
append
(
asyncio
.
create_task
(
self
.
post_task
(
task_type
=
"image_encoder"
,
urls
=
self
.
config
[
"sub_servers"
][
"image_encoder"
],
message
=
{
"task_id"
:
generate_task_id
(),
"img"
:
img_byte
},
device
=
"cuda"
,
)
)
)
tasks
.
append
(
asyncio
.
create_task
(
self
.
post_task
(
task_type
=
"vae_model/encoder"
,
urls
=
self
.
config
[
"sub_servers"
][
"vae_model"
],
message
=
{
"task_id"
:
generate_task_id
(),
"img"
:
img_byte
},
device
=
"cuda"
,
)
)
)
tasks
.
append
(
asyncio
.
create_task
(
self
.
post_task
(
task_type
=
"text_encoders"
,
urls
=
self
.
config
[
"sub_servers"
][
"text_encoders"
],
message
=
{
"task_id"
:
generate_task_id
(),
"text"
:
prompt
,
"img"
:
img_byte
,
"n_prompt"
:
n_prompt
,
},
device
=
"cuda"
,
)
)
)
results
=
await
asyncio
.
gather
(
*
tasks
)
# clip_encoder, vae_encoder, text_encoders
return
results
[
0
],
results
[
1
],
results
[
2
]
async
def
post_encoders_t2v
(
self
,
prompt
,
n_prompt
=
None
):
tasks
=
[]
tasks
.
append
(
asyncio
.
create_task
(
self
.
post_task
(
task_type
=
"text_encoders"
,
urls
=
self
.
config
[
"sub_servers"
][
"text_encoders"
],
message
=
{
"task_id"
:
generate_task_id
(),
"text"
:
prompt
,
"img"
:
None
,
"n_prompt"
:
n_prompt
,
},
device
=
"cuda"
,
)
)
)
results
=
await
asyncio
.
gather
(
*
tasks
)
# text_encoders
return
results
[
0
]
async
def
_run_input_encoder_server_i2v
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
n_prompt
=
self
.
config
.
get
(
"negative_prompt"
,
""
)
img
=
Image
.
open
(
self
.
config
[
"image_path"
]).
convert
(
"RGB"
)
(
clip_encoder_out
,
vae_encode_out
,
text_encoder_output
,
)
=
await
self
.
post_encoders_i2v
(
prompt
,
img
,
n_prompt
)
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
return
self
.
get_encoder_output_i2v
(
clip_encoder_out
,
vae_encode_out
,
text_encoder_output
,
img
)
async
def
_run_input_encoder_server_t2v
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
n_prompt
=
self
.
config
.
get
(
"negative_prompt"
,
""
)
text_encoder_output
=
await
self
.
post_encoders_t2v
(
prompt
,
n_prompt
)
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
return
{
"text_encoder_output"
:
text_encoder_output
,
"image_encoder_output"
:
None
,
}
async
def
_run_dit_server
(
self
,
kwargs
):
if
self
.
inputs
.
get
(
"image_encoder_output"
,
None
)
is
not
None
:
self
.
inputs
[
"image_encoder_output"
].
pop
(
"img"
,
None
)
dit_output
=
await
self
.
post_task
(
task_type
=
"dit"
,
urls
=
self
.
config
[
"sub_servers"
][
"dit"
],
message
=
{
"task_id"
:
generate_task_id
(),
"inputs"
:
self
.
tensor_transporter
.
prepare_tensor
(
self
.
inputs
),
"kwargs"
:
self
.
tensor_transporter
.
prepare_tensor
(
kwargs
),
},
device
=
"cuda"
,
)
return
dit_output
,
None
async
def
_run_vae_decoder_server
(
self
,
latents
,
generator
):
images
=
await
self
.
post_task
(
task_type
=
"vae_model/decoder"
,
urls
=
self
.
config
[
"sub_servers"
][
"vae_model"
],
message
=
{
"task_id"
:
generate_task_id
(),
"latents"
:
self
.
tensor_transporter
.
prepare_tensor
(
latents
),
},
device
=
"cpu"
,
)
return
images
async
def
run_pipeline
(
self
,
save_video
=
True
):
async
with
AsyncWrapper
(
self
)
as
wrapper
:
if
self
.
config
[
"use_prompt_enhancer"
]:
self
.
config
[
"prompt_enhanced"
]
=
await
wrapper
.
run_prompt_enhancer
()
self
.
inputs
=
await
wrapper
.
run_input_encoder
()
self
.
inputs
=
self
.
run_input_encoder
()
kwargs
=
self
.
set_target_shape
()
kwargs
=
self
.
set_target_shape
()
latents
,
generator
=
await
wrapper
.
run_dit
(
kwargs
)
latents
,
generator
=
self
.
run_dit
(
kwargs
)
images
=
await
wrapper
.
run_vae_decoder
(
latents
,
generator
)
images
=
self
.
run_vae_decoder
(
latents
,
generator
)
if
save_video
:
if
save_video
:
await
wrapper
.
save_video
(
images
)
self
.
save_video
(
images
)
del
latents
,
generator
del
latents
,
generator
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
return
images
return
images
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment