Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
38c40fa0
Commit
38c40fa0
authored
Jul 12, 2025
by
gushiqiao
Committed by
GitHub
Jul 12, 2025
Browse files
Merge pull request #108 from ModelTC/dev_offload
Update gradio and offload
parents
9518ff04
6f9bbff6
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
96 additions
and
22 deletions
+96
-22
lightx2v/common/ops/tensor/tensor.py
lightx2v/common/ops/tensor/tensor.py
+5
-1
lightx2v/models/input_encoders/hf/q_linear.py
lightx2v/models/input_encoders/hf/q_linear.py
+0
-1
lightx2v/models/input_encoders/hf/t5/model.py
lightx2v/models/input_encoders/hf/t5/model.py
+62
-6
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+5
-0
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+11
-6
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+13
-8
No files found.
lightx2v/common/ops/tensor/tensor.py
View file @
38c40fa0
...
...
@@ -22,7 +22,11 @@ class DefaultTensor:
self
.
pinned_tensor
=
torch
.
empty
(
self
.
tensor
.
shape
,
pin_memory
=
True
,
dtype
=
self
.
tensor
.
dtype
)
def
clear
(
self
):
del
self
.
tensor
attrs
=
[
"tensor"
,
"pinned_tensor"
]
for
attr
in
attrs
:
if
hasattr
(
self
,
attr
):
delattr
(
self
,
attr
)
setattr
(
self
,
attr
,
None
)
def
_calculate_size
(
self
):
return
self
.
tensor
.
numel
()
*
self
.
tensor
.
element_size
()
...
...
lightx2v/models/input_encoders/hf/q_linear.py
View file @
38c40fa0
...
...
@@ -61,7 +61,6 @@ class QuantLinearFp8(nn.Module):
self
.
out_features
=
out_features
self
.
register_buffer
(
"weight"
,
torch
.
empty
((
out_features
,
in_features
),
dtype
=
torch
.
float8_e4m3fn
))
self
.
register_buffer
(
"weight_scale"
,
torch
.
empty
((
out_features
,
1
),
dtype
=
torch
.
float32
))
if
bias
:
self
.
register_buffer
(
"bias"
,
torch
.
empty
(
out_features
,
dtype
=
dtype
))
else
:
...
...
lightx2v/models/input_encoders/hf/t5/model.py
View file @
38c40fa0
...
...
@@ -27,6 +27,14 @@ def fp16_clamp(x):
return
x
def
optimize_memory_usage
():
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
empty_cache
()
import
gc
gc
.
collect
()
def
init_weights
(
m
):
if
isinstance
(
m
,
T5LayerNorm
):
nn
.
init
.
ones_
(
m
.
weight
)
...
...
@@ -114,10 +122,14 @@ class T5Attention(nn.Module):
# compute attention (T5 does not use scaling)
attn
=
torch
.
einsum
(
"binc,bjnc->bnij"
,
q
,
k
)
+
attn_bias
if
hasattr
(
self
,
"cpu_offload"
)
and
self
.
cpu_offload
:
del
attn_bias
attn
=
F
.
softmax
(
attn
.
float
(),
dim
=-
1
).
to
(
torch
.
bfloat16
)
x
=
torch
.
einsum
(
"bnij,bjnc->binc"
,
attn
,
v
)
# output
if
hasattr
(
self
,
"cpu_offload"
)
and
self
.
cpu_offload
:
del
attn
x
=
x
.
reshape
(
b
,
-
1
,
n
*
c
)
x
=
self
.
o
(
x
)
x
=
self
.
dropout
(
x
)
...
...
@@ -144,7 +156,14 @@ class T5FeedForward(nn.Module):
self
.
dropout
=
nn
.
Dropout
(
dropout
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
*
self
.
gate
(
x
)
if
hasattr
(
self
,
"cpu_offload"
)
and
self
.
cpu_offload
:
gate_out
=
self
.
gate
(
x
)
fc1_out
=
self
.
fc1
(
x
)
x
=
fc1_out
*
gate_out
del
gate_out
,
fc1_out
else
:
x
=
self
.
fc1
(
x
)
*
self
.
gate
(
x
)
x
=
self
.
dropout
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
dropout
(
x
)
...
...
@@ -170,8 +189,19 @@ class T5SelfAttention(nn.Module):
def
forward
(
self
,
x
,
mask
=
None
,
pos_bias
=
None
):
e
=
pos_bias
if
self
.
shared_pos
else
self
.
pos_embedding
(
x
.
size
(
1
),
x
.
size
(
1
))
x
=
fp16_clamp
(
x
+
self
.
attn
(
self
.
norm1
(
x
),
mask
=
mask
,
pos_bias
=
e
))
x
=
fp16_clamp
(
x
+
self
.
ffn
(
self
.
norm2
(
x
)))
if
hasattr
(
self
,
"cpu_offload"
)
and
self
.
cpu_offload
:
attn_out
=
self
.
attn
(
self
.
norm1
(
x
),
mask
=
mask
,
pos_bias
=
e
)
x
=
fp16_clamp
(
x
+
attn_out
)
del
attn_out
ffn_out
=
self
.
ffn
(
self
.
norm2
(
x
))
x
=
fp16_clamp
(
x
+
ffn_out
)
del
ffn_out
else
:
x
=
fp16_clamp
(
x
+
self
.
attn
(
self
.
norm1
(
x
),
mask
=
mask
,
pos_bias
=
e
))
x
=
fp16_clamp
(
x
+
self
.
ffn
(
self
.
norm2
(
x
)))
return
x
...
...
@@ -270,6 +300,12 @@ class T5Encoder(nn.Module):
self
.
pos_embedding
=
T5RelativeEmbedding
(
num_buckets
,
num_heads
,
bidirectional
=
True
,
dtype
=
dtype
)
if
shared_pos
else
None
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
blocks
=
nn
.
ModuleList
([
T5SelfAttention
(
dim
,
dim_attn
,
dim_ffn
,
num_heads
,
num_buckets
,
shared_pos
,
dropout
,
quantized
,
quant_scheme
,
dtype
)
for
_
in
range
(
num_layers
)])
if
cpu_offload
:
for
block
in
self
.
blocks
:
block
.
cpu_offload
=
cpu_offload
block
.
attn
.
cpu_offload
=
cpu_offload
block
.
ffn
.
cpu_offload
=
cpu_offload
self
.
norm
=
T5LayerNorm
(
dim
,
dtype
=
dtype
)
# initialize weights
...
...
@@ -281,23 +317,32 @@ class T5Encoder(nn.Module):
x
=
self
.
token_embedding
(
ids
)
if
self
.
cpu_offload
:
self
.
token_embedding
=
self
.
token_embedding
.
cpu
()
optimize_memory_usage
()
x
=
self
.
dropout
(
x
)
if
self
.
cpu_offload
and
self
.
pos_embedding
is
not
None
:
self
.
pos_embedding
=
self
.
pos_embedding
.
cuda
()
e
=
self
.
pos_embedding
(
x
.
size
(
1
),
x
.
size
(
1
))
if
self
.
shared_pos
else
None
if
self
.
cpu_offload
and
self
.
pos_embedding
is
not
None
:
self
.
pos_embedding
=
self
.
pos_embedding
.
cpu
()
for
block
in
self
.
blocks
:
optimize_memory_usage
()
for
i
,
block
in
enumerate
(
self
.
blocks
):
if
self
.
cpu_offload
:
block
=
block
.
cuda
()
x
=
block
(
x
,
mask
,
pos_bias
=
e
)
if
self
.
cpu_offload
:
block
=
block
.
cpu
()
del
block
optimize_memory_usage
()
if
self
.
cpu_offload
:
self
.
norm
=
self
.
norm
.
cuda
()
x
=
self
.
norm
(
x
)
if
self
.
cpu_offload
:
self
.
norm
=
self
.
norm
.
cpu
()
optimize_memory_usage
()
x
=
self
.
dropout
(
x
)
return
x
.
to
(
torch
.
bfloat16
)
...
...
@@ -529,6 +574,10 @@ class T5EncoderModel:
def
to_cuda
(
self
):
self
.
model
=
self
.
model
.
to
(
"cuda"
)
def
optimize_memory
(
self
):
"""优化内存使用"""
optimize_memory_usage
()
def
infer
(
self
,
texts
):
if
self
.
cpu_offload
and
self
.
offload_granularity
==
"model"
:
self
.
to_cuda
()
...
...
@@ -537,10 +586,17 @@ class T5EncoderModel:
ids
=
ids
.
cuda
()
mask
=
mask
.
cuda
()
seq_lens
=
mask
.
gt
(
0
).
sum
(
dim
=
1
).
long
()
context
=
self
.
model
(
ids
,
mask
)
with
torch
.
no_grad
():
context
=
self
.
model
(
ids
,
mask
)
if
self
.
cpu_offload
and
self
.
offload_granularity
==
"model"
:
self
.
to_cpu
()
optimize_memory_usage
()
del
ids
,
mask
if
self
.
cpu_offload
:
optimize_memory_usage
()
return
[
u
[:
v
]
for
u
,
v
in
zip
(
context
,
seq_lens
)]
...
...
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
38c40fa0
...
...
@@ -24,6 +24,11 @@ class WanTransformerWeights(WeightModule):
self
.
blocks
=
WeightModuleList
([
WanTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
)
for
i
in
range
(
self
.
blocks_num
)])
self
.
add_module
(
"blocks"
,
self
.
blocks
)
def
clear
(
self
):
for
block
in
self
.
blocks
:
for
phase
in
block
.
compute_phases
:
phase
.
clear
()
class
WanTransformerAttentionBlock
(
WeightModule
):
def
__init__
(
self
,
block_index
,
task
,
mm_type
,
config
):
...
...
lightx2v/models/runners/default_runner.py
View file @
38c40fa0
...
...
@@ -49,7 +49,7 @@ class DefaultRunner:
else
:
self
.
run_input_encoder
=
self
.
run_input_encoder_server_t2v
else
:
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
if
not
self
.
config
.
get
(
"lazy_load"
,
False
)
and
not
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
load_model
()
self
.
run_dit
=
self
.
run_dit_local
self
.
run_vae_decoder
=
self
.
run_vae_decoder_local
...
...
@@ -136,8 +136,13 @@ class DefaultRunner:
def
end_run
(
self
):
self
.
model
.
scheduler
.
clear
()
del
self
.
inputs
,
self
.
model
.
scheduler
if
self
.
config
.
get
(
"lazy_load"
,
False
):
self
.
model
.
transformer_infer
.
weights_stream_mgr
.
clear
()
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
hasattr
(
self
.
model
.
transformer_infer
,
"weights_stream_mgr"
):
self
.
model
.
transformer_infer
.
weights_stream_mgr
.
clear
()
if
hasattr
(
self
.
model
.
transformer_weights
,
"clear"
):
self
.
model
.
transformer_weights
.
clear
()
self
.
model
.
pre_weight
.
clear
()
self
.
model
.
post_weight
.
clear
()
del
self
.
model
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
...
...
@@ -163,7 +168,7 @@ class DefaultRunner:
@
ProfilingContext
(
"Run DiT"
)
async
def
run_dit_local
(
self
,
kwargs
):
if
self
.
config
.
get
(
"lazy_load"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
model
=
self
.
load_transformer
()
self
.
init_scheduler
()
self
.
model
.
scheduler
.
prepare
(
self
.
inputs
[
"image_encoder_output"
])
...
...
@@ -173,10 +178,10 @@ class DefaultRunner:
@
ProfilingContext
(
"Run VAE Decoder"
)
async
def
run_vae_decoder_local
(
self
,
latents
,
generator
):
if
self
.
config
.
get
(
"lazy_load"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
vae_decoder
=
self
.
load_vae_decoder
()
images
=
self
.
vae_decoder
.
decode
(
latents
,
generator
=
generator
,
config
=
self
.
config
)
if
self
.
config
.
get
(
"lazy_load"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
vae_decoder
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
38c40fa0
...
...
@@ -61,14 +61,19 @@ class WanRunner(DefaultRunner):
return
image_encoder
def
load_text_encoder
(
self
):
t5_offload
=
self
.
config
.
get
(
"t5_cpu_offload"
,
False
)
if
t5_offload
:
t5_device
=
torch
.
device
(
"cpu"
)
else
:
t5_device
=
torch
.
device
(
"cuda"
)
text_encoder
=
T5EncoderModel
(
text_len
=
self
.
config
[
"text_len"
],
dtype
=
torch
.
bfloat16
,
device
=
self
.
ini
t_device
,
device
=
t
5
_device
,
checkpoint_path
=
os
.
path
.
join
(
self
.
config
.
model_path
,
"models_t5_umt5-xxl-enc-bf16.pth"
),
tokenizer_path
=
os
.
path
.
join
(
self
.
config
.
model_path
,
"google/umt5-xxl"
),
shard_fn
=
None
,
cpu_offload
=
self
.
config
.
cpu
_offload
,
cpu_offload
=
t5
_offload
,
offload_granularity
=
self
.
config
.
get
(
"t5_offload_granularity"
,
"model"
),
t5_quantized
=
self
.
config
.
get
(
"t5_quantized"
,
False
),
t5_quantized_ckpt
=
self
.
config
.
get
(
"t5_quantized_ckpt"
,
None
),
...
...
@@ -129,13 +134,13 @@ class WanRunner(DefaultRunner):
self
.
model
.
set_scheduler
(
scheduler
)
def
run_text_encoder
(
self
,
text
,
img
):
if
self
.
config
.
get
(
"lazy_load"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
text_encoders
=
self
.
load_text_encoder
()
text_encoder_output
=
{}
n_prompt
=
self
.
config
.
get
(
"negative_prompt"
,
""
)
context
=
self
.
text_encoders
[
0
].
infer
([
text
])
context_null
=
self
.
text_encoders
[
0
].
infer
([
n_prompt
if
n_prompt
else
""
])
if
self
.
config
.
get
(
"lazy_load"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
text_encoders
[
0
]
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
...
...
@@ -144,11 +149,11 @@ class WanRunner(DefaultRunner):
return
text_encoder_output
def
run_image_encoder
(
self
,
img
):
if
self
.
config
.
get
(
"lazy_load"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
image_encoder
=
self
.
load_image_encoder
()
img
=
TF
.
to_tensor
(
img
).
sub_
(
0.5
).
div_
(
0.5
).
cuda
()
clip_encoder_out
=
self
.
image_encoder
.
visual
([
img
[:,
None
,
:,
:]],
self
.
config
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
if
self
.
config
.
get
(
"lazy_load"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
image_encoder
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
...
...
@@ -179,7 +184,7 @@ class WanRunner(DefaultRunner):
msk
=
torch
.
concat
([
torch
.
repeat_interleave
(
msk
[:,
0
:
1
],
repeats
=
4
,
dim
=
1
),
msk
[:,
1
:]],
dim
=
1
)
msk
=
msk
.
view
(
1
,
msk
.
shape
[
1
]
//
4
,
4
,
lat_h
,
lat_w
)
msk
=
msk
.
transpose
(
1
,
2
)[
0
]
if
self
.
config
.
get
(
"lazy_load"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
vae_encoder
=
self
.
load_vae_encoder
()
vae_encode_out
=
self
.
vae_encoder
.
encode
(
[
...
...
@@ -193,7 +198,7 @@ class WanRunner(DefaultRunner):
],
self
.
config
,
)[
0
]
if
self
.
config
.
get
(
"lazy_load"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
vae_encoder
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment