Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
e08c4f90
Commit
e08c4f90
authored
Jul 17, 2025
by
sandy
Committed by
GitHub
Jul 17, 2025
Browse files
Merge branch 'main' into audio_r2v
parents
12bfd120
6d07a72e
Changes
191
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
573 additions
and
253 deletions
+573
-253
lightx2v/models/networks/wan/causvid_model.py
lightx2v/models/networks/wan/causvid_model.py
+17
-12
lightx2v/models/networks/wan/infer/pre_infer.py
lightx2v/models/networks/wan/infer/pre_infer.py
+1
-1
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+6
-4
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+19
-4
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+5
-0
lightx2v/models/runners/base_runner.py
lightx2v/models/runners/base_runner.py
+165
-0
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+62
-142
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+7
-4
lightx2v/models/runners/wan/wan_causvid_runner.py
lightx2v/models/runners/wan/wan_causvid_runner.py
+11
-9
lightx2v/models/runners/wan/wan_distill_runner.py
lightx2v/models/runners/wan/wan_distill_runner.py
+6
-4
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+76
-29
lightx2v/models/schedulers/wan/changing_resolution/scheduler.py
...2v/models/schedulers/wan/changing_resolution/scheduler.py
+66
-0
lightx2v/models/schedulers/wan/step_distill/scheduler.py
lightx2v/models/schedulers/wan/step_distill/scheduler.py
+26
-38
lightx2v/models/video_encoders/hf/wan/vae.py
lightx2v/models/video_encoders/hf/wan/vae.py
+2
-2
lightx2v/server/service.py
lightx2v/server/service.py
+6
-1
lightx2v/utils/async_io.py
lightx2v/utils/async_io.py
+82
-0
lightx2v/utils/quant_utils.py
lightx2v/utils/quant_utils.py
+6
-1
lightx2v/utils/set_config.py
lightx2v/utils/set_config.py
+1
-2
lightx2v/utils/utils.py
lightx2v/utils/utils.py
+8
-0
lightx2v_kernel/CMakeLists.txt
lightx2v_kernel/CMakeLists.txt
+1
-0
No files found.
lightx2v/models/networks/wan/causvid_model.py
View file @
e08c4f90
...
@@ -12,6 +12,7 @@ from lightx2v.models.networks.wan.infer.causvid.transformer_infer import (
...
@@ -12,6 +12,7 @@ from lightx2v.models.networks.wan.infer.causvid.transformer_infer import (
WanTransformerInferCausVid
,
WanTransformerInferCausVid
,
)
)
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
safetensors
import
safe_open
class
WanCausVidModel
(
WanModel
):
class
WanCausVidModel
(
WanModel
):
...
@@ -28,18 +29,22 @@ class WanCausVidModel(WanModel):
...
@@ -28,18 +29,22 @@ class WanCausVidModel(WanModel):
self
.
transformer_infer_class
=
WanTransformerInferCausVid
self
.
transformer_infer_class
=
WanTransformerInferCausVid
def
_load_ckpt
(
self
,
use_bf16
,
skip_bf16
):
def
_load_ckpt
(
self
,
use_bf16
,
skip_bf16
):
use_bfloat16
=
GET_DTYPE
()
==
"BF16"
ckpt_folder
=
"causvid_models"
ckpt_path
=
os
.
path
.
join
(
self
.
model_path
,
"causal_model.pt"
)
safetensors_path
=
os
.
path
.
join
(
self
.
model_path
,
f
"
{
ckpt_folder
}
/causal_model.safetensors"
)
if
not
os
.
path
.
exists
(
ckpt_path
):
if
os
.
path
.
exists
(
safetensors_path
):
return
super
().
_load_ckpt
(
use_bf16
,
skip_bf16
)
with
safe_open
(
safetensors_path
,
framework
=
"pt"
)
as
f
:
weight_dict
=
{
key
:
(
f
.
get_tensor
(
key
).
to
(
torch
.
bfloat16
)
if
use_bf16
or
all
(
s
not
in
key
for
s
in
skip_bf16
)
else
f
.
get_tensor
(
key
)).
pin_memory
().
to
(
self
.
device
)
for
key
in
f
.
keys
()}
weight_dict
=
torch
.
load
(
ckpt_path
,
map_location
=
"cpu"
,
weights_only
=
True
)
return
weight_dict
dtype
=
torch
.
bfloat16
if
use_bfloat16
else
None
ckpt_path
=
os
.
path
.
join
(
self
.
model_path
,
f
"
{
ckpt_folder
}
/causal_model.pt"
)
for
key
,
value
in
weight_dict
.
items
():
if
os
.
path
.
exists
(
ckpt_path
):
weight_dict
[
key
]
=
value
.
to
(
device
=
self
.
device
,
dtype
=
dtype
)
weight_dict
=
torch
.
load
(
ckpt_path
,
map_location
=
"cpu"
,
weights_only
=
True
)
weight_dict
=
{
return
weight_dict
key
:
(
weight_dict
[
key
].
to
(
torch
.
bfloat16
)
if
use_bf16
or
all
(
s
not
in
key
for
s
in
skip_bf16
)
else
weight_dict
[
key
]).
pin_memory
().
to
(
self
.
device
)
for
key
in
weight_dict
.
keys
()
}
return
weight_dict
return
super
().
_load_ckpt
(
use_bf16
,
skip_bf16
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
,
kv_start
,
kv_end
):
def
infer
(
self
,
inputs
,
kv_start
,
kv_end
):
...
...
lightx2v/models/networks/wan/infer/pre_infer.py
View file @
e08c4f90
...
@@ -64,7 +64,7 @@ class WanPreInfer:
...
@@ -64,7 +64,7 @@ class WanPreInfer:
embed
=
sinusoidal_embedding_1d
(
self
.
freq_dim
,
t
.
flatten
())
embed
=
sinusoidal_embedding_1d
(
self
.
freq_dim
,
t
.
flatten
())
if
self
.
enable_dynamic_cfg
:
if
self
.
enable_dynamic_cfg
:
s
=
torch
.
tensor
([
self
.
cfg_scale
],
dtype
=
torch
.
float32
).
to
(
x
.
device
)
s
=
torch
.
tensor
([
self
.
cfg_scale
],
dtype
=
torch
.
float32
).
to
(
x
.
device
)
cfg_embed
=
guidance_scale_embedding
(
s
,
embedding_dim
=
256
,
cfg_range
=
(
0
.0
,
8.0
),
target_range
=
1000.0
,
dtype
=
torch
.
float32
).
type_as
(
x
)
cfg_embed
=
guidance_scale_embedding
(
s
,
embedding_dim
=
256
,
cfg_range
=
(
1
.0
,
8.0
),
target_range
=
1000.0
,
dtype
=
torch
.
float32
).
type_as
(
x
)
cfg_embed
=
weights
.
cfg_cond_proj
.
apply
(
cfg_embed
)
cfg_embed
=
weights
.
cfg_cond_proj
.
apply
(
cfg_embed
)
embed
=
embed
+
cfg_embed
embed
=
embed
+
cfg_embed
if
GET_DTYPE
()
!=
"BF16"
:
if
GET_DTYPE
()
!=
"BF16"
:
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
e08c4f90
...
@@ -29,6 +29,8 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -29,6 +29,8 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
mask_map
=
None
self
.
mask_map
=
None
if
self
.
config
[
"cpu_offload"
]:
if
self
.
config
[
"cpu_offload"
]:
if
torch
.
cuda
.
get_device_capability
(
0
)
==
(
9
,
0
):
assert
self
.
config
[
"self_attn_1_type"
]
!=
"sage_attn2"
if
"offload_ratio"
in
self
.
config
:
if
"offload_ratio"
in
self
.
config
:
offload_ratio
=
self
.
config
[
"offload_ratio"
]
offload_ratio
=
self
.
config
[
"offload_ratio"
]
else
:
else
:
...
@@ -104,7 +106,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -104,7 +106,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return
x
return
x
def
_infer_with_lazy_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
_infer_with_lazy_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
self
.
weights_stream_mgr
.
prefetch_weights_from_disk
(
weights
)
self
.
weights_stream_mgr
.
prefetch_weights_from_disk
(
weights
.
blocks
)
for
block_idx
in
range
(
self
.
blocks_num
):
for
block_idx
in
range
(
self
.
blocks_num
):
if
block_idx
==
0
:
if
block_idx
==
0
:
...
@@ -132,7 +134,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -132,7 +134,7 @@ class WanTransformerInfer(BaseTransformerInfer):
if
block_idx
==
self
.
blocks_num
-
1
:
if
block_idx
==
self
.
blocks_num
-
1
:
self
.
weights_stream_mgr
.
pin_memory_buffer
.
pop_front
()
self
.
weights_stream_mgr
.
pin_memory_buffer
.
pop_front
()
self
.
weights_stream_mgr
.
_async_prefetch_block
(
weights
)
self
.
weights_stream_mgr
.
_async_prefetch_block
(
weights
.
blocks
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
grid_sizes
,
embed
,
embed0
,
seq_lens
,
freqs
,
context
del
grid_sizes
,
embed
,
embed0
,
seq_lens
,
freqs
,
context
...
@@ -189,7 +191,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -189,7 +191,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return
x
return
x
def
_infer_with_phases_lazy_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
def
_infer_with_phases_lazy_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
self
.
weights_stream_mgr
.
prefetch_weights_from_disk
(
weights
)
self
.
weights_stream_mgr
.
prefetch_weights_from_disk
(
weights
.
blocks
)
for
block_idx
in
range
(
weights
.
blocks_num
):
for
block_idx
in
range
(
weights
.
blocks_num
):
for
phase_idx
in
range
(
self
.
weights_stream_mgr
.
phases_num
):
for
phase_idx
in
range
(
self
.
weights_stream_mgr
.
phases_num
):
...
@@ -236,7 +238,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -236,7 +238,7 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
weights_stream_mgr
.
swap_phases
()
self
.
weights_stream_mgr
.
swap_phases
()
self
.
weights_stream_mgr
.
_async_prefetch_block
(
weights
)
self
.
weights_stream_mgr
.
_async_prefetch_block
(
weights
.
blocks
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
attn_out
,
y_out
,
y
del
attn_out
,
y_out
,
y
...
...
lightx2v/models/networks/wan/model.py
View file @
e08c4f90
import
os
import
os
import
sys
import
torch
import
torch
import
glob
import
glob
import
json
import
json
...
@@ -37,7 +36,11 @@ class WanModel:
...
@@ -37,7 +36,11 @@ class WanModel:
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
dit_quantized
=
self
.
config
.
mm_config
.
get
(
"mm_type"
,
"Default"
)
!=
"Default"
self
.
dit_quantized
=
self
.
config
.
mm_config
.
get
(
"mm_type"
,
"Default"
)
!=
"Default"
self
.
dit_quantized_ckpt
=
self
.
config
.
get
(
"dit_quantized_ckpt"
,
None
)
if
self
.
dit_quantized
:
dit_quant_scheme
=
self
.
config
.
mm_config
.
get
(
"mm_type"
).
split
(
"-"
)[
1
]
self
.
dit_quantized_ckpt
=
self
.
config
.
get
(
"dit_quantized_ckpt"
,
os
.
path
.
join
(
model_path
,
dit_quant_scheme
))
else
:
self
.
dit_quantized_ckpt
=
None
self
.
weight_auto_quant
=
self
.
config
.
mm_config
.
get
(
"weight_auto_quant"
,
False
)
self
.
weight_auto_quant
=
self
.
config
.
mm_config
.
get
(
"weight_auto_quant"
,
False
)
if
self
.
dit_quantized
:
if
self
.
dit_quantized
:
assert
self
.
weight_auto_quant
or
self
.
dit_quantized_ckpt
is
not
None
assert
self
.
weight_auto_quant
or
self
.
dit_quantized_ckpt
is
not
None
...
@@ -80,7 +83,12 @@ class WanModel:
...
@@ -80,7 +83,12 @@ class WanModel:
safetensors_files
=
glob
.
glob
(
safetensors_pattern
)
safetensors_files
=
glob
.
glob
(
safetensors_pattern
)
if
not
safetensors_files
:
if
not
safetensors_files
:
raise
FileNotFoundError
(
f
"No .safetensors files found in directory:
{
self
.
model_path
}
"
)
original_pattern
=
os
.
path
.
join
(
self
.
model_path
,
"original"
,
"*.safetensors"
)
safetensors_files
=
glob
.
glob
(
original_pattern
)
if
not
safetensors_files
:
raise
FileNotFoundError
(
f
"No .safetensors files found in directory:
{
self
.
model_path
}
"
)
weight_dict
=
{}
weight_dict
=
{}
for
file_path
in
safetensors_files
:
for
file_path
in
safetensors_files
:
file_weights
=
self
.
_load_safetensor_to_dict
(
file_path
,
use_bf16
,
skip_bf16
)
file_weights
=
self
.
_load_safetensor_to_dict
(
file_path
,
use_bf16
,
skip_bf16
)
...
@@ -138,7 +146,14 @@ class WanModel:
...
@@ -138,7 +146,14 @@ class WanModel:
def
_init_weights
(
self
,
weight_dict
=
None
):
def
_init_weights
(
self
,
weight_dict
=
None
):
use_bf16
=
GET_DTYPE
()
==
"BF16"
use_bf16
=
GET_DTYPE
()
==
"BF16"
# Some layers run with float32 to achieve high accuracy
# Some layers run with float32 to achieve high accuracy
skip_bf16
=
{
"norm"
,
"embedding"
,
"modulation"
,
"time"
,
"img_emb.proj.0"
,
"img_emb.proj.4"
}
skip_bf16
=
{
"norm"
,
"embedding"
,
"modulation"
,
"time"
,
"img_emb.proj.0"
,
"img_emb.proj.4"
,
}
if
weight_dict
is
None
:
if
weight_dict
is
None
:
if
not
self
.
dit_quantized
or
self
.
weight_auto_quant
:
if
not
self
.
dit_quantized
or
self
.
weight_auto_quant
:
self
.
original_weight_dict
=
self
.
_load_ckpt
(
use_bf16
,
skip_bf16
)
self
.
original_weight_dict
=
self
.
_load_ckpt
(
use_bf16
,
skip_bf16
)
...
...
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
e08c4f90
...
@@ -24,6 +24,11 @@ class WanTransformerWeights(WeightModule):
...
@@ -24,6 +24,11 @@ class WanTransformerWeights(WeightModule):
self
.
blocks
=
WeightModuleList
([
WanTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
)
for
i
in
range
(
self
.
blocks_num
)])
self
.
blocks
=
WeightModuleList
([
WanTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
)
for
i
in
range
(
self
.
blocks_num
)])
self
.
add_module
(
"blocks"
,
self
.
blocks
)
self
.
add_module
(
"blocks"
,
self
.
blocks
)
def
clear
(
self
):
for
block
in
self
.
blocks
:
for
phase
in
block
.
compute_phases
:
phase
.
clear
()
class
WanTransformerAttentionBlock
(
WeightModule
):
class
WanTransformerAttentionBlock
(
WeightModule
):
def
__init__
(
self
,
block_index
,
task
,
mm_type
,
config
):
def
__init__
(
self
,
block_index
,
task
,
mm_type
,
config
):
...
...
lightx2v/models/runners/base_runner.py
0 → 100644
View file @
e08c4f90
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Dict
,
Tuple
,
Optional
,
Union
,
List
,
Protocol
from
lightx2v.utils.utils
import
save_videos_grid
class
TransformerModel
(
Protocol
):
"""Protocol for transformer models"""
def
set_scheduler
(
self
,
scheduler
:
Any
)
->
None
:
...
def
scheduler
(
self
)
->
Any
:
...
class
TextEncoderModel
(
Protocol
):
"""Protocol for text encoder models"""
def
infer
(
self
,
texts
:
List
[
str
],
config
:
Dict
[
str
,
Any
])
->
Any
:
...
class
ImageEncoderModel
(
Protocol
):
"""Protocol for image encoder models"""
def
encode
(
self
,
image
:
Any
)
->
Any
:
...
class
VAEModel
(
Protocol
):
"""Protocol for VAE models"""
def
encode
(
self
,
image
:
Any
)
->
Tuple
[
Any
,
Dict
[
str
,
Any
]]:
...
def
decode
(
self
,
latents
:
Any
,
generator
:
Optional
[
Any
]
=
None
,
config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
Any
:
...
class
BaseRunner
(
ABC
):
"""Abstract base class for all Runners
Defines interface methods that all subclasses must implement
"""
def
__init__
(
self
,
config
:
Dict
[
str
,
Any
]):
self
.
config
=
config
@
abstractmethod
def
load_transformer
(
self
)
->
TransformerModel
:
"""Load transformer model
Returns:
Loaded transformer model instance
"""
pass
@
abstractmethod
def
load_text_encoder
(
self
)
->
Union
[
TextEncoderModel
,
List
[
TextEncoderModel
]]:
"""Load text encoder
Returns:
Text encoder instance or list of text encoder instances
"""
pass
@
abstractmethod
def
load_image_encoder
(
self
)
->
Optional
[
ImageEncoderModel
]:
"""Load image encoder
Returns:
Image encoder instance or None if not needed
"""
pass
@
abstractmethod
def
load_vae
(
self
)
->
Tuple
[
VAEModel
,
VAEModel
]:
"""Load VAE encoder and decoder
Returns:
Tuple[vae_encoder, vae_decoder]: VAE encoder and decoder instances
"""
pass
@
abstractmethod
def
run_image_encoder
(
self
,
img
:
Any
)
->
Any
:
"""Run image encoder
Args:
img: Input image
Returns:
Image encoding result
"""
pass
@
abstractmethod
def
run_vae_encoder
(
self
,
img
:
Any
)
->
Tuple
[
Any
,
Dict
[
str
,
Any
]]:
"""Run VAE encoder
Args:
img: Input image
Returns:
Tuple of VAE encoding result and additional parameters
"""
pass
@
abstractmethod
def
run_text_encoder
(
self
,
prompt
:
str
,
img
:
Optional
[
Any
]
=
None
)
->
Any
:
"""Run text encoder
Args:
prompt: Input text prompt
img: Optional input image (for some models)
Returns:
Text encoding result
"""
pass
@
abstractmethod
def
get_encoder_output_i2v
(
self
,
clip_encoder_out
:
Any
,
vae_encode_out
:
Any
,
text_encoder_output
:
Any
,
img
:
Any
)
->
Dict
[
str
,
Any
]:
"""Combine encoder outputs for i2v task
Args:
clip_encoder_out: CLIP encoder output
vae_encode_out: VAE encoder output
text_encoder_output: Text encoder output
img: Original image
Returns:
Combined encoder output dictionary
"""
pass
@
abstractmethod
def
init_scheduler
(
self
)
->
None
:
"""Initialize scheduler"""
pass
def
set_target_shape
(
self
)
->
Dict
[
str
,
Any
]:
"""Set target shape
Subclasses can override this method to provide specific implementation
Returns:
Dictionary containing target shape information
"""
return
{}
def
save_video_func
(
self
,
images
:
Any
)
->
None
:
"""Save video implementation
Subclasses can override this method to customize save logic
Args:
images: Image sequence to save
"""
save_videos_grid
(
images
,
self
.
config
.
get
(
"save_video_path"
,
"./output.mp4"
),
n_rows
=
1
,
fps
=
self
.
config
.
get
(
"fps"
,
8
))
def
load_vae_decoder
(
self
)
->
VAEModel
:
"""Load VAE decoder
Default implementation: get decoder from load_vae method
Subclasses can override this method to provide different loading logic
Returns:
VAE decoder instance
"""
if
not
hasattr
(
self
,
"vae_decoder"
)
or
self
.
vae_decoder
is
None
:
_
,
self
.
vae_decoder
=
self
.
load_vae
()
return
self
.
vae_decoder
lightx2v/models/runners/default_runner.py
View file @
e08c4f90
import
asyncio
import
gc
import
gc
import
aiohttp
import
requests
import
requests
from
requests.exceptions
import
RequestException
from
requests.exceptions
import
RequestException
import
torch
import
torch
...
@@ -13,12 +11,14 @@ from lightx2v.utils.generate_task_id import generate_task_id
...
@@ -13,12 +11,14 @@ from lightx2v.utils.generate_task_id import generate_task_id
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.service_utils
import
TensorTransporter
,
ImageTransporter
from
lightx2v.utils.service_utils
import
TensorTransporter
,
ImageTransporter
from
loguru
import
logger
from
loguru
import
logger
from
.base_runner
import
BaseRunner
class
DefaultRunner
:
class
DefaultRunner
(
BaseRunner
)
:
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
s
elf
.
config
=
config
s
uper
().
__init__
(
config
)
self
.
has_prompt_enhancer
=
False
self
.
has_prompt_enhancer
=
False
self
.
progress_callback
=
None
if
self
.
config
[
"task"
]
==
"t2v"
and
self
.
config
.
get
(
"sub_servers"
,
{}).
get
(
"prompt_enhancer"
)
is
not
None
:
if
self
.
config
[
"task"
]
==
"t2v"
and
self
.
config
.
get
(
"sub_servers"
,
{}).
get
(
"prompt_enhancer"
)
is
not
None
:
self
.
has_prompt_enhancer
=
True
self
.
has_prompt_enhancer
=
True
if
not
self
.
check_sub_servers
(
"prompt_enhancer"
):
if
not
self
.
check_sub_servers
(
"prompt_enhancer"
):
...
@@ -30,33 +30,14 @@ class DefaultRunner:
...
@@ -30,33 +30,14 @@ class DefaultRunner:
def
init_modules
(
self
):
def
init_modules
(
self
):
logger
.
info
(
"Initializing runner modules..."
)
logger
.
info
(
"Initializing runner modules..."
)
if
self
.
config
[
"mode"
]
==
"split_server"
:
if
not
self
.
config
.
get
(
"lazy_load"
,
False
)
and
not
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
tensor_transporter
=
TensorTransporter
()
self
.
load_model
()
self
.
image_transporter
=
ImageTransporter
()
self
.
run_dit
=
self
.
_run_dit_local
if
not
self
.
check_sub_servers
(
"dit"
):
self
.
run_vae_decoder
=
self
.
_run_vae_decoder_local
raise
ValueError
(
"No dit server available"
)
if
self
.
config
[
"task"
]
==
"i2v"
:
if
not
self
.
check_sub_servers
(
"text_encoders"
):
self
.
run_input_encoder
=
self
.
_run_input_encoder_local_i2v
raise
ValueError
(
"No text encoder server available"
)
if
self
.
config
[
"task"
]
==
"i2v"
:
if
not
self
.
check_sub_servers
(
"image_encoder"
):
raise
ValueError
(
"No image encoder server available"
)
if
not
self
.
check_sub_servers
(
"vae_model"
):
raise
ValueError
(
"No vae server available"
)
self
.
run_dit
=
self
.
run_dit_server
self
.
run_vae_decoder
=
self
.
run_vae_decoder_server
if
self
.
config
[
"task"
]
==
"i2v"
:
self
.
run_input_encoder
=
self
.
run_input_encoder_server_i2v
else
:
self
.
run_input_encoder
=
self
.
run_input_encoder_server_t2v
else
:
else
:
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
self
.
run_input_encoder
=
self
.
_run_input_encoder_local_t2v
self
.
load_model
()
self
.
run_dit
=
self
.
run_dit_local
self
.
run_vae_decoder
=
self
.
run_vae_decoder_local
if
self
.
config
[
"task"
]
==
"i2v"
:
self
.
run_input_encoder
=
self
.
run_input_encoder_local_i2v
else
:
self
.
run_input_encoder
=
self
.
run_input_encoder_local_t2v
def
set_init_device
(
self
):
def
set_init_device
(
self
):
if
self
.
config
[
"parallel_attn_type"
]:
if
self
.
config
[
"parallel_attn_type"
]:
...
@@ -110,9 +91,13 @@ class DefaultRunner:
...
@@ -110,9 +91,13 @@ class DefaultRunner:
# self.config["sample_shift"] = inputs.get("sample_shift", self.config.get("sample_shift", 5))
# self.config["sample_shift"] = inputs.get("sample_shift", self.config.get("sample_shift", 5))
# self.config["sample_guide_scale"] = inputs.get("sample_guide_scale", self.config.get("sample_guide_scale", 5))
# self.config["sample_guide_scale"] = inputs.get("sample_guide_scale", self.config.get("sample_guide_scale", 5))
def
set_progress_callback
(
self
,
callback
):
self
.
progress_callback
=
callback
def
run
(
self
):
def
run
(
self
):
for
step_index
in
range
(
self
.
model
.
scheduler
.
infer_steps
):
total_steps
=
self
.
model
.
scheduler
.
infer_steps
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
self
.
model
.
scheduler
.
infer_steps
}
"
)
for
step_index
in
range
(
total_steps
):
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
total_steps
}
"
)
with
ProfilingContext4Debug
(
"step_pre"
):
with
ProfilingContext4Debug
(
"step_pre"
):
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
...
@@ -123,11 +108,14 @@ class DefaultRunner:
...
@@ -123,11 +108,14 @@ class DefaultRunner:
with
ProfilingContext4Debug
(
"step_post"
):
with
ProfilingContext4Debug
(
"step_post"
):
self
.
model
.
scheduler
.
step_post
()
self
.
model
.
scheduler
.
step_post
()
if
self
.
progress_callback
:
self
.
progress_callback
(
step_index
+
1
,
total_steps
)
return
self
.
model
.
scheduler
.
latents
,
self
.
model
.
scheduler
.
generator
return
self
.
model
.
scheduler
.
latents
,
self
.
model
.
scheduler
.
generator
async
def
run_step
(
self
,
step_index
=
0
):
def
run_step
(
self
,
step_index
=
0
):
self
.
init_scheduler
()
self
.
init_scheduler
()
await
self
.
run_input_encoder
()
self
.
inputs
=
self
.
run_input_encoder
()
self
.
model
.
scheduler
.
prepare
(
self
.
inputs
[
"image_encoder_output"
])
self
.
model
.
scheduler
.
prepare
(
self
.
inputs
[
"image_encoder_output"
])
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
self
.
model
.
infer
(
self
.
inputs
)
self
.
model
.
infer
(
self
.
inputs
)
...
@@ -136,14 +124,19 @@ class DefaultRunner:
...
@@ -136,14 +124,19 @@ class DefaultRunner:
def
end_run
(
self
):
def
end_run
(
self
):
self
.
model
.
scheduler
.
clear
()
self
.
model
.
scheduler
.
clear
()
del
self
.
inputs
,
self
.
model
.
scheduler
del
self
.
inputs
,
self
.
model
.
scheduler
if
self
.
config
.
get
(
"lazy_load"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
model
.
transformer_infer
.
weights_stream_mgr
.
clear
()
if
hasattr
(
self
.
model
.
transformer_infer
,
"weights_stream_mgr"
):
self
.
model
.
transformer_infer
.
weights_stream_mgr
.
clear
()
if
hasattr
(
self
.
model
.
transformer_weights
,
"clear"
):
self
.
model
.
transformer_weights
.
clear
()
self
.
model
.
pre_weight
.
clear
()
self
.
model
.
post_weight
.
clear
()
del
self
.
model
del
self
.
model
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
@
ProfilingContext
(
"Run Encoders"
)
@
ProfilingContext
(
"Run Encoders"
)
async
def
run_input_encoder_local_i2v
(
self
):
def
_
run_input_encoder_local_i2v
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
img
=
Image
.
open
(
self
.
config
[
"image_path"
]).
convert
(
"RGB"
)
img
=
Image
.
open
(
self
.
config
[
"image_path"
]).
convert
(
"RGB"
)
clip_encoder_out
=
self
.
run_image_encoder
(
img
)
clip_encoder_out
=
self
.
run_image_encoder
(
img
)
...
@@ -154,16 +147,19 @@ class DefaultRunner:
...
@@ -154,16 +147,19 @@ class DefaultRunner:
return
self
.
get_encoder_output_i2v
(
clip_encoder_out
,
vae_encode_out
,
text_encoder_output
,
img
)
return
self
.
get_encoder_output_i2v
(
clip_encoder_out
,
vae_encode_out
,
text_encoder_output
,
img
)
@
ProfilingContext
(
"Run Encoders"
)
@
ProfilingContext
(
"Run Encoders"
)
async
def
run_input_encoder_local_t2v
(
self
):
def
_
run_input_encoder_local_t2v
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
None
)
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
None
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
return
{
"text_encoder_output"
:
text_encoder_output
,
"image_encoder_output"
:
None
}
return
{
"text_encoder_output"
:
text_encoder_output
,
"image_encoder_output"
:
None
,
}
@
ProfilingContext
(
"Run DiT"
)
@
ProfilingContext
(
"Run DiT"
)
async
def
run_dit_local
(
self
,
kwargs
):
def
_
run_dit_local
(
self
,
kwargs
):
if
self
.
config
.
get
(
"lazy_load"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
model
=
self
.
load_transformer
()
self
.
model
=
self
.
load_transformer
()
self
.
init_scheduler
()
self
.
init_scheduler
()
self
.
model
.
scheduler
.
prepare
(
self
.
inputs
[
"image_encoder_output"
])
self
.
model
.
scheduler
.
prepare
(
self
.
inputs
[
"image_encoder_output"
])
...
@@ -172,11 +168,11 @@ class DefaultRunner:
...
@@ -172,11 +168,11 @@ class DefaultRunner:
return
latents
,
generator
return
latents
,
generator
@
ProfilingContext
(
"Run VAE Decoder"
)
@
ProfilingContext
(
"Run VAE Decoder"
)
async
def
run_vae_decoder_local
(
self
,
latents
,
generator
):
def
_
run_vae_decoder_local
(
self
,
latents
,
generator
):
if
self
.
config
.
get
(
"lazy_load"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
vae_decoder
=
self
.
load_vae_decoder
()
self
.
vae_decoder
=
self
.
load_vae_decoder
()
images
=
self
.
vae_decoder
.
decode
(
latents
,
generator
=
generator
,
config
=
self
.
config
)
images
=
self
.
vae_decoder
.
decode
(
latents
,
generator
=
generator
,
config
=
self
.
config
)
if
self
.
config
.
get
(
"lazy_load"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
vae_decoder
del
self
.
vae_decoder
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
...
@@ -187,115 +183,39 @@ class DefaultRunner:
...
@@ -187,115 +183,39 @@ class DefaultRunner:
if
not
self
.
config
.
parallel_attn_type
or
(
self
.
config
.
parallel_attn_type
and
dist
.
get_rank
()
==
0
):
if
not
self
.
config
.
parallel_attn_type
or
(
self
.
config
.
parallel_attn_type
and
dist
.
get_rank
()
==
0
):
self
.
save_video_func
(
images
)
self
.
save_video_func
(
images
)
async
def
post_task
(
self
,
task_type
,
urls
,
message
,
device
=
"cuda"
):
while
True
:
for
url
in
urls
:
async
with
aiohttp
.
ClientSession
()
as
session
:
async
with
session
.
get
(
f
"
{
url
}
/v1/local/
{
task_type
}
/generate/service_status"
)
as
response
:
status
=
await
response
.
json
()
if
status
[
"service_status"
]
==
"idle"
:
async
with
session
.
post
(
f
"
{
url
}
/v1/local/
{
task_type
}
/generate"
,
json
=
message
)
as
response
:
result
=
await
response
.
json
()
if
result
[
"kwargs"
]
is
not
None
:
for
k
,
v
in
result
[
"kwargs"
].
items
():
setattr
(
self
.
config
,
k
,
v
)
return
self
.
tensor_transporter
.
load_tensor
(
result
[
"output"
],
device
)
await
asyncio
.
sleep
(
0.1
)
def
post_prompt_enhancer
(
self
):
def
post_prompt_enhancer
(
self
):
while
True
:
while
True
:
for
url
in
self
.
config
[
"sub_servers"
][
"prompt_enhancer"
]:
for
url
in
self
.
config
[
"sub_servers"
][
"prompt_enhancer"
]:
response
=
requests
.
get
(
f
"
{
url
}
/v1/local/prompt_enhancer/generate/service_status"
).
json
()
response
=
requests
.
get
(
f
"
{
url
}
/v1/local/prompt_enhancer/generate/service_status"
).
json
()
if
response
[
"service_status"
]
==
"idle"
:
if
response
[
"service_status"
]
==
"idle"
:
response
=
requests
.
post
(
f
"
{
url
}
/v1/local/prompt_enhancer/generate"
,
json
=
{
"task_id"
:
generate_task_id
(),
"prompt"
:
self
.
config
[
"prompt"
]})
response
=
requests
.
post
(
f
"
{
url
}
/v1/local/prompt_enhancer/generate"
,
json
=
{
"task_id"
:
generate_task_id
(),
"prompt"
:
self
.
config
[
"prompt"
],
},
)
enhanced_prompt
=
response
.
json
()[
"output"
]
enhanced_prompt
=
response
.
json
()[
"output"
]
logger
.
info
(
f
"Enhanced prompt:
{
enhanced_prompt
}
"
)
logger
.
info
(
f
"Enhanced prompt:
{
enhanced_prompt
}
"
)
return
enhanced_prompt
return
enhanced_prompt
async
def
post_encoders_i2v
(
self
,
prompt
,
img
=
None
,
n_prompt
=
None
,
i2v
=
False
):
def
run_pipeline
(
self
,
save_video
=
True
):
tasks
=
[]
img_byte
=
self
.
image_transporter
.
prepare_image
(
img
)
tasks
.
append
(
asyncio
.
create_task
(
self
.
post_task
(
task_type
=
"image_encoder"
,
urls
=
self
.
config
[
"sub_servers"
][
"image_encoder"
],
message
=
{
"task_id"
:
generate_task_id
(),
"img"
:
img_byte
},
device
=
"cuda"
))
)
tasks
.
append
(
asyncio
.
create_task
(
self
.
post_task
(
task_type
=
"vae_model/encoder"
,
urls
=
self
.
config
[
"sub_servers"
][
"vae_model"
],
message
=
{
"task_id"
:
generate_task_id
(),
"img"
:
img_byte
},
device
=
"cuda"
))
)
tasks
.
append
(
asyncio
.
create_task
(
self
.
post_task
(
task_type
=
"text_encoders"
,
urls
=
self
.
config
[
"sub_servers"
][
"text_encoders"
],
message
=
{
"task_id"
:
generate_task_id
(),
"text"
:
prompt
,
"img"
:
img_byte
,
"n_prompt"
:
n_prompt
},
device
=
"cuda"
,
)
)
)
results
=
await
asyncio
.
gather
(
*
tasks
)
# clip_encoder, vae_encoder, text_encoders
return
results
[
0
],
results
[
1
],
results
[
2
]
async
def
post_encoders_t2v
(
self
,
prompt
,
n_prompt
=
None
):
tasks
=
[]
tasks
.
append
(
asyncio
.
create_task
(
self
.
post_task
(
task_type
=
"text_encoders"
,
urls
=
self
.
config
[
"sub_servers"
][
"text_encoders"
],
message
=
{
"task_id"
:
generate_task_id
(),
"text"
:
prompt
,
"img"
:
None
,
"n_prompt"
:
n_prompt
},
device
=
"cuda"
,
)
)
)
results
=
await
asyncio
.
gather
(
*
tasks
)
# text_encoders
return
results
[
0
]
async
def
run_input_encoder_server_i2v
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
n_prompt
=
self
.
config
.
get
(
"negative_prompt"
,
""
)
img
=
Image
.
open
(
self
.
config
[
"image_path"
]).
convert
(
"RGB"
)
clip_encoder_out
,
vae_encode_out
,
text_encoder_output
=
await
self
.
post_encoders_i2v
(
prompt
,
img
,
n_prompt
)
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
return
self
.
get_encoder_output_i2v
(
clip_encoder_out
,
vae_encode_out
,
text_encoder_output
,
img
)
async
def
run_input_encoder_server_t2v
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
n_prompt
=
self
.
config
.
get
(
"negative_prompt"
,
""
)
text_encoder_output
=
await
self
.
post_encoders_t2v
(
prompt
,
n_prompt
)
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
return
{
"text_encoder_output"
:
text_encoder_output
,
"image_encoder_output"
:
None
}
async
def
run_dit_server
(
self
,
kwargs
):
if
self
.
inputs
.
get
(
"image_encoder_output"
,
None
)
is
not
None
:
self
.
inputs
[
"image_encoder_output"
].
pop
(
"img"
,
None
)
dit_output
=
await
self
.
post_task
(
task_type
=
"dit"
,
urls
=
self
.
config
[
"sub_servers"
][
"dit"
],
message
=
{
"task_id"
:
generate_task_id
(),
"inputs"
:
self
.
tensor_transporter
.
prepare_tensor
(
self
.
inputs
),
"kwargs"
:
self
.
tensor_transporter
.
prepare_tensor
(
kwargs
)},
device
=
"cuda"
,
)
return
dit_output
,
None
async
def
run_vae_decoder_server
(
self
,
latents
,
generator
):
images
=
await
self
.
post_task
(
task_type
=
"vae_model/decoder"
,
urls
=
self
.
config
[
"sub_servers"
][
"vae_model"
],
message
=
{
"task_id"
:
generate_task_id
(),
"latents"
:
self
.
tensor_transporter
.
prepare_tensor
(
latents
)},
device
=
"cpu"
,
)
return
images
async
def
run_pipeline
(
self
):
if
self
.
config
[
"use_prompt_enhancer"
]:
if
self
.
config
[
"use_prompt_enhancer"
]:
self
.
config
[
"prompt_enhanced"
]
=
self
.
post_prompt_enhancer
()
self
.
config
[
"prompt_enhanced"
]
=
self
.
post_prompt_enhancer
()
self
.
inputs
=
await
self
.
run_input_encoder
()
self
.
inputs
=
self
.
run_input_encoder
()
kwargs
=
self
.
set_target_shape
()
kwargs
=
self
.
set_target_shape
()
latents
,
generator
=
await
self
.
run_dit
(
kwargs
)
images
=
await
self
.
run_vae_decoder
(
latents
,
generator
)
latents
,
generator
=
self
.
run_dit
(
kwargs
)
self
.
save_video
(
images
)
del
latents
,
generator
,
images
images
=
self
.
run_vae_decoder
(
latents
,
generator
)
if
save_video
:
self
.
save_video
(
images
)
del
latents
,
generator
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
return
images
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
e08c4f90
...
@@ -329,12 +329,15 @@ class WanAudioRunner(WanRunner):
...
@@ -329,12 +329,15 @@ class WanAudioRunner(WanRunner):
def
load_transformer
(
self
):
def
load_transformer
(
self
):
base_model
=
WanAudioModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
)
base_model
=
WanAudioModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
)
if
self
.
config
.
lora_path
:
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
.
lora_configs
:
assert
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
or
self
.
config
.
mm_config
.
get
(
"weight_auto_quant"
,
False
)
assert
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
or
self
.
config
.
mm_config
.
get
(
"weight_auto_quant"
,
False
)
lora_wrapper
=
WanLoraWrapper
(
base_model
)
lora_wrapper
=
WanLoraWrapper
(
base_model
)
lora_name
=
lora_wrapper
.
load_lora
(
self
.
config
.
lora_path
)
for
lora_config
in
self
.
config
.
lora_configs
:
lora_wrapper
.
apply_lora
(
lora_name
,
self
.
config
.
strength_model
)
lora_path
=
lora_config
[
"path"
]
logger
.
info
(
f
"Loaded LoRA:
{
lora_name
}
"
)
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
lora_name
=
lora_wrapper
.
load_lora
(
lora_path
)
lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
logger
.
info
(
f
"Loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
return
base_model
return
base_model
...
...
lightx2v/models/runners/wan/wan_causvid_runner.py
View file @
e08c4f90
...
@@ -24,24 +24,26 @@ import torch.distributed as dist
...
@@ -24,24 +24,26 @@ import torch.distributed as dist
class
WanCausVidRunner
(
WanRunner
):
class
WanCausVidRunner
(
WanRunner
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
num_frame_per_block
=
self
.
model
.
config
.
num_frame_per_block
self
.
num_frame_per_block
=
self
.
config
.
num_frame_per_block
self
.
num_frames
=
self
.
model
.
config
.
num_frames
self
.
num_frames
=
self
.
config
.
num_frames
self
.
frame_seq_length
=
self
.
model
.
config
.
frame_seq_length
self
.
frame_seq_length
=
self
.
config
.
frame_seq_length
self
.
infer_blocks
=
self
.
model
.
config
.
num_blocks
self
.
infer_blocks
=
self
.
config
.
num_blocks
self
.
num_fragments
=
self
.
model
.
config
.
num_fragments
self
.
num_fragments
=
self
.
config
.
num_fragments
def
load_transformer
(
self
):
def
load_transformer
(
self
):
if
self
.
config
.
lora_path
:
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
.
lora_configs
:
model
=
WanModel
(
model
=
WanModel
(
self
.
config
.
model_path
,
self
.
config
.
model_path
,
self
.
config
,
self
.
config
,
self
.
init_device
,
self
.
init_device
,
)
)
lora_wrapper
=
WanLoraWrapper
(
model
)
lora_wrapper
=
WanLoraWrapper
(
model
)
for
lora_path
in
self
.
config
.
lora_path
:
for
lora_config
in
self
.
config
.
lora_configs
:
lora_path
=
lora_config
[
"path"
]
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
lora_name
=
lora_wrapper
.
load_lora
(
lora_path
)
lora_name
=
lora_wrapper
.
load_lora
(
lora_path
)
lora_wrapper
.
apply_lora
(
lora_name
,
self
.
config
.
strength
_model
)
lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
logger
.
info
(
f
"Loaded LoRA:
{
lora_name
}
"
)
logger
.
info
(
f
"Loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
else
:
else
:
model
=
WanCausVidModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
)
model
=
WanCausVidModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
)
return
model
return
model
...
...
lightx2v/models/runners/wan/wan_distill_runner.py
View file @
e08c4f90
...
@@ -24,17 +24,19 @@ class WanDistillRunner(WanRunner):
...
@@ -24,17 +24,19 @@ class WanDistillRunner(WanRunner):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
def
load_transformer
(
self
):
def
load_transformer
(
self
):
if
self
.
config
.
lora_path
:
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
.
lora_configs
:
model
=
WanModel
(
model
=
WanModel
(
self
.
config
.
model_path
,
self
.
config
.
model_path
,
self
.
config
,
self
.
config
,
self
.
init_device
,
self
.
init_device
,
)
)
lora_wrapper
=
WanLoraWrapper
(
model
)
lora_wrapper
=
WanLoraWrapper
(
model
)
for
lora_path
in
self
.
config
.
lora_path
:
for
lora_config
in
self
.
config
.
lora_configs
:
lora_path
=
lora_config
[
"path"
]
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
lora_name
=
lora_wrapper
.
load_lora
(
lora_path
)
lora_name
=
lora_wrapper
.
load_lora
(
lora_path
)
lora_wrapper
.
apply_lora
(
lora_name
,
self
.
config
.
strength
_model
)
lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
logger
.
info
(
f
"Loaded LoRA:
{
lora_name
}
"
)
logger
.
info
(
f
"Loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
else
:
else
:
model
=
WanDistillModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
)
model
=
WanDistillModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
)
return
model
return
model
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
e08c4f90
...
@@ -7,6 +7,9 @@ from PIL import Image
...
@@ -7,6 +7,9 @@ from PIL import Image
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.models.runners.default_runner
import
DefaultRunner
from
lightx2v.models.runners.default_runner
import
DefaultRunner
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v.models.schedulers.wan.changing_resolution.scheduler
import
(
WanScheduler4ChangingResolution
,
)
from
lightx2v.models.schedulers.wan.feature_caching.scheduler
import
(
from
lightx2v.models.schedulers.wan.feature_caching.scheduler
import
(
WanSchedulerTeaCaching
,
WanSchedulerTeaCaching
,
WanSchedulerTaylorCaching
,
WanSchedulerTaylorCaching
,
...
@@ -35,18 +38,36 @@ class WanRunner(DefaultRunner):
...
@@ -35,18 +38,36 @@ class WanRunner(DefaultRunner):
self
.
config
,
self
.
config
,
self
.
init_device
,
self
.
init_device
,
)
)
if
self
.
config
.
lora_path
:
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
.
lora_configs
:
assert
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
or
self
.
config
.
mm_config
.
get
(
"weight_auto_quant"
,
False
)
assert
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
or
self
.
config
.
mm_config
.
get
(
"weight_auto_quant"
,
False
)
lora_wrapper
=
WanLoraWrapper
(
model
)
lora_wrapper
=
WanLoraWrapper
(
model
)
for
lora_path
in
self
.
config
.
lora_path
:
for
lora_config
in
self
.
config
.
lora_configs
:
lora_path
=
lora_config
[
"path"
]
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
lora_name
=
lora_wrapper
.
load_lora
(
lora_path
)
lora_name
=
lora_wrapper
.
load_lora
(
lora_path
)
lora_wrapper
.
apply_lora
(
lora_name
,
self
.
config
.
strength
_model
)
lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
logger
.
info
(
f
"Loaded LoRA:
{
lora_name
}
"
)
logger
.
info
(
f
"Loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
return
model
return
model
def
load_image_encoder
(
self
):
def
load_image_encoder
(
self
):
image_encoder
=
None
image_encoder
=
None
if
self
.
config
.
task
==
"i2v"
:
if
self
.
config
.
task
==
"i2v"
:
# quant_config
clip_quantized
=
self
.
config
.
get
(
"clip_quantized"
,
False
)
if
clip_quantized
:
clip_quant_scheme
=
self
.
config
.
get
(
"clip_quant_scheme"
,
None
)
assert
clip_quant_scheme
is
not
None
clip_quantized_ckpt
=
self
.
config
.
get
(
"clip_quantized_ckpt"
,
os
.
path
.
join
(
os
.
path
.
join
(
self
.
config
.
model_path
,
clip_quant_scheme
),
f
"clip-
{
clip_quant_scheme
}
.pth"
,
),
)
else
:
clip_quantized_ckpt
=
None
clip_quant_scheme
=
None
image_encoder
=
CLIPModel
(
image_encoder
=
CLIPModel
(
dtype
=
torch
.
float16
,
dtype
=
torch
.
float16
,
device
=
self
.
init_device
,
device
=
self
.
init_device
,
...
@@ -54,25 +75,48 @@ class WanRunner(DefaultRunner):
...
@@ -54,25 +75,48 @@ class WanRunner(DefaultRunner):
self
.
config
.
model_path
,
self
.
config
.
model_path
,
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
,
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
,
),
),
clip_quantized
=
self
.
config
.
get
(
"
clip_quantized
"
,
False
)
,
clip_quantized
=
clip_quantized
,
clip_quantized_ckpt
=
self
.
config
.
get
(
"
clip_quantized_ckpt
"
,
None
)
,
clip_quantized_ckpt
=
clip_quantized_ckpt
,
quant_scheme
=
self
.
config
.
get
(
"
clip_quant_scheme
"
,
None
)
,
quant_scheme
=
clip_quant_scheme
,
)
)
return
image_encoder
return
image_encoder
def
load_text_encoder
(
self
):
def
load_text_encoder
(
self
):
# offload config
t5_offload
=
self
.
config
.
get
(
"t5_cpu_offload"
,
False
)
if
t5_offload
:
t5_device
=
torch
.
device
(
"cpu"
)
else
:
t5_device
=
torch
.
device
(
"cuda"
)
# quant_config
t5_quantized
=
self
.
config
.
get
(
"t5_quantized"
,
False
)
if
t5_quantized
:
t5_quant_scheme
=
self
.
config
.
get
(
"t5_quant_scheme"
,
None
)
assert
t5_quant_scheme
is
not
None
t5_quantized_ckpt
=
self
.
config
.
get
(
"t5_quantized_ckpt"
,
os
.
path
.
join
(
os
.
path
.
join
(
self
.
config
.
model_path
,
t5_quant_scheme
),
f
"models_t5_umt5-xxl-enc-
{
t5_quant_scheme
}
.pth"
,
),
)
else
:
t5_quant_scheme
=
None
t5_quantized_ckpt
=
None
text_encoder
=
T5EncoderModel
(
text_encoder
=
T5EncoderModel
(
text_len
=
self
.
config
[
"text_len"
],
text_len
=
self
.
config
[
"text_len"
],
dtype
=
torch
.
bfloat16
,
dtype
=
torch
.
bfloat16
,
device
=
self
.
ini
t_device
,
device
=
t
5
_device
,
checkpoint_path
=
os
.
path
.
join
(
self
.
config
.
model_path
,
"models_t5_umt5-xxl-enc-bf16.pth"
),
checkpoint_path
=
os
.
path
.
join
(
self
.
config
.
model_path
,
"models_t5_umt5-xxl-enc-bf16.pth"
),
tokenizer_path
=
os
.
path
.
join
(
self
.
config
.
model_path
,
"google/umt5-xxl"
),
tokenizer_path
=
os
.
path
.
join
(
self
.
config
.
model_path
,
"google/umt5-xxl"
),
shard_fn
=
None
,
shard_fn
=
None
,
cpu_offload
=
self
.
config
.
cpu
_offload
,
cpu_offload
=
t5
_offload
,
offload_granularity
=
self
.
config
.
get
(
"t5_offload_granularity"
,
"model"
),
offload_granularity
=
self
.
config
.
get
(
"t5_offload_granularity"
,
"model"
),
t5_quantized
=
self
.
config
.
get
(
"
t5_quantized
"
,
False
)
,
t5_quantized
=
t5_quantized
,
t5_quantized_ckpt
=
self
.
config
.
get
(
"
t5_quantized_ckpt
"
,
None
)
,
t5_quantized_ckpt
=
t5_quantized_ckpt
,
quant_scheme
=
self
.
config
.
get
(
"
t5_quant_scheme
"
,
None
)
,
quant_scheme
=
t5_quant_scheme
,
)
)
text_encoders
=
[
text_encoder
]
text_encoders
=
[
text_encoder
]
return
text_encoders
return
text_encoders
...
@@ -114,28 +158,31 @@ class WanRunner(DefaultRunner):
...
@@ -114,28 +158,31 @@ class WanRunner(DefaultRunner):
return
vae_encoder
,
vae_decoder
return
vae_encoder
,
vae_decoder
def
init_scheduler
(
self
):
def
init_scheduler
(
self
):
if
self
.
config
.
feature_caching
==
"NoCaching"
:
if
self
.
config
.
get
(
"changing_resolution"
,
False
):
scheduler
=
WanScheduler
(
self
.
config
)
scheduler
=
WanScheduler4ChangingResolution
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"Tea"
:
scheduler
=
WanSchedulerTeaCaching
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"TaylorSeer"
:
scheduler
=
WanSchedulerTaylorCaching
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"Ada"
:
scheduler
=
WanSchedulerAdaCaching
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"Custom"
:
scheduler
=
WanSchedulerCustomCaching
(
self
.
config
)
else
:
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
.
feature_caching
}
"
)
if
self
.
config
.
feature_caching
==
"NoCaching"
:
scheduler
=
WanScheduler
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"Tea"
:
scheduler
=
WanSchedulerTeaCaching
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"TaylorSeer"
:
scheduler
=
WanSchedulerTaylorCaching
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"Ada"
:
scheduler
=
WanSchedulerAdaCaching
(
self
.
config
)
elif
self
.
config
.
feature_caching
==
"Custom"
:
scheduler
=
WanSchedulerCustomCaching
(
self
.
config
)
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
.
feature_caching
}
"
)
self
.
model
.
set_scheduler
(
scheduler
)
self
.
model
.
set_scheduler
(
scheduler
)
def
run_text_encoder
(
self
,
text
,
img
):
def
run_text_encoder
(
self
,
text
,
img
):
if
self
.
config
.
get
(
"lazy_load"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
text_encoders
=
self
.
load_text_encoder
()
self
.
text_encoders
=
self
.
load_text_encoder
()
text_encoder_output
=
{}
text_encoder_output
=
{}
n_prompt
=
self
.
config
.
get
(
"negative_prompt"
,
""
)
n_prompt
=
self
.
config
.
get
(
"negative_prompt"
,
""
)
context
=
self
.
text_encoders
[
0
].
infer
([
text
])
context
=
self
.
text_encoders
[
0
].
infer
([
text
])
context_null
=
self
.
text_encoders
[
0
].
infer
([
n_prompt
if
n_prompt
else
""
])
context_null
=
self
.
text_encoders
[
0
].
infer
([
n_prompt
if
n_prompt
else
""
])
if
self
.
config
.
get
(
"lazy_load"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
text_encoders
[
0
]
del
self
.
text_encoders
[
0
]
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
...
@@ -144,11 +191,11 @@ class WanRunner(DefaultRunner):
...
@@ -144,11 +191,11 @@ class WanRunner(DefaultRunner):
return
text_encoder_output
return
text_encoder_output
def
run_image_encoder
(
self
,
img
):
def
run_image_encoder
(
self
,
img
):
if
self
.
config
.
get
(
"lazy_load"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
image_encoder
=
self
.
load_image_encoder
()
self
.
image_encoder
=
self
.
load_image_encoder
()
img
=
TF
.
to_tensor
(
img
).
sub_
(
0.5
).
div_
(
0.5
).
cuda
()
img
=
TF
.
to_tensor
(
img
).
sub_
(
0.5
).
div_
(
0.5
).
cuda
()
clip_encoder_out
=
self
.
image_encoder
.
visual
([
img
[:,
None
,
:,
:]],
self
.
config
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
clip_encoder_out
=
self
.
image_encoder
.
visual
([
img
[:,
None
,
:,
:]],
self
.
config
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
if
self
.
config
.
get
(
"lazy_load"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
image_encoder
del
self
.
image_encoder
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
...
@@ -179,7 +226,7 @@ class WanRunner(DefaultRunner):
...
@@ -179,7 +226,7 @@ class WanRunner(DefaultRunner):
msk
=
torch
.
concat
([
torch
.
repeat_interleave
(
msk
[:,
0
:
1
],
repeats
=
4
,
dim
=
1
),
msk
[:,
1
:]],
dim
=
1
)
msk
=
torch
.
concat
([
torch
.
repeat_interleave
(
msk
[:,
0
:
1
],
repeats
=
4
,
dim
=
1
),
msk
[:,
1
:]],
dim
=
1
)
msk
=
msk
.
view
(
1
,
msk
.
shape
[
1
]
//
4
,
4
,
lat_h
,
lat_w
)
msk
=
msk
.
view
(
1
,
msk
.
shape
[
1
]
//
4
,
4
,
lat_h
,
lat_w
)
msk
=
msk
.
transpose
(
1
,
2
)[
0
]
msk
=
msk
.
transpose
(
1
,
2
)[
0
]
if
self
.
config
.
get
(
"lazy_load"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
vae_encoder
=
self
.
load_vae_encoder
()
self
.
vae_encoder
=
self
.
load_vae_encoder
()
vae_encode_out
=
self
.
vae_encoder
.
encode
(
vae_encode_out
=
self
.
vae_encoder
.
encode
(
[
[
...
@@ -193,7 +240,7 @@ class WanRunner(DefaultRunner):
...
@@ -193,7 +240,7 @@ class WanRunner(DefaultRunner):
],
],
self
.
config
,
self
.
config
,
)[
0
]
)[
0
]
if
self
.
config
.
get
(
"lazy_load"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
vae_encoder
del
self
.
vae_encoder
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
...
...
lightx2v/models/schedulers/wan/changing_resolution/scheduler.py
0 → 100755
View file @
e08c4f90
import
torch
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
class
WanScheduler4ChangingResolution
(
WanScheduler
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
resolution_rate
=
config
.
get
(
"resolution_rate"
,
0.75
)
self
.
changing_resolution_steps
=
config
.
get
(
"changing_resolution_steps"
,
config
.
infer_steps
//
2
)
def
prepare_latents
(
self
,
target_shape
,
dtype
=
torch
.
float32
):
self
.
latents
=
torch
.
randn
(
target_shape
[
0
],
target_shape
[
1
],
int
(
target_shape
[
2
]
*
self
.
resolution_rate
)
//
2
*
2
,
int
(
target_shape
[
3
]
*
self
.
resolution_rate
)
//
2
*
2
,
dtype
=
dtype
,
device
=
self
.
device
,
generator
=
self
.
generator
,
)
self
.
noise_original_resolution
=
torch
.
randn
(
target_shape
[
0
],
target_shape
[
1
],
target_shape
[
2
],
target_shape
[
3
],
dtype
=
dtype
,
device
=
self
.
device
,
generator
=
self
.
generator
,
)
def
step_post
(
self
):
if
self
.
step_index
==
self
.
changing_resolution_steps
:
self
.
step_post_upsample
()
else
:
super
().
step_post
()
def
step_post_upsample
(
self
):
# 1. denoised sample to clean noise
model_output
=
self
.
noise_pred
.
to
(
torch
.
float32
)
sample
=
self
.
latents
.
to
(
torch
.
float32
)
sigma_t
=
self
.
sigmas
[
self
.
step_index
]
x0_pred
=
sample
-
sigma_t
*
model_output
denoised_sample
=
x0_pred
.
to
(
sample
.
dtype
)
# 2. upsample clean noise to target shape
denoised_sample_5d
=
denoised_sample
.
unsqueeze
(
0
)
# (C,T,H,W) -> (1,C,T,H,W)
clean_noise
=
torch
.
nn
.
functional
.
interpolate
(
denoised_sample_5d
,
size
=
(
self
.
config
.
target_shape
[
1
],
self
.
config
.
target_shape
[
2
],
self
.
config
.
target_shape
[
3
]),
mode
=
"trilinear"
)
clean_noise
=
clean_noise
.
squeeze
(
0
)
# (1,C,T,H,W) -> (C,T,H,W)
# 3. add noise to clean noise
noisy_sample
=
self
.
add_noise
(
clean_noise
,
self
.
noise_original_resolution
,
self
.
timesteps
[
self
.
step_index
+
1
])
# 4. update latents
self
.
latents
=
noisy_sample
# self.disable_corrector = [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37] # maybe not needed
# 5. update timesteps using shift + 2 更激进的去噪
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
+
2
)
def
add_noise
(
self
,
original_samples
,
noise
,
timesteps
):
sigma
=
self
.
sigmas
[
self
.
step_index
]
alpha_t
,
sigma_t
=
self
.
_sigma_to_alpha_sigma_t
(
sigma
)
noisy_samples
=
alpha_t
*
original_samples
+
sigma_t
*
noise
return
noisy_samples
lightx2v/models/schedulers/wan/step_distill/scheduler.py
View file @
e08c4f90
...
@@ -9,9 +9,13 @@ class WanStepDistillScheduler(WanScheduler):
...
@@ -9,9 +9,13 @@ class WanStepDistillScheduler(WanScheduler):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
denoising_step_list
=
config
.
denoising_step_list
self
.
denoising_step_list
=
config
.
denoising_step_list
self
.
infer_steps
=
self
.
config
.
infer_steps
self
.
infer_steps
=
len
(
self
.
denoising_step_list
)
self
.
sample_shift
=
self
.
config
.
sample_shift
self
.
sample_shift
=
self
.
config
.
sample_shift
self
.
num_train_timesteps
=
1000
self
.
sigma_max
=
1.0
self
.
sigma_min
=
0.0
def
prepare
(
self
,
image_encoder_output
):
def
prepare
(
self
,
image_encoder_output
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
)
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
)
self
.
generator
.
manual_seed
(
self
.
config
.
seed
)
self
.
generator
.
manual_seed
(
self
.
config
.
seed
)
...
@@ -23,46 +27,30 @@ class WanStepDistillScheduler(WanScheduler):
...
@@ -23,46 +27,30 @@ class WanStepDistillScheduler(WanScheduler):
elif
self
.
config
.
task
in
[
"i2v"
]:
elif
self
.
config
.
task
in
[
"i2v"
]:
self
.
seq_len
=
self
.
config
.
lat_h
*
self
.
config
.
lat_w
//
(
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
2
])
*
self
.
config
.
target_shape
[
1
]
self
.
seq_len
=
self
.
config
.
lat_h
*
self
.
config
.
lat_w
//
(
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
2
])
*
self
.
config
.
target_shape
[
1
]
alphas
=
np
.
linspace
(
1
,
1
/
self
.
num_train_timesteps
,
self
.
num_train_timesteps
)[::
-
1
].
copy
()
self
.
set_denoising_timesteps
(
device
=
self
.
device
)
sigmas
=
1.0
-
alphas
sigmas
=
torch
.
from_numpy
(
sigmas
).
to
(
dtype
=
torch
.
float32
)
sigmas
=
self
.
shift
*
sigmas
/
(
1
+
(
self
.
shift
-
1
)
*
sigmas
)
self
.
sigmas
=
sigmas
self
.
timesteps
=
sigmas
*
self
.
num_train_timesteps
self
.
model_outputs
=
[
None
]
*
self
.
solver_order
self
.
timestep_list
=
[
None
]
*
self
.
solver_order
self
.
last_sample
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
self
.
sigma_min
=
self
.
sigmas
[
-
1
].
item
()
self
.
sigma_max
=
self
.
sigmas
[
0
].
item
()
if
len
(
self
.
denoising_step_list
)
==
self
.
infer_steps
:
# 如果denoising_step_list有效既使用
self
.
set_denoising_timesteps
(
device
=
self
.
device
)
else
:
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
)
def
set_denoising_timesteps
(
self
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
def
set_denoising_timesteps
(
self
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
s
elf
.
timesteps
=
torch
.
tensor
(
self
.
denoising_step_list
,
device
=
device
,
dtype
=
torch
.
int64
)
s
igma_start
=
self
.
sigma_min
+
(
self
.
sigma_max
-
self
.
sigma_min
)
self
.
sigmas
=
torch
.
cat
([
self
.
timesteps
/
self
.
num_train_timesteps
,
torch
.
tensor
([
0.0
],
device
=
device
)])
self
.
sigmas
=
torch
.
linspace
(
sigma_start
,
self
.
sigma_min
,
self
.
num_train_timesteps
+
1
)[:
-
1
]
self
.
sigmas
=
self
.
s
igmas
.
to
(
"cpu"
)
self
.
sigmas
=
self
.
s
ample_shift
*
self
.
sigmas
/
(
1
+
(
self
.
sample_shift
-
1
)
*
self
.
sigmas
)
self
.
infer_
steps
=
len
(
self
.
timesteps
)
self
.
time
steps
=
self
.
sigmas
*
self
.
num_train_
timesteps
self
.
model_outputs
=
[
self
.
denoising_step_index
=
[
self
.
num_train_timesteps
-
x
for
x
in
self
.
denoising_step_list
]
None
,
self
.
timesteps
=
self
.
timesteps
[
self
.
denoising_step_index
].
to
(
device
)
]
*
self
.
solver_order
self
.
sigmas
=
self
.
sigmas
[
self
.
denoising_step_index
].
to
(
"cpu"
)
self
.
lower_order_nums
=
0
self
.
last_sample
=
None
self
.
_begin_index
=
None
def
reset
(
self
):
def
reset
(
self
):
self
.
model_outputs
=
[
None
]
*
self
.
solver_order
self
.
timestep_list
=
[
None
]
*
self
.
solver_order
self
.
last_sample
=
None
self
.
noise_pred
=
None
self
.
this_order
=
None
self
.
lower_order_nums
=
0
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
def
add_noise
(
self
,
original_samples
,
noise
,
sigma
):
sample
=
(
1
-
sigma
)
*
original_samples
+
sigma
*
noise
return
sample
.
type_as
(
noise
)
def
step_post
(
self
):
flow_pred
=
self
.
noise_pred
.
to
(
torch
.
float32
)
sigma
=
self
.
sigmas
[
self
.
step_index
].
item
()
noisy_image_or_video
=
self
.
latents
.
to
(
torch
.
float32
)
-
sigma
*
flow_pred
if
self
.
step_index
<
self
.
infer_steps
-
1
:
sigma
=
self
.
sigmas
[
self
.
step_index
+
1
].
item
()
noisy_image_or_video
=
self
.
add_noise
(
noisy_image_or_video
,
torch
.
randn_like
(
noisy_image_or_video
),
self
.
sigmas
[
self
.
step_index
+
1
].
item
())
self
.
latents
=
noisy_image_or_video
.
to
(
self
.
latents
.
dtype
)
lightx2v/models/video_encoders/hf/wan/vae.py
View file @
e08c4f90
...
@@ -868,7 +868,7 @@ class WanVAE:
...
@@ -868,7 +868,7 @@ class WanVAE:
"""
"""
videos: A list of videos each with shape [C, T, H, W].
videos: A list of videos each with shape [C, T, H, W].
"""
"""
if
args
.
cpu_offload
:
if
hasattr
(
args
,
"cpu_offload"
)
and
args
.
cpu_offload
:
self
.
to_cuda
()
self
.
to_cuda
()
if
self
.
use_tiling
:
if
self
.
use_tiling
:
...
@@ -876,7 +876,7 @@ class WanVAE:
...
@@ -876,7 +876,7 @@ class WanVAE:
else
:
else
:
out
=
[
self
.
model
.
encode
(
u
.
unsqueeze
(
0
),
self
.
scale
).
float
().
squeeze
(
0
)
for
u
in
videos
]
out
=
[
self
.
model
.
encode
(
u
.
unsqueeze
(
0
),
self
.
scale
).
float
().
squeeze
(
0
)
for
u
in
videos
]
if
args
.
cpu_offload
:
if
hasattr
(
args
,
"cpu_offload"
)
and
args
.
cpu_offload
:
self
.
to_cpu
()
self
.
to_cpu
()
return
out
return
out
...
...
lightx2v/server/service.py
View file @
e08c4f90
...
@@ -90,7 +90,6 @@ def _distributed_inference_worker(rank, world_size, master_addr, master_port, ar
...
@@ -90,7 +90,6 @@ def _distributed_inference_worker(rank, world_size, master_addr, master_port, ar
# Initialize configuration and model
# Initialize configuration and model
config
=
set_config
(
args
)
config
=
set_config
(
args
)
config
[
"mode"
]
=
"server"
logger
.
info
(
f
"Rank
{
rank
}
config:
{
config
}
"
)
logger
.
info
(
f
"Rank
{
rank
}
config:
{
config
}
"
)
runner
=
init_runner
(
config
)
runner
=
init_runner
(
config
)
...
@@ -186,6 +185,12 @@ class DistributedInferenceService:
...
@@ -186,6 +185,12 @@ class DistributedInferenceService:
self
.
is_running
=
False
self
.
is_running
=
False
def
start_distributed_inference
(
self
,
args
)
->
bool
:
def
start_distributed_inference
(
self
,
args
)
->
bool
:
if
hasattr
(
args
,
"lora_path"
)
and
args
.
lora_path
:
args
.
lora_configs
=
[{
"path"
:
args
.
lora_path
,
"strength"
:
getattr
(
args
,
"lora_strength"
,
1.0
)}]
delattr
(
args
,
"lora_path"
)
if
hasattr
(
args
,
"lora_strength"
):
delattr
(
args
,
"lora_strength"
)
self
.
args
=
args
self
.
args
=
args
if
self
.
is_running
:
if
self
.
is_running
:
logger
.
warning
(
"Distributed inference service is already running"
)
logger
.
warning
(
"Distributed inference service is already running"
)
...
...
lightx2v/utils/async_io.py
0 → 100644
View file @
e08c4f90
import
aiofiles
import
asyncio
from
PIL
import
Image
import
io
from
typing
import
Union
from
pathlib
import
Path
from
loguru
import
logger
async
def
load_image_async
(
path
:
Union
[
str
,
Path
])
->
Image
.
Image
:
try
:
async
with
aiofiles
.
open
(
path
,
"rb"
)
as
f
:
data
=
await
f
.
read
()
return
await
asyncio
.
to_thread
(
lambda
:
Image
.
open
(
io
.
BytesIO
(
data
)).
convert
(
"RGB"
))
except
Exception
as
e
:
logger
.
error
(
f
"Failed to load image from
{
path
}
:
{
e
}
"
)
raise
async
def
save_video_async
(
video_path
:
Union
[
str
,
Path
],
video_data
:
bytes
):
try
:
video_path
=
Path
(
video_path
)
video_path
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
async
with
aiofiles
.
open
(
video_path
,
"wb"
)
as
f
:
await
f
.
write
(
video_data
)
logger
.
info
(
f
"Video saved to
{
video_path
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to save video to
{
video_path
}
:
{
e
}
"
)
raise
async
def
read_text_async
(
path
:
Union
[
str
,
Path
],
encoding
:
str
=
"utf-8"
)
->
str
:
try
:
async
with
aiofiles
.
open
(
path
,
"r"
,
encoding
=
encoding
)
as
f
:
return
await
f
.
read
()
except
Exception
as
e
:
logger
.
error
(
f
"Failed to read text from
{
path
}
:
{
e
}
"
)
raise
async
def
write_text_async
(
path
:
Union
[
str
,
Path
],
content
:
str
,
encoding
:
str
=
"utf-8"
):
try
:
path
=
Path
(
path
)
path
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
async
with
aiofiles
.
open
(
path
,
"w"
,
encoding
=
encoding
)
as
f
:
await
f
.
write
(
content
)
logger
.
info
(
f
"Text written to
{
path
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to write text to
{
path
}
:
{
e
}
"
)
raise
async
def
exists_async
(
path
:
Union
[
str
,
Path
])
->
bool
:
return
await
asyncio
.
to_thread
(
lambda
:
Path
(
path
).
exists
())
async
def
read_bytes_async
(
path
:
Union
[
str
,
Path
])
->
bytes
:
try
:
async
with
aiofiles
.
open
(
path
,
"rb"
)
as
f
:
return
await
f
.
read
()
except
Exception
as
e
:
logger
.
error
(
f
"Failed to read bytes from
{
path
}
:
{
e
}
"
)
raise
async
def
write_bytes_async
(
path
:
Union
[
str
,
Path
],
data
:
bytes
):
try
:
path
=
Path
(
path
)
path
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
async
with
aiofiles
.
open
(
path
,
"wb"
)
as
f
:
await
f
.
write
(
data
)
logger
.
debug
(
f
"Bytes written to
{
path
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to write bytes to
{
path
}
:
{
e
}
"
)
raise
lightx2v/utils/quant_utils.py
View file @
e08c4f90
import
torch
import
torch
from
qtorch.quant
import
float_quantize
from
loguru
import
logger
from
loguru
import
logger
try
:
from
qtorch.quant
import
float_quantize
except
Exception
:
logger
.
warning
(
"qtorch not found, please install qtorch.Please install qtorch (pip install qtorch)."
)
float_quantize
=
None
class
BaseQuantizer
(
object
):
class
BaseQuantizer
(
object
):
def
__init__
(
self
,
bit
,
symmetric
,
granularity
,
**
kwargs
):
def
__init__
(
self
,
bit
,
symmetric
,
granularity
,
**
kwargs
):
...
...
lightx2v/utils/set_config.py
View file @
e08c4f90
...
@@ -17,8 +17,7 @@ def get_default_config():
...
@@ -17,8 +17,7 @@ def get_default_config():
"teacache_thresh"
:
0.26
,
"teacache_thresh"
:
0.26
,
"use_ret_steps"
:
False
,
"use_ret_steps"
:
False
,
"use_bfloat16"
:
True
,
"use_bfloat16"
:
True
,
"lora_path"
:
None
,
"lora_configs"
:
None
,
# List of dicts with 'path' and 'strength' keys
"strength_model"
:
1.0
,
"mm_config"
:
{},
"mm_config"
:
{},
"use_prompt_enhancer"
:
False
,
"use_prompt_enhancer"
:
False
,
}
}
...
...
lightx2v/utils/utils.py
View file @
e08c4f90
...
@@ -58,6 +58,14 @@ def cache_video(
...
@@ -58,6 +58,14 @@ def cache_video(
value_range
=
(
-
1
,
1
),
value_range
=
(
-
1
,
1
),
retry
=
5
,
retry
=
5
,
):
):
save_dir
=
os
.
path
.
dirname
(
save_file
)
try
:
if
not
os
.
path
.
exists
(
save_dir
):
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to create directory:
{
save_dir
}
, error:
{
e
}
"
)
return
None
cache_file
=
save_file
cache_file
=
save_file
# save to cache
# save to cache
...
...
lightx2v_kernel/CMakeLists.txt
View file @
e08c4f90
...
@@ -94,6 +94,7 @@ set(SOURCES
...
@@ -94,6 +94,7 @@ set(SOURCES
"csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu"
"csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu"
"csrc/gemm/nvfp4_quant_kernels_sm120.cu"
"csrc/gemm/nvfp4_quant_kernels_sm120.cu"
"csrc/gemm/mxfp8_quant_kernels_sm120.cu"
"csrc/gemm/mxfp8_quant_kernels_sm120.cu"
"csrc/gemm/mxfp6_quant_kernels_sm120.cu"
"csrc/gemm/mxfp6_mxfp8_scaled_mm_kernels_sm120.cu"
"csrc/gemm/mxfp6_mxfp8_scaled_mm_kernels_sm120.cu"
"csrc/gemm/mxfp8_scaled_mm_kernels_sm120.cu"
"csrc/gemm/mxfp8_scaled_mm_kernels_sm120.cu"
"csrc/common_extension.cc"
"csrc/common_extension.cc"
...
...
Prev
1
…
3
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment