Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
e9e33065
Commit
e9e33065
authored
Jun 16, 2025
by
Zhuguanyu Wu
Committed by
GitHub
Jun 16, 2025
Browse files
Dev distill (#69)
* add step & cfg distillation wan model
parent
497ff9fe
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
18 additions
and
15 deletions
+18
-15
lightx2v/common/apis/dit.py
lightx2v/common/apis/dit.py
+1
-1
lightx2v/common/apis/image_encoder.py
lightx2v/common/apis/image_encoder.py
+1
-1
lightx2v/common/apis/text_encoder.py
lightx2v/common/apis/text_encoder.py
+1
-1
lightx2v/common/apis/vae.py
lightx2v/common/apis/vae.py
+1
-1
lightx2v/models/runners/cogvideox/cogvidex_runner.py
lightx2v/models/runners/cogvideox/cogvidex_runner.py
+4
-4
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+4
-1
lightx2v/models/runners/wan/wan_causvid_runner.py
lightx2v/models/runners/wan/wan_causvid_runner.py
+2
-2
lightx2v/models/runners/wan/wan_distill_runner.py
lightx2v/models/runners/wan/wan_distill_runner.py
+2
-2
scripts/run_wan_t2v_distill.sh
scripts/run_wan_t2v_distill.sh
+2
-2
No files found.
lightx2v/common/apis/dit.py
View file @
e9e33065
...
...
@@ -53,7 +53,7 @@ class DiTRunner:
self
.
runner_cls
=
RUNNER_REGISTER
[
self
.
config
.
model_cls
]
self
.
runner
=
self
.
runner_cls
(
config
)
self
.
runner
.
model
=
self
.
runner
.
load_transformer
(
self
.
runner
.
get_init_device
()
)
self
.
runner
.
model
=
self
.
runner
.
load_transformer
()
def
_run_dit
(
self
,
inputs
,
kwargs
):
self
.
runner
.
config
.
update
(
tensor_transporter
.
load_tensor
(
kwargs
))
...
...
lightx2v/common/apis/image_encoder.py
View file @
e9e33065
...
...
@@ -51,7 +51,7 @@ class ImageEncoderRunner:
self
.
runner_cls
=
RUNNER_REGISTER
[
self
.
config
.
model_cls
]
self
.
runner
=
self
.
runner_cls
(
config
)
self
.
runner
.
image_encoder
=
self
.
runner
.
load_image_encoder
(
self
.
runner
.
get_init_device
()
)
self
.
runner
.
image_encoder
=
self
.
runner
.
load_image_encoder
()
def
_run_image_encoder
(
self
,
img
):
img
=
image_transporter
.
load_image
(
img
)
...
...
lightx2v/common/apis/text_encoder.py
View file @
e9e33065
...
...
@@ -53,7 +53,7 @@ class TextEncoderRunner:
self
.
runner_cls
=
RUNNER_REGISTER
[
self
.
config
.
model_cls
]
self
.
runner
=
self
.
runner_cls
(
config
)
self
.
runner
.
text_encoders
=
self
.
runner
.
load_text_encoder
(
self
.
runner
.
get_init_device
()
)
self
.
runner
.
text_encoders
=
self
.
runner
.
load_text_encoder
()
def
_run_text_encoder
(
self
,
text
,
img
,
n_prompt
):
if
img
is
not
None
:
...
...
lightx2v/common/apis/vae.py
View file @
e9e33065
...
...
@@ -56,7 +56,7 @@ class VAERunner:
self
.
runner_cls
=
RUNNER_REGISTER
[
self
.
config
.
model_cls
]
self
.
runner
=
self
.
runner_cls
(
config
)
self
.
runner
.
vae_encoder
,
self
.
runner
.
vae_decoder
=
self
.
runner
.
load_vae
(
self
.
runner
.
get_init_device
()
)
self
.
runner
.
vae_encoder
,
self
.
runner
.
vae_decoder
=
self
.
runner
.
load_vae
()
def
_run_vae_encoder
(
self
,
img
):
img
=
image_transporter
.
load_image
(
img
)
...
...
lightx2v/models/runners/cogvideox/cogvidex_runner.py
View file @
e9e33065
...
...
@@ -16,19 +16,19 @@ class CogvideoxRunner(DefaultRunner):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
load_transformer
(
self
,
init_device
):
def
load_transformer
(
self
):
model
=
CogvideoxModel
(
self
.
config
)
return
model
def
load_image_encoder
(
self
,
init_device
):
def
load_image_encoder
(
self
):
return
None
def
load_text_encoder
(
self
,
init_device
):
def
load_text_encoder
(
self
):
text_encoder
=
T5EncoderModel_v1_1_xxl
(
self
.
config
)
text_encoders
=
[
text_encoder
]
return
text_encoders
def
load_vae
(
self
,
init_device
):
def
load_vae
(
self
):
vae_model
=
CogvideoxVAE
(
self
.
config
)
return
vae_model
,
vae_model
...
...
lightx2v/models/runners/default_runner.py
View file @
e9e33065
...
...
@@ -24,10 +24,12 @@ class DefaultRunner:
if
not
self
.
check_sub_servers
(
"prompt_enhancer"
):
self
.
has_prompt_enhancer
=
False
logger
.
warning
(
"No prompt enhancer server available, disable prompt enhancer."
)
if
not
self
.
has_prompt_enhancer
:
self
.
config
[
"use_prompt_enhancer"
]
=
False
self
.
set_init_device
()
def
init_modules
(
self
):
logger
.
info
(
"Initializing runner modules..."
)
self
.
set_init_device
()
if
self
.
config
[
"mode"
]
==
"split_server"
:
self
.
tensor_transporter
=
TensorTransporter
()
self
.
image_transporter
=
ImageTransporter
()
...
...
@@ -93,6 +95,7 @@ class DefaultRunner:
def
set_inputs
(
self
,
inputs
):
self
.
config
[
"prompt"
]
=
inputs
.
get
(
"prompt"
,
""
)
self
.
config
[
"use_prompt_enhancer"
]
=
False
if
self
.
has_prompt_enhancer
:
self
.
config
[
"use_prompt_enhancer"
]
=
inputs
.
get
(
"use_prompt_enhancer"
,
False
)
# Reset use_prompt_enhancer from clinet side.
self
.
config
[
"negative_prompt"
]
=
inputs
.
get
(
"negative_prompt"
,
""
)
...
...
lightx2v/models/runners/wan/wan_causvid_runner.py
View file @
e9e33065
...
...
@@ -29,8 +29,8 @@ class WanCausVidRunner(WanRunner):
self
.
infer_blocks
=
self
.
model
.
config
.
num_blocks
self
.
num_fragments
=
self
.
model
.
config
.
num_fragments
def
load_transformer
(
self
,
init_device
):
return
WanCausVidModel
(
self
.
config
.
model_path
,
self
.
config
,
init_device
)
def
load_transformer
(
self
):
return
WanCausVidModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
)
def
set_inputs
(
self
,
inputs
):
super
().
set_inputs
(
inputs
)
...
...
lightx2v/models/runners/wan/wan_distill_runner.py
View file @
e9e33065
...
...
@@ -23,8 +23,8 @@ class WanDistillRunner(WanRunner):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
load_transformer
(
self
,
init_device
):
model
=
WanDistillModel
(
self
.
config
.
model_path
,
self
.
config
,
init_device
)
def
load_transformer
(
self
):
model
=
WanDistillModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
)
if
self
.
config
.
lora_path
:
lora_wrapper
=
WanLoraWrapper
(
model
)
lora_name
=
lora_wrapper
.
load_lora
(
self
.
config
.
lora_path
)
...
...
scripts/run_wan_t2v_distill.sh
View file @
e9e33065
#!/bin/bash
# set path and first
lightx2v_path
=
"/data/lightx2v-dev/"
model_path
=
"/data/lightx2v-dev/Wan2.1-T2V-14B/"
lightx2v_path
=
model_path
=
# check section
if
[
-z
"
${
CUDA_VISIBLE_DEVICES
}
"
]
;
then
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment