Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
92539ed8
Commit
92539ed8
authored
Jul 12, 2025
by
gushiqiao
Browse files
Update gradio and offload
parent
8e941d39
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
93 additions
and
22 deletions
+93
-22
lightx2v/models/input_encoders/hf/q_linear.py
lightx2v/models/input_encoders/hf/q_linear.py
+0
-2
lightx2v/models/input_encoders/hf/t5/model.py
lightx2v/models/input_encoders/hf/t5/model.py
+64
-6
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+5
-0
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+11
-6
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+13
-8
No files found.
lightx2v/models/input_encoders/hf/q_linear.py
View file @
92539ed8
...
...
@@ -59,8 +59,6 @@ class QuantLinearFp8(nn.Module):
super
().
__init__
()
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
register_buffer
(
"weight"
,
torch
.
empty
((
out_features
,
in_features
),
dtype
=
torch
.
float8_e4m3fn
))
self
.
register_buffer
(
"weight_scale"
,
torch
.
empty
((
out_features
,
1
),
dtype
=
torch
.
float32
))
if
bias
:
self
.
register_buffer
(
"bias"
,
torch
.
empty
(
out_features
,
dtype
=
dtype
))
...
...
lightx2v/models/input_encoders/hf/t5/model.py
View file @
92539ed8
...
...
@@ -3,6 +3,8 @@
import
logging
import
math
import
os
from
six
import
b
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
...
...
@@ -27,6 +29,14 @@ def fp16_clamp(x):
return
x
def
optimize_memory_usage
():
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
empty_cache
()
import
gc
gc
.
collect
()
def
init_weights
(
m
):
if
isinstance
(
m
,
T5LayerNorm
):
nn
.
init
.
ones_
(
m
.
weight
)
...
...
@@ -114,10 +124,14 @@ class T5Attention(nn.Module):
# compute attention (T5 does not use scaling)
attn
=
torch
.
einsum
(
"binc,bjnc->bnij"
,
q
,
k
)
+
attn_bias
if
hasattr
(
self
,
"cpu_offload"
)
and
self
.
cpu_offload
:
del
attn_bias
attn
=
F
.
softmax
(
attn
.
float
(),
dim
=-
1
).
to
(
torch
.
bfloat16
)
x
=
torch
.
einsum
(
"bnij,bjnc->binc"
,
attn
,
v
)
# output
if
hasattr
(
self
,
"cpu_offload"
)
and
self
.
cpu_offload
:
del
attn
x
=
x
.
reshape
(
b
,
-
1
,
n
*
c
)
x
=
self
.
o
(
x
)
x
=
self
.
dropout
(
x
)
...
...
@@ -144,7 +158,14 @@ class T5FeedForward(nn.Module):
self
.
dropout
=
nn
.
Dropout
(
dropout
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
*
self
.
gate
(
x
)
if
hasattr
(
self
,
"cpu_offload"
)
and
self
.
cpu_offload
:
gate_out
=
self
.
gate
(
x
)
fc1_out
=
self
.
fc1
(
x
)
x
=
fc1_out
*
gate_out
del
gate_out
,
fc1_out
else
:
x
=
self
.
fc1
(
x
)
*
self
.
gate
(
x
)
x
=
self
.
dropout
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
dropout
(
x
)
...
...
@@ -170,8 +191,19 @@ class T5SelfAttention(nn.Module):
def
forward
(
self
,
x
,
mask
=
None
,
pos_bias
=
None
):
e
=
pos_bias
if
self
.
shared_pos
else
self
.
pos_embedding
(
x
.
size
(
1
),
x
.
size
(
1
))
x
=
fp16_clamp
(
x
+
self
.
attn
(
self
.
norm1
(
x
),
mask
=
mask
,
pos_bias
=
e
))
x
=
fp16_clamp
(
x
+
self
.
ffn
(
self
.
norm2
(
x
)))
if
hasattr
(
self
,
"cpu_offload"
)
and
self
.
cpu_offload
:
attn_out
=
self
.
attn
(
self
.
norm1
(
x
),
mask
=
mask
,
pos_bias
=
e
)
x
=
fp16_clamp
(
x
+
attn_out
)
del
attn_out
ffn_out
=
self
.
ffn
(
self
.
norm2
(
x
))
x
=
fp16_clamp
(
x
+
ffn_out
)
del
ffn_out
else
:
x
=
fp16_clamp
(
x
+
self
.
attn
(
self
.
norm1
(
x
),
mask
=
mask
,
pos_bias
=
e
))
x
=
fp16_clamp
(
x
+
self
.
ffn
(
self
.
norm2
(
x
)))
return
x
...
...
@@ -270,6 +302,12 @@ class T5Encoder(nn.Module):
self
.
pos_embedding
=
T5RelativeEmbedding
(
num_buckets
,
num_heads
,
bidirectional
=
True
,
dtype
=
dtype
)
if
shared_pos
else
None
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
blocks
=
nn
.
ModuleList
([
T5SelfAttention
(
dim
,
dim_attn
,
dim_ffn
,
num_heads
,
num_buckets
,
shared_pos
,
dropout
,
quantized
,
quant_scheme
,
dtype
)
for
_
in
range
(
num_layers
)])
if
cpu_offload
:
for
block
in
self
.
blocks
:
block
.
cpu_offload
=
cpu_offload
block
.
attn
.
cpu_offload
=
cpu_offload
block
.
ffn
.
cpu_offload
=
cpu_offload
self
.
norm
=
T5LayerNorm
(
dim
,
dtype
=
dtype
)
# initialize weights
...
...
@@ -281,23 +319,32 @@ class T5Encoder(nn.Module):
x
=
self
.
token_embedding
(
ids
)
if
self
.
cpu_offload
:
self
.
token_embedding
=
self
.
token_embedding
.
cpu
()
optimize_memory_usage
()
x
=
self
.
dropout
(
x
)
if
self
.
cpu_offload
and
self
.
pos_embedding
is
not
None
:
self
.
pos_embedding
=
self
.
pos_embedding
.
cuda
()
e
=
self
.
pos_embedding
(
x
.
size
(
1
),
x
.
size
(
1
))
if
self
.
shared_pos
else
None
if
self
.
cpu_offload
and
self
.
pos_embedding
is
not
None
:
self
.
pos_embedding
=
self
.
pos_embedding
.
cpu
()
for
block
in
self
.
blocks
:
optimize_memory_usage
()
for
i
,
block
in
enumerate
(
self
.
blocks
):
if
self
.
cpu_offload
:
block
=
block
.
cuda
()
x
=
block
(
x
,
mask
,
pos_bias
=
e
)
if
self
.
cpu_offload
:
block
=
block
.
cpu
()
del
block
optimize_memory_usage
()
if
self
.
cpu_offload
:
self
.
norm
=
self
.
norm
.
cuda
()
x
=
self
.
norm
(
x
)
if
self
.
cpu_offload
:
self
.
norm
=
self
.
norm
.
cpu
()
optimize_memory_usage
()
x
=
self
.
dropout
(
x
)
return
x
.
to
(
torch
.
bfloat16
)
...
...
@@ -529,6 +576,10 @@ class T5EncoderModel:
def
to_cuda
(
self
):
self
.
model
=
self
.
model
.
to
(
"cuda"
)
def
optimize_memory
(
self
):
"""优化内存使用"""
optimize_memory_usage
()
def
infer
(
self
,
texts
):
if
self
.
cpu_offload
and
self
.
offload_granularity
==
"model"
:
self
.
to_cuda
()
...
...
@@ -537,10 +588,17 @@ class T5EncoderModel:
ids
=
ids
.
cuda
()
mask
=
mask
.
cuda
()
seq_lens
=
mask
.
gt
(
0
).
sum
(
dim
=
1
).
long
()
context
=
self
.
model
(
ids
,
mask
)
with
torch
.
no_grad
():
context
=
self
.
model
(
ids
,
mask
)
if
self
.
cpu_offload
and
self
.
offload_granularity
==
"model"
:
self
.
to_cpu
()
optimize_memory_usage
()
del
ids
,
mask
if
self
.
cpu_offload
:
optimize_memory_usage
()
return
[
u
[:
v
]
for
u
,
v
in
zip
(
context
,
seq_lens
)]
...
...
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
92539ed8
...
...
@@ -24,6 +24,11 @@ class WanTransformerWeights(WeightModule):
self
.
blocks
=
WeightModuleList
([
WanTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
)
for
i
in
range
(
self
.
blocks_num
)])
self
.
add_module
(
"blocks"
,
self
.
blocks
)
def
clear
(
self
):
for
block
in
self
.
blocks
:
for
phase
in
block
.
compute_phases
:
phase
.
clear
()
class
WanTransformerAttentionBlock
(
WeightModule
):
def
__init__
(
self
,
block_index
,
task
,
mm_type
,
config
):
...
...
lightx2v/models/runners/default_runner.py
View file @
92539ed8
...
...
@@ -49,7 +49,7 @@ class DefaultRunner:
else
:
self
.
run_input_encoder
=
self
.
run_input_encoder_server_t2v
else
:
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
if
not
self
.
config
.
get
(
"lazy_load"
,
False
)
and
not
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
load_model
()
self
.
run_dit
=
self
.
run_dit_local
self
.
run_vae_decoder
=
self
.
run_vae_decoder_local
...
...
@@ -136,8 +136,13 @@ class DefaultRunner:
def
end_run
(
self
):
self
.
model
.
scheduler
.
clear
()
del
self
.
inputs
,
self
.
model
.
scheduler
if
self
.
config
.
get
(
"lazy_load"
,
False
):
self
.
model
.
transformer_infer
.
weights_stream_mgr
.
clear
()
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
hasattr
(
self
.
model
.
transformer_infer
,
"weights_stream_mgr"
):
self
.
model
.
transformer_infer
.
weights_stream_mgr
.
clear
()
if
hasattr
(
self
.
model
.
transformer_weights
,
"clear"
):
self
.
model
.
transformer_weights
.
clear
()
self
.
model
.
pre_weight
.
clear
()
self
.
model
.
post_weight
.
clear
()
del
self
.
model
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
...
...
@@ -163,7 +168,7 @@ class DefaultRunner:
@
ProfilingContext
(
"Run DiT"
)
async
def
run_dit_local
(
self
,
kwargs
):
if
self
.
config
.
get
(
"lazy_load"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
model
=
self
.
load_transformer
()
self
.
init_scheduler
()
self
.
model
.
scheduler
.
prepare
(
self
.
inputs
[
"image_encoder_output"
])
...
...
@@ -173,10 +178,10 @@ class DefaultRunner:
@
ProfilingContext
(
"Run VAE Decoder"
)
async
def
run_vae_decoder_local
(
self
,
latents
,
generator
):
if
self
.
config
.
get
(
"lazy_load"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
vae_decoder
=
self
.
load_vae_decoder
()
images
=
self
.
vae_decoder
.
decode
(
latents
,
generator
=
generator
,
config
=
self
.
config
)
if
self
.
config
.
get
(
"lazy_load"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
vae_decoder
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
92539ed8
...
...
@@ -61,14 +61,19 @@ class WanRunner(DefaultRunner):
return
image_encoder
def
load_text_encoder
(
self
):
t5_offload
=
self
.
config
.
get
(
"t5_cpu_offload"
,
False
)
if
t5_offload
:
t5_device
=
torch
.
device
(
"cpu"
)
else
:
t5_device
=
torch
.
device
(
"cuda"
)
text_encoder
=
T5EncoderModel
(
text_len
=
self
.
config
[
"text_len"
],
dtype
=
torch
.
bfloat16
,
device
=
self
.
ini
t_device
,
device
=
t
5
_device
,
checkpoint_path
=
os
.
path
.
join
(
self
.
config
.
model_path
,
"models_t5_umt5-xxl-enc-bf16.pth"
),
tokenizer_path
=
os
.
path
.
join
(
self
.
config
.
model_path
,
"google/umt5-xxl"
),
shard_fn
=
None
,
cpu_offload
=
self
.
config
.
cpu
_offload
,
cpu_offload
=
t5
_offload
,
offload_granularity
=
self
.
config
.
get
(
"t5_offload_granularity"
,
"model"
),
t5_quantized
=
self
.
config
.
get
(
"t5_quantized"
,
False
),
t5_quantized_ckpt
=
self
.
config
.
get
(
"t5_quantized_ckpt"
,
None
),
...
...
@@ -129,13 +134,13 @@ class WanRunner(DefaultRunner):
self
.
model
.
set_scheduler
(
scheduler
)
def
run_text_encoder
(
self
,
text
,
img
):
if
self
.
config
.
get
(
"lazy_load"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
text_encoders
=
self
.
load_text_encoder
()
text_encoder_output
=
{}
n_prompt
=
self
.
config
.
get
(
"negative_prompt"
,
""
)
context
=
self
.
text_encoders
[
0
].
infer
([
text
])
context_null
=
self
.
text_encoders
[
0
].
infer
([
n_prompt
if
n_prompt
else
""
])
if
self
.
config
.
get
(
"lazy_load"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
text_encoders
[
0
]
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
...
...
@@ -144,11 +149,11 @@ class WanRunner(DefaultRunner):
return
text_encoder_output
def
run_image_encoder
(
self
,
img
):
if
self
.
config
.
get
(
"lazy_load"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
image_encoder
=
self
.
load_image_encoder
()
img
=
TF
.
to_tensor
(
img
).
sub_
(
0.5
).
div_
(
0.5
).
cuda
()
clip_encoder_out
=
self
.
image_encoder
.
visual
([
img
[:,
None
,
:,
:]],
self
.
config
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
if
self
.
config
.
get
(
"lazy_load"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
image_encoder
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
...
...
@@ -179,7 +184,7 @@ class WanRunner(DefaultRunner):
msk
=
torch
.
concat
([
torch
.
repeat_interleave
(
msk
[:,
0
:
1
],
repeats
=
4
,
dim
=
1
),
msk
[:,
1
:]],
dim
=
1
)
msk
=
msk
.
view
(
1
,
msk
.
shape
[
1
]
//
4
,
4
,
lat_h
,
lat_w
)
msk
=
msk
.
transpose
(
1
,
2
)[
0
]
if
self
.
config
.
get
(
"lazy_load"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
vae_encoder
=
self
.
load_vae_encoder
()
vae_encode_out
=
self
.
vae_encoder
.
encode
(
[
...
...
@@ -193,7 +198,7 @@ class WanRunner(DefaultRunner):
],
self
.
config
,
)[
0
]
if
self
.
config
.
get
(
"lazy_load"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
vae_encoder
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment