Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
ca696d83
"vscode:/vscode.git/clone" did not exist on "7dcf63e69cb580bc90213a29936f987df632b595"
Commit
ca696d83
authored
May 08, 2025
by
helloyongyang
Browse files
Fix enhancer bugs
parent
dbfa688b
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
33 additions
and
27 deletions
+33
-27
lightx2v/api_server.py
lightx2v/api_server.py
+4
-3
lightx2v/infer.py
lightx2v/infer.py
+2
-2
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+7
-13
lightx2v/utils/prompt_enhancer.py
lightx2v/utils/prompt_enhancer.py
+15
-4
lightx2v/utils/set_config.py
lightx2v/utils/set_config.py
+1
-0
scripts/post.py
scripts/post.py
+1
-1
scripts/post_enhancer.py
scripts/post_enhancer.py
+2
-3
scripts/run_wan_t2v_save_quant.sh
scripts/run_wan_t2v_save_quant.sh
+1
-1
No files found.
lightx2v/api_server.py
View file @
ca696d83
...
...
@@ -68,8 +68,8 @@ async def v1_local_video_generate(message: Message):
logger
.
info
(
f
"message:
{
message
}
"
)
await
asyncio
.
to_thread
(
runner
.
run_pipeline
)
response
=
{
"response"
:
"finished"
,
"save_video_path"
:
message
.
save_video_path
}
if
runner
.
has_prompt_enhancer
and
message
.
use_prompt_enhancer
:
response
[
"enhanced
_prompt
"
]
=
runner
.
config
[
"prompt"
]
if
message
.
use_prompt_enhancer
:
response
[
"
prompt_
enhanced"
]
=
runner
.
config
[
"prompt
_enhanced
"
]
return
response
...
...
@@ -80,11 +80,12 @@ async def v1_local_video_generate(message: Message):
if
__name__
==
"__main__"
:
signal
.
signal
(
signal
.
SIGINT
,
signal_handler
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model_cls"
,
type
=
str
,
required
=
True
,
choices
=
[
"wan2.1"
,
"hunyuan"
,
"wan2.1_causvid"
],
default
=
"hunyuan"
)
parser
.
add_argument
(
"--model_cls"
,
type
=
str
,
required
=
True
,
choices
=
[
"wan2.1"
,
"hunyuan"
,
"wan2.1_causvid"
,
"wan2.1_skyreels_v2_df"
],
default
=
"hunyuan"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
choices
=
[
"t2v"
,
"i2v"
],
default
=
"t2v"
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--prompt_enhancer"
,
default
=
None
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
)
args
=
parser
.
parse_args
()
logger
.
info
(
f
"args:
{
args
}
"
)
...
...
lightx2v/infer.py
View file @
ca696d83
...
...
@@ -39,9 +39,9 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--task"
,
type
=
str
,
choices
=
[
"t2v"
,
"i2v"
],
default
=
"t2v"
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--enable_cfg"
,
type
=
bool
,
default
=
False
)
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--prompt_enhancer"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--negative_prompt"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--image_path"
,
type
=
str
,
default
=
""
,
help
=
"The path to input image file or path for image-to-video (i2v) task"
)
parser
.
add_argument
(
"--save_video_path"
,
type
=
str
,
default
=
"./output_lightx2v.mp4"
,
help
=
"The path to save video path/file"
)
...
...
lightx2v/models/runners/default_runner.py
View file @
ca696d83
...
...
@@ -11,29 +11,22 @@ from loguru import logger
class
DefaultRunner
:
def
__init__
(
self
,
config
):
self
.
config
=
config
self
.
config
[
"user_prompt"
]
=
self
.
config
[
"prompt"
]
self
.
has_prompt_enhancer
=
self
.
config
.
prompt_enhancer
is
not
None
and
self
.
config
.
task
==
"t2v"
self
.
config
[
"use_prompt_enhancer"
]
=
self
.
has_prompt_enhancer
if
self
.
has_prompt_enhancer
:
if
self
.
config
.
prompt_enhancer
is
not
None
and
self
.
config
.
task
==
"t2v"
:
self
.
load_prompt_enhancer
()
self
.
model
,
self
.
text_encoders
,
self
.
vae_model
,
self
.
image_encoder
=
self
.
load_model
()
@
ProfilingContext
(
"Load prompt enhancer"
)
def
load_prompt_enhancer
(
self
):
gpu_count
=
torch
.
cuda
.
device_count
()
if
gpu_count
==
1
:
logger
.
info
(
"Only one GPU, use prompt enhancer cpu offload"
)
raise
NotImplementedError
(
"prompt enhancer cpu offload is not supported."
)
self
.
prompt_enhancer
=
PromptEnhancer
(
model_name
=
self
.
config
.
prompt_enhancer
,
device_map
=
"cuda:1"
)
self
.
config
[
"use_prompt_enhancer"
]
=
True
# Set use_prompt_enhancer to True now. (Default is False)
def
set_inputs
(
self
,
inputs
):
self
.
config
[
"user_prompt"
]
=
inputs
.
get
(
"prompt"
,
""
)
self
.
config
[
"prompt"
]
=
inputs
.
get
(
"prompt"
,
""
)
self
.
config
[
"use_prompt_enhancer"
]
=
inputs
.
get
(
"use_prompt_enhancer"
,
False
)
self
.
config
[
"use_prompt_enhancer"
]
=
inputs
.
get
(
"use_prompt_enhancer"
,
False
)
# Reset use_prompt_enhancer from clinet side.
self
.
config
[
"negative_prompt"
]
=
inputs
.
get
(
"negative_prompt"
,
""
)
self
.
config
[
"image_path"
]
=
inputs
.
get
(
"image_path"
,
""
)
self
.
config
[
"save_video_path"
]
=
inputs
.
get
(
"save_video_path"
,
""
)
...
...
@@ -44,7 +37,8 @@ class DefaultRunner:
with
ProfilingContext
(
"Run Img Encoder"
):
image_encoder_output
=
self
.
run_image_encoder
(
self
.
config
,
self
.
image_encoder
,
self
.
vae_model
)
with
ProfilingContext
(
"Run Text Encoder"
):
text_encoder_output
=
self
.
run_text_encoder
(
self
.
config
[
"prompt"
],
self
.
text_encoders
,
self
.
config
,
image_encoder_output
)
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
self
.
text_encoders
,
self
.
config
,
image_encoder_output
)
self
.
set_target_shape
()
self
.
inputs
=
{
"text_encoder_output"
:
text_encoder_output
,
"image_encoder_output"
:
image_encoder_output
}
...
...
@@ -93,8 +87,8 @@ class DefaultRunner:
save_videos_grid
(
images
,
self
.
config
.
save_video_path
,
fps
=
24
)
def
run_pipeline
(
self
):
if
self
.
has_prompt_enhancer
and
self
.
config
[
"use_prompt_enhancer"
]:
self
.
config
[
"prompt"
]
=
self
.
prompt_enhancer
(
self
.
config
[
"
user_
prompt"
])
if
self
.
config
[
"use_prompt_enhancer"
]:
self
.
config
[
"prompt
_enhanced
"
]
=
self
.
prompt_enhancer
(
self
.
config
[
"prompt"
])
self
.
init_scheduler
()
self
.
run_input_encoder
()
self
.
model
.
scheduler
.
prepare
(
self
.
inputs
[
"image_encoder_output"
])
...
...
lightx2v/utils/prompt_enhancer.py
View file @
ca696d83
import
argparse
import
torch
from
loguru
import
logger
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
,
ProfilingContext
...
...
@@ -38,6 +39,7 @@ class PromptEnhancer:
self
.
model
=
self
.
model
.
to
(
device
)
@
ProfilingContext
(
"Run prompt enhancer"
)
@
torch
.
no_grad
()
def
__call__
(
self
,
prompt
):
prompt
=
prompt
.
strip
()
prompt
=
sys_prompt
.
format
(
prompt
)
...
...
@@ -46,11 +48,20 @@ class PromptEnhancer:
model_inputs
=
self
.
tokenizer
([
text
],
return_tensors
=
"pt"
).
to
(
self
.
model
.
device
)
generated_ids
=
self
.
model
.
generate
(
**
model_inputs
,
max_new_tokens
=
2048
,
max_new_tokens
=
8192
,
)
generated_ids
=
[
output_ids
[
len
(
input_ids
)
:]
for
input_ids
,
output_ids
in
zip
(
model_inputs
.
input_ids
,
generated_ids
)]
rewritten_prompt
=
self
.
tokenizer
.
batch_decode
(
generated_ids
,
skip_special_tokens
=
True
)[
0
]
logger
.
info
(
f
"Enhanced prompt:
{
rewritten_prompt
}
"
)
output_ids
=
generated_ids
[
0
][
len
(
model_inputs
.
input_ids
[
0
])
:].
tolist
()
think_id
=
self
.
tokenizer
.
encode
(
"</think>"
)
if
len
(
think_id
)
==
1
:
index
=
len
(
output_ids
)
-
output_ids
[::
-
1
].
index
(
think_id
[
0
])
else
:
index
=
0
thinking_content
=
self
.
tokenizer
.
decode
(
output_ids
[:
index
],
skip_special_tokens
=
True
).
strip
(
"
\n
"
)
logger
.
info
(
f
"[Enhanced] thinking content:
{
thinking_content
}
"
)
rewritten_prompt
=
self
.
tokenizer
.
decode
(
output_ids
[
index
:],
skip_special_tokens
=
True
).
strip
(
"
\n
"
)
logger
.
info
(
f
"[Enhanced] rewritten prompt:
{
rewritten_prompt
}
"
)
return
rewritten_prompt
...
...
lightx2v/utils/set_config.py
View file @
ca696d83
...
...
@@ -19,6 +19,7 @@ def get_default_config():
"lora_path"
:
None
,
"strength_model"
:
1.0
,
"mm_config"
:
{},
"use_prompt_enhancer"
:
False
,
}
return
default_config
...
...
scripts/post.py
View file @
ca696d83
...
...
@@ -8,7 +8,7 @@ message = {
"prompt"
:
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
,
"negative_prompt"
:
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
,
"image_path"
:
""
,
"save_video_path"
:
"./output_lightx2v_wan_t2v
_ap4
.mp4"
,
# It is best to set it to an absolute path.
"save_video_path"
:
"./output_lightx2v_wan_t2v.mp4"
,
# It is best to set it to an absolute path.
}
logger
.
info
(
f
"message:
{
message
}
"
)
...
...
scripts/post_enhancer.py
View file @
ca696d83
...
...
@@ -6,11 +6,10 @@ url = "http://localhost:8000/v1/local/video/generate"
message
=
{
"prompt"
:
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
,
"use_prompt_enhancer"
:
True
,
"negative_prompt"
:
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
,
"image_path"
:
""
,
"
num_fragments"
:
1
,
"
save_video_path"
:
"./output_lightx2v_wan_t2v_ap4.mp4"
,
# It is best to set it to an absolute path.
"
save_video_path"
:
"./output_lightx2v_wan_t2v_enhanced.mp4"
,
# It is best to set it to an absolute path.
"
use_prompt_enhancer"
:
True
,
}
logger
.
info
(
f
"message:
{
message
}
"
)
...
...
scripts/run_wan_t2v_save_quant.sh
View file @
ca696d83
...
...
@@ -6,7 +6,7 @@ model_path=
# check section
if
[
-z
"
${
CUDA_VISIBLE_DEVICES
}
"
]
;
then
cuda_devices
=
2
cuda_devices
=
0
echo
"Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value:
${
cuda_devices
}
, change at shell script or set env variable."
export
CUDA_VISIBLE_DEVICES
=
${
cuda_devices
}
fi
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment