Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
486e6279
Commit
486e6279
authored
May 22, 2025
by
root
Browse files
Add support for running lightx2v on 8 GB GPUs
parent
e74270f5
Changes
20
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1034 additions
and
154 deletions
+1034
-154
configs/offload/wan_i2v_block.json
configs/offload/wan_i2v_block.json
+17
-0
configs/offload/wan_i2v_phase.json
configs/offload/wan_i2v_phase.json
+21
-0
configs/offload/wan_t2v_block.json
configs/offload/wan_t2v_block.json
+18
-0
configs/offload/wan_t2v_phase.json
configs/offload/wan_t2v_phase.json
+21
-0
lightx2v/common/modules/weight_module.py
lightx2v/common/modules/weight_module.py
+36
-12
lightx2v/common/offload/manager.py
lightx2v/common/offload/manager.py
+18
-3
lightx2v/common/ops/conv/conv3d.py
lightx2v/common/ops/conv/conv3d.py
+6
-6
lightx2v/models/input_encoders/hf/t5/model.py
lightx2v/models/input_encoders/hf/t5/model.py
+41
-13
lightx2v/models/networks/hunyuan/infer/transformer_infer.py
lightx2v/models/networks/hunyuan/infer/transformer_infer.py
+3
-3
lightx2v/models/networks/wan/infer/causvid/transformer_infer.py
...2v/models/networks/wan/infer/causvid/transformer_infer.py
+0
-1
lightx2v/models/networks/wan/infer/pre_infer.py
lightx2v/models/networks/wan/infer/pre_infer.py
+0
-1
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+119
-50
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+2
-7
lightx2v/models/networks/wan/weights/pre_weights.py
lightx2v/models/networks/wan/weights/pre_weights.py
+0
-1
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+190
-45
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+34
-11
lightx2v/models/video_encoders/hf/tae.py
lightx2v/models/video_encoders/hf/tae.py
+332
-0
lightx2v/models/video_encoders/hf/wan/vae.py
lightx2v/models/video_encoders/hf/wan/vae.py
+144
-1
lightx2v/models/video_encoders/hf/wan/vae_tiny.py
lightx2v/models/video_encoders/hf/wan/vae_tiny.py
+27
-0
lightx2v/utils/set_config.py
lightx2v/utils/set_config.py
+5
-0
No files found.
configs/offload/wan_i2v_block.json
0 → 100755
View file @
486e6279
{
"infer_steps"
:
40
,
"target_video_length"
:
81
,
"target_height"
:
480
,
"target_width"
:
832
,
"attention_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
5
,
"sample_shift"
:
5
,
"enable_cfg"
:
true
,
"cpu_offload"
:
true
,
"offload_granularity"
:
"block"
,
"mm_config"
:
{
"mm_type"
:
"W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F"
,
"weight_auto_quant"
:
true
}
}
configs/offload/wan_i2v_phase.json
0 → 100755
View file @
486e6279
{
"infer_steps"
:
40
,
"target_video_length"
:
81
,
"target_height"
:
480
,
"target_width"
:
832
,
"attention_type"
:
"sage_attn2"
,
"seed"
:
42
,
"sample_guide_scale"
:
5
,
"sample_shift"
:
5
,
"enable_cfg"
:
true
,
"cpu_offload"
:
true
,
"offload_granularity"
:
"phase"
,
"mm_config"
:
{
"mm_type"
:
"W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F"
,
"weight_auto_quant"
:
true
},
"use_tiling_vae"
:
true
,
"tiny_vae"
:
true
,
"tiny_vae_path"
:
"/mnt/afs_2/gushiqiao/x2v_models/taew2_1.pth"
,
"text_encoder_offload_granularity"
:
"block"
}
configs/offload/wan_t2v_block.json
0 → 100755
View file @
486e6279
{
"infer_steps"
:
50
,
"target_video_length"
:
81
,
"text_len"
:
512
,
"target_height"
:
480
,
"target_width"
:
832
,
"attention_type"
:
"sage_attn2"
,
"seed"
:
42
,
"sample_guide_scale"
:
6
,
"sample_shift"
:
8
,
"enable_cfg"
:
true
,
"cpu_offload"
:
true
,
"offload_granularity"
:
"block"
,
"mm_config"
:
{
"mm_type"
:
"W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F"
,
"weight_auto_quant"
:
true
}
}
configs/offload/wan_t2v_phase.json
0 → 100755
View file @
486e6279
{
"infer_steps"
:
50
,
"target_video_length"
:
81
,
"text_len"
:
512
,
"target_height"
:
480
,
"target_width"
:
832
,
"attention_type"
:
"sage_attn2"
,
"seed"
:
42
,
"sample_guide_scale"
:
6
,
"sample_shift"
:
8
,
"enable_cfg"
:
true
,
"cpu_offload"
:
true
,
"offload_granularity"
:
"phase"
,
"mm_config"
:
{
"mm_type"
:
"W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F"
,
"weight_auto_quant"
:
true
},
"tiny_vae"
:
true
,
"tiny_vae_path"
:
"/mnt/afs_2/gushiqiao/x2v_models/taew2_1.pth"
,
"text_encoder_offload_granularity"
:
"block"
}
lightx2v/common/modules/weight_module.py
View file @
486e6279
...
@@ -53,8 +53,14 @@ class WeightModule:
...
@@ -53,8 +53,14 @@ class WeightModule:
self
.
_parameters
[
name
].
to_cpu
()
self
.
_parameters
[
name
].
to_cpu
()
setattr
(
self
,
name
,
self
.
_parameters
[
name
])
setattr
(
self
,
name
,
self
.
_parameters
[
name
])
for
module
in
self
.
_modules
.
values
():
for
module
in
self
.
_modules
.
values
():
if
module
is
not
None
and
hasattr
(
module
,
"to_cpu"
):
if
isinstance
(
module
,
WeightModuleList
):
module
.
to_cpu
()
for
i
in
range
(
len
(
module
)):
for
m
in
module
[
i
].
_modules
.
values
():
if
m
is
not
None
and
hasattr
(
m
,
"to_cpu"
):
m
.
to_cpu
()
else
:
if
module
is
not
None
and
hasattr
(
module
,
"to_cpu"
):
module
.
to_cpu
()
def
to_cuda
(
self
):
def
to_cuda
(
self
):
for
name
,
param
in
self
.
_parameters
.
items
():
for
name
,
param
in
self
.
_parameters
.
items
():
...
@@ -65,10 +71,16 @@ class WeightModule:
...
@@ -65,10 +71,16 @@ class WeightModule:
self
.
_parameters
[
name
].
to_cuda
()
self
.
_parameters
[
name
].
to_cuda
()
setattr
(
self
,
name
,
self
.
_parameters
[
name
])
setattr
(
self
,
name
,
self
.
_parameters
[
name
])
for
module
in
self
.
_modules
.
values
():
for
module
in
self
.
_modules
.
values
():
if
module
is
not
None
and
hasattr
(
module
,
"to_cuda"
):
if
isinstance
(
module
,
WeightModuleList
):
module
.
to_cuda
()
for
i
in
range
(
len
(
module
)):
for
m
in
module
[
i
].
_modules
.
values
():
def
to_cpu_sync
(
self
):
if
m
is
not
None
and
hasattr
(
m
,
"to_cuda"
):
m
.
to_cuda
()
else
:
if
module
is
not
None
and
hasattr
(
module
,
"to_cuda"
):
module
.
to_cuda
()
def
to_cpu_async
(
self
):
for
name
,
param
in
self
.
_parameters
.
items
():
for
name
,
param
in
self
.
_parameters
.
items
():
if
param
is
not
None
:
if
param
is
not
None
:
if
hasattr
(
param
,
"cpu"
):
if
hasattr
(
param
,
"cpu"
):
...
@@ -78,10 +90,16 @@ class WeightModule:
...
@@ -78,10 +90,16 @@ class WeightModule:
self
.
_parameters
[
name
].
to_cpu
(
non_blocking
=
True
)
self
.
_parameters
[
name
].
to_cpu
(
non_blocking
=
True
)
setattr
(
self
,
name
,
self
.
_parameters
[
name
])
setattr
(
self
,
name
,
self
.
_parameters
[
name
])
for
module
in
self
.
_modules
.
values
():
for
module
in
self
.
_modules
.
values
():
if
module
is
not
None
and
hasattr
(
module
,
"to_cpu"
):
if
isinstance
(
module
,
WeightModuleList
):
module
.
to_cpu
(
non_blocking
=
True
)
for
i
in
range
(
len
(
module
)):
for
m
in
module
[
i
].
_modules
.
values
():
def
to_cuda_sync
(
self
):
if
m
is
not
None
and
hasattr
(
m
,
"to_cpu"
):
m
.
to_cpu
(
non_blocking
=
True
)
else
:
if
module
is
not
None
and
hasattr
(
module
,
"to_cpu"
):
module
.
to_cpu
(
non_blocking
=
True
)
def
to_cuda_async
(
self
):
for
name
,
param
in
self
.
_parameters
.
items
():
for
name
,
param
in
self
.
_parameters
.
items
():
if
param
is
not
None
:
if
param
is
not
None
:
if
hasattr
(
param
,
"cuda"
):
if
hasattr
(
param
,
"cuda"
):
...
@@ -90,8 +108,14 @@ class WeightModule:
...
@@ -90,8 +108,14 @@ class WeightModule:
self
.
_parameters
[
name
].
to_cuda
(
non_blocking
=
True
)
self
.
_parameters
[
name
].
to_cuda
(
non_blocking
=
True
)
setattr
(
self
,
name
,
self
.
_parameters
[
name
])
setattr
(
self
,
name
,
self
.
_parameters
[
name
])
for
module
in
self
.
_modules
.
values
():
for
module
in
self
.
_modules
.
values
():
if
module
is
not
None
and
hasattr
(
module
,
"to_cuda"
):
if
isinstance
(
module
,
WeightModuleList
):
module
.
to_cuda
(
non_blocking
=
True
)
for
i
in
range
(
len
(
module
)):
for
m
in
module
[
i
].
_modules
.
values
():
if
m
is
not
None
and
hasattr
(
m
,
"to_cuda"
):
m
.
to_cuda
(
non_blocking
=
True
)
else
:
if
module
is
not
None
and
hasattr
(
module
,
"to_cuda"
):
module
.
to_cuda
(
non_blocking
=
True
)
class
WeightModuleList
(
WeightModule
):
class
WeightModuleList
(
WeightModule
):
...
...
lightx2v/common/offload/manager.py
View file @
486e6279
import
torch
import
torch
class
WeightStreamManager
(
object
):
class
Weight
Async
StreamManager
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
active_weights
=
[
None
for
_
in
range
(
2
)]
self
.
active_weights
=
[
None
for
_
in
range
(
2
)]
self
.
active_weights
=
[
None
for
_
in
range
(
2
)]
self
.
compute_stream
=
torch
.
cuda
.
Stream
(
priority
=-
1
)
self
.
compute_stream
=
torch
.
cuda
.
Stream
(
priority
=-
1
)
self
.
load_stream
=
torch
.
cuda
.
Stream
(
priority
=
0
)
self
.
load_stream
=
torch
.
cuda
.
Stream
(
priority
=
0
)
...
@@ -10,9 +11,9 @@ class WeightStreamManager(object):
...
@@ -10,9 +11,9 @@ class WeightStreamManager(object):
def
prefetch_weights
(
self
,
block_idx
,
blocks_weights
):
def
prefetch_weights
(
self
,
block_idx
,
blocks_weights
):
with
torch
.
cuda
.
stream
(
self
.
load_stream
):
with
torch
.
cuda
.
stream
(
self
.
load_stream
):
if
self
.
active_weights
[
1
]
is
not
None
:
if
self
.
active_weights
[
1
]
is
not
None
:
self
.
active_weights
[
1
].
to_cpu_sync
()
self
.
active_weights
[
1
].
to_cpu_
a
sync
()
new_weights
=
blocks_weights
[
block_idx
]
new_weights
=
blocks_weights
[
block_idx
]
new_weights
.
to_cuda_sync
()
new_weights
.
to_cuda_
a
sync
()
self
.
active_weights
[
1
]
=
new_weights
self
.
active_weights
[
1
]
=
new_weights
def
swap_weights
(
self
):
def
swap_weights
(
self
):
...
@@ -23,3 +24,17 @@ class WeightStreamManager(object):
...
@@ -23,3 +24,17 @@ class WeightStreamManager(object):
self
.
active_weights
[
1
],
self
.
active_weights
[
1
],
self
.
active_weights
[
0
],
self
.
active_weights
[
0
],
)
)
def
prefetch_phase
(
self
,
block_idx
,
phase_idx
,
blocks
):
with
torch
.
cuda
.
stream
(
self
.
load_stream
):
if
self
.
active_weights
[
1
]
is
not
None
:
_
,
old_phase
=
self
.
active_weights
[
1
]
old_phase
.
to_cpu_async
()
new_phase
=
blocks
[
block_idx
].
compute_phases
[
phase_idx
]
new_phase
.
to_cuda_async
()
self
.
active_weights
[
1
]
=
(
phase_idx
,
new_phase
)
def
swap_phases
(
self
):
self
.
compute_stream
.
synchronize
()
self
.
load_stream
.
synchronize
()
self
.
active_weights
[
0
],
self
.
active_weights
[
1
]
=
self
.
active_weights
[
1
],
self
.
active_weights
[
0
]
lightx2v/common/ops/conv/conv3d.py
View file @
486e6279
...
@@ -39,15 +39,15 @@ class Conv3dWeight(Conv3dWeightTemplate):
...
@@ -39,15 +39,15 @@ class Conv3dWeight(Conv3dWeightTemplate):
input_tensor
=
torch
.
nn
.
functional
.
conv3d
(
input_tensor
,
weight
=
self
.
weight
,
bias
=
self
.
bias
,
stride
=
self
.
stride
,
padding
=
self
.
padding
,
dilation
=
self
.
dilation
,
groups
=
self
.
groups
)
input_tensor
=
torch
.
nn
.
functional
.
conv3d
(
input_tensor
,
weight
=
self
.
weight
,
bias
=
self
.
bias
,
stride
=
self
.
stride
,
padding
=
self
.
padding
,
dilation
=
self
.
dilation
,
groups
=
self
.
groups
)
return
input_tensor
return
input_tensor
def
to_cpu
(
self
):
def
to_cpu
(
self
,
non_blocking
=
False
):
self
.
weight
=
self
.
weight
.
cpu
(
)
self
.
weight
=
self
.
weight
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
cpu
(
)
self
.
bias
=
self
.
bias
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
def
to_cuda
(
self
):
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
weight
=
self
.
weight
.
cuda
()
self
.
weight
=
self
.
weight
.
cuda
(
non_blocking
=
non_blocking
)
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
cuda
()
self
.
bias
=
self
.
bias
.
cuda
(
non_blocking
=
non_blocking
)
def
state_dict
(
self
,
destination
=
None
):
def
state_dict
(
self
,
destination
=
None
):
if
destination
is
None
:
if
destination
is
None
:
...
...
lightx2v/models/input_encoders/hf/t5/model.py
View file @
486e6279
...
@@ -256,8 +256,10 @@ class T5Encoder(nn.Module):
...
@@ -256,8 +256,10 @@ class T5Encoder(nn.Module):
num_buckets
,
num_buckets
,
shared_pos
=
True
,
shared_pos
=
True
,
dropout
=
0.1
,
dropout
=
0.1
,
cpu_offload
=
False
,
):
):
super
(
T5Encoder
,
self
).
__init__
()
super
(
T5Encoder
,
self
).
__init__
()
self
.
cpu_offload
=
cpu_offload
self
.
dim
=
dim
self
.
dim
=
dim
self
.
dim_attn
=
dim_attn
self
.
dim_attn
=
dim_attn
self
.
dim_ffn
=
dim_ffn
self
.
dim_ffn
=
dim_ffn
...
@@ -277,12 +279,28 @@ class T5Encoder(nn.Module):
...
@@ -277,12 +279,28 @@ class T5Encoder(nn.Module):
self
.
apply
(
init_weights
)
self
.
apply
(
init_weights
)
def
forward
(
self
,
ids
,
mask
=
None
):
def
forward
(
self
,
ids
,
mask
=
None
):
if
self
.
cpu_offload
:
self
.
token_embedding
=
self
.
token_embedding
.
cuda
()
x
=
self
.
token_embedding
(
ids
)
x
=
self
.
token_embedding
(
ids
)
if
self
.
cpu_offload
:
self
.
token_embedding
=
self
.
token_embedding
.
cpu
()
x
=
self
.
dropout
(
x
)
x
=
self
.
dropout
(
x
)
if
self
.
cpu_offload
and
self
.
pos_embedding
is
not
None
:
self
.
pos_embedding
=
self
.
pos_embedding
.
cuda
()
e
=
self
.
pos_embedding
(
x
.
size
(
1
),
x
.
size
(
1
))
if
self
.
shared_pos
else
None
e
=
self
.
pos_embedding
(
x
.
size
(
1
),
x
.
size
(
1
))
if
self
.
shared_pos
else
None
if
self
.
cpu_offload
and
self
.
pos_embedding
is
not
None
:
self
.
pos_embedding
=
self
.
pos_embedding
.
cpu
()
for
block
in
self
.
blocks
:
for
block
in
self
.
blocks
:
if
self
.
cpu_offload
:
block
=
block
.
cuda
()
x
=
block
(
x
,
mask
,
pos_bias
=
e
)
x
=
block
(
x
,
mask
,
pos_bias
=
e
)
if
self
.
cpu_offload
:
block
=
block
.
cpu
()
if
self
.
cpu_offload
:
self
.
norm
=
self
.
norm
.
cuda
()
x
=
self
.
norm
(
x
)
x
=
self
.
norm
(
x
)
if
self
.
cpu_offload
:
self
.
norm
=
self
.
norm
.
cpu
()
x
=
self
.
dropout
(
x
)
x
=
self
.
dropout
(
x
)
return
x
return
x
...
@@ -432,15 +450,7 @@ def _t5(
...
@@ -432,15 +450,7 @@ def _t5(
# set device
# set device
model
=
model
.
to
(
dtype
=
dtype
,
device
=
device
)
model
=
model
.
to
(
dtype
=
dtype
,
device
=
device
)
return
model
# init tokenizer
if
return_tokenizer
:
from
.tokenizers
import
HuggingfaceTokenizer
tokenizer
=
HuggingfaceTokenizer
(
f
"google/
{
name
}
"
,
**
tokenizer_kwargs
)
return
model
,
tokenizer
else
:
return
model
def
umt5_xxl
(
**
kwargs
):
def
umt5_xxl
(
**
kwargs
):
...
@@ -470,15 +480,33 @@ class T5EncoderModel:
...
@@ -470,15 +480,33 @@ class T5EncoderModel:
checkpoint_path
=
None
,
checkpoint_path
=
None
,
tokenizer_path
=
None
,
tokenizer_path
=
None
,
shard_fn
=
None
,
shard_fn
=
None
,
cpu_offload
=
False
,
offload_granularity
=
"model"
,
):
):
self
.
text_len
=
text_len
self
.
text_len
=
text_len
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
device
=
device
self
.
device
=
device
self
.
checkpoint_path
=
checkpoint_path
self
.
checkpoint_path
=
checkpoint_path
self
.
tokenizer_path
=
tokenizer_path
self
.
tokenizer_path
=
tokenizer_path
self
.
offload_granularity
=
offload_granularity
# sync cpu offload
self
.
cpu_offload
=
cpu_offload
if
self
.
cpu_offload
:
assert
self
.
offload_granularity
in
[
"block"
,
"model"
]
# init model
# init model
model
=
umt5_xxl
(
encoder_only
=
True
,
return_tokenizer
=
False
,
dtype
=
dtype
,
device
=
device
).
eval
().
requires_grad_
(
False
)
model
=
(
umt5_xxl
(
encoder_only
=
True
,
return_tokenizer
=
False
,
dtype
=
dtype
,
device
=
device
,
cpu_offload
=
cpu_offload
if
self
.
offload_granularity
==
"block"
else
False
,
)
.
eval
()
.
requires_grad_
(
False
)
)
logging
.
info
(
f
"loading
{
checkpoint_path
}
"
)
logging
.
info
(
f
"loading
{
checkpoint_path
}
"
)
model
.
load_state_dict
(
torch
.
load
(
checkpoint_path
,
map_location
=
"cpu"
,
weights_only
=
True
))
model
.
load_state_dict
(
torch
.
load
(
checkpoint_path
,
map_location
=
"cpu"
,
weights_only
=
True
))
self
.
model
=
model
self
.
model
=
model
...
@@ -495,8 +523,8 @@ class T5EncoderModel:
...
@@ -495,8 +523,8 @@ class T5EncoderModel:
def
to_cuda
(
self
):
def
to_cuda
(
self
):
self
.
model
=
self
.
model
.
to
(
"cuda"
)
self
.
model
=
self
.
model
.
to
(
"cuda"
)
def
infer
(
self
,
texts
,
config
):
def
infer
(
self
,
texts
):
if
config
.
cpu_offload
:
if
self
.
cpu_offload
and
self
.
offload_granularity
==
"model"
:
self
.
to_cuda
()
self
.
to_cuda
()
ids
,
mask
=
self
.
tokenizer
(
texts
,
return_mask
=
True
,
add_special_tokens
=
True
)
ids
,
mask
=
self
.
tokenizer
(
texts
,
return_mask
=
True
,
add_special_tokens
=
True
)
...
@@ -505,7 +533,7 @@ class T5EncoderModel:
...
@@ -505,7 +533,7 @@ class T5EncoderModel:
seq_lens
=
mask
.
gt
(
0
).
sum
(
dim
=
1
).
long
()
seq_lens
=
mask
.
gt
(
0
).
sum
(
dim
=
1
).
long
()
context
=
self
.
model
(
ids
,
mask
)
context
=
self
.
model
(
ids
,
mask
)
if
config
.
cpu_offload
:
if
self
.
cpu_offload
and
self
.
offload_granularity
==
"model"
:
self
.
to_cpu
()
self
.
to_cpu
()
return
[
u
[:
v
]
for
u
,
v
in
zip
(
context
,
seq_lens
)]
return
[
u
[:
v
]
for
u
,
v
in
zip
(
context
,
seq_lens
)]
...
...
lightx2v/models/networks/hunyuan/infer/transformer_infer.py
View file @
486e6279
import
torch
import
torch
from
einops
import
rearrange
from
einops
import
rearrange
from
.utils_bf16
import
apply_rotary_emb
from
.utils_bf16
import
apply_rotary_emb
from
lightx2v.common.offload.manager
import
WeightStreamManager
from
lightx2v.common.offload.manager
import
Weight
Async
StreamManager
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
...
@@ -16,8 +16,8 @@ class HunyuanTransformerInfer:
...
@@ -16,8 +16,8 @@ class HunyuanTransformerInfer:
self
.
mlp_hidden_dim
=
12288
self
.
mlp_hidden_dim
=
12288
self
.
parallel_attention
=
None
self
.
parallel_attention
=
None
if
self
.
config
[
"cpu_offload"
]:
if
self
.
config
[
"cpu_offload"
]:
self
.
double_weights_stream_mgr
=
WeightStreamManager
()
self
.
double_weights_stream_mgr
=
Weight
Async
StreamManager
()
self
.
single_weights_stream_mgr
=
WeightStreamManager
()
self
.
single_weights_stream_mgr
=
Weight
Async
StreamManager
()
self
.
infer_func
=
self
.
_infer_with_offload
self
.
infer_func
=
self
.
_infer_with_offload
else
:
else
:
self
.
infer_func
=
self
.
_infer_without_offload
self
.
infer_func
=
self
.
_infer_without_offload
...
...
lightx2v/models/networks/wan/infer/causvid/transformer_infer.py
View file @
486e6279
import
torch
import
torch
import
math
import
math
from
..utils
import
compute_freqs
,
compute_freqs_causvid
,
compute_freqs_dist
,
apply_rotary_emb
from
..utils
import
compute_freqs
,
compute_freqs_causvid
,
compute_freqs_dist
,
apply_rotary_emb
from
lightx2v.common.offload.manager
import
WeightStreamManager
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
..transformer_infer
import
WanTransformerInfer
from
..transformer_infer
import
WanTransformerInfer
...
...
lightx2v/models/networks/wan/infer/pre_infer.py
View file @
486e6279
import
torch
import
torch
import
math
import
math
from
.utils
import
rope_params
,
sinusoidal_embedding_1d
from
.utils
import
rope_params
,
sinusoidal_embedding_1d
import
torch.cuda.amp
as
amp
class
WanPreInfer
:
class
WanPreInfer
:
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
486e6279
import
torch
import
torch
from
.utils
import
compute_freqs
,
compute_freqs_dist
,
apply_rotary_emb
from
.utils
import
compute_freqs
,
compute_freqs_dist
,
apply_rotary_emb
from
lightx2v.common.offload.manager
import
WeightStreamManager
from
lightx2v.common.offload.manager
import
Weight
Async
StreamManager
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
...
@@ -15,30 +15,32 @@ class WanTransformerInfer:
...
@@ -15,30 +15,32 @@ class WanTransformerInfer:
self
.
window_size
=
config
.
get
(
"window_size"
,
(
-
1
,
-
1
))
self
.
window_size
=
config
.
get
(
"window_size"
,
(
-
1
,
-
1
))
self
.
parallel_attention
=
None
self
.
parallel_attention
=
None
if
self
.
config
[
"cpu_offload"
]:
if
self
.
config
[
"cpu_offload"
]:
self
.
weights_stream_mgr
=
WeightStreamManager
()
offload_granularity
=
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
self
.
infer_func
=
self
.
_infer_with_offload
self
.
weights_stream_mgr
=
WeightAsyncStreamManager
()
if
offload_granularity
==
"block"
:
self
.
infer_func
=
self
.
_infer_with_offload
elif
offload_granularity
==
"phase"
:
self
.
infer_func
=
self
.
_infer_with_phases_offload
else
:
else
:
self
.
infer_func
=
self
.
_infer_without_offload
self
.
infer_func
=
self
.
_infer_without_offload
def
set_scheduler
(
self
,
scheduler
):
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
def
_calculate_q_k_len
(
self
,
q
,
k
,
k_lens
):
def
_calculate_q_k_len
(
self
,
q
,
k_lens
):
lq
,
nq
,
c1
=
q
.
size
()
lk
,
nk
,
c1_k
=
k
.
size
()
# Handle query and key lengths (use `q_lens` and `k_lens` or set them to Lq and Lk if None)
# Handle query and key lengths (use `q_lens` and `k_lens` or set them to Lq and Lk if None)
q_lens
=
torch
.
tensor
([
l
q
],
dtype
=
torch
.
int32
,
device
=
q
.
device
)
q_lens
=
torch
.
tensor
([
q
.
size
(
0
)
],
dtype
=
torch
.
int32
,
device
=
q
.
device
)
# We don't have a batch dimension anymore, so directly use the `q_lens` and `k_lens` values
# We don't have a batch dimension anymore, so directly use the `q_lens` and `k_lens` values
cu_seqlens_q
=
torch
.
cat
([
q_lens
.
new_zeros
([
1
]),
q_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
)
cu_seqlens_q
=
torch
.
cat
([
q_lens
.
new_zeros
([
1
]),
q_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
cat
([
k_lens
.
new_zeros
([
1
]),
k_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
cat
([
k_lens
.
new_zeros
([
1
]),
k_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
)
return
cu_seqlens_q
,
cu_seqlens_k
,
lq
,
lk
return
cu_seqlens_q
,
cu_seqlens_k
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
infer
(
self
,
weights
,
grid_sizes
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
return
self
.
infer_func
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
return
self
.
infer_func
(
weights
,
grid_sizes
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
def
_infer_with_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
_infer_with_offload
(
self
,
weights
,
grid_sizes
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
for
block_idx
in
range
(
self
.
blocks_num
):
for
block_idx
in
range
(
self
.
blocks_num
):
if
block_idx
==
0
:
if
block_idx
==
0
:
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
blocks
[
0
]
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
blocks
[
0
]
...
@@ -48,7 +50,6 @@ class WanTransformerInfer:
...
@@ -48,7 +50,6 @@ class WanTransformerInfer:
x
=
self
.
infer_block
(
x
=
self
.
infer_block
(
self
.
weights_stream_mgr
.
active_weights
[
0
],
self
.
weights_stream_mgr
.
active_weights
[
0
],
grid_sizes
,
grid_sizes
,
embed
,
x
,
x
,
embed0
,
embed0
,
seq_lens
,
seq_lens
,
...
@@ -62,12 +63,62 @@ class WanTransformerInfer:
...
@@ -62,12 +63,62 @@ class WanTransformerInfer:
return
x
return
x
def
_infer_without_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
_infer_with_phases_offload
(
self
,
weights
,
grid_sizes
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
for
block_idx
in
range
(
weights
.
blocks_num
):
weights
.
blocks
[
block_idx
].
modulation
.
to_cuda
()
if
embed0
.
dim
()
==
3
:
modulation
=
weights
.
blocks
[
block_idx
].
modulation
.
tensor
.
unsqueeze
(
2
)
current_embed0
=
(
modulation
+
embed0
).
chunk
(
6
,
dim
=
1
)
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
[
ei
.
squeeze
(
1
)
for
ei
in
current_embed0
]
elif
embed0
.
dim
()
==
2
:
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
(
weights
.
blocks
[
block_idx
].
modulation
.
tensor
+
embed0
).
chunk
(
6
,
dim
=
1
)
for
phase_idx
in
range
(
3
):
if
block_idx
==
0
and
phase_idx
==
0
:
phase
=
weights
.
blocks
[
block_idx
].
compute_phases
[
phase_idx
]
phase
.
to_cuda
()
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
(
phase_idx
,
phase
)
with
torch
.
cuda
.
stream
(
self
.
weights_stream_mgr
.
compute_stream
):
cur_phase_idx
,
cur_phase
=
self
.
weights_stream_mgr
.
active_weights
[
0
]
if
cur_phase_idx
==
0
:
x
=
self
.
_infer_self_attn
(
cur_phase
,
x
,
shift_msa
,
scale_msa
,
gate_msa
,
grid_sizes
,
freqs
,
seq_lens
,
)
elif
cur_phase_idx
==
1
:
x
=
self
.
_infer_cross_attn
(
cur_phase
,
x
,
context
)
elif
cur_phase_idx
==
2
:
x
=
self
.
_infer_ffn
(
cur_phase
,
x
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
)
is_last_phase
=
block_idx
==
weights
.
blocks_num
-
1
and
phase_idx
==
2
if
not
is_last_phase
:
next_block_idx
=
block_idx
+
1
if
cur_phase_idx
==
2
else
block_idx
next_phase_idx
=
(
cur_phase_idx
+
1
)
%
3
self
.
weights_stream_mgr
.
prefetch_phase
(
next_block_idx
,
next_phase_idx
,
weights
.
blocks
)
self
.
weights_stream_mgr
.
swap_phases
()
weights
.
blocks
[
block_idx
].
modulation
.
to_cpu
()
torch
.
cuda
.
empty_cache
()
return
x
def
_infer_without_offload
(
self
,
weights
,
grid_sizes
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
for
block_idx
in
range
(
self
.
blocks_num
):
for
block_idx
in
range
(
self
.
blocks_num
):
x
=
self
.
infer_block
(
x
=
self
.
infer_block
(
weights
.
blocks
[
block_idx
],
weights
.
blocks
[
block_idx
],
grid_sizes
,
grid_sizes
,
embed
,
x
,
x
,
embed0
,
embed0
,
seq_lens
,
seq_lens
,
...
@@ -76,21 +127,13 @@ class WanTransformerInfer:
...
@@ -76,21 +127,13 @@ class WanTransformerInfer:
)
)
return
x
return
x
def
infer_block
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
_infer_self_attn
(
self
,
weights
,
x
,
shift_msa
,
scale_msa
,
gate_msa
,
grid_sizes
,
freqs
,
seq_lens
):
if
embed0
.
dim
()
==
3
:
modulation
=
weights
.
modulation
.
tensor
.
unsqueeze
(
2
)
# 1, 6, 1, dim
embed0
=
embed0
.
unsqueeze
(
0
)
#
embed0
=
(
modulation
+
embed0
).
chunk
(
6
,
dim
=
1
)
embed0
=
[
ei
.
squeeze
(
1
)
for
ei
in
embed0
]
elif
embed0
.
dim
()
==
2
:
embed0
=
(
weights
.
modulation
.
tensor
+
embed0
).
chunk
(
6
,
dim
=
1
)
if
hasattr
(
weights
,
"smooth_norm1_weight"
):
if
hasattr
(
weights
,
"smooth_norm1_weight"
):
norm1_weight
=
(
1
+
embed0
[
1
]
)
*
weights
.
smooth_norm1_weight
.
tensor
norm1_weight
=
(
1
+
scale_msa
)
*
weights
.
smooth_norm1_weight
.
tensor
norm1_bias
=
embed0
[
0
]
*
weights
.
smooth_norm1_bias
.
tensor
norm1_bias
=
shift_msa
*
weights
.
smooth_norm1_bias
.
tensor
else
:
else
:
norm1_weight
=
1
+
embed0
[
1
]
norm1_weight
=
1
+
scale_msa
norm1_bias
=
embed0
[
0
]
norm1_bias
=
shift_msa
norm1_out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
x
.
shape
[
1
],),
None
,
None
,
1e-6
)
norm1_out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
x
.
shape
[
1
],),
None
,
None
,
1e-6
)
norm1_out
=
(
norm1_out
*
norm1_weight
+
norm1_bias
).
squeeze
(
0
)
norm1_out
=
(
norm1_out
*
norm1_weight
+
norm1_bias
).
squeeze
(
0
)
...
@@ -108,7 +151,7 @@ class WanTransformerInfer:
...
@@ -108,7 +151,7 @@ class WanTransformerInfer:
q
=
apply_rotary_emb
(
q
,
freqs_i
)
q
=
apply_rotary_emb
(
q
,
freqs_i
)
k
=
apply_rotary_emb
(
k
,
freqs_i
)
k
=
apply_rotary_emb
(
k
,
freqs_i
)
cu_seqlens_q
,
cu_seqlens_k
,
lq
,
lk
=
self
.
_calculate_q_k_len
(
q
,
k
,
k_lens
=
seq_lens
)
cu_seqlens_q
,
cu_seqlens_k
=
self
.
_calculate_q_k_len
(
q
,
k_lens
=
seq_lens
)
if
not
self
.
parallel_attention
:
if
not
self
.
parallel_attention
:
attn_out
=
weights
.
self_attn_1
.
apply
(
attn_out
=
weights
.
self_attn_1
.
apply
(
...
@@ -117,8 +160,8 @@ class WanTransformerInfer:
...
@@ -117,8 +160,8 @@ class WanTransformerInfer:
v
=
v
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_k
,
cu_seqlens_kv
=
cu_seqlens_k
,
max_seqlen_q
=
l
q
,
max_seqlen_q
=
q
.
size
(
0
)
,
max_seqlen_kv
=
l
k
,
max_seqlen_kv
=
k
.
size
(
0
)
,
model_cls
=
self
.
config
[
"model_cls"
],
model_cls
=
self
.
config
[
"model_cls"
],
)
)
else
:
else
:
...
@@ -129,25 +172,30 @@ class WanTransformerInfer:
...
@@ -129,25 +172,30 @@ class WanTransformerInfer:
v
=
v
,
v
=
v
,
img_qkv_len
=
q
.
shape
[
0
],
img_qkv_len
=
q
.
shape
[
0
],
cu_seqlens_qkv
=
cu_seqlens_q
,
cu_seqlens_qkv
=
cu_seqlens_q
,
# cu_seqlens_qkv=cu_seqlens_qkv,
# max_seqlen_qkv=max_seqlen_qkv,
)
)
y
=
weights
.
self_attn_o
.
apply
(
attn_out
)
x
=
x
+
y
*
embed0
[
2
].
squeeze
(
0
)
y
=
weights
.
self_attn_o
.
apply
(
attn_out
)
x
.
add_
(
y
*
gate_msa
.
squeeze
(
0
))
return
x
def
_infer_cross_attn
(
self
,
weights
,
x
,
context
):
norm3_out
=
weights
.
norm3
.
apply
(
x
)
norm3_out
=
weights
.
norm3
.
apply
(
x
)
if
self
.
task
==
"i2v"
:
if
self
.
task
==
"i2v"
:
context_img
=
context
[:
257
]
context_img
=
context
[:
257
]
context
=
context
[
257
:]
context
=
context
[
257
:]
else
:
context_img
=
None
n
,
d
=
self
.
num_heads
,
self
.
head_dim
n
,
d
=
self
.
num_heads
,
self
.
head_dim
q
=
weights
.
cross_attn_norm_q
.
apply
(
weights
.
cross_attn_q
.
apply
(
norm3_out
)).
view
(
-
1
,
n
,
d
)
q
=
weights
.
cross_attn_norm_q
.
apply
(
weights
.
cross_attn_q
.
apply
(
norm3_out
)).
view
(
-
1
,
n
,
d
)
k
=
weights
.
cross_attn_norm_k
.
apply
(
weights
.
cross_attn_k
.
apply
(
context
)).
view
(
-
1
,
n
,
d
)
k
=
weights
.
cross_attn_norm_k
.
apply
(
weights
.
cross_attn_k
.
apply
(
context
)).
view
(
-
1
,
n
,
d
)
v
=
weights
.
cross_attn_v
.
apply
(
context
).
view
(
-
1
,
n
,
d
)
v
=
weights
.
cross_attn_v
.
apply
(
context
).
view
(
-
1
,
n
,
d
)
cu_seqlens_q
,
cu_seqlens_k
,
lq
,
lk
=
self
.
_calculate_q_k_len
(
q
,
k
,
k_lens
=
torch
.
tensor
([
k
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
k
.
device
))
cu_seqlens_q
,
cu_seqlens_k
=
self
.
_calculate_q_k_len
(
q
,
k_lens
=
torch
.
tensor
([
k
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
k
.
device
),
)
attn_out
=
weights
.
cross_attn_1
.
apply
(
attn_out
=
weights
.
cross_attn_1
.
apply
(
q
=
q
,
q
=
q
,
...
@@ -155,18 +203,17 @@ class WanTransformerInfer:
...
@@ -155,18 +203,17 @@ class WanTransformerInfer:
v
=
v
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_k
,
cu_seqlens_kv
=
cu_seqlens_k
,
max_seqlen_q
=
l
q
,
max_seqlen_q
=
q
.
size
(
0
)
,
max_seqlen_kv
=
l
k
,
max_seqlen_kv
=
k
.
size
(
0
)
,
model_cls
=
self
.
config
[
"model_cls"
],
model_cls
=
self
.
config
[
"model_cls"
],
)
)
if
self
.
task
==
"i2v"
:
if
self
.
task
==
"i2v"
and
context_img
is
not
None
:
k_img
=
weights
.
cross_attn_norm_k_img
.
apply
(
weights
.
cross_attn_k_img
.
apply
(
context_img
)).
view
(
-
1
,
n
,
d
)
k_img
=
weights
.
cross_attn_norm_k_img
.
apply
(
weights
.
cross_attn_k_img
.
apply
(
context_img
)).
view
(
-
1
,
n
,
d
)
v_img
=
weights
.
cross_attn_v_img
.
apply
(
context_img
).
view
(
-
1
,
n
,
d
)
v_img
=
weights
.
cross_attn_v_img
.
apply
(
context_img
).
view
(
-
1
,
n
,
d
)
cu_seqlens_q
,
cu_seqlens_k
,
lq
,
lk
=
self
.
_calculate_q_k_len
(
cu_seqlens_q
,
cu_seqlens_k
=
self
.
_calculate_q_k_len
(
q
,
q
,
k_img
,
k_lens
=
torch
.
tensor
([
k_img
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
k
.
device
),
k_lens
=
torch
.
tensor
([
k_img
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
k
.
device
),
)
)
...
@@ -176,28 +223,50 @@ class WanTransformerInfer:
...
@@ -176,28 +223,50 @@ class WanTransformerInfer:
v
=
v_img
,
v
=
v_img
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_k
,
cu_seqlens_kv
=
cu_seqlens_k
,
max_seqlen_q
=
l
q
,
max_seqlen_q
=
q
.
size
(
0
)
,
max_seqlen_kv
=
l
k
,
max_seqlen_kv
=
k
_img
.
size
(
0
)
,
model_cls
=
self
.
config
[
"model_cls"
],
model_cls
=
self
.
config
[
"model_cls"
],
)
)
attn_out
=
attn_out
+
img_attn_out
attn_out
=
attn_out
+
img_attn_out
attn_out
=
weights
.
cross_attn_o
.
apply
(
attn_out
)
attn_out
=
weights
.
cross_attn_o
.
apply
(
attn_out
)
x
.
add_
(
attn_out
)
return
x
x
=
x
+
attn_out
def
_infer_ffn
(
self
,
weights
,
x
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
):
if
hasattr
(
weights
,
"smooth_norm2_weight"
):
if
hasattr
(
weights
,
"smooth_norm2_weight"
):
norm2_weight
=
(
1
+
embed0
[
4
]
.
squeeze
(
0
))
*
weights
.
smooth_norm2_weight
.
tensor
norm2_weight
=
(
1
+
c_scale_msa
.
squeeze
(
0
))
*
weights
.
smooth_norm2_weight
.
tensor
norm2_bias
=
embed0
[
3
]
.
squeeze
(
0
)
*
weights
.
smooth_norm2_bias
.
tensor
norm2_bias
=
c_shift_msa
.
squeeze
(
0
)
*
weights
.
smooth_norm2_bias
.
tensor
else
:
else
:
norm2_weight
=
1
+
embed0
[
4
]
.
squeeze
(
0
)
norm2_weight
=
1
+
c_scale_msa
.
squeeze
(
0
)
norm2_bias
=
embed0
[
3
]
.
squeeze
(
0
)
norm2_bias
=
c_shift_msa
.
squeeze
(
0
)
norm2_out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
x
.
shape
[
1
],),
None
,
None
,
1e-6
)
norm2_out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
x
.
shape
[
1
],),
None
,
None
,
1e-6
)
y
=
weights
.
ffn_0
.
apply
(
norm2_out
*
norm2_weight
+
norm2_bias
)
y
=
weights
.
ffn_0
.
apply
(
norm2_out
*
norm2_weight
+
norm2_bias
)
y
=
torch
.
nn
.
functional
.
gelu
(
y
,
approximate
=
"tanh"
)
y
=
torch
.
nn
.
functional
.
gelu
(
y
,
approximate
=
"tanh"
)
y
=
weights
.
ffn_2
.
apply
(
y
)
y
=
weights
.
ffn_2
.
apply
(
y
)
x
=
x
+
y
*
embed0
[
5
].
squeeze
(
0
)
x
.
add_
(
y
*
c_gate_msa
.
squeeze
(
0
))
return
x
def
infer_block
(
self
,
weights
,
grid_sizes
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
if
embed0
.
dim
()
==
3
:
modulation
=
weights
.
modulation
.
tensor
.
unsqueeze
(
2
)
embed0
=
(
modulation
+
embed0
).
chunk
(
6
,
dim
=
1
)
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
[
ei
.
squeeze
(
1
)
for
ei
in
embed0
]
elif
embed0
.
dim
()
==
2
:
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
(
weights
.
modulation
.
tensor
+
embed0
).
chunk
(
6
,
dim
=
1
)
x
=
self
.
_infer_self_attn
(
weights
.
compute_phases
[
1
],
x
,
shift_msa
,
scale_msa
,
gate_msa
,
grid_sizes
,
freqs
,
seq_lens
,
)
x
=
self
.
_infer_cross_attn
(
weights
.
compute_phases
[
2
],
x
,
context
)
x
=
self
.
_infer_ffn
(
weights
.
compute_phases
[
3
],
x
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
)
return
x
return
x
lightx2v/models/networks/wan/model.py
View file @
486e6279
...
@@ -52,11 +52,6 @@ class WanModel:
...
@@ -52,11 +52,6 @@ class WanModel:
else
:
else
:
raise
Exception
(
f
"Unsuppotred parallel_attn_type"
)
raise
Exception
(
f
"Unsuppotred parallel_attn_type"
)
if
self
.
config
[
"cpu_offload"
]:
self
.
to_cpu
()
else
:
self
.
to_cuda
()
def
_init_infer_class
(
self
):
def
_init_infer_class
(
self
):
self
.
pre_infer_class
=
WanPreInfer
self
.
pre_infer_class
=
WanPreInfer
self
.
post_infer_class
=
WanPostInfer
self
.
post_infer_class
=
WanPostInfer
...
@@ -188,7 +183,7 @@ class WanModel:
...
@@ -188,7 +183,7 @@ class WanModel:
self
.
post_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
*
pre_infer_out
)
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
if
self
.
config
[
"feature_caching"
]
==
"Tea"
:
if
self
.
config
[
"feature_caching"
]
==
"Tea"
:
...
@@ -199,7 +194,7 @@ class WanModel:
...
@@ -199,7 +194,7 @@ class WanModel:
if
self
.
config
[
"enable_cfg"
]:
if
self
.
config
[
"enable_cfg"
]:
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
*
pre_infer_out
)
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
if
self
.
config
[
"feature_caching"
]
==
"Tea"
:
if
self
.
config
[
"feature_caching"
]
==
"Tea"
:
...
...
lightx2v/models/networks/wan/weights/pre_weights.py
View file @
486e6279
import
torch
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
,
LN_WEIGHT_REGISTER
,
CONV3D_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
,
LN_WEIGHT_REGISTER
,
CONV3D_WEIGHT_REGISTER
from
lightx2v.common.modules.weight_module
import
WeightModule
from
lightx2v.common.modules.weight_module
import
WeightModule
...
...
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
486e6279
import
torch
import
torch
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
,
LN_WEIGHT_REGISTER
,
RMS_WEIGHT_REGISTER
,
TENSOR_REGISTER
,
ATTN_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
(
MM_WEIGHT_REGISTER
,
LN_WEIGHT_REGISTER
,
RMS_WEIGHT_REGISTER
,
TENSOR_REGISTER
,
ATTN_WEIGHT_REGISTER
,
)
from
lightx2v.common.modules.weight_module
import
WeightModule
,
WeightModuleList
from
lightx2v.common.modules.weight_module
import
WeightModule
,
WeightModuleList
...
@@ -26,57 +32,196 @@ class WanTransformerAttentionBlock(WeightModule):
...
@@ -26,57 +32,196 @@ class WanTransformerAttentionBlock(WeightModule):
self
.
config
=
config
self
.
config
=
config
self
.
quant_method
=
config
[
"mm_config"
].
get
(
"quant_method"
,
None
)
self
.
quant_method
=
config
[
"mm_config"
].
get
(
"quant_method"
,
None
)
self
.
sparge
=
config
.
get
(
"sparge"
,
False
)
self
.
sparge
=
config
.
get
(
"sparge"
,
False
)
self
.
register_parameter
(
"modulation"
,
TENSOR_REGISTER
[
"Default"
](
f
"blocks.
{
self
.
block_index
}
.modulation"
),
)
self
.
compute_phases
=
WeightModuleList
(
[
WanSelfAttention
(
block_index
,
task
,
mm_type
,
config
),
WanCrossAttention
(
block_index
,
task
,
mm_type
,
config
),
WanFFN
(
block_index
,
task
,
mm_type
,
config
),
]
)
self
.
add_module
(
"self_attn_q"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.self_attn.q.weight"
,
f
"blocks.
{
self
.
block_index
}
.self_attn.q.bias"
))
self
.
add_module
(
"compute_phases"
,
self
.
compute_phases
)
self
.
add_module
(
"self_attn_k"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.self_attn.k.weight"
,
f
"blocks.
{
self
.
block_index
}
.self_attn.k.bias"
))
self
.
add_module
(
"self_attn_v"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.self_attn.v.weight"
,
f
"blocks.
{
self
.
block_index
}
.self_attn.v.bias"
))
self
.
add_module
(
"self_attn_o"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.self_attn.o.weight"
,
f
"blocks.
{
self
.
block_index
}
.self_attn.o.bias"
))
self
.
add_module
(
"self_attn_norm_q"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.self_attn.norm_q.weight"
))
self
.
add_module
(
"self_attn_norm_k"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.self_attn.norm_k.weight"
))
self
.
add_module
(
"norm3"
,
LN_WEIGHT_REGISTER
[
"Default"
](
f
"blocks.
{
self
.
block_index
}
.norm3.weight"
,
f
"blocks.
{
self
.
block_index
}
.norm3.bias"
,
eps
=
1e-6
))
self
.
add_module
(
"cross_attn_q"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.q.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.q.bias"
))
self
.
add_module
(
"cross_attn_k"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.k.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.k.bias"
))
self
.
add_module
(
"cross_attn_v"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.v.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.v.bias"
))
self
.
add_module
(
"cross_attn_o"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.o.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.o.bias"
))
self
.
add_module
(
"cross_attn_norm_q"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.norm_q.weight"
))
self
.
add_module
(
"cross_attn_norm_k"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.norm_k.weight"
))
self
.
add_module
(
"ffn_0"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.ffn.0.weight"
,
f
"blocks.
{
self
.
block_index
}
.ffn.0.bias"
))
self
.
add_module
(
"ffn_2"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.ffn.2.weight"
,
f
"blocks.
{
self
.
block_index
}
.ffn.2.bias"
))
# attention weights section
if
self
.
sparge
:
assert
self
.
config
[
"sparge_ckpt"
],
"sparge_ckpt must be set when sparge is True"
self
.
add_module
(
"self_attn_1"
,
ATTN_WEIGHT_REGISTER
[
"Sparge"
](
f
"blocks.
{
self
.
block_index
}
"
))
self
.
add_module
(
"cross_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"attention_type"
]]())
else
:
self
.
add_module
(
"self_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"attention_type"
]]())
self
.
add_module
(
"cross_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"attention_type"
]]())
if
self
.
task
==
"i2v"
:
# i2v
self
.
add_module
(
"cross_attn_k_img"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.k_img.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.k_img.bias"
))
self
.
add_module
(
"cross_attn_v_img"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.v_img.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.v_img.bias"
))
self
.
add_module
(
"cross_attn_norm_k_img"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.norm_k_img.weight"
))
# attention weights
self
.
add_module
(
"cross_attn_2"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"attention_type"
]]())
# load attn weights
class
WanSelfAttention
(
WeightModule
):
def
__init__
(
self
,
block_index
,
task
,
mm_type
,
config
):
super
().
__init__
()
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
self
.
task
=
task
self
.
config
=
config
self
.
quant_method
=
config
[
"mm_config"
].
get
(
"quant_method"
,
None
)
self
.
sparge
=
config
.
get
(
"sparge"
,
False
)
self
.
add_module
(
"self_attn_q"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.self_attn.q.weight"
,
f
"blocks.
{
self
.
block_index
}
.self_attn.q.bias"
,
),
)
self
.
add_module
(
"self_attn_k"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.self_attn.k.weight"
,
f
"blocks.
{
self
.
block_index
}
.self_attn.k.bias"
,
),
)
self
.
add_module
(
"self_attn_v"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.self_attn.v.weight"
,
f
"blocks.
{
self
.
block_index
}
.self_attn.v.bias"
,
),
)
self
.
add_module
(
"self_attn_o"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.self_attn.o.weight"
,
f
"blocks.
{
self
.
block_index
}
.self_attn.o.bias"
,
),
)
self
.
add_module
(
"self_attn_norm_q"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.self_attn.norm_q.weight"
),
)
self
.
add_module
(
"self_attn_norm_k"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.self_attn.norm_k.weight"
),
)
if
self
.
sparge
:
if
self
.
sparge
:
assert
self
.
config
[
"sparge_ckpt"
],
"sparge_ckpt must be set when sparge is True"
assert
self
.
config
[
"sparge_ckpt"
],
"sparge_ckpt must be set when sparge is True"
self
.
add_module
(
"self_attn_1"
,
ATTN_WEIGHT_REGISTER
[
"Sparge"
](
f
"blocks.
{
self
.
block_index
}
"
),
)
sparge_ckpt
=
torch
.
load
(
self
.
config
[
"sparge_ckpt"
])
sparge_ckpt
=
torch
.
load
(
self
.
config
[
"sparge_ckpt"
])
self
.
self_attn_1
.
load
(
sparge_ckpt
)
self
.
self_attn_1
.
load
(
sparge_ckpt
)
else
:
else
:
# do not load weights
self
.
add_module
(
"self_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"attention_type"
]]())
pass
if
self
.
quant_method
in
[
"smoothquant"
,
"awq"
]:
self
.
register_parameter
(
"smooth_norm1_weight"
,
TENSOR_REGISTER
[
"Default"
](
f
"blocks.
{
self
.
block_index
}
.affine_norm1.weight"
),
)
self
.
register_parameter
(
"smooth_norm1_bias"
,
TENSOR_REGISTER
[
"Default"
](
f
"blocks.
{
self
.
block_index
}
.affine_norm1.bias"
),
)
class
WanCrossAttention
(
WeightModule
):
def
__init__
(
self
,
block_index
,
task
,
mm_type
,
config
):
super
().
__init__
()
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
self
.
task
=
task
self
.
config
=
config
self
.
add_module
(
"norm3"
,
LN_WEIGHT_REGISTER
[
"Default"
](
f
"blocks.
{
self
.
block_index
}
.norm3.weight"
,
f
"blocks.
{
self
.
block_index
}
.norm3.bias"
,
eps
=
1e-6
,
),
)
self
.
add_module
(
"cross_attn_q"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.q.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.q.bias"
,
),
)
self
.
add_module
(
"cross_attn_k"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.k.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.k.bias"
,
),
)
self
.
add_module
(
"cross_attn_v"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.v.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.v.bias"
,
),
)
self
.
add_module
(
"cross_attn_o"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.o.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.o.bias"
,
),
)
self
.
add_module
(
"cross_attn_norm_q"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.norm_q.weight"
),
)
self
.
add_module
(
"cross_attn_norm_k"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.norm_k.weight"
),
)
self
.
add_module
(
"cross_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"attention_type"
]]())
if
self
.
config
.
task
==
"i2v"
:
self
.
add_module
(
"cross_attn_k_img"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.k_img.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.k_img.bias"
,
),
)
self
.
add_module
(
"cross_attn_v_img"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.v_img.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.v_img.bias"
,
),
)
self
.
add_module
(
"cross_attn_norm_k_img"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.norm_k_img.weight"
),
)
self
.
add_module
(
"cross_attn_2"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"attention_type"
]]())
class
WanFFN
(
WeightModule
):
def
__init__
(
self
,
block_index
,
task
,
mm_type
,
config
):
super
().
__init__
()
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
self
.
task
=
task
self
.
config
=
config
self
.
quant_method
=
config
[
"mm_config"
].
get
(
"quant_method"
,
None
)
self
.
add_module
(
"ffn_0"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.ffn.0.weight"
,
f
"blocks.
{
self
.
block_index
}
.ffn.0.bias"
,
),
)
self
.
add_module
(
"ffn_2"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.ffn.2.weight"
,
f
"blocks.
{
self
.
block_index
}
.ffn.2.bias"
,
),
)
# For smoothquant or awq
if
self
.
quant_method
in
[
"smoothquant"
,
"awq"
]:
if
self
.
quant_method
in
[
"smoothquant"
,
"awq"
]:
self
.
register_parameter
(
"smooth_norm1_weight"
,
TENSOR_REGISTER
[
"Default"
](
f
"blocks.
{
self
.
block_index
}
.affine_norm1.weight"
))
self
.
register_parameter
(
self
.
register_parameter
(
"smooth_norm1_bias"
,
TENSOR_REGISTER
[
"Default"
](
f
"blocks.
{
self
.
block_index
}
.affine_norm1.bias"
))
"smooth_norm2_weight"
,
self
.
register_parameter
(
"smooth_norm2_weight"
,
TENSOR_REGISTER
[
"Default"
](
f
"blocks.
{
self
.
block_index
}
.affine_norm
3
.weight"
)
)
TENSOR_REGISTER
[
"Default"
](
f
"blocks.
{
self
.
block_index
}
.affine_norm
2
.weight"
)
,
self
.
register_parameter
(
"smooth_norm2_bias"
,
TENSOR_REGISTER
[
"Default"
](
f
"blocks.
{
self
.
block_index
}
.affine_norm3.bias"
)
)
)
elif
self
.
quant_method
is
not
None
:
self
.
register_parameter
(
raise
NotImplementedError
(
f
"This
{
self
.
quant_method
}
method is not implemented yet."
)
"smooth_norm2_bias"
,
TENSOR_REGISTER
[
"Default"
](
f
"blocks.
{
self
.
block_index
}
.affine_norm2.bias"
),
self
.
register_parameter
(
"modulation"
,
TENSOR_REGISTER
[
"Default"
](
f
"blocks.
{
self
.
block_index
}
.modulation"
)
)
)
lightx2v/models/runners/wan/wan_runner.py
View file @
486e6279
...
@@ -15,6 +15,7 @@ from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
...
@@ -15,6 +15,7 @@ from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from
lightx2v.models.networks.wan.model
import
WanModel
from
lightx2v.models.networks.wan.model
import
WanModel
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.video_encoders.hf.wan.vae
import
WanVAE
from
lightx2v.models.video_encoders.hf.wan.vae
import
WanVAE
from
lightx2v.models.video_encoders.hf.wan.vae_tiny
import
WanVAE_tiny
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
lightx2v.utils.memory_profiler
import
peak_memory_decorator
from
lightx2v.utils.memory_profiler
import
peak_memory_decorator
from
loguru
import
logger
from
loguru
import
logger
...
@@ -43,6 +44,8 @@ class WanRunner(DefaultRunner):
...
@@ -43,6 +44,8 @@ class WanRunner(DefaultRunner):
checkpoint_path
=
os
.
path
.
join
(
self
.
config
.
model_path
,
"models_t5_umt5-xxl-enc-bf16.pth"
),
checkpoint_path
=
os
.
path
.
join
(
self
.
config
.
model_path
,
"models_t5_umt5-xxl-enc-bf16.pth"
),
tokenizer_path
=
os
.
path
.
join
(
self
.
config
.
model_path
,
"google/umt5-xxl"
),
tokenizer_path
=
os
.
path
.
join
(
self
.
config
.
model_path
,
"google/umt5-xxl"
),
shard_fn
=
None
,
shard_fn
=
None
,
cpu_offload
=
self
.
config
.
cpu_offload
,
offload_granularity
=
self
.
config
.
get
(
"text_encoder_offload_granularity"
,
"model"
),
)
)
text_encoders
=
[
text_encoder
]
text_encoders
=
[
text_encoder
]
model
=
WanModel
(
self
.
config
.
model_path
,
self
.
config
,
init_device
)
model
=
WanModel
(
self
.
config
.
model_path
,
self
.
config
,
init_device
)
...
@@ -53,11 +56,19 @@ class WanRunner(DefaultRunner):
...
@@ -53,11 +56,19 @@ class WanRunner(DefaultRunner):
lora_wrapper
.
apply_lora
(
lora_name
,
self
.
config
.
strength_model
)
lora_wrapper
.
apply_lora
(
lora_name
,
self
.
config
.
strength_model
)
logger
.
info
(
f
"Loaded LoRA:
{
lora_name
}
"
)
logger
.
info
(
f
"Loaded LoRA:
{
lora_name
}
"
)
vae_model
=
WanVAE
(
if
self
.
config
.
get
(
"tiny_vae"
,
False
):
vae_pth
=
os
.
path
.
join
(
self
.
config
.
model_path
,
"Wan2.1_VAE.pth"
),
vae_model
=
WanVAE_tiny
(
device
=
init_device
,
vae_pth
=
self
.
config
.
tiny_vae_path
,
parallel
=
self
.
config
.
parallel_vae
,
device
=
init_device
,
)
)
vae_model
=
vae_model
.
to
(
"cuda"
)
else
:
vae_model
=
WanVAE
(
vae_pth
=
os
.
path
.
join
(
self
.
config
.
model_path
,
"Wan2.1_VAE.pth"
),
device
=
init_device
,
parallel
=
self
.
config
.
parallel_vae
,
use_tiling
=
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
)
if
self
.
config
.
task
==
"i2v"
:
if
self
.
config
.
task
==
"i2v"
:
image_encoder
=
CLIPModel
(
image_encoder
=
CLIPModel
(
dtype
=
torch
.
float16
,
dtype
=
torch
.
float16
,
...
@@ -68,6 +79,14 @@ class WanRunner(DefaultRunner):
...
@@ -68,6 +79,14 @@ class WanRunner(DefaultRunner):
),
),
tokenizer_path
=
os
.
path
.
join
(
self
.
config
.
model_path
,
"xlm-roberta-large"
),
tokenizer_path
=
os
.
path
.
join
(
self
.
config
.
model_path
,
"xlm-roberta-large"
),
)
)
if
self
.
config
.
get
(
"tiny_vae"
,
False
):
org_vae
=
WanVAE
(
vae_pth
=
os
.
path
.
join
(
self
.
config
.
model_path
,
"Wan2.1_VAE.pth"
),
device
=
init_device
,
parallel
=
self
.
config
.
parallel_vae
,
use_tiling
=
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
)
image_encoder
=
[
image_encoder
,
org_vae
]
return
model
,
text_encoders
,
vae_model
,
image_encoder
return
model
,
text_encoders
,
vae_model
,
image_encoder
...
@@ -84,17 +103,21 @@ class WanRunner(DefaultRunner):
...
@@ -84,17 +103,21 @@ class WanRunner(DefaultRunner):
def
run_text_encoder
(
self
,
text
,
text_encoders
,
config
,
image_encoder_output
):
def
run_text_encoder
(
self
,
text
,
text_encoders
,
config
,
image_encoder_output
):
text_encoder_output
=
{}
text_encoder_output
=
{}
n_prompt
=
config
.
get
(
"negative_prompt"
,
""
)
n_prompt
=
config
.
get
(
"negative_prompt"
,
""
)
context
=
text_encoders
[
0
].
infer
([
text
]
,
config
)
context
=
text_encoders
[
0
].
infer
([
text
])
context_null
=
text_encoders
[
0
].
infer
([
n_prompt
if
n_prompt
else
""
]
,
config
)
context_null
=
text_encoders
[
0
].
infer
([
n_prompt
if
n_prompt
else
""
])
text_encoder_output
[
"context"
]
=
context
text_encoder_output
[
"context"
]
=
context
text_encoder_output
[
"context_null"
]
=
context_null
text_encoder_output
[
"context_null"
]
=
context_null
return
text_encoder_output
return
text_encoder_output
@
peak_memory_decorator
@
peak_memory_decorator
def
run_image_encoder
(
self
,
config
,
image_encoder
,
vae_model
):
def
run_image_encoder
(
self
,
config
,
image_encoder
,
vae_model
):
if
self
.
config
.
get
(
"tiny_vae"
,
False
):
clip_image_encoder
,
vae_image_encoder
=
image_encoder
[
0
],
image_encoder
[
1
]
else
:
clip_image_encoder
,
vae_image_encoder
=
image_encoder
,
vae_model
img
=
Image
.
open
(
config
.
image_path
).
convert
(
"RGB"
)
img
=
Image
.
open
(
config
.
image_path
).
convert
(
"RGB"
)
img
=
TF
.
to_tensor
(
img
).
sub_
(
0.5
).
div_
(
0.5
).
cuda
()
img
=
TF
.
to_tensor
(
img
).
sub_
(
0.5
).
div_
(
0.5
).
cuda
()
clip_encoder_out
=
image_encoder
.
visual
([
img
[:,
None
,
:,
:]],
config
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
clip_encoder_out
=
clip_
image_encoder
.
visual
([
img
[:,
None
,
:,
:]],
config
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
h
,
w
=
img
.
shape
[
1
:]
h
,
w
=
img
.
shape
[
1
:]
aspect_ratio
=
h
/
w
aspect_ratio
=
h
/
w
max_area
=
config
.
target_height
*
config
.
target_width
max_area
=
config
.
target_height
*
config
.
target_width
...
@@ -111,7 +134,7 @@ class WanRunner(DefaultRunner):
...
@@ -111,7 +134,7 @@ class WanRunner(DefaultRunner):
msk
=
torch
.
concat
([
torch
.
repeat_interleave
(
msk
[:,
0
:
1
],
repeats
=
4
,
dim
=
1
),
msk
[:,
1
:]],
dim
=
1
)
msk
=
torch
.
concat
([
torch
.
repeat_interleave
(
msk
[:,
0
:
1
],
repeats
=
4
,
dim
=
1
),
msk
[:,
1
:]],
dim
=
1
)
msk
=
msk
.
view
(
1
,
msk
.
shape
[
1
]
//
4
,
4
,
lat_h
,
lat_w
)
msk
=
msk
.
view
(
1
,
msk
.
shape
[
1
]
//
4
,
4
,
lat_h
,
lat_w
)
msk
=
msk
.
transpose
(
1
,
2
)[
0
]
msk
=
msk
.
transpose
(
1
,
2
)[
0
]
vae_encode_out
=
vae_
m
ode
l
.
encode
(
vae_encode_out
=
vae_
image_enc
ode
r
.
encode
(
[
[
torch
.
concat
(
torch
.
concat
(
[
[
...
@@ -131,14 +154,14 @@ class WanRunner(DefaultRunner):
...
@@ -131,14 +154,14 @@ class WanRunner(DefaultRunner):
if
self
.
config
.
task
==
"i2v"
:
if
self
.
config
.
task
==
"i2v"
:
self
.
config
.
target_shape
=
(
self
.
config
.
target_shape
=
(
num_channels_latents
,
num_channels_latents
,
(
self
.
config
.
target_video_length
-
1
)
//
4
+
1
,
(
self
.
config
.
target_video_length
-
1
)
//
self
.
config
.
vae_stride
[
0
]
+
1
,
self
.
config
.
lat_h
,
self
.
config
.
lat_h
,
self
.
config
.
lat_w
,
self
.
config
.
lat_w
,
)
)
elif
self
.
config
.
task
==
"t2v"
:
elif
self
.
config
.
task
==
"t2v"
:
self
.
config
.
target_shape
=
(
self
.
config
.
target_shape
=
(
num_channels_latents
,
num_channels_latents
,
(
self
.
config
.
target_video_length
-
1
)
//
4
+
1
,
(
self
.
config
.
target_video_length
-
1
)
//
self
.
config
.
vae_stride
[
0
]
+
1
,
int
(
self
.
config
.
target_height
)
//
self
.
config
.
vae_stride
[
1
],
int
(
self
.
config
.
target_height
)
//
self
.
config
.
vae_stride
[
1
],
int
(
self
.
config
.
target_width
)
//
self
.
config
.
vae_stride
[
2
],
int
(
self
.
config
.
target_width
)
//
self
.
config
.
vae_stride
[
2
],
)
)
lightx2v/models/video_encoders/hf/tae.py
0 → 100644
View file @
486e6279
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
tqdm.auto
import
tqdm
from
collections
import
namedtuple
import
gc
import
os
os
.
environ
[
"PYTORCH_CUDA_ALLOC_CONF"
]
=
"max_split_size_mb:32,expandable_segments:True"
DecoderResult
=
namedtuple
(
"DecoderResult"
,
(
"frame"
,
"memory"
))
TWorkItem
=
namedtuple
(
"TWorkItem"
,
(
"input_tensor"
,
"block_index"
))
def
conv
(
n_in
,
n_out
,
**
kwargs
):
return
nn
.
Conv2d
(
n_in
,
n_out
,
3
,
padding
=
1
,
**
kwargs
)
class
Clamp
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
torch
.
tanh
(
x
/
3
)
*
3
class
MemBlock
(
nn
.
Module
):
def
__init__
(
self
,
n_in
,
n_out
):
super
().
__init__
()
self
.
conv
=
nn
.
Sequential
(
conv
(
n_in
*
2
,
n_out
),
nn
.
ReLU
(
inplace
=
True
),
conv
(
n_out
,
n_out
),
nn
.
ReLU
(
inplace
=
True
),
conv
(
n_out
,
n_out
))
self
.
skip
=
nn
.
Conv2d
(
n_in
,
n_out
,
1
,
bias
=
False
)
if
n_in
!=
n_out
else
nn
.
Identity
()
self
.
act
=
nn
.
ReLU
(
inplace
=
True
)
def
forward
(
self
,
x
,
past
):
return
self
.
act
(
self
.
conv
(
torch
.
cat
([
x
,
past
],
1
))
+
self
.
skip
(
x
))
class
TPool
(
nn
.
Module
):
def
__init__
(
self
,
n_f
,
stride
):
super
().
__init__
()
self
.
stride
=
stride
self
.
conv
=
nn
.
Conv2d
(
n_f
*
stride
,
n_f
,
1
,
bias
=
False
)
def
forward
(
self
,
x
):
_NT
,
C
,
H
,
W
=
x
.
shape
return
self
.
conv
(
x
.
reshape
(
-
1
,
self
.
stride
*
C
,
H
,
W
))
class
TGrow
(
nn
.
Module
):
def
__init__
(
self
,
n_f
,
stride
):
super
().
__init__
()
self
.
stride
=
stride
self
.
conv
=
nn
.
Conv2d
(
n_f
,
n_f
*
stride
,
1
,
bias
=
False
)
def
forward
(
self
,
x
):
_NT
,
C
,
H
,
W
=
x
.
shape
x
=
self
.
conv
(
x
)
return
x
.
reshape
(
-
1
,
C
,
H
,
W
)
def
apply_model_with_memblocks
(
model
,
x
,
parallel
,
show_progress_bar
):
"""
Apply a sequential model with memblocks to the given input.
Args:
- model: nn.Sequential of blocks to apply
- x: input data, of dimensions NTCHW
- parallel: if True, parallelize over timesteps (fast but uses O(T) memory)
if False, each timestep will be processed sequentially (slow but uses O(1) memory)
- show_progress_bar: if True, enables tqdm progressbar display
Returns NTCHW tensor of output data.
"""
assert
x
.
ndim
==
5
,
f
"TAEHV operates on NTCHW tensors, but got
{
x
.
ndim
}
-dim tensor"
N
,
T
,
C
,
H
,
W
=
x
.
shape
if
parallel
:
x
=
x
.
reshape
(
N
*
T
,
C
,
H
,
W
)
# parallel over input timesteps, iterate over blocks
for
b
in
tqdm
(
model
,
disable
=
not
show_progress_bar
):
if
isinstance
(
b
,
MemBlock
):
NT
,
C
,
H
,
W
=
x
.
shape
T
=
NT
//
N
_x
=
x
.
reshape
(
N
,
T
,
C
,
H
,
W
)
mem
=
F
.
pad
(
_x
,
(
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
),
value
=
0
)[:,
:
T
].
reshape
(
x
.
shape
)
x
=
b
(
x
,
mem
)
else
:
x
=
b
(
x
)
NT
,
C
,
H
,
W
=
x
.
shape
T
=
NT
//
N
x
=
x
.
view
(
N
,
T
,
C
,
H
,
W
)
else
:
# TODO(oboerbohan): at least on macos this still gradually uses more memory during decode...
# need to fix :(
out
=
[]
# iterate over input timesteps and also iterate over blocks.
# because of the cursed TPool/TGrow blocks, this is not a nested loop,
# it's actually a ***graph traversal*** problem! so let's make a queue
work_queue
=
[
TWorkItem
(
xt
,
0
)
for
t
,
xt
in
enumerate
(
x
.
reshape
(
N
,
T
*
C
,
H
,
W
).
chunk
(
T
,
dim
=
1
))]
# in addition to manually managing our queue, we also need to manually manage our progressbar.
# we'll update it for every source node that we consume.
progress_bar
=
tqdm
(
range
(
T
),
disable
=
not
show_progress_bar
)
# we'll also need a separate addressable memory per node as well
mem
=
[
None
]
*
len
(
model
)
while
work_queue
:
xt
,
i
=
work_queue
.
pop
(
0
)
if
i
==
0
:
# new source node consumed
progress_bar
.
update
(
1
)
if
i
==
len
(
model
):
# reached end of the graph, append result to output list
out
.
append
(
xt
)
else
:
# fetch the block to process
b
=
model
[
i
]
if
isinstance
(
b
,
MemBlock
):
# mem blocks are simple since we're visiting the graph in causal order
if
mem
[
i
]
is
None
:
xt_new
=
b
(
xt
,
xt
*
0
)
mem
[
i
]
=
xt
else
:
xt_new
=
b
(
xt
,
mem
[
i
])
mem
[
i
].
copy_
(
xt
)
# inplace might reduce mysterious pytorch memory allocations? doesn't help though
# add successor to work queue
work_queue
.
insert
(
0
,
TWorkItem
(
xt_new
,
i
+
1
))
elif
isinstance
(
b
,
TPool
):
# pool blocks are miserable
if
mem
[
i
]
is
None
:
mem
[
i
]
=
[]
# pool memory is itself a queue of inputs to pool
mem
[
i
].
append
(
xt
)
if
len
(
mem
[
i
])
>
b
.
stride
:
# pool mem is in invalid state, we should have pooled before this
raise
ValueError
(
"???"
)
elif
len
(
mem
[
i
])
<
b
.
stride
:
# pool mem is not yet full, go back to processing the work queue
pass
else
:
# pool mem is ready, run the pool block
N
,
C
,
H
,
W
=
xt
.
shape
xt
=
b
(
torch
.
cat
(
mem
[
i
],
1
).
view
(
N
*
b
.
stride
,
C
,
H
,
W
))
# reset the pool mem
mem
[
i
]
=
[]
# add successor to work queue
work_queue
.
insert
(
0
,
TWorkItem
(
xt
,
i
+
1
))
elif
isinstance
(
b
,
TGrow
):
xt
=
b
(
xt
)
NT
,
C
,
H
,
W
=
xt
.
shape
# each tgrow has multiple successor nodes
for
xt_next
in
reversed
(
xt
.
view
(
N
,
b
.
stride
*
C
,
H
,
W
).
chunk
(
b
.
stride
,
1
)):
# add successor to work queue
work_queue
.
insert
(
0
,
TWorkItem
(
xt_next
,
i
+
1
))
else
:
# normal block with no funny business
xt
=
b
(
xt
)
# add successor to work queue
work_queue
.
insert
(
0
,
TWorkItem
(
xt
,
i
+
1
))
progress_bar
.
close
()
x
=
torch
.
stack
(
out
,
1
)
return
x
class
TAEHV
(
nn
.
Module
):
latent_channels
=
16
image_channels
=
3
def
__init__
(
self
,
checkpoint_path
=
"taehv.pth"
,
decoder_time_upscale
=
(
True
,
True
),
decoder_space_upscale
=
(
True
,
True
,
True
)):
"""Initialize pretrained TAEHV from the given checkpoint.
Arg:
checkpoint_path: path to weight file to load. taehv.pth for Hunyuan, taew2_1.pth for Wan 2.1.
decoder_time_upscale: whether temporal upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
decoder_space_upscale: whether spatial upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
"""
super
().
__init__
()
self
.
encoder
=
nn
.
Sequential
(
conv
(
TAEHV
.
image_channels
,
64
),
nn
.
ReLU
(
inplace
=
True
),
TPool
(
64
,
2
),
conv
(
64
,
64
,
stride
=
2
,
bias
=
False
),
MemBlock
(
64
,
64
),
MemBlock
(
64
,
64
),
MemBlock
(
64
,
64
),
TPool
(
64
,
2
),
conv
(
64
,
64
,
stride
=
2
,
bias
=
False
),
MemBlock
(
64
,
64
),
MemBlock
(
64
,
64
),
MemBlock
(
64
,
64
),
TPool
(
64
,
1
),
conv
(
64
,
64
,
stride
=
2
,
bias
=
False
),
MemBlock
(
64
,
64
),
MemBlock
(
64
,
64
),
MemBlock
(
64
,
64
),
conv
(
64
,
TAEHV
.
latent_channels
),
)
n_f
=
[
256
,
128
,
64
,
64
]
self
.
frames_to_trim
=
2
**
sum
(
decoder_time_upscale
)
-
1
self
.
decoder
=
nn
.
Sequential
(
Clamp
(),
conv
(
TAEHV
.
latent_channels
,
n_f
[
0
]),
nn
.
ReLU
(
inplace
=
True
),
MemBlock
(
n_f
[
0
],
n_f
[
0
]),
MemBlock
(
n_f
[
0
],
n_f
[
0
]),
MemBlock
(
n_f
[
0
],
n_f
[
0
]),
nn
.
Upsample
(
scale_factor
=
2
if
decoder_space_upscale
[
0
]
else
1
),
TGrow
(
n_f
[
0
],
1
),
conv
(
n_f
[
0
],
n_f
[
1
],
bias
=
False
),
MemBlock
(
n_f
[
1
],
n_f
[
1
]),
MemBlock
(
n_f
[
1
],
n_f
[
1
]),
MemBlock
(
n_f
[
1
],
n_f
[
1
]),
nn
.
Upsample
(
scale_factor
=
2
if
decoder_space_upscale
[
1
]
else
1
),
TGrow
(
n_f
[
1
],
2
if
decoder_time_upscale
[
0
]
else
1
),
conv
(
n_f
[
1
],
n_f
[
2
],
bias
=
False
),
MemBlock
(
n_f
[
2
],
n_f
[
2
]),
MemBlock
(
n_f
[
2
],
n_f
[
2
]),
MemBlock
(
n_f
[
2
],
n_f
[
2
]),
nn
.
Upsample
(
scale_factor
=
2
if
decoder_space_upscale
[
2
]
else
1
),
TGrow
(
n_f
[
2
],
2
if
decoder_time_upscale
[
1
]
else
1
),
conv
(
n_f
[
2
],
n_f
[
3
],
bias
=
False
),
nn
.
ReLU
(
inplace
=
True
),
conv
(
n_f
[
3
],
TAEHV
.
image_channels
),
)
if
checkpoint_path
is
not
None
:
self
.
load_state_dict
(
self
.
patch_tgrow_layers
(
torch
.
load
(
checkpoint_path
,
map_location
=
"cpu"
,
weights_only
=
True
)))
def
patch_tgrow_layers
(
self
,
sd
):
"""Patch TGrow layers to use a smaller kernel if needed.
Args:
sd: state dict to patch
"""
new_sd
=
self
.
state_dict
()
for
i
,
layer
in
enumerate
(
self
.
decoder
):
if
isinstance
(
layer
,
TGrow
):
key
=
f
"decoder.
{
i
}
.conv.weight"
if
sd
[
key
].
shape
[
0
]
>
new_sd
[
key
].
shape
[
0
]:
# take the last-timestep output channels
sd
[
key
]
=
sd
[
key
][
-
new_sd
[
key
].
shape
[
0
]
:]
return
sd
def
encode_video
(
self
,
x
,
parallel
=
True
,
show_progress_bar
=
True
):
"""Encode a sequence of frames.
Args:
x: input NTCHW RGB (C=3) tensor with values in [0, 1].
parallel: if True, all frames will be processed at once.
(this is faster but may require more memory).
if False, frames will be processed sequentially.
Returns NTCHW latent tensor with ~Gaussian values.
"""
return
apply_model_with_memblocks
(
self
.
encoder
,
x
,
parallel
,
show_progress_bar
)
def
decode_video
(
self
,
x
,
parallel
=
True
,
show_progress_bar
=
True
):
"""Decode a sequence of frames.
Args:
x: input NTCHW latent (C=12) tensor with ~Gaussian values.
parallel: if True, all frames will be processed at once.
(this is faster but may require more memory).
if False, frames will be processed sequentially.
Returns NTCHW RGB tensor with ~[0, 1] values.
"""
x
=
apply_model_with_memblocks
(
self
.
decoder
,
x
,
parallel
,
show_progress_bar
)
return
x
[:,
self
.
frames_to_trim
:]
def
forward
(
self
,
x
):
return
self
.
c
(
x
)
@
torch
.
no_grad
()
def
main
():
"""Run TAEHV roundtrip reconstruction on the given video paths."""
import
sys
import
cv2
# no highly esteemed deed is commemorated here
class
VideoTensorReader
:
def
__init__
(
self
,
video_file_path
):
self
.
cap
=
cv2
.
VideoCapture
(
video_file_path
)
assert
self
.
cap
.
isOpened
(),
f
"Could not load
{
video_file_path
}
"
self
.
fps
=
self
.
cap
.
get
(
cv2
.
CAP_PROP_FPS
)
def
__iter__
(
self
):
return
self
def
__next__
(
self
):
ret
,
frame
=
self
.
cap
.
read
()
if
not
ret
:
self
.
cap
.
release
()
raise
StopIteration
# End of video or error
return
torch
.
from_numpy
(
cv2
.
cvtColor
(
frame
,
cv2
.
COLOR_BGR2RGB
)).
permute
(
2
,
0
,
1
)
# BGR HWC -> RGB CHW
class
VideoTensorWriter
:
def
__init__
(
self
,
video_file_path
,
width_height
,
fps
=
30
):
self
.
writer
=
cv2
.
VideoWriter
(
video_file_path
,
cv2
.
VideoWriter_fourcc
(
*
"mp4v"
),
fps
,
width_height
)
assert
self
.
writer
.
isOpened
(),
f
"Could not create writer for
{
video_file_path
}
"
def
write
(
self
,
frame_tensor
):
assert
frame_tensor
.
ndim
==
3
and
frame_tensor
.
shape
[
0
]
==
3
,
f
"
{
frame_tensor
.
shape
}
??"
self
.
writer
.
write
(
cv2
.
cvtColor
(
frame_tensor
.
permute
(
1
,
2
,
0
).
numpy
(),
cv2
.
COLOR_RGB2BGR
))
# RGB CHW -> BGR HWC
def
__del__
(
self
):
if
hasattr
(
self
,
"writer"
):
self
.
writer
.
release
()
dev
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"mps"
if
torch
.
backends
.
mps
.
is_available
()
else
"cpu"
)
dtype
=
torch
.
float16
print
(
"Using device"
,
dev
,
"and dtype"
,
dtype
)
taehv
=
TAEHV
().
to
(
dev
,
dtype
)
for
video_path
in
sys
.
argv
[
1
:]:
print
(
f
"Processing
{
video_path
}
..."
)
video_in
=
VideoTensorReader
(
video_path
)
video
=
torch
.
stack
(
list
(
video_in
),
0
)[
None
]
vid_dev
=
video
.
to
(
dev
,
dtype
).
div_
(
255.0
)
# convert to device tensor
if
video
.
numel
()
<
100_000_000
:
print
(
f
"
{
video_path
}
seems small enough, will process all frames in parallel"
)
# convert to device tensor
vid_enc
=
taehv
.
encode_video
(
vid_dev
)
print
(
f
" Encoded
{
video_path
}
->
{
vid_enc
.
shape
}
. Decoding..."
)
vid_dec
=
taehv
.
decode_video
(
vid_enc
)
print
(
f
" Decoded
{
video_path
}
->
{
vid_dec
.
shape
}
"
)
else
:
print
(
f
"
{
video_path
}
seems large, will process each frame sequentially"
)
# convert to device tensor
vid_enc
=
taehv
.
encode_video
(
vid_dev
,
parallel
=
False
)
print
(
f
" Encoded
{
video_path
}
->
{
vid_enc
.
shape
}
. Decoding..."
)
vid_dec
=
taehv
.
decode_video
(
vid_enc
,
parallel
=
False
)
print
(
f
" Decoded
{
video_path
}
->
{
vid_dec
.
shape
}
"
)
video_out_path
=
video_path
+
".reconstructed_by_taehv.mp4"
video_out
=
VideoTensorWriter
(
video_out_path
,
(
vid_dec
.
shape
[
-
1
],
vid_dec
.
shape
[
-
2
]),
fps
=
int
(
round
(
video_in
.
fps
)))
for
frame
in
vid_dec
.
clamp_
(
0
,
1
).
mul_
(
255
).
round_
().
byte
().
cpu
()[
0
]:
video_out
.
write
(
frame
)
print
(
f
" Saved to
{
video_out_path
}
"
)
if
__name__
==
"__main__"
:
main
()
lightx2v/models/video_encoders/hf/wan/vae.py
View file @
486e6279
...
@@ -517,7 +517,15 @@ class WanVAE_(nn.Module):
...
@@ -517,7 +517,15 @@ class WanVAE_(nn.Module):
self
.
attn_scales
=
attn_scales
self
.
attn_scales
=
attn_scales
self
.
temperal_downsample
=
temperal_downsample
self
.
temperal_downsample
=
temperal_downsample
self
.
temperal_upsample
=
temperal_downsample
[::
-
1
]
self
.
temperal_upsample
=
temperal_downsample
[::
-
1
]
self
.
spatial_compression_ratio
=
2
**
len
(
self
.
temperal_downsample
)
# The minimal tile height and width for spatial tiling to be used
self
.
tile_sample_min_height
=
256
self
.
tile_sample_min_width
=
256
# The minimal distance between two spatial tiles
self
.
tile_sample_stride_height
=
192
self
.
tile_sample_stride_width
=
192
# modules
# modules
self
.
encoder
=
Encoder3d
(
self
.
encoder
=
Encoder3d
(
dim
,
dim
,
...
@@ -546,6 +554,134 @@ class WanVAE_(nn.Module):
...
@@ -546,6 +554,134 @@ class WanVAE_(nn.Module):
x_recon
=
self
.
decode
(
z
)
x_recon
=
self
.
decode
(
z
)
return
x_recon
,
mu
,
log_var
return
x_recon
,
mu
,
log_var
def
blend_v
(
self
,
a
,
b
,
blend_extent
):
blend_extent
=
min
(
a
.
shape
[
-
2
],
b
.
shape
[
-
2
],
blend_extent
)
for
y
in
range
(
blend_extent
):
b
[:,
:,
:,
y
,
:]
=
a
[:,
:,
:,
-
blend_extent
+
y
,
:]
*
(
1
-
y
/
blend_extent
)
+
b
[:,
:,
:,
y
,
:]
*
(
y
/
blend_extent
)
return
b
def
blend_h
(
self
,
a
,
b
,
blend_extent
):
blend_extent
=
min
(
a
.
shape
[
-
1
],
b
.
shape
[
-
1
],
blend_extent
)
for
x
in
range
(
blend_extent
):
b
[:,
:,
:,
:,
x
]
=
a
[:,
:,
:,
:,
-
blend_extent
+
x
]
*
(
1
-
x
/
blend_extent
)
+
b
[:,
:,
:,
:,
x
]
*
(
x
/
blend_extent
)
return
b
def
tiled_encode
(
self
,
x
,
scale
):
_
,
_
,
num_frames
,
height
,
width
=
x
.
shape
latent_height
=
height
//
self
.
spatial_compression_ratio
latent_width
=
width
//
self
.
spatial_compression_ratio
tile_latent_min_height
=
self
.
tile_sample_min_height
//
self
.
spatial_compression_ratio
tile_latent_min_width
=
self
.
tile_sample_min_width
//
self
.
spatial_compression_ratio
tile_latent_stride_height
=
self
.
tile_sample_stride_height
//
self
.
spatial_compression_ratio
tile_latent_stride_width
=
self
.
tile_sample_stride_width
//
self
.
spatial_compression_ratio
blend_height
=
tile_latent_min_height
-
tile_latent_stride_height
blend_width
=
tile_latent_min_width
-
tile_latent_stride_width
# Split x into overlapping tiles and encode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows
=
[]
for
i
in
range
(
0
,
height
,
self
.
tile_sample_stride_height
):
row
=
[]
for
j
in
range
(
0
,
width
,
self
.
tile_sample_stride_width
):
self
.
clear_cache
()
time
=
[]
frame_range
=
1
+
(
num_frames
-
1
)
//
4
for
k
in
range
(
frame_range
):
self
.
_enc_conv_idx
=
[
0
]
if
k
==
0
:
tile
=
x
[:,
:,
:
1
,
i
:
i
+
self
.
tile_sample_min_height
,
j
:
j
+
self
.
tile_sample_min_width
]
else
:
tile
=
x
[
:,
:,
1
+
4
*
(
k
-
1
)
:
1
+
4
*
k
,
i
:
i
+
self
.
tile_sample_min_height
,
j
:
j
+
self
.
tile_sample_min_width
,
]
tile
=
self
.
encoder
(
tile
,
feat_cache
=
self
.
_enc_feat_map
,
feat_idx
=
self
.
_enc_conv_idx
)
mu
,
log_var
=
self
.
conv1
(
tile
).
chunk
(
2
,
dim
=
1
)
if
isinstance
(
scale
[
0
],
torch
.
Tensor
):
mu
=
(
mu
-
scale
[
0
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
))
*
scale
[
1
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
)
else
:
mu
=
(
mu
-
scale
[
0
])
*
scale
[
1
]
time
.
append
(
mu
)
row
.
append
(
torch
.
cat
(
time
,
dim
=
2
))
rows
.
append
(
row
)
self
.
clear_cache
()
result_rows
=
[]
for
i
,
row
in
enumerate
(
rows
):
result_row
=
[]
for
j
,
tile
in
enumerate
(
row
):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if
i
>
0
:
tile
=
self
.
blend_v
(
rows
[
i
-
1
][
j
],
tile
,
blend_height
)
if
j
>
0
:
tile
=
self
.
blend_h
(
row
[
j
-
1
],
tile
,
blend_width
)
result_row
.
append
(
tile
[:,
:,
:,
:
tile_latent_stride_height
,
:
tile_latent_stride_width
])
result_rows
.
append
(
torch
.
cat
(
result_row
,
dim
=-
1
))
enc
=
torch
.
cat
(
result_rows
,
dim
=
3
)[:,
:,
:,
:
latent_height
,
:
latent_width
]
return
enc
def
tiled_decode
(
self
,
z
,
scale
):
if
isinstance
(
scale
[
0
],
torch
.
Tensor
):
z
=
z
/
scale
[
1
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
)
+
scale
[
0
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
)
else
:
z
=
z
/
scale
[
1
]
+
scale
[
0
]
_
,
_
,
num_frames
,
height
,
width
=
z
.
shape
sample_height
=
height
*
self
.
spatial_compression_ratio
sample_width
=
width
*
self
.
spatial_compression_ratio
tile_latent_min_height
=
self
.
tile_sample_min_height
//
self
.
spatial_compression_ratio
tile_latent_min_width
=
self
.
tile_sample_min_width
//
self
.
spatial_compression_ratio
tile_latent_stride_height
=
self
.
tile_sample_stride_height
//
self
.
spatial_compression_ratio
tile_latent_stride_width
=
self
.
tile_sample_stride_width
//
self
.
spatial_compression_ratio
blend_height
=
self
.
tile_sample_min_height
-
self
.
tile_sample_stride_height
blend_width
=
self
.
tile_sample_min_width
-
self
.
tile_sample_stride_width
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows
=
[]
for
i
in
range
(
0
,
height
,
tile_latent_stride_height
):
row
=
[]
for
j
in
range
(
0
,
width
,
tile_latent_stride_width
):
self
.
clear_cache
()
time
=
[]
for
k
in
range
(
num_frames
):
self
.
_conv_idx
=
[
0
]
tile
=
z
[:,
:,
k
:
k
+
1
,
i
:
i
+
tile_latent_min_height
,
j
:
j
+
tile_latent_min_width
]
tile
=
self
.
conv2
(
tile
)
decoded
=
self
.
decoder
(
tile
,
feat_cache
=
self
.
_feat_map
,
feat_idx
=
self
.
_conv_idx
)
time
.
append
(
decoded
)
row
.
append
(
torch
.
cat
(
time
,
dim
=
2
))
rows
.
append
(
row
)
self
.
clear_cache
()
result_rows
=
[]
for
i
,
row
in
enumerate
(
rows
):
result_row
=
[]
for
j
,
tile
in
enumerate
(
row
):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if
i
>
0
:
tile
=
self
.
blend_v
(
rows
[
i
-
1
][
j
],
tile
,
blend_height
)
if
j
>
0
:
tile
=
self
.
blend_h
(
row
[
j
-
1
],
tile
,
blend_width
)
result_row
.
append
(
tile
[:,
:,
:,
:
self
.
tile_sample_stride_height
,
:
self
.
tile_sample_stride_width
])
result_rows
.
append
(
torch
.
cat
(
result_row
,
dim
=-
1
))
dec
=
torch
.
cat
(
result_rows
,
dim
=
3
)[:,
:,
:,
:
sample_height
,
:
sample_width
]
return
dec
def
encode
(
self
,
x
,
scale
):
def
encode
(
self
,
x
,
scale
):
self
.
clear_cache
()
self
.
clear_cache
()
## cache
## cache
...
@@ -660,10 +796,12 @@ class WanVAE:
...
@@ -660,10 +796,12 @@ class WanVAE:
dtype
=
torch
.
float
,
dtype
=
torch
.
float
,
device
=
"cuda"
,
device
=
"cuda"
,
parallel
=
False
,
parallel
=
False
,
use_tiling
=
False
,
):
):
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
device
=
device
self
.
device
=
device
self
.
parallel
=
parallel
self
.
parallel
=
parallel
self
.
use_tiling
=
use_tiling
mean
=
[
mean
=
[
-
0.7571
,
-
0.7571
,
...
@@ -735,7 +873,10 @@ class WanVAE:
...
@@ -735,7 +873,10 @@ class WanVAE:
if
args
.
cpu_offload
:
if
args
.
cpu_offload
:
self
.
to_cuda
()
self
.
to_cuda
()
out
=
[
self
.
model
.
encode
(
u
.
unsqueeze
(
0
),
self
.
scale
).
float
().
squeeze
(
0
)
for
u
in
videos
]
if
self
.
use_tiling
:
out
=
[
self
.
model
.
tiled_encode
(
u
.
unsqueeze
(
0
),
self
.
scale
).
float
().
squeeze
(
0
)
for
u
in
videos
]
else
:
out
=
[
self
.
model
.
encode
(
u
.
unsqueeze
(
0
),
self
.
scale
).
float
().
squeeze
(
0
)
for
u
in
videos
]
if
args
.
cpu_offload
:
if
args
.
cpu_offload
:
self
.
to_cpu
()
self
.
to_cpu
()
...
@@ -806,6 +947,8 @@ class WanVAE:
...
@@ -806,6 +947,8 @@ class WanVAE:
else
:
else
:
logger
.
info
(
"Fall back to naive decode mode"
)
logger
.
info
(
"Fall back to naive decode mode"
)
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
elif
self
.
use_tiling
:
images
=
self
.
model
.
tiled_decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
else
:
else
:
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
...
...
lightx2v/models/video_encoders/hf/wan/vae_tiny.py
0 → 100644
View file @
486e6279
import
torch
import
torch.nn
as
nn
from
..tae
import
TAEHV
from
lightx2v.utils.memory_profiler
import
peak_memory_decorator
class
DotDict
(
dict
):
__getattr__
=
dict
.
__getitem__
__setattr__
=
dict
.
__setitem__
class
WanVAE_tiny
(
nn
.
Module
):
def
__init__
(
self
,
vae_pth
=
"taew2_1.pth"
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
):
super
().
__init__
()
self
.
dtype
=
dtype
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
taehv
=
TAEHV
(
vae_pth
).
to
(
self
.
dtype
)
self
.
temperal_downsample
=
[
True
,
True
,
False
]
self
.
config
=
DotDict
(
scaling_factor
=
1.0
,
latents_mean
=
torch
.
zeros
(
16
),
z_dim
=
16
,
latents_std
=
torch
.
ones
(
16
))
@
peak_memory_decorator
@
torch
.
no_grad
()
def
decode
(
self
,
latents
,
generator
=
None
,
return_dict
=
None
,
config
=
None
):
latents
=
latents
.
unsqueeze
(
0
)
n
,
c
,
t
,
h
,
w
=
latents
.
shape
# low-memory, set parallel=True for faster + higher memory
return
self
.
taehv
.
decode_video
(
latents
.
transpose
(
1
,
2
).
to
(
self
.
dtype
),
parallel
=
False
).
transpose
(
1
,
2
).
mul_
(
2
).
sub_
(
1
)
lightx2v/utils/set_config.py
View file @
486e6279
import
json
import
json
import
os
import
os
from
easydict
import
EasyDict
from
easydict
import
EasyDict
from
loguru
import
logger
def
get_default_config
():
def
get_default_config
():
...
@@ -38,4 +39,8 @@ def set_config(args):
...
@@ -38,4 +39,8 @@ def set_config(args):
model_config
=
json
.
load
(
f
)
model_config
=
json
.
load
(
f
)
config
.
update
(
model_config
)
config
.
update
(
model_config
)
if
config
.
target_video_length
%
config
.
vae_stride
[
0
]
!=
1
:
logger
.
warning
(
f
"`num_frames - 1` has to be divisible by
{
config
.
vae_stride
[
0
]
}
. Rounding to the nearest number."
)
config
.
target_video_length
=
config
.
target_video_length
//
config
.
vae_stride
[
0
]
*
config
.
vae_stride
[
0
]
+
1
return
config
return
config
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment