Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
daf4c74e
Commit
daf4c74e
authored
Mar 24, 2025
by
helloyongyang
Committed by
Yang Yong(雍洋)
Apr 08, 2025
Browse files
first commit
parent
6c79160f
Changes
107
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3931 additions
and
0 deletions
+3931
-0
lightx2v/text2v/models/schedulers/wan/scheduler.py
lightx2v/text2v/models/schedulers/wan/scheduler.py
+364
-0
lightx2v/text2v/models/text_encoders/hf/clip/__init__.py
lightx2v/text2v/models/text_encoders/hf/clip/__init__.py
+0
-0
lightx2v/text2v/models/text_encoders/hf/clip/model.py
lightx2v/text2v/models/text_encoders/hf/clip/model.py
+56
-0
lightx2v/text2v/models/text_encoders/hf/llama/__init__.py
lightx2v/text2v/models/text_encoders/hf/llama/__init__.py
+0
-0
lightx2v/text2v/models/text_encoders/hf/llama/model.py
lightx2v/text2v/models/text_encoders/hf/llama/model.py
+69
-0
lightx2v/text2v/models/text_encoders/hf/t5/__init__.py
lightx2v/text2v/models/text_encoders/hf/t5/__init__.py
+0
-0
lightx2v/text2v/models/text_encoders/hf/t5/model.py
lightx2v/text2v/models/text_encoders/hf/t5/model.py
+592
-0
lightx2v/text2v/models/text_encoders/hf/t5/tokenizer.py
lightx2v/text2v/models/text_encoders/hf/t5/tokenizer.py
+85
-0
lightx2v/text2v/models/video_encoders/hf/__init__.py
lightx2v/text2v/models/video_encoders/hf/__init__.py
+0
-0
lightx2v/text2v/models/video_encoders/hf/autoencoder_kl_causal_3d/__init__.py
...ls/video_encoders/hf/autoencoder_kl_causal_3d/__init__.py
+0
-0
lightx2v/text2v/models/video_encoders/hf/autoencoder_kl_causal_3d/autoencoder_kl_causal_3d.py
...s/hf/autoencoder_kl_causal_3d/autoencoder_kl_causal_3d.py
+603
-0
lightx2v/text2v/models/video_encoders/hf/autoencoder_kl_causal_3d/model.py
...odels/video_encoders/hf/autoencoder_kl_causal_3d/model.py
+45
-0
lightx2v/text2v/models/video_encoders/hf/autoencoder_kl_causal_3d/unet_causal_3d_blocks.py
...ders/hf/autoencoder_kl_causal_3d/unet_causal_3d_blocks.py
+783
-0
lightx2v/text2v/models/video_encoders/hf/autoencoder_kl_causal_3d/vae.py
.../models/video_encoders/hf/autoencoder_kl_causal_3d/vae.py
+355
-0
lightx2v/text2v/models/video_encoders/hf/wan/__init__.py
lightx2v/text2v/models/video_encoders/hf/wan/__init__.py
+0
-0
lightx2v/text2v/models/video_encoders/hf/wan/vae.py
lightx2v/text2v/models/video_encoders/hf/wan/vae.py
+774
-0
lightx2v/text2v/models/video_encoders/trt/__init__.py
lightx2v/text2v/models/video_encoders/trt/__init__.py
+0
-0
lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/model.py
...dels/video_encoders/trt/autoencoder_kl_causal_3d/model.py
+39
-0
lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/trt_vae_infer.py
...eo_encoders/trt/autoencoder_kl_causal_3d/trt_vae_infer.py
+166
-0
lightx2v/utils/__init__.py
lightx2v/utils/__init__.py
+0
-0
No files found.
lightx2v/text2v/models/schedulers/wan/scheduler.py
0 → 100755
View file @
daf4c74e
import
math
import
numpy
as
np
import
torch
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
lightx2v.text2v.models.schedulers.scheduler
import
BaseScheduler
class
WanScheduler
(
BaseScheduler
):
def
__init__
(
self
,
args
):
super
().
__init__
(
args
)
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
infer_steps
=
self
.
args
.
infer_steps
self
.
target_video_length
=
self
.
args
.
target_video_length
self
.
sample_shift
=
self
.
args
.
sample_shift
self
.
shift
=
1
self
.
num_train_timesteps
=
1000
self
.
disable_corrector
=
[]
self
.
solver_order
=
2
self
.
noise_pred
=
None
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
)
self
.
generator
.
manual_seed
(
self
.
args
.
seed
)
self
.
prepare_latents
(
self
.
args
.
target_shape
,
dtype
=
torch
.
float32
)
if
self
.
args
.
task
in
[
"t2v"
]:
self
.
seq_len
=
math
.
ceil
(
(
self
.
args
.
target_shape
[
2
]
*
self
.
args
.
target_shape
[
3
])
/
(
self
.
args
.
patch_size
[
1
]
*
self
.
args
.
patch_size
[
2
])
*
self
.
args
.
target_shape
[
1
]
)
elif
self
.
args
.
task
in
[
"i2v"
]:
self
.
seq_len
=
((
self
.
args
.
target_video_length
-
1
)
//
self
.
args
.
vae_stride
[
0
]
+
1
)
*
args
.
lat_h
*
args
.
lat_w
//
(
args
.
patch_size
[
1
]
*
args
.
patch_size
[
2
])
alphas
=
np
.
linspace
(
1
,
1
/
self
.
num_train_timesteps
,
self
.
num_train_timesteps
)[
::
-
1
].
copy
()
sigmas
=
1.0
-
alphas
sigmas
=
torch
.
from_numpy
(
sigmas
).
to
(
dtype
=
torch
.
float32
)
sigmas
=
self
.
shift
*
sigmas
/
(
1
+
(
self
.
shift
-
1
)
*
sigmas
)
self
.
sigmas
=
sigmas
self
.
timesteps
=
sigmas
*
self
.
num_train_timesteps
self
.
model_outputs
=
[
None
]
*
self
.
solver_order
self
.
timestep_list
=
[
None
]
*
self
.
solver_order
self
.
last_sample
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
self
.
sigma_min
=
self
.
sigmas
[
-
1
].
item
()
self
.
sigma_max
=
self
.
sigmas
[
0
].
item
()
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
)
def
prepare_latents
(
self
,
target_shape
,
dtype
=
torch
.
float32
):
self
.
latents
=
torch
.
randn
(
target_shape
[
0
],
target_shape
[
1
],
target_shape
[
2
],
target_shape
[
3
],
dtype
=
dtype
,
device
=
self
.
device
,
generator
=
self
.
generator
,
)
def
set_timesteps
(
self
,
infer_steps
:
Union
[
int
,
None
]
=
None
,
device
:
Union
[
str
,
torch
.
device
]
=
None
,
sigmas
:
Optional
[
List
[
float
]]
=
None
,
mu
:
Optional
[
Union
[
float
,
None
]]
=
None
,
shift
:
Optional
[
Union
[
float
,
None
]]
=
None
,
):
sigmas
=
np
.
linspace
(
self
.
sigma_max
,
self
.
sigma_min
,
infer_steps
+
1
).
copy
()[
:
-
1
]
if
shift
is
None
:
shift
=
self
.
shift
sigmas
=
shift
*
sigmas
/
(
1
+
(
shift
-
1
)
*
sigmas
)
sigma_last
=
0
timesteps
=
sigmas
*
self
.
num_train_timesteps
sigmas
=
np
.
concatenate
([
sigmas
,
[
sigma_last
]]).
astype
(
np
.
float32
)
self
.
sigmas
=
torch
.
from_numpy
(
sigmas
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
=
device
,
dtype
=
torch
.
int64
)
assert
len
(
self
.
timesteps
)
==
self
.
infer_steps
self
.
model_outputs
=
[
None
,
]
*
self
.
solver_order
self
.
lower_order_nums
=
0
self
.
last_sample
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
def
_sigma_to_alpha_sigma_t
(
self
,
sigma
):
return
1
-
sigma
,
sigma
def
convert_model_output
(
self
,
model_output
:
torch
.
Tensor
,
*
args
,
sample
:
torch
.
Tensor
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
timestep
=
args
[
0
]
if
len
(
args
)
>
0
else
kwargs
.
pop
(
"timestep"
,
None
)
if
sample
is
None
:
if
len
(
args
)
>
1
:
sample
=
args
[
1
]
else
:
raise
ValueError
(
"missing `sample` as a required keyward argument"
)
sigma
=
self
.
sigmas
[
self
.
step_index
]
alpha_t
,
sigma_t
=
self
.
_sigma_to_alpha_sigma_t
(
sigma
)
sigma_t
=
self
.
sigmas
[
self
.
step_index
]
x0_pred
=
sample
-
sigma_t
*
model_output
return
x0_pred
def
multistep_uni_p_bh_update
(
self
,
model_output
:
torch
.
Tensor
,
*
args
,
sample
:
torch
.
Tensor
=
None
,
order
:
int
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
prev_timestep
=
args
[
0
]
if
len
(
args
)
>
0
else
kwargs
.
pop
(
"prev_timestep"
,
None
)
if
sample
is
None
:
if
len
(
args
)
>
1
:
sample
=
args
[
1
]
else
:
raise
ValueError
(
" missing `sample` as a required keyward argument"
)
if
order
is
None
:
if
len
(
args
)
>
2
:
order
=
args
[
2
]
else
:
raise
ValueError
(
" missing `order` as a required keyward argument"
)
model_output_list
=
self
.
model_outputs
s0
=
self
.
timestep_list
[
-
1
]
m0
=
model_output_list
[
-
1
]
x
=
sample
sigma_t
,
sigma_s0
=
(
self
.
sigmas
[
self
.
step_index
+
1
],
self
.
sigmas
[
self
.
step_index
],
)
alpha_t
,
sigma_t
=
self
.
_sigma_to_alpha_sigma_t
(
sigma_t
)
alpha_s0
,
sigma_s0
=
self
.
_sigma_to_alpha_sigma_t
(
sigma_s0
)
lambda_t
=
torch
.
log
(
alpha_t
)
-
torch
.
log
(
sigma_t
)
lambda_s0
=
torch
.
log
(
alpha_s0
)
-
torch
.
log
(
sigma_s0
)
h
=
lambda_t
-
lambda_s0
device
=
sample
.
device
rks
=
[]
D1s
=
[]
for
i
in
range
(
1
,
order
):
si
=
self
.
step_index
-
i
mi
=
model_output_list
[
-
(
i
+
1
)]
alpha_si
,
sigma_si
=
self
.
_sigma_to_alpha_sigma_t
(
self
.
sigmas
[
si
])
lambda_si
=
torch
.
log
(
alpha_si
)
-
torch
.
log
(
sigma_si
)
rk
=
(
lambda_si
-
lambda_s0
)
/
h
rks
.
append
(
rk
)
D1s
.
append
((
mi
-
m0
)
/
rk
)
rks
.
append
(
1.0
)
rks
=
torch
.
tensor
(
rks
,
device
=
device
)
R
=
[]
b
=
[]
hh
=
-
h
h_phi_1
=
torch
.
expm1
(
hh
)
# h\phi_1(h) = e^h - 1
h_phi_k
=
h_phi_1
/
hh
-
1
factorial_i
=
1
B_h
=
torch
.
expm1
(
hh
)
for
i
in
range
(
1
,
order
+
1
):
R
.
append
(
torch
.
pow
(
rks
,
i
-
1
))
b
.
append
(
h_phi_k
*
factorial_i
/
B_h
)
factorial_i
*=
i
+
1
h_phi_k
=
h_phi_k
/
hh
-
1
/
factorial_i
R
=
torch
.
stack
(
R
)
b
=
torch
.
tensor
(
b
,
device
=
device
)
if
len
(
D1s
)
>
0
:
D1s
=
torch
.
stack
(
D1s
,
dim
=
1
)
# (B, K)
# for order 2, we use a simplified version
if
order
==
2
:
rhos_p
=
torch
.
tensor
([
0.5
],
dtype
=
x
.
dtype
,
device
=
device
)
else
:
rhos_p
=
torch
.
linalg
.
solve
(
R
[:
-
1
,
:
-
1
],
b
[:
-
1
]).
to
(
device
).
to
(
x
.
dtype
)
else
:
D1s
=
None
x_t_
=
sigma_t
/
sigma_s0
*
x
-
alpha_t
*
h_phi_1
*
m0
if
D1s
is
not
None
:
pred_res
=
torch
.
einsum
(
"k,bkc...->bc..."
,
rhos_p
,
D1s
)
else
:
pred_res
=
0
x_t
=
x_t_
-
alpha_t
*
B_h
*
pred_res
x_t
=
x_t
.
to
(
x
.
dtype
)
return
x_t
def
multistep_uni_c_bh_update
(
self
,
this_model_output
:
torch
.
Tensor
,
*
args
,
last_sample
:
torch
.
Tensor
=
None
,
this_sample
:
torch
.
Tensor
=
None
,
order
:
int
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
this_timestep
=
args
[
0
]
if
len
(
args
)
>
0
else
kwargs
.
pop
(
"this_timestep"
,
None
)
if
last_sample
is
None
:
if
len
(
args
)
>
1
:
last_sample
=
args
[
1
]
else
:
raise
ValueError
(
" missing`last_sample` as a required keyward argument"
)
if
this_sample
is
None
:
if
len
(
args
)
>
2
:
this_sample
=
args
[
2
]
else
:
raise
ValueError
(
" missing`this_sample` as a required keyward argument"
)
if
order
is
None
:
if
len
(
args
)
>
3
:
order
=
args
[
3
]
else
:
raise
ValueError
(
" missing`order` as a required keyward argument"
)
model_output_list
=
self
.
model_outputs
m0
=
model_output_list
[
-
1
]
x
=
last_sample
x_t
=
this_sample
model_t
=
this_model_output
sigma_t
,
sigma_s0
=
(
self
.
sigmas
[
self
.
step_index
],
self
.
sigmas
[
self
.
step_index
-
1
],
)
alpha_t
,
sigma_t
=
self
.
_sigma_to_alpha_sigma_t
(
sigma_t
)
alpha_s0
,
sigma_s0
=
self
.
_sigma_to_alpha_sigma_t
(
sigma_s0
)
lambda_t
=
torch
.
log
(
alpha_t
)
-
torch
.
log
(
sigma_t
)
lambda_s0
=
torch
.
log
(
alpha_s0
)
-
torch
.
log
(
sigma_s0
)
h
=
lambda_t
-
lambda_s0
device
=
this_sample
.
device
rks
=
[]
D1s
=
[]
for
i
in
range
(
1
,
order
):
si
=
self
.
step_index
-
(
i
+
1
)
mi
=
model_output_list
[
-
(
i
+
1
)]
alpha_si
,
sigma_si
=
self
.
_sigma_to_alpha_sigma_t
(
self
.
sigmas
[
si
])
lambda_si
=
torch
.
log
(
alpha_si
)
-
torch
.
log
(
sigma_si
)
rk
=
(
lambda_si
-
lambda_s0
)
/
h
rks
.
append
(
rk
)
D1s
.
append
((
mi
-
m0
)
/
rk
)
rks
.
append
(
1.0
)
rks
=
torch
.
tensor
(
rks
,
device
=
device
)
R
=
[]
b
=
[]
hh
=
-
h
h_phi_1
=
torch
.
expm1
(
hh
)
# h\phi_1(h) = e^h - 1
h_phi_k
=
h_phi_1
/
hh
-
1
factorial_i
=
1
B_h
=
torch
.
expm1
(
hh
)
for
i
in
range
(
1
,
order
+
1
):
R
.
append
(
torch
.
pow
(
rks
,
i
-
1
))
b
.
append
(
h_phi_k
*
factorial_i
/
B_h
)
factorial_i
*=
i
+
1
h_phi_k
=
h_phi_k
/
hh
-
1
/
factorial_i
R
=
torch
.
stack
(
R
)
b
=
torch
.
tensor
(
b
,
device
=
device
)
if
len
(
D1s
)
>
0
:
D1s
=
torch
.
stack
(
D1s
,
dim
=
1
)
else
:
D1s
=
None
# for order 1, we use a simplified version
if
order
==
1
:
rhos_c
=
torch
.
tensor
([
0.5
],
dtype
=
x
.
dtype
,
device
=
device
)
else
:
rhos_c
=
torch
.
linalg
.
solve
(
R
,
b
).
to
(
device
).
to
(
x
.
dtype
)
x_t_
=
sigma_t
/
sigma_s0
*
x
-
alpha_t
*
h_phi_1
*
m0
if
D1s
is
not
None
:
corr_res
=
torch
.
einsum
(
"k,bkc...->bc..."
,
rhos_c
[:
-
1
],
D1s
)
else
:
corr_res
=
0
D1_t
=
model_t
-
m0
x_t
=
x_t_
-
alpha_t
*
B_h
*
(
corr_res
+
rhos_c
[
-
1
]
*
D1_t
)
x_t
=
x_t
.
to
(
x
.
dtype
)
return
x_t
def
step_post
(
self
):
model_output
=
self
.
noise_pred
.
to
(
torch
.
float32
)
timestep
=
self
.
timesteps
[
self
.
step_index
]
sample
=
self
.
latents
.
to
(
torch
.
float32
)
use_corrector
=
(
self
.
step_index
>
0
and
self
.
step_index
-
1
not
in
self
.
disable_corrector
and
self
.
last_sample
is
not
None
)
model_output_convert
=
self
.
convert_model_output
(
model_output
,
sample
=
sample
)
if
use_corrector
:
sample
=
self
.
multistep_uni_c_bh_update
(
this_model_output
=
model_output_convert
,
last_sample
=
self
.
last_sample
,
this_sample
=
sample
,
order
=
self
.
this_order
,
)
for
i
in
range
(
self
.
solver_order
-
1
):
self
.
model_outputs
[
i
]
=
self
.
model_outputs
[
i
+
1
]
self
.
timestep_list
[
i
]
=
self
.
timestep_list
[
i
+
1
]
self
.
model_outputs
[
-
1
]
=
model_output_convert
self
.
timestep_list
[
-
1
]
=
timestep
this_order
=
min
(
self
.
solver_order
,
len
(
self
.
timesteps
)
-
self
.
step_index
)
self
.
this_order
=
min
(
this_order
,
self
.
lower_order_nums
+
1
)
# warmup for multistep
assert
self
.
this_order
>
0
self
.
last_sample
=
sample
prev_sample
=
self
.
multistep_uni_p_bh_update
(
model_output
=
model_output
,
sample
=
sample
,
order
=
self
.
this_order
,
)
if
self
.
lower_order_nums
<
self
.
solver_order
:
self
.
lower_order_nums
+=
1
self
.
latents
=
prev_sample
lightx2v/text2v/models/text_encoders/hf/clip/__init__.py
0 → 100755
View file @
daf4c74e
lightx2v/text2v/models/text_encoders/hf/clip/model.py
0 → 100755
View file @
daf4c74e
import
torch
from
transformers
import
CLIPTextModel
,
AutoTokenizer
class
TextEncoderHFClipModel
():
def
__init__
(
self
,
model_path
,
device
):
self
.
device
=
device
self
.
model_path
=
model_path
self
.
init
()
self
.
load
()
def
init
(
self
):
self
.
max_length
=
77
def
load
(
self
):
self
.
model
=
CLIPTextModel
.
from_pretrained
(
self
.
model_path
).
to
(
torch
.
float16
).
to
(
self
.
device
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
model_path
,
padding_side
=
"right"
)
def
to_cpu
(
self
):
self
.
model
=
self
.
model
.
to
(
"cpu"
)
def
to_cuda
(
self
):
self
.
model
=
self
.
model
.
to
(
"cuda"
)
@
torch
.
no_grad
()
def
infer
(
self
,
text
,
args
):
if
args
.
cpu_offload
:
self
.
to_cuda
()
tokens
=
self
.
tokenizer
(
text
,
return_length
=
False
,
return_overflowing_tokens
=
False
,
return_attention_mask
=
True
,
truncation
=
True
,
max_length
=
self
.
max_length
,
padding
=
"max_length"
,
return_tensors
=
"pt"
,
).
to
(
"cuda"
)
outputs
=
self
.
model
(
input_ids
=
tokens
[
"input_ids"
],
attention_mask
=
tokens
[
"attention_mask"
],
output_hidden_states
=
False
,
)
last_hidden_state
=
outputs
[
"pooler_output"
]
if
args
.
cpu_offload
:
self
.
to_cpu
()
return
last_hidden_state
,
tokens
[
"attention_mask"
]
if
__name__
==
"__main__"
:
model
=
TextEncoderHFClipModel
(
"/mnt/nvme0/yongyang/projects/hy/HunyuanVideo/ckpts/text_encoder_2"
,
torch
.
device
(
"cuda"
))
text
=
'A cat walks on the grass, realistic style.'
outputs
=
model
.
infer
(
text
)
print
(
outputs
)
lightx2v/text2v/models/text_encoders/hf/llama/__init__.py
0 → 100755
View file @
daf4c74e
lightx2v/text2v/models/text_encoders/hf/llama/model.py
0 → 100755
View file @
daf4c74e
import
torch
from
transformers
import
AutoModel
,
AutoTokenizer
class
TextEncoderHFLlamaModel
():
def
__init__
(
self
,
model_path
,
device
):
self
.
device
=
device
self
.
model_path
=
model_path
self
.
init
()
self
.
load
()
def
init
(
self
):
self
.
max_length
=
351
self
.
hidden_state_skip_layer
=
2
self
.
crop_start
=
95
self
.
prompt_template
=
(
"<|start_header_id|>system<|end_header_id|>
\n\n
Describe the video by detailing the following aspects: "
"1. The main content and theme of the video."
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
"4. background environment, light, style and atmosphere."
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>
\n\n
{}<|eot_id|>"
)
def
load
(
self
):
self
.
model
=
AutoModel
.
from_pretrained
(
self
.
model_path
,
low_cpu_mem_usage
=
True
).
to
(
torch
.
float16
).
to
(
self
.
device
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
model_path
,
padding_side
=
"right"
)
def
to_cpu
(
self
):
self
.
model
=
self
.
model
.
to
(
"cpu"
)
def
to_cuda
(
self
):
self
.
model
=
self
.
model
.
to
(
"cuda"
)
@
torch
.
no_grad
()
def
infer
(
self
,
text
,
args
):
if
args
.
cpu_offload
:
self
.
to_cuda
()
text
=
self
.
prompt_template
.
format
(
text
)
tokens
=
self
.
tokenizer
(
text
,
return_length
=
False
,
return_overflowing_tokens
=
False
,
return_attention_mask
=
True
,
truncation
=
True
,
max_length
=
self
.
max_length
,
padding
=
"max_length"
,
return_tensors
=
"pt"
,
).
to
(
"cuda"
)
outputs
=
self
.
model
(
input_ids
=
tokens
[
"input_ids"
],
attention_mask
=
tokens
[
"attention_mask"
],
output_hidden_states
=
True
,
)
last_hidden_state
=
outputs
.
hidden_states
[
-
(
self
.
hidden_state_skip_layer
+
1
)][:,
self
.
crop_start
:]
attention_mask
=
tokens
[
"attention_mask"
][:,
self
.
crop_start
:]
if
args
.
cpu_offload
:
self
.
to_cpu
()
return
last_hidden_state
,
attention_mask
if
__name__
==
"__main__"
:
model
=
TextEncoderHFLlamaModel
(
"/mnt/nvme0/yongyang/projects/hy/HunyuanVideo/ckpts/text_encoder"
,
torch
.
device
(
"cuda"
))
text
=
'A cat walks on the grass, realistic style.'
outputs
=
model
.
infer
(
text
)
print
(
outputs
)
lightx2v/text2v/models/text_encoders/hf/t5/__init__.py
0 → 100755
View file @
daf4c74e
lightx2v/text2v/models/text_encoders/hf/t5/model.py
0 → 100755
View file @
daf4c74e
# Modified from transformers.models.t5.modeling_t5
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import
logging
import
math
import
os
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
.tokenizer
import
HuggingfaceTokenizer
__all__
=
[
"T5Model"
,
"T5Encoder"
,
"T5Decoder"
,
"T5EncoderModel"
,
]
def
fp16_clamp
(
x
):
if
x
.
dtype
==
torch
.
float16
and
torch
.
isinf
(
x
).
any
():
clamp
=
torch
.
finfo
(
x
.
dtype
).
max
-
1000
x
=
torch
.
clamp
(
x
,
min
=-
clamp
,
max
=
clamp
)
return
x
def
init_weights
(
m
):
if
isinstance
(
m
,
T5LayerNorm
):
nn
.
init
.
ones_
(
m
.
weight
)
elif
isinstance
(
m
,
T5Model
):
nn
.
init
.
normal_
(
m
.
token_embedding
.
weight
,
std
=
1.0
)
elif
isinstance
(
m
,
T5FeedForward
):
nn
.
init
.
normal_
(
m
.
gate
[
0
].
weight
,
std
=
m
.
dim
**-
0.5
)
nn
.
init
.
normal_
(
m
.
fc1
.
weight
,
std
=
m
.
dim
**-
0.5
)
nn
.
init
.
normal_
(
m
.
fc2
.
weight
,
std
=
m
.
dim_ffn
**-
0.5
)
elif
isinstance
(
m
,
T5Attention
):
nn
.
init
.
normal_
(
m
.
q
.
weight
,
std
=
(
m
.
dim
*
m
.
dim_attn
)
**
-
0.5
)
nn
.
init
.
normal_
(
m
.
k
.
weight
,
std
=
m
.
dim
**-
0.5
)
nn
.
init
.
normal_
(
m
.
v
.
weight
,
std
=
m
.
dim
**-
0.5
)
nn
.
init
.
normal_
(
m
.
o
.
weight
,
std
=
(
m
.
num_heads
*
m
.
dim_attn
)
**
-
0.5
)
elif
isinstance
(
m
,
T5RelativeEmbedding
):
nn
.
init
.
normal_
(
m
.
embedding
.
weight
,
std
=
(
2
*
m
.
num_buckets
*
m
.
num_heads
)
**
-
0.5
)
class
GELU
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
(
0.5
*
x
*
(
1.0
+
torch
.
tanh
(
math
.
sqrt
(
2.0
/
math
.
pi
)
*
(
x
+
0.044715
*
torch
.
pow
(
x
,
3.0
))
)
)
)
class
T5LayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
eps
=
1e-6
):
super
(
T5LayerNorm
,
self
).
__init__
()
self
.
dim
=
dim
self
.
eps
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
dim
))
def
forward
(
self
,
x
):
x
=
x
*
torch
.
rsqrt
(
x
.
float
().
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
self
.
eps
)
if
self
.
weight
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]:
x
=
x
.
type_as
(
self
.
weight
)
return
self
.
weight
*
x
class
T5Attention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_attn
,
num_heads
,
dropout
=
0.1
):
assert
dim_attn
%
num_heads
==
0
super
(
T5Attention
,
self
).
__init__
()
self
.
dim
=
dim
self
.
dim_attn
=
dim_attn
self
.
num_heads
=
num_heads
self
.
head_dim
=
dim_attn
//
num_heads
# layers
self
.
q
=
nn
.
Linear
(
dim
,
dim_attn
,
bias
=
False
)
self
.
k
=
nn
.
Linear
(
dim
,
dim_attn
,
bias
=
False
)
self
.
v
=
nn
.
Linear
(
dim
,
dim_attn
,
bias
=
False
)
self
.
o
=
nn
.
Linear
(
dim_attn
,
dim
,
bias
=
False
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
,
pos_bias
=
None
):
"""
x: [B, L1, C].
context: [B, L2, C] or None.
mask: [B, L2] or [B, L1, L2] or None.
"""
# check inputs
context
=
x
if
context
is
None
else
context
b
,
n
,
c
=
x
.
size
(
0
),
self
.
num_heads
,
self
.
head_dim
# compute query, key, value
q
=
self
.
q
(
x
).
view
(
b
,
-
1
,
n
,
c
)
k
=
self
.
k
(
context
).
view
(
b
,
-
1
,
n
,
c
)
v
=
self
.
v
(
context
).
view
(
b
,
-
1
,
n
,
c
)
# attention bias
attn_bias
=
x
.
new_zeros
(
b
,
n
,
q
.
size
(
1
),
k
.
size
(
1
))
if
pos_bias
is
not
None
:
attn_bias
+=
pos_bias
if
mask
is
not
None
:
assert
mask
.
ndim
in
[
2
,
3
]
mask
=
mask
.
view
(
b
,
1
,
1
,
-
1
)
if
mask
.
ndim
==
2
else
mask
.
unsqueeze
(
1
)
attn_bias
.
masked_fill_
(
mask
==
0
,
torch
.
finfo
(
x
.
dtype
).
min
)
# compute attention (T5 does not use scaling)
attn
=
torch
.
einsum
(
"binc,bjnc->bnij"
,
q
,
k
)
+
attn_bias
attn
=
F
.
softmax
(
attn
.
float
(),
dim
=-
1
).
type_as
(
attn
)
x
=
torch
.
einsum
(
"bnij,bjnc->binc"
,
attn
,
v
)
# output
x
=
x
.
reshape
(
b
,
-
1
,
n
*
c
)
x
=
self
.
o
(
x
)
x
=
self
.
dropout
(
x
)
return
x
class
T5FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_ffn
,
dropout
=
0.1
):
super
(
T5FeedForward
,
self
).
__init__
()
self
.
dim
=
dim
self
.
dim_ffn
=
dim_ffn
# layers
self
.
gate
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
dim_ffn
,
bias
=
False
),
GELU
())
self
.
fc1
=
nn
.
Linear
(
dim
,
dim_ffn
,
bias
=
False
)
self
.
fc2
=
nn
.
Linear
(
dim_ffn
,
dim
,
bias
=
False
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
*
self
.
gate
(
x
)
x
=
self
.
dropout
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
dropout
(
x
)
return
x
class
T5SelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_attn
,
dim_ffn
,
num_heads
,
num_buckets
,
shared_pos
=
True
,
dropout
=
0.1
,
):
super
(
T5SelfAttention
,
self
).
__init__
()
self
.
dim
=
dim
self
.
dim_attn
=
dim_attn
self
.
dim_ffn
=
dim_ffn
self
.
num_heads
=
num_heads
self
.
num_buckets
=
num_buckets
self
.
shared_pos
=
shared_pos
# layers
self
.
norm1
=
T5LayerNorm
(
dim
)
self
.
attn
=
T5Attention
(
dim
,
dim_attn
,
num_heads
,
dropout
)
self
.
norm2
=
T5LayerNorm
(
dim
)
self
.
ffn
=
T5FeedForward
(
dim
,
dim_ffn
,
dropout
)
self
.
pos_embedding
=
(
None
if
shared_pos
else
T5RelativeEmbedding
(
num_buckets
,
num_heads
,
bidirectional
=
True
)
)
def
forward
(
self
,
x
,
mask
=
None
,
pos_bias
=
None
):
e
=
pos_bias
if
self
.
shared_pos
else
self
.
pos_embedding
(
x
.
size
(
1
),
x
.
size
(
1
))
x
=
fp16_clamp
(
x
+
self
.
attn
(
self
.
norm1
(
x
),
mask
=
mask
,
pos_bias
=
e
))
x
=
fp16_clamp
(
x
+
self
.
ffn
(
self
.
norm2
(
x
)))
return
x
class
T5CrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_attn
,
dim_ffn
,
num_heads
,
num_buckets
,
shared_pos
=
True
,
dropout
=
0.1
,
):
super
(
T5CrossAttention
,
self
).
__init__
()
self
.
dim
=
dim
self
.
dim_attn
=
dim_attn
self
.
dim_ffn
=
dim_ffn
self
.
num_heads
=
num_heads
self
.
num_buckets
=
num_buckets
self
.
shared_pos
=
shared_pos
# layers
self
.
norm1
=
T5LayerNorm
(
dim
)
self
.
self_attn
=
T5Attention
(
dim
,
dim_attn
,
num_heads
,
dropout
)
self
.
norm2
=
T5LayerNorm
(
dim
)
self
.
cross_attn
=
T5Attention
(
dim
,
dim_attn
,
num_heads
,
dropout
)
self
.
norm3
=
T5LayerNorm
(
dim
)
self
.
ffn
=
T5FeedForward
(
dim
,
dim_ffn
,
dropout
)
self
.
pos_embedding
=
(
None
if
shared_pos
else
T5RelativeEmbedding
(
num_buckets
,
num_heads
,
bidirectional
=
False
)
)
def
forward
(
self
,
x
,
mask
=
None
,
encoder_states
=
None
,
encoder_mask
=
None
,
pos_bias
=
None
):
e
=
pos_bias
if
self
.
shared_pos
else
self
.
pos_embedding
(
x
.
size
(
1
),
x
.
size
(
1
))
x
=
fp16_clamp
(
x
+
self
.
self_attn
(
self
.
norm1
(
x
),
mask
=
mask
,
pos_bias
=
e
))
x
=
fp16_clamp
(
x
+
self
.
cross_attn
(
self
.
norm2
(
x
),
context
=
encoder_states
,
mask
=
encoder_mask
)
)
x
=
fp16_clamp
(
x
+
self
.
ffn
(
self
.
norm3
(
x
)))
return
x
class
T5RelativeEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
num_buckets
,
num_heads
,
bidirectional
,
max_dist
=
128
):
super
(
T5RelativeEmbedding
,
self
).
__init__
()
self
.
num_buckets
=
num_buckets
self
.
num_heads
=
num_heads
self
.
bidirectional
=
bidirectional
self
.
max_dist
=
max_dist
# layers
self
.
embedding
=
nn
.
Embedding
(
num_buckets
,
num_heads
)
def
forward
(
self
,
lq
,
lk
):
device
=
self
.
embedding
.
weight
.
device
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
# torch.arange(lq).unsqueeze(1).to(device)
rel_pos
=
torch
.
arange
(
lk
,
device
=
device
).
unsqueeze
(
0
)
-
torch
.
arange
(
lq
,
device
=
device
).
unsqueeze
(
1
)
rel_pos
=
self
.
_relative_position_bucket
(
rel_pos
)
rel_pos_embeds
=
self
.
embedding
(
rel_pos
)
rel_pos_embeds
=
rel_pos_embeds
.
permute
(
2
,
0
,
1
).
unsqueeze
(
0
)
# [1, N, Lq, Lk]
return
rel_pos_embeds
.
contiguous
()
def
_relative_position_bucket
(
self
,
rel_pos
):
# preprocess
if
self
.
bidirectional
:
num_buckets
=
self
.
num_buckets
//
2
rel_buckets
=
(
rel_pos
>
0
).
long
()
*
num_buckets
rel_pos
=
torch
.
abs
(
rel_pos
)
else
:
num_buckets
=
self
.
num_buckets
rel_buckets
=
0
rel_pos
=
-
torch
.
min
(
rel_pos
,
torch
.
zeros_like
(
rel_pos
))
# embeddings for small and large positions
max_exact
=
num_buckets
//
2
rel_pos_large
=
(
max_exact
+
(
torch
.
log
(
rel_pos
.
float
()
/
max_exact
)
/
math
.
log
(
self
.
max_dist
/
max_exact
)
*
(
num_buckets
-
max_exact
)
).
long
()
)
rel_pos_large
=
torch
.
min
(
rel_pos_large
,
torch
.
full_like
(
rel_pos_large
,
num_buckets
-
1
)
)
rel_buckets
+=
torch
.
where
(
rel_pos
<
max_exact
,
rel_pos
,
rel_pos_large
)
return
rel_buckets
class
T5Encoder
(
nn
.
Module
):
def
__init__
(
self
,
vocab
,
dim
,
dim_attn
,
dim_ffn
,
num_heads
,
num_layers
,
num_buckets
,
shared_pos
=
True
,
dropout
=
0.1
,
):
super
(
T5Encoder
,
self
).
__init__
()
self
.
dim
=
dim
self
.
dim_attn
=
dim_attn
self
.
dim_ffn
=
dim_ffn
self
.
num_heads
=
num_heads
self
.
num_layers
=
num_layers
self
.
num_buckets
=
num_buckets
self
.
shared_pos
=
shared_pos
# layers
self
.
token_embedding
=
(
vocab
if
isinstance
(
vocab
,
nn
.
Embedding
)
else
nn
.
Embedding
(
vocab
,
dim
)
)
self
.
pos_embedding
=
(
T5RelativeEmbedding
(
num_buckets
,
num_heads
,
bidirectional
=
True
)
if
shared_pos
else
None
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
blocks
=
nn
.
ModuleList
(
[
T5SelfAttention
(
dim
,
dim_attn
,
dim_ffn
,
num_heads
,
num_buckets
,
shared_pos
,
dropout
)
for
_
in
range
(
num_layers
)
]
)
self
.
norm
=
T5LayerNorm
(
dim
)
# initialize weights
self
.
apply
(
init_weights
)
def
forward
(
self
,
ids
,
mask
=
None
):
x
=
self
.
token_embedding
(
ids
)
x
=
self
.
dropout
(
x
)
e
=
self
.
pos_embedding
(
x
.
size
(
1
),
x
.
size
(
1
))
if
self
.
shared_pos
else
None
for
block
in
self
.
blocks
:
x
=
block
(
x
,
mask
,
pos_bias
=
e
)
x
=
self
.
norm
(
x
)
x
=
self
.
dropout
(
x
)
return
x
class
T5Decoder
(
nn
.
Module
):
def
__init__
(
self
,
vocab
,
dim
,
dim_attn
,
dim_ffn
,
num_heads
,
num_layers
,
num_buckets
,
shared_pos
=
True
,
dropout
=
0.1
,
):
super
(
T5Decoder
,
self
).
__init__
()
self
.
dim
=
dim
self
.
dim_attn
=
dim_attn
self
.
dim_ffn
=
dim_ffn
self
.
num_heads
=
num_heads
self
.
num_layers
=
num_layers
self
.
num_buckets
=
num_buckets
self
.
shared_pos
=
shared_pos
# layers
self
.
token_embedding
=
(
vocab
if
isinstance
(
vocab
,
nn
.
Embedding
)
else
nn
.
Embedding
(
vocab
,
dim
)
)
self
.
pos_embedding
=
(
T5RelativeEmbedding
(
num_buckets
,
num_heads
,
bidirectional
=
False
)
if
shared_pos
else
None
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
blocks
=
nn
.
ModuleList
(
[
T5CrossAttention
(
dim
,
dim_attn
,
dim_ffn
,
num_heads
,
num_buckets
,
shared_pos
,
dropout
)
for
_
in
range
(
num_layers
)
]
)
self
.
norm
=
T5LayerNorm
(
dim
)
# initialize weights
self
.
apply
(
init_weights
)
def
forward
(
self
,
ids
,
mask
=
None
,
encoder_states
=
None
,
encoder_mask
=
None
):
b
,
s
=
ids
.
size
()
# causal mask
if
mask
is
None
:
mask
=
torch
.
tril
(
torch
.
ones
(
1
,
s
,
s
).
to
(
ids
.
device
))
elif
mask
.
ndim
==
2
:
mask
=
torch
.
tril
(
mask
.
unsqueeze
(
1
).
expand
(
-
1
,
s
,
-
1
))
# layers
x
=
self
.
token_embedding
(
ids
)
x
=
self
.
dropout
(
x
)
e
=
self
.
pos_embedding
(
x
.
size
(
1
),
x
.
size
(
1
))
if
self
.
shared_pos
else
None
for
block
in
self
.
blocks
:
x
=
block
(
x
,
mask
,
encoder_states
,
encoder_mask
,
pos_bias
=
e
)
x
=
self
.
norm
(
x
)
x
=
self
.
dropout
(
x
)
return
x
class
T5Model
(
nn
.
Module
):
def
__init__
(
self
,
vocab_size
,
dim
,
dim_attn
,
dim_ffn
,
num_heads
,
encoder_layers
,
decoder_layers
,
num_buckets
,
shared_pos
=
True
,
dropout
=
0.1
,
):
super
(
T5Model
,
self
).
__init__
()
self
.
vocab_size
=
vocab_size
self
.
dim
=
dim
self
.
dim_attn
=
dim_attn
self
.
dim_ffn
=
dim_ffn
self
.
num_heads
=
num_heads
self
.
encoder_layers
=
encoder_layers
self
.
decoder_layers
=
decoder_layers
self
.
num_buckets
=
num_buckets
# layers
self
.
token_embedding
=
nn
.
Embedding
(
vocab_size
,
dim
)
self
.
encoder
=
T5Encoder
(
self
.
token_embedding
,
dim
,
dim_attn
,
dim_ffn
,
num_heads
,
encoder_layers
,
num_buckets
,
shared_pos
,
dropout
,
)
self
.
decoder
=
T5Decoder
(
self
.
token_embedding
,
dim
,
dim_attn
,
dim_ffn
,
num_heads
,
decoder_layers
,
num_buckets
,
shared_pos
,
dropout
,
)
self
.
head
=
nn
.
Linear
(
dim
,
vocab_size
,
bias
=
False
)
# initialize weights
self
.
apply
(
init_weights
)
def
forward
(
self
,
encoder_ids
,
encoder_mask
,
decoder_ids
,
decoder_mask
):
x
=
self
.
encoder
(
encoder_ids
,
encoder_mask
)
x
=
self
.
decoder
(
decoder_ids
,
decoder_mask
,
x
,
encoder_mask
)
x
=
self
.
head
(
x
)
return
x
def
_t5
(
name
,
encoder_only
=
False
,
decoder_only
=
False
,
return_tokenizer
=
False
,
tokenizer_kwargs
=
{},
dtype
=
torch
.
float32
,
device
=
"cpu"
,
**
kwargs
,
):
# sanity check
assert
not
(
encoder_only
and
decoder_only
)
# params
if
encoder_only
:
model_cls
=
T5Encoder
kwargs
[
"vocab"
]
=
kwargs
.
pop
(
"vocab_size"
)
kwargs
[
"num_layers"
]
=
kwargs
.
pop
(
"encoder_layers"
)
_
=
kwargs
.
pop
(
"decoder_layers"
)
elif
decoder_only
:
model_cls
=
T5Decoder
kwargs
[
"vocab"
]
=
kwargs
.
pop
(
"vocab_size"
)
kwargs
[
"num_layers"
]
=
kwargs
.
pop
(
"decoder_layers"
)
_
=
kwargs
.
pop
(
"encoder_layers"
)
else
:
model_cls
=
T5Model
# init model
with
torch
.
device
(
device
):
model
=
model_cls
(
**
kwargs
)
# set device
model
=
model
.
to
(
dtype
=
dtype
,
device
=
device
)
# init tokenizer
if
return_tokenizer
:
from
.tokenizers
import
HuggingfaceTokenizer
tokenizer
=
HuggingfaceTokenizer
(
f
"google/
{
name
}
"
,
**
tokenizer_kwargs
)
return
model
,
tokenizer
else
:
return
model
def
umt5_xxl
(
**
kwargs
):
cfg
=
dict
(
vocab_size
=
256384
,
dim
=
4096
,
dim_attn
=
4096
,
dim_ffn
=
10240
,
num_heads
=
64
,
encoder_layers
=
24
,
decoder_layers
=
24
,
num_buckets
=
32
,
shared_pos
=
False
,
dropout
=
0.1
,
)
cfg
.
update
(
**
kwargs
)
return
_t5
(
"umt5-xxl"
,
**
cfg
)
class
T5EncoderModel
:
def
__init__
(
self
,
text_len
,
dtype
=
torch
.
bfloat16
,
device
=
torch
.
cuda
.
current_device
(),
checkpoint_path
=
None
,
tokenizer_path
=
None
,
shard_fn
=
None
,
):
self
.
text_len
=
text_len
self
.
dtype
=
dtype
self
.
device
=
device
self
.
checkpoint_path
=
checkpoint_path
self
.
tokenizer_path
=
tokenizer_path
# init model
model
=
(
umt5_xxl
(
encoder_only
=
True
,
return_tokenizer
=
False
,
dtype
=
dtype
,
device
=
device
)
.
eval
()
.
requires_grad_
(
False
)
)
logging
.
info
(
f
"loading
{
checkpoint_path
}
"
)
model
.
load_state_dict
(
torch
.
load
(
checkpoint_path
,
map_location
=
"cpu"
,
weights_only
=
True
))
self
.
model
=
model
if
shard_fn
is
not
None
:
self
.
model
=
shard_fn
(
self
.
model
,
sync_module_states
=
False
)
else
:
self
.
model
.
to
(
self
.
device
)
# init tokenizer
self
.
tokenizer
=
HuggingfaceTokenizer
(
name
=
tokenizer_path
,
seq_len
=
text_len
,
clean
=
"whitespace"
)
def
infer
(
self
,
texts
,
args
):
ids
,
mask
=
self
.
tokenizer
(
texts
,
return_mask
=
True
,
add_special_tokens
=
True
)
ids
=
ids
.
cuda
()
mask
=
mask
.
cuda
()
seq_lens
=
mask
.
gt
(
0
).
sum
(
dim
=
1
).
long
()
context
=
self
.
model
(
ids
,
mask
)
return
[
u
[:
v
]
for
u
,
v
in
zip
(
context
,
seq_lens
)]
if
__name__
==
"__main__"
:
checkpoint_dir
=
"/mnt/nvme0/yongyang/projects/wan/Wan2.1-T2V-1.3B"
t5_checkpoint
=
"models_t5_umt5-xxl-enc-bf16.pth"
t5_tokenizer
=
"google/umt5-xxl"
model
=
T5EncoderModel
(
text_len
=
512
,
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
"cuda"
),
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
t5_checkpoint
),
tokenizer_path
=
os
.
path
.
join
(
checkpoint_dir
,
t5_tokenizer
),
shard_fn
=
None
,
)
text
=
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
outputs
=
model
.
infer
(
text
)
print
(
outputs
)
lightx2v/text2v/models/text_encoders/hf/t5/tokenizer.py
0 → 100644
View file @
daf4c74e
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import
html
import
string
import
ftfy
import
regex
as
re
from
transformers
import
AutoTokenizer
__all__
=
[
"HuggingfaceTokenizer"
]
def
basic_clean
(
text
):
text
=
ftfy
.
fix_text
(
text
)
text
=
html
.
unescape
(
html
.
unescape
(
text
))
return
text
.
strip
()
def
whitespace_clean
(
text
):
text
=
re
.
sub
(
r
"\s+"
,
" "
,
text
)
text
=
text
.
strip
()
return
text
def
canonicalize
(
text
,
keep_punctuation_exact_string
=
None
):
text
=
text
.
replace
(
"_"
,
" "
)
if
keep_punctuation_exact_string
:
text
=
keep_punctuation_exact_string
.
join
(
part
.
translate
(
str
.
maketrans
(
""
,
""
,
string
.
punctuation
))
for
part
in
text
.
split
(
keep_punctuation_exact_string
)
)
else
:
text
=
text
.
translate
(
str
.
maketrans
(
""
,
""
,
string
.
punctuation
))
text
=
text
.
lower
()
text
=
re
.
sub
(
r
"\s+"
,
" "
,
text
)
return
text
.
strip
()
class
HuggingfaceTokenizer
:
def
__init__
(
self
,
name
,
seq_len
=
None
,
clean
=
None
,
**
kwargs
):
assert
clean
in
(
None
,
"whitespace"
,
"lower"
,
"canonicalize"
)
self
.
name
=
name
self
.
seq_len
=
seq_len
self
.
clean
=
clean
# init tokenizer
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
name
,
**
kwargs
)
self
.
vocab_size
=
self
.
tokenizer
.
vocab_size
def
__call__
(
self
,
sequence
,
**
kwargs
):
return_mask
=
kwargs
.
pop
(
"return_mask"
,
False
)
# arguments
_kwargs
=
{
"return_tensors"
:
"pt"
}
if
self
.
seq_len
is
not
None
:
_kwargs
.
update
(
{
"padding"
:
"max_length"
,
"truncation"
:
True
,
"max_length"
:
self
.
seq_len
,
}
)
_kwargs
.
update
(
**
kwargs
)
# tokenization
if
isinstance
(
sequence
,
str
):
sequence
=
[
sequence
]
if
self
.
clean
:
sequence
=
[
self
.
_clean
(
u
)
for
u
in
sequence
]
ids
=
self
.
tokenizer
(
sequence
,
**
_kwargs
)
# output
if
return_mask
:
return
ids
.
input_ids
,
ids
.
attention_mask
else
:
return
ids
.
input_ids
def
_clean
(
self
,
text
):
if
self
.
clean
==
"whitespace"
:
text
=
whitespace_clean
(
basic_clean
(
text
))
elif
self
.
clean
==
"lower"
:
text
=
whitespace_clean
(
basic_clean
(
text
)).
lower
()
elif
self
.
clean
==
"canonicalize"
:
text
=
canonicalize
(
basic_clean
(
text
))
return
text
lightx2v/text2v/models/video_encoders/hf/__init__.py
0 → 100644
View file @
daf4c74e
lightx2v/text2v/models/video_encoders/hf/autoencoder_kl_causal_3d/__init__.py
0 → 100755
View file @
daf4c74e
lightx2v/text2v/models/video_encoders/hf/autoencoder_kl_causal_3d/autoencoder_kl_causal_3d.py
0 → 100755
View file @
daf4c74e
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# Modified from diffusers==0.29.2
#
# ==============================================================================
from
typing
import
Dict
,
Optional
,
Tuple
,
Union
from
dataclasses
import
dataclass
import
torch
import
torch.nn
as
nn
from
diffusers.configuration_utils
import
ConfigMixin
,
register_to_config
try
:
# This diffusers is modified and packed in the mirror.
from
diffusers.loaders
import
FromOriginalVAEMixin
except
ImportError
:
# Use this to be compatible with the original diffusers.
from
diffusers.loaders.single_file_model
import
FromOriginalModelMixin
as
FromOriginalVAEMixin
from
diffusers.utils.accelerate_utils
import
apply_forward_hook
from
diffusers.models.attention_processor
import
(
ADDED_KV_ATTENTION_PROCESSORS
,
CROSS_ATTENTION_PROCESSORS
,
Attention
,
AttentionProcessor
,
AttnAddedKVProcessor
,
AttnProcessor
,
)
from
diffusers.models.modeling_outputs
import
AutoencoderKLOutput
from
diffusers.models.modeling_utils
import
ModelMixin
from
.vae
import
DecoderCausal3D
,
BaseOutput
,
DecoderOutput
,
DiagonalGaussianDistribution
,
EncoderCausal3D
@
dataclass
class
DecoderOutput2
(
BaseOutput
):
sample
:
torch
.
FloatTensor
posterior
:
Optional
[
DiagonalGaussianDistribution
]
=
None
class
AutoencoderKLCausal3D
(
ModelMixin
,
ConfigMixin
,
FromOriginalVAEMixin
):
r
"""
A VAE model with KL loss for encoding images/videos into latents and decoding latent representations into images/videos.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
"""
_supports_gradient_checkpointing
=
True
@
register_to_config
def
__init__
(
self
,
in_channels
:
int
=
3
,
out_channels
:
int
=
3
,
down_block_types
:
Tuple
[
str
]
=
(
"DownEncoderBlockCausal3D"
,),
up_block_types
:
Tuple
[
str
]
=
(
"UpDecoderBlockCausal3D"
,),
block_out_channels
:
Tuple
[
int
]
=
(
64
,),
layers_per_block
:
int
=
1
,
act_fn
:
str
=
"silu"
,
latent_channels
:
int
=
4
,
norm_num_groups
:
int
=
32
,
sample_size
:
int
=
32
,
sample_tsize
:
int
=
64
,
scaling_factor
:
float
=
0.18215
,
force_upcast
:
float
=
True
,
spatial_compression_ratio
:
int
=
8
,
time_compression_ratio
:
int
=
4
,
mid_block_add_attention
:
bool
=
True
,
):
super
().
__init__
()
self
.
time_compression_ratio
=
time_compression_ratio
self
.
encoder
=
EncoderCausal3D
(
in_channels
=
in_channels
,
out_channels
=
latent_channels
,
down_block_types
=
down_block_types
,
block_out_channels
=
block_out_channels
,
layers_per_block
=
layers_per_block
,
act_fn
=
act_fn
,
norm_num_groups
=
norm_num_groups
,
double_z
=
True
,
time_compression_ratio
=
time_compression_ratio
,
spatial_compression_ratio
=
spatial_compression_ratio
,
mid_block_add_attention
=
mid_block_add_attention
,
)
self
.
decoder
=
DecoderCausal3D
(
in_channels
=
latent_channels
,
out_channels
=
out_channels
,
up_block_types
=
up_block_types
,
block_out_channels
=
block_out_channels
,
layers_per_block
=
layers_per_block
,
norm_num_groups
=
norm_num_groups
,
act_fn
=
act_fn
,
time_compression_ratio
=
time_compression_ratio
,
spatial_compression_ratio
=
spatial_compression_ratio
,
mid_block_add_attention
=
mid_block_add_attention
,
)
self
.
quant_conv
=
nn
.
Conv3d
(
2
*
latent_channels
,
2
*
latent_channels
,
kernel_size
=
1
)
self
.
post_quant_conv
=
nn
.
Conv3d
(
latent_channels
,
latent_channels
,
kernel_size
=
1
)
self
.
use_slicing
=
False
self
.
use_spatial_tiling
=
False
self
.
use_temporal_tiling
=
False
# only relevant if vae tiling is enabled
self
.
tile_sample_min_tsize
=
sample_tsize
self
.
tile_latent_min_tsize
=
sample_tsize
//
time_compression_ratio
self
.
tile_sample_min_size
=
self
.
config
.
sample_size
sample_size
=
(
self
.
config
.
sample_size
[
0
]
if
isinstance
(
self
.
config
.
sample_size
,
(
list
,
tuple
))
else
self
.
config
.
sample_size
)
self
.
tile_latent_min_size
=
int
(
sample_size
/
(
2
**
(
len
(
self
.
config
.
block_out_channels
)
-
1
)))
self
.
tile_overlap_factor
=
0.25
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
isinstance
(
module
,
(
EncoderCausal3D
,
DecoderCausal3D
)):
module
.
gradient_checkpointing
=
value
def
enable_temporal_tiling
(
self
,
use_tiling
:
bool
=
True
):
self
.
use_temporal_tiling
=
use_tiling
def
disable_temporal_tiling
(
self
):
self
.
enable_temporal_tiling
(
False
)
def
enable_spatial_tiling
(
self
,
use_tiling
:
bool
=
True
):
self
.
use_spatial_tiling
=
use_tiling
def
disable_spatial_tiling
(
self
):
self
.
enable_spatial_tiling
(
False
)
def
enable_tiling
(
self
,
use_tiling
:
bool
=
True
):
r
"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger videos.
"""
self
.
enable_spatial_tiling
(
use_tiling
)
self
.
enable_temporal_tiling
(
use_tiling
)
def
disable_tiling
(
self
):
r
"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self
.
disable_spatial_tiling
()
self
.
disable_temporal_tiling
()
def
enable_slicing
(
self
):
r
"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self
.
use_slicing
=
True
def
disable_slicing
(
self
):
r
"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self
.
use_slicing
=
False
@
property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def
attn_processors
(
self
)
->
Dict
[
str
,
AttentionProcessor
]:
r
"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors
=
{}
def
fn_recursive_add_processors
(
name
:
str
,
module
:
torch
.
nn
.
Module
,
processors
:
Dict
[
str
,
AttentionProcessor
]):
if
hasattr
(
module
,
"get_processor"
):
processors
[
f
"
{
name
}
.processor"
]
=
module
.
get_processor
(
return_deprecated_lora
=
True
)
for
sub_name
,
child
in
module
.
named_children
():
fn_recursive_add_processors
(
f
"
{
name
}
.
{
sub_name
}
"
,
child
,
processors
)
return
processors
for
name
,
module
in
self
.
named_children
():
fn_recursive_add_processors
(
name
,
module
,
processors
)
return
processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def
set_attn_processor
(
self
,
processor
:
Union
[
AttentionProcessor
,
Dict
[
str
,
AttentionProcessor
]],
_remove_lora
=
False
):
r
"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count
=
len
(
self
.
attn_processors
.
keys
())
if
isinstance
(
processor
,
dict
)
and
len
(
processor
)
!=
count
:
raise
ValueError
(
f
"A dict of processors was passed, but the number of processors
{
len
(
processor
)
}
does not match the"
f
" number of attention layers:
{
count
}
. Please make sure to pass
{
count
}
processor classes."
)
def
fn_recursive_attn_processor
(
name
:
str
,
module
:
torch
.
nn
.
Module
,
processor
):
if
hasattr
(
module
,
"set_processor"
):
if
not
isinstance
(
processor
,
dict
):
module
.
set_processor
(
processor
,
_remove_lora
=
_remove_lora
)
else
:
module
.
set_processor
(
processor
.
pop
(
f
"
{
name
}
.processor"
),
_remove_lora
=
_remove_lora
)
for
sub_name
,
child
in
module
.
named_children
():
fn_recursive_attn_processor
(
f
"
{
name
}
.
{
sub_name
}
"
,
child
,
processor
)
for
name
,
module
in
self
.
named_children
():
fn_recursive_attn_processor
(
name
,
module
,
processor
)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def
set_default_attn_processor
(
self
):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if
all
(
proc
.
__class__
in
ADDED_KV_ATTENTION_PROCESSORS
for
proc
in
self
.
attn_processors
.
values
()):
processor
=
AttnAddedKVProcessor
()
elif
all
(
proc
.
__class__
in
CROSS_ATTENTION_PROCESSORS
for
proc
in
self
.
attn_processors
.
values
()):
processor
=
AttnProcessor
()
else
:
raise
ValueError
(
f
"Cannot call `set_default_attn_processor` when attention processors are of type
{
next
(
iter
(
self
.
attn_processors
.
values
()))
}
"
)
self
.
set_attn_processor
(
processor
,
_remove_lora
=
True
)
@
apply_forward_hook
def
encode
(
self
,
x
:
torch
.
FloatTensor
,
return_dict
:
bool
=
True
)
->
Union
[
AutoencoderKLOutput
,
Tuple
[
DiagonalGaussianDistribution
]]:
"""
Encode a batch of images/videos into latents.
Args:
x (`torch.FloatTensor`): Input batch of images/videos.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded images/videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
assert
len
(
x
.
shape
)
==
5
,
"The input tensor should have 5 dimensions."
if
self
.
use_temporal_tiling
and
x
.
shape
[
2
]
>
self
.
tile_sample_min_tsize
:
return
self
.
temporal_tiled_encode
(
x
,
return_dict
=
return_dict
)
if
self
.
use_spatial_tiling
and
(
x
.
shape
[
-
1
]
>
self
.
tile_sample_min_size
or
x
.
shape
[
-
2
]
>
self
.
tile_sample_min_size
):
return
self
.
spatial_tiled_encode
(
x
,
return_dict
=
return_dict
)
if
self
.
use_slicing
and
x
.
shape
[
0
]
>
1
:
encoded_slices
=
[
self
.
encoder
(
x_slice
)
for
x_slice
in
x
.
split
(
1
)]
h
=
torch
.
cat
(
encoded_slices
)
else
:
h
=
self
.
encoder
(
x
)
moments
=
self
.
quant_conv
(
h
)
posterior
=
DiagonalGaussianDistribution
(
moments
)
if
not
return_dict
:
return
(
posterior
,)
return
AutoencoderKLOutput
(
latent_dist
=
posterior
)
def
_decode
(
self
,
z
:
torch
.
FloatTensor
,
return_dict
:
bool
=
True
)
->
Union
[
DecoderOutput
,
torch
.
FloatTensor
]:
assert
len
(
z
.
shape
)
==
5
,
"The input tensor should have 5 dimensions."
if
self
.
use_temporal_tiling
and
z
.
shape
[
2
]
>
self
.
tile_latent_min_tsize
:
return
self
.
temporal_tiled_decode
(
z
,
return_dict
=
return_dict
)
if
self
.
use_spatial_tiling
and
(
z
.
shape
[
-
1
]
>
self
.
tile_latent_min_size
or
z
.
shape
[
-
2
]
>
self
.
tile_latent_min_size
):
return
self
.
spatial_tiled_decode
(
z
,
return_dict
=
return_dict
)
z
=
self
.
post_quant_conv
(
z
)
dec
=
self
.
decoder
(
z
)
if
not
return_dict
:
return
(
dec
,)
return
DecoderOutput
(
sample
=
dec
)
@
apply_forward_hook
def
decode
(
self
,
z
:
torch
.
FloatTensor
,
return_dict
:
bool
=
True
,
generator
=
None
)
->
Union
[
DecoderOutput
,
torch
.
FloatTensor
]:
"""
Decode a batch of images/videos.
Args:
z (`torch.FloatTensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if
self
.
use_slicing
and
z
.
shape
[
0
]
>
1
:
decoded_slices
=
[
self
.
_decode
(
z_slice
).
sample
for
z_slice
in
z
.
split
(
1
)]
decoded
=
torch
.
cat
(
decoded_slices
)
else
:
decoded
=
self
.
_decode
(
z
).
sample
if
not
return_dict
:
return
(
decoded
,)
return
DecoderOutput
(
sample
=
decoded
)
def
blend_v
(
self
,
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
blend_extent
:
int
)
->
torch
.
Tensor
:
blend_extent
=
min
(
a
.
shape
[
-
2
],
b
.
shape
[
-
2
],
blend_extent
)
for
y
in
range
(
blend_extent
):
b
[:,
:,
:,
y
,
:]
=
a
[:,
:,
:,
-
blend_extent
+
y
,
:]
*
(
1
-
y
/
blend_extent
)
+
b
[:,
:,
:,
y
,
:]
*
(
y
/
blend_extent
)
return
b
def
blend_h
(
self
,
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
blend_extent
:
int
)
->
torch
.
Tensor
:
blend_extent
=
min
(
a
.
shape
[
-
1
],
b
.
shape
[
-
1
],
blend_extent
)
for
x
in
range
(
blend_extent
):
b
[:,
:,
:,
:,
x
]
=
a
[:,
:,
:,
:,
-
blend_extent
+
x
]
*
(
1
-
x
/
blend_extent
)
+
b
[:,
:,
:,
:,
x
]
*
(
x
/
blend_extent
)
return
b
def
blend_t
(
self
,
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
blend_extent
:
int
)
->
torch
.
Tensor
:
blend_extent
=
min
(
a
.
shape
[
-
3
],
b
.
shape
[
-
3
],
blend_extent
)
for
x
in
range
(
blend_extent
):
b
[:,
:,
x
,
:,
:]
=
a
[:,
:,
-
blend_extent
+
x
,
:,
:]
*
(
1
-
x
/
blend_extent
)
+
b
[:,
:,
x
,
:,
:]
*
(
x
/
blend_extent
)
return
b
def
spatial_tiled_encode
(
self
,
x
:
torch
.
FloatTensor
,
return_dict
:
bool
=
True
,
return_moments
:
bool
=
False
)
->
AutoencoderKLOutput
:
r
"""Encode a batch of images/videos using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image/videos size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
output, but they should be much less noticeable.
Args:
x (`torch.FloatTensor`): Input batch of images/videos.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
`tuple` is returned.
"""
overlap_size
=
int
(
self
.
tile_sample_min_size
*
(
1
-
self
.
tile_overlap_factor
))
blend_extent
=
int
(
self
.
tile_latent_min_size
*
self
.
tile_overlap_factor
)
row_limit
=
self
.
tile_latent_min_size
-
blend_extent
# Split video into tiles and encode them separately.
rows
=
[]
for
i
in
range
(
0
,
x
.
shape
[
-
2
],
overlap_size
):
row
=
[]
for
j
in
range
(
0
,
x
.
shape
[
-
1
],
overlap_size
):
tile
=
x
[:,
:,
:,
i
:
i
+
self
.
tile_sample_min_size
,
j
:
j
+
self
.
tile_sample_min_size
]
tile
=
self
.
encoder
(
tile
)
tile
=
self
.
quant_conv
(
tile
)
row
.
append
(
tile
)
rows
.
append
(
row
)
result_rows
=
[]
for
i
,
row
in
enumerate
(
rows
):
result_row
=
[]
for
j
,
tile
in
enumerate
(
row
):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if
i
>
0
:
tile
=
self
.
blend_v
(
rows
[
i
-
1
][
j
],
tile
,
blend_extent
)
if
j
>
0
:
tile
=
self
.
blend_h
(
row
[
j
-
1
],
tile
,
blend_extent
)
result_row
.
append
(
tile
[:,
:,
:,
:
row_limit
,
:
row_limit
])
result_rows
.
append
(
torch
.
cat
(
result_row
,
dim
=-
1
))
moments
=
torch
.
cat
(
result_rows
,
dim
=-
2
)
if
return_moments
:
return
moments
posterior
=
DiagonalGaussianDistribution
(
moments
)
if
not
return_dict
:
return
(
posterior
,)
return
AutoencoderKLOutput
(
latent_dist
=
posterior
)
def
spatial_tiled_decode
(
self
,
z
:
torch
.
FloatTensor
,
return_dict
:
bool
=
True
)
->
Union
[
DecoderOutput
,
torch
.
FloatTensor
]:
r
"""
Decode a batch of images/videos using a tiled decoder.
Args:
z (`torch.FloatTensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
overlap_size
=
int
(
self
.
tile_latent_min_size
*
(
1
-
self
.
tile_overlap_factor
))
blend_extent
=
int
(
self
.
tile_sample_min_size
*
self
.
tile_overlap_factor
)
row_limit
=
self
.
tile_sample_min_size
-
blend_extent
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows
=
[]
for
i
in
range
(
0
,
z
.
shape
[
-
2
],
overlap_size
):
row
=
[]
for
j
in
range
(
0
,
z
.
shape
[
-
1
],
overlap_size
):
tile
=
z
[:,
:,
:,
i
:
i
+
self
.
tile_latent_min_size
,
j
:
j
+
self
.
tile_latent_min_size
]
tile
=
self
.
post_quant_conv
(
tile
)
decoded
=
self
.
decoder
(
tile
)
row
.
append
(
decoded
)
rows
.
append
(
row
)
result_rows
=
[]
for
i
,
row
in
enumerate
(
rows
):
result_row
=
[]
for
j
,
tile
in
enumerate
(
row
):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if
i
>
0
:
tile
=
self
.
blend_v
(
rows
[
i
-
1
][
j
],
tile
,
blend_extent
)
if
j
>
0
:
tile
=
self
.
blend_h
(
row
[
j
-
1
],
tile
,
blend_extent
)
result_row
.
append
(
tile
[:,
:,
:,
:
row_limit
,
:
row_limit
])
result_rows
.
append
(
torch
.
cat
(
result_row
,
dim
=-
1
))
dec
=
torch
.
cat
(
result_rows
,
dim
=-
2
)
if
not
return_dict
:
return
(
dec
,)
return
DecoderOutput
(
sample
=
dec
)
def
temporal_tiled_encode
(
self
,
x
:
torch
.
FloatTensor
,
return_dict
:
bool
=
True
)
->
AutoencoderKLOutput
:
B
,
C
,
T
,
H
,
W
=
x
.
shape
overlap_size
=
int
(
self
.
tile_sample_min_tsize
*
(
1
-
self
.
tile_overlap_factor
))
blend_extent
=
int
(
self
.
tile_latent_min_tsize
*
self
.
tile_overlap_factor
)
t_limit
=
self
.
tile_latent_min_tsize
-
blend_extent
# Split the video into tiles and encode them separately.
row
=
[]
for
i
in
range
(
0
,
T
,
overlap_size
):
tile
=
x
[:,
:,
i
:
i
+
self
.
tile_sample_min_tsize
+
1
,
:,
:]
if
self
.
use_spatial_tiling
and
(
tile
.
shape
[
-
1
]
>
self
.
tile_sample_min_size
or
tile
.
shape
[
-
2
]
>
self
.
tile_sample_min_size
):
tile
=
self
.
spatial_tiled_encode
(
tile
,
return_moments
=
True
)
else
:
tile
=
self
.
encoder
(
tile
)
tile
=
self
.
quant_conv
(
tile
)
if
i
>
0
:
tile
=
tile
[:,
:,
1
:,
:,
:]
row
.
append
(
tile
)
result_row
=
[]
for
i
,
tile
in
enumerate
(
row
):
if
i
>
0
:
tile
=
self
.
blend_t
(
row
[
i
-
1
],
tile
,
blend_extent
)
result_row
.
append
(
tile
[:,
:,
:
t_limit
,
:,
:])
else
:
result_row
.
append
(
tile
[:,
:,
:
t_limit
+
1
,
:,
:])
moments
=
torch
.
cat
(
result_row
,
dim
=
2
)
posterior
=
DiagonalGaussianDistribution
(
moments
)
if
not
return_dict
:
return
(
posterior
,)
return
AutoencoderKLOutput
(
latent_dist
=
posterior
)
def
temporal_tiled_decode
(
self
,
z
:
torch
.
FloatTensor
,
return_dict
:
bool
=
True
)
->
Union
[
DecoderOutput
,
torch
.
FloatTensor
]:
# Split z into overlapping tiles and decode them separately.
B
,
C
,
T
,
H
,
W
=
z
.
shape
overlap_size
=
int
(
self
.
tile_latent_min_tsize
*
(
1
-
self
.
tile_overlap_factor
))
blend_extent
=
int
(
self
.
tile_sample_min_tsize
*
self
.
tile_overlap_factor
)
t_limit
=
self
.
tile_sample_min_tsize
-
blend_extent
row
=
[]
for
i
in
range
(
0
,
T
,
overlap_size
):
tile
=
z
[:,
:,
i
:
i
+
self
.
tile_latent_min_tsize
+
1
,
:,
:]
if
self
.
use_spatial_tiling
and
(
tile
.
shape
[
-
1
]
>
self
.
tile_latent_min_size
or
tile
.
shape
[
-
2
]
>
self
.
tile_latent_min_size
):
decoded
=
self
.
spatial_tiled_decode
(
tile
,
return_dict
=
True
).
sample
else
:
tile
=
self
.
post_quant_conv
(
tile
)
decoded
=
self
.
decoder
(
tile
)
if
i
>
0
:
decoded
=
decoded
[:,
:,
1
:,
:,
:]
row
.
append
(
decoded
)
result_row
=
[]
for
i
,
tile
in
enumerate
(
row
):
if
i
>
0
:
tile
=
self
.
blend_t
(
row
[
i
-
1
],
tile
,
blend_extent
)
result_row
.
append
(
tile
[:,
:,
:
t_limit
,
:,
:])
else
:
result_row
.
append
(
tile
[:,
:,
:
t_limit
+
1
,
:,
:])
dec
=
torch
.
cat
(
result_row
,
dim
=
2
)
if
not
return_dict
:
return
(
dec
,)
return
DecoderOutput
(
sample
=
dec
)
def
forward
(
self
,
sample
:
torch
.
FloatTensor
,
sample_posterior
:
bool
=
False
,
return_dict
:
bool
=
True
,
return_posterior
:
bool
=
False
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
)
->
Union
[
DecoderOutput2
,
torch
.
FloatTensor
]:
r
"""
Args:
sample (`torch.FloatTensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x
=
sample
posterior
=
self
.
encode
(
x
).
latent_dist
if
sample_posterior
:
z
=
posterior
.
sample
(
generator
=
generator
)
else
:
z
=
posterior
.
mode
()
dec
=
self
.
decode
(
z
).
sample
if
not
return_dict
:
if
return_posterior
:
return
(
dec
,
posterior
)
else
:
return
(
dec
,)
if
return_posterior
:
return
DecoderOutput2
(
sample
=
dec
,
posterior
=
posterior
)
else
:
return
DecoderOutput2
(
sample
=
dec
)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def
fuse_qkv_projections
(
self
):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self
.
original_attn_processors
=
None
for
_
,
attn_processor
in
self
.
attn_processors
.
items
():
if
"Added"
in
str
(
attn_processor
.
__class__
.
__name__
):
raise
ValueError
(
"`fuse_qkv_projections()` is not supported for models having added KV projections."
)
self
.
original_attn_processors
=
self
.
attn_processors
for
module
in
self
.
modules
():
if
isinstance
(
module
,
Attention
):
module
.
fuse_projections
(
fuse
=
True
)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def
unfuse_qkv_projections
(
self
):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if
self
.
original_attn_processors
is
not
None
:
self
.
set_attn_processor
(
self
.
original_attn_processors
)
lightx2v/text2v/models/video_encoders/hf/autoencoder_kl_causal_3d/model.py
0 → 100755
View file @
daf4c74e
import
os
import
torch
from
.autoencoder_kl_causal_3d
import
AutoencoderKLCausal3D
class
VideoEncoderKLCausal3DModel
():
def
__init__
(
self
,
model_path
,
dtype
,
device
):
self
.
model_path
=
model_path
self
.
dtype
=
dtype
self
.
device
=
device
self
.
load
()
def
load
(
self
):
self
.
vae_path
=
os
.
path
.
join
(
self
.
model_path
,
'hunyuan-video-t2v-720p/vae'
)
config
=
AutoencoderKLCausal3D
.
load_config
(
self
.
vae_path
)
self
.
model
=
AutoencoderKLCausal3D
.
from_config
(
config
)
ckpt
=
torch
.
load
(
os
.
path
.
join
(
self
.
vae_path
,
'pytorch_model.pt'
),
map_location
=
'cpu'
,
weights_only
=
True
)
self
.
model
.
load_state_dict
(
ckpt
)
self
.
model
=
self
.
model
.
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
self
.
model
.
requires_grad_
(
False
)
self
.
model
.
eval
()
def
to_cpu
(
self
):
self
.
model
=
self
.
model
.
to
(
"cpu"
)
def
to_cuda
(
self
):
self
.
model
=
self
.
model
.
to
(
"cuda"
)
def
decode
(
self
,
latents
,
generator
,
args
):
if
args
.
cpu_offload
:
self
.
to_cuda
()
latents
=
latents
/
self
.
model
.
config
.
scaling_factor
latents
=
latents
.
to
(
dtype
=
self
.
dtype
,
device
=
torch
.
device
(
"cuda"
))
self
.
model
.
enable_tiling
()
image
=
self
.
model
.
decode
(
latents
,
return_dict
=
False
,
generator
=
generator
)[
0
]
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
image
.
cpu
().
float
()
if
args
.
cpu_offload
:
self
.
to_cpu
()
return
image
if
__name__
==
"__main__"
:
vae_model
=
VideoEncoderKLCausal3DModel
(
"/mnt/nvme0/yongyang/projects/hy/new/HunyuanVideo/ckpts"
,
dtype
=
torch
.
float16
,
device
=
torch
.
device
(
"cuda"
))
lightx2v/text2v/models/video_encoders/hf/autoencoder_kl_causal_3d/unet_causal_3d_blocks.py
0 → 100755
View file @
daf4c74e
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# Modified from diffusers==0.29.2
#
# ==============================================================================
from
typing
import
Optional
,
Tuple
,
Union
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
einops
import
rearrange
from
diffusers.utils
import
logging
from
diffusers.models.activations
import
get_activation
from
diffusers.models.attention_processor
import
SpatialNorm
from
diffusers.models.attention_processor
import
Attention
from
diffusers.models.normalization
import
AdaGroupNorm
from
diffusers.models.normalization
import
RMSNorm
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
def
prepare_causal_attention_mask_ori
(
n_frame
:
int
,
n_hw
:
int
,
dtype
,
device
,
batch_size
:
int
=
None
):
seq_len
=
n_frame
*
n_hw
mask
=
torch
.
full
((
seq_len
,
seq_len
),
float
(
"-inf"
),
dtype
=
dtype
,
device
=
device
)
for
i
in
range
(
seq_len
):
i_frame
=
i
//
n_hw
mask
[
i
,
:
(
i_frame
+
1
)
*
n_hw
]
=
0
if
batch_size
is
not
None
:
mask
=
mask
.
unsqueeze
(
0
).
expand
(
batch_size
,
-
1
,
-
1
)
return
mask
def
prepare_causal_attention_mask
(
n_frame
:
int
,
n_hw
:
int
,
dtype
,
device
,
batch_size
:
int
=
None
):
seq_len
=
n_frame
*
n_hw
mask
=
torch
.
full
((
n_frame
,
n_frame
,
n_hw
,
n_hw
),
float
(
"-inf"
),
dtype
=
dtype
,
device
=
device
)
# mask = mask.reshape(n_frame, n_frame, n_hw, n_hw)
idx_arr
=
torch
.
tril
(
torch
.
ones
(
n_frame
,
n_frame
,
dtype
=
dtype
,
device
=
device
))
idx_arr
=
idx_arr
>
torch
.
zeros_like
(
idx_arr
)
for
i
in
range
(
n_frame
):
for
j
in
range
(
n_frame
):
if
idx_arr
[
i
,
j
]:
mask
[
i
,
j
]
=
torch
.
zeros
(
n_hw
,
n_hw
,
dtype
=
dtype
,
device
=
device
)
# mask[idx_arr] = torch.zeros(n_hw, n_hw, dtype=dtype, device=device)
mask
=
mask
.
view
(
n_frame
,
-
1
,
n_hw
).
transpose
(
1
,
0
).
reshape
(
seq_len
,
-
1
).
transpose
(
1
,
0
)
if
batch_size
is
not
None
:
mask
=
mask
.
unsqueeze
(
0
).
expand
(
batch_size
,
-
1
,
-
1
)
return
mask
.
to
(
device
)
class
CausalConv3d
(
nn
.
Module
):
"""
Implements a causal 3D convolution layer where each position only depends on previous timesteps and current spatial locations.
This maintains temporal causality in video generation tasks.
"""
def
__init__
(
self
,
chan_in
,
chan_out
,
kernel_size
:
Union
[
int
,
Tuple
[
int
,
int
,
int
]],
stride
:
Union
[
int
,
Tuple
[
int
,
int
,
int
]]
=
1
,
dilation
:
Union
[
int
,
Tuple
[
int
,
int
,
int
]]
=
1
,
pad_mode
=
'replicate'
,
**
kwargs
):
super
().
__init__
()
self
.
pad_mode
=
pad_mode
padding
=
(
kernel_size
//
2
,
kernel_size
//
2
,
kernel_size
//
2
,
kernel_size
//
2
,
kernel_size
-
1
,
0
)
# W, H, T
self
.
time_causal_padding
=
padding
self
.
conv
=
nn
.
Conv3d
(
chan_in
,
chan_out
,
kernel_size
,
stride
=
stride
,
dilation
=
dilation
,
**
kwargs
)
def
forward
(
self
,
x
):
x
=
F
.
pad
(
x
,
self
.
time_causal_padding
,
mode
=
self
.
pad_mode
)
return
self
.
conv
(
x
)
class
UpsampleCausal3D
(
nn
.
Module
):
"""
A 3D upsampling layer with an optional convolution.
"""
def
__init__
(
self
,
channels
:
int
,
use_conv
:
bool
=
False
,
use_conv_transpose
:
bool
=
False
,
out_channels
:
Optional
[
int
]
=
None
,
name
:
str
=
"conv"
,
kernel_size
:
Optional
[
int
]
=
None
,
padding
=
1
,
norm_type
=
None
,
eps
=
None
,
elementwise_affine
=
None
,
bias
=
True
,
interpolate
=
True
,
upsample_factor
=
(
2
,
2
,
2
),
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
use_conv_transpose
=
use_conv_transpose
self
.
name
=
name
self
.
interpolate
=
interpolate
self
.
upsample_factor
=
upsample_factor
if
norm_type
==
"ln_norm"
:
self
.
norm
=
nn
.
LayerNorm
(
channels
,
eps
,
elementwise_affine
)
elif
norm_type
==
"rms_norm"
:
self
.
norm
=
RMSNorm
(
channels
,
eps
,
elementwise_affine
)
elif
norm_type
is
None
:
self
.
norm
=
None
else
:
raise
ValueError
(
f
"unknown norm_type:
{
norm_type
}
"
)
conv
=
None
if
use_conv_transpose
:
raise
NotImplementedError
elif
use_conv
:
if
kernel_size
is
None
:
kernel_size
=
3
conv
=
CausalConv3d
(
self
.
channels
,
self
.
out_channels
,
kernel_size
=
kernel_size
,
bias
=
bias
)
if
name
==
"conv"
:
self
.
conv
=
conv
else
:
self
.
Conv2d_0
=
conv
def
forward
(
self
,
hidden_states
:
torch
.
FloatTensor
,
output_size
:
Optional
[
int
]
=
None
,
scale
:
float
=
1.0
,
)
->
torch
.
FloatTensor
:
assert
hidden_states
.
shape
[
1
]
==
self
.
channels
if
self
.
norm
is
not
None
:
raise
NotImplementedError
if
self
.
use_conv_transpose
:
return
self
.
conv
(
hidden_states
)
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
dtype
=
hidden_states
.
dtype
if
dtype
==
torch
.
bfloat16
:
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if
hidden_states
.
shape
[
0
]
>=
64
:
hidden_states
=
hidden_states
.
contiguous
()
# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
if
self
.
interpolate
:
B
,
C
,
T
,
H
,
W
=
hidden_states
.
shape
first_h
,
other_h
=
hidden_states
.
split
((
1
,
T
-
1
),
dim
=
2
)
if
output_size
is
None
:
if
T
>
1
:
other_h
=
F
.
interpolate
(
other_h
,
scale_factor
=
self
.
upsample_factor
,
mode
=
"nearest"
)
# first_h = first_h.squeeze(2)
first_h
=
first_h
.
view
(
B
,
C
,
H
,
W
)
first_h
=
F
.
interpolate
(
first_h
,
scale_factor
=
self
.
upsample_factor
[
1
:],
mode
=
"nearest"
)
first_h
=
first_h
.
unsqueeze
(
2
)
else
:
raise
NotImplementedError
if
T
>
1
:
hidden_states
=
torch
.
cat
((
first_h
,
other_h
),
dim
=
2
)
else
:
hidden_states
=
first_h
# If the input is bfloat16, we cast back to bfloat16
if
dtype
==
torch
.
bfloat16
:
hidden_states
=
hidden_states
.
to
(
dtype
)
if
self
.
use_conv
:
if
self
.
name
==
"conv"
:
hidden_states
=
self
.
conv
(
hidden_states
)
else
:
hidden_states
=
self
.
Conv2d_0
(
hidden_states
)
return
hidden_states
class
DownsampleCausal3D
(
nn
.
Module
):
"""
A 3D downsampling layer with an optional convolution.
"""
def
__init__
(
self
,
channels
:
int
,
use_conv
:
bool
=
False
,
out_channels
:
Optional
[
int
]
=
None
,
padding
:
int
=
1
,
name
:
str
=
"conv"
,
kernel_size
=
3
,
norm_type
=
None
,
eps
=
None
,
elementwise_affine
=
None
,
bias
=
True
,
stride
=
2
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
padding
=
padding
stride
=
stride
self
.
name
=
name
if
norm_type
==
"ln_norm"
:
self
.
norm
=
nn
.
LayerNorm
(
channels
,
eps
,
elementwise_affine
)
elif
norm_type
==
"rms_norm"
:
self
.
norm
=
RMSNorm
(
channels
,
eps
,
elementwise_affine
)
elif
norm_type
is
None
:
self
.
norm
=
None
else
:
raise
ValueError
(
f
"unknown norm_type:
{
norm_type
}
"
)
if
use_conv
:
conv
=
CausalConv3d
(
self
.
channels
,
self
.
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
bias
=
bias
)
else
:
raise
NotImplementedError
if
name
==
"conv"
:
self
.
Conv2d_0
=
conv
self
.
conv
=
conv
elif
name
==
"Conv2d_0"
:
self
.
conv
=
conv
else
:
self
.
conv
=
conv
def
forward
(
self
,
hidden_states
:
torch
.
FloatTensor
,
scale
:
float
=
1.0
)
->
torch
.
FloatTensor
:
assert
hidden_states
.
shape
[
1
]
==
self
.
channels
if
self
.
norm
is
not
None
:
hidden_states
=
self
.
norm
(
hidden_states
.
permute
(
0
,
2
,
3
,
1
)).
permute
(
0
,
3
,
1
,
2
)
assert
hidden_states
.
shape
[
1
]
==
self
.
channels
hidden_states
=
self
.
conv
(
hidden_states
)
return
hidden_states
class
ResnetBlockCausal3D
(
nn
.
Module
):
r
"""
A Resnet block.
"""
def
__init__
(
self
,
*
,
in_channels
:
int
,
out_channels
:
Optional
[
int
]
=
None
,
conv_shortcut
:
bool
=
False
,
dropout
:
float
=
0.0
,
temb_channels
:
int
=
512
,
groups
:
int
=
32
,
groups_out
:
Optional
[
int
]
=
None
,
pre_norm
:
bool
=
True
,
eps
:
float
=
1e-6
,
non_linearity
:
str
=
"swish"
,
skip_time_act
:
bool
=
False
,
# default, scale_shift, ada_group, spatial
time_embedding_norm
:
str
=
"default"
,
kernel
:
Optional
[
torch
.
FloatTensor
]
=
None
,
output_scale_factor
:
float
=
1.0
,
use_in_shortcut
:
Optional
[
bool
]
=
None
,
up
:
bool
=
False
,
down
:
bool
=
False
,
conv_shortcut_bias
:
bool
=
True
,
conv_3d_out_channels
:
Optional
[
int
]
=
None
,
):
super
().
__init__
()
self
.
pre_norm
=
pre_norm
self
.
pre_norm
=
True
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
up
=
up
self
.
down
=
down
self
.
output_scale_factor
=
output_scale_factor
self
.
time_embedding_norm
=
time_embedding_norm
self
.
skip_time_act
=
skip_time_act
linear_cls
=
nn
.
Linear
if
groups_out
is
None
:
groups_out
=
groups
if
self
.
time_embedding_norm
==
"ada_group"
:
self
.
norm1
=
AdaGroupNorm
(
temb_channels
,
in_channels
,
groups
,
eps
=
eps
)
elif
self
.
time_embedding_norm
==
"spatial"
:
self
.
norm1
=
SpatialNorm
(
in_channels
,
temb_channels
)
else
:
self
.
norm1
=
torch
.
nn
.
GroupNorm
(
num_groups
=
groups
,
num_channels
=
in_channels
,
eps
=
eps
,
affine
=
True
)
self
.
conv1
=
CausalConv3d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
)
if
temb_channels
is
not
None
:
if
self
.
time_embedding_norm
==
"default"
:
self
.
time_emb_proj
=
linear_cls
(
temb_channels
,
out_channels
)
elif
self
.
time_embedding_norm
==
"scale_shift"
:
self
.
time_emb_proj
=
linear_cls
(
temb_channels
,
2
*
out_channels
)
elif
self
.
time_embedding_norm
==
"ada_group"
or
self
.
time_embedding_norm
==
"spatial"
:
self
.
time_emb_proj
=
None
else
:
raise
ValueError
(
f
"Unknown time_embedding_norm :
{
self
.
time_embedding_norm
}
"
)
else
:
self
.
time_emb_proj
=
None
if
self
.
time_embedding_norm
==
"ada_group"
:
self
.
norm2
=
AdaGroupNorm
(
temb_channels
,
out_channels
,
groups_out
,
eps
=
eps
)
elif
self
.
time_embedding_norm
==
"spatial"
:
self
.
norm2
=
SpatialNorm
(
out_channels
,
temb_channels
)
else
:
self
.
norm2
=
torch
.
nn
.
GroupNorm
(
num_groups
=
groups_out
,
num_channels
=
out_channels
,
eps
=
eps
,
affine
=
True
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
conv_3d_out_channels
=
conv_3d_out_channels
or
out_channels
self
.
conv2
=
CausalConv3d
(
out_channels
,
conv_3d_out_channels
,
kernel_size
=
3
,
stride
=
1
)
self
.
nonlinearity
=
get_activation
(
non_linearity
)
self
.
upsample
=
self
.
downsample
=
None
if
self
.
up
:
self
.
upsample
=
UpsampleCausal3D
(
in_channels
,
use_conv
=
False
)
elif
self
.
down
:
self
.
downsample
=
DownsampleCausal3D
(
in_channels
,
use_conv
=
False
,
name
=
"op"
)
self
.
use_in_shortcut
=
self
.
in_channels
!=
conv_3d_out_channels
if
use_in_shortcut
is
None
else
use_in_shortcut
self
.
conv_shortcut
=
None
if
self
.
use_in_shortcut
:
self
.
conv_shortcut
=
CausalConv3d
(
in_channels
,
conv_3d_out_channels
,
kernel_size
=
1
,
stride
=
1
,
bias
=
conv_shortcut_bias
,
)
def
forward
(
self
,
input_tensor
:
torch
.
FloatTensor
,
temb
:
torch
.
FloatTensor
,
scale
:
float
=
1.0
,
)
->
torch
.
FloatTensor
:
hidden_states
=
input_tensor
if
self
.
time_embedding_norm
==
"ada_group"
or
self
.
time_embedding_norm
==
"spatial"
:
hidden_states
=
self
.
norm1
(
hidden_states
,
temb
)
else
:
hidden_states
=
self
.
norm1
(
hidden_states
)
hidden_states
=
self
.
nonlinearity
(
hidden_states
)
if
self
.
upsample
is
not
None
:
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if
hidden_states
.
shape
[
0
]
>=
64
:
input_tensor
=
input_tensor
.
contiguous
()
hidden_states
=
hidden_states
.
contiguous
()
input_tensor
=
(
self
.
upsample
(
input_tensor
,
scale
=
scale
)
)
hidden_states
=
(
self
.
upsample
(
hidden_states
,
scale
=
scale
)
)
elif
self
.
downsample
is
not
None
:
input_tensor
=
(
self
.
downsample
(
input_tensor
,
scale
=
scale
)
)
hidden_states
=
(
self
.
downsample
(
hidden_states
,
scale
=
scale
)
)
hidden_states
=
self
.
conv1
(
hidden_states
)
if
self
.
time_emb_proj
is
not
None
:
if
not
self
.
skip_time_act
:
temb
=
self
.
nonlinearity
(
temb
)
temb
=
(
self
.
time_emb_proj
(
temb
,
scale
)[:,
:,
None
,
None
]
)
if
temb
is
not
None
and
self
.
time_embedding_norm
==
"default"
:
hidden_states
=
hidden_states
+
temb
if
self
.
time_embedding_norm
==
"ada_group"
or
self
.
time_embedding_norm
==
"spatial"
:
hidden_states
=
self
.
norm2
(
hidden_states
,
temb
)
else
:
hidden_states
=
self
.
norm2
(
hidden_states
)
if
temb
is
not
None
and
self
.
time_embedding_norm
==
"scale_shift"
:
scale
,
shift
=
torch
.
chunk
(
temb
,
2
,
dim
=
1
)
hidden_states
=
hidden_states
*
(
1
+
scale
)
+
shift
hidden_states
=
self
.
nonlinearity
(
hidden_states
)
hidden_states
=
self
.
dropout
(
hidden_states
)
hidden_states
=
self
.
conv2
(
hidden_states
)
if
self
.
conv_shortcut
is
not
None
:
input_tensor
=
(
self
.
conv_shortcut
(
input_tensor
)
)
output_tensor
=
(
input_tensor
+
hidden_states
)
/
self
.
output_scale_factor
return
output_tensor
def
get_down_block3d
(
down_block_type
:
str
,
num_layers
:
int
,
in_channels
:
int
,
out_channels
:
int
,
temb_channels
:
int
,
add_downsample
:
bool
,
downsample_stride
:
int
,
resnet_eps
:
float
,
resnet_act_fn
:
str
,
transformer_layers_per_block
:
int
=
1
,
num_attention_heads
:
Optional
[
int
]
=
None
,
resnet_groups
:
Optional
[
int
]
=
None
,
cross_attention_dim
:
Optional
[
int
]
=
None
,
downsample_padding
:
Optional
[
int
]
=
None
,
dual_cross_attention
:
bool
=
False
,
use_linear_projection
:
bool
=
False
,
only_cross_attention
:
bool
=
False
,
upcast_attention
:
bool
=
False
,
resnet_time_scale_shift
:
str
=
"default"
,
attention_type
:
str
=
"default"
,
resnet_skip_time_act
:
bool
=
False
,
resnet_out_scale_factor
:
float
=
1.0
,
cross_attention_norm
:
Optional
[
str
]
=
None
,
attention_head_dim
:
Optional
[
int
]
=
None
,
downsample_type
:
Optional
[
str
]
=
None
,
dropout
:
float
=
0.0
,
):
# If attn head dim is not defined, we default it to the number of heads
if
attention_head_dim
is
None
:
logger
.
warn
(
f
"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to
{
num_attention_heads
}
."
)
attention_head_dim
=
num_attention_heads
down_block_type
=
down_block_type
[
7
:]
if
down_block_type
.
startswith
(
"UNetRes"
)
else
down_block_type
if
down_block_type
==
"DownEncoderBlockCausal3D"
:
return
DownEncoderBlockCausal3D
(
num_layers
=
num_layers
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
dropout
=
dropout
,
add_downsample
=
add_downsample
,
downsample_stride
=
downsample_stride
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
resnet_groups
=
resnet_groups
,
downsample_padding
=
downsample_padding
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
)
raise
ValueError
(
f
"
{
down_block_type
}
does not exist."
)
def
get_up_block3d
(
up_block_type
:
str
,
num_layers
:
int
,
in_channels
:
int
,
out_channels
:
int
,
prev_output_channel
:
int
,
temb_channels
:
int
,
add_upsample
:
bool
,
upsample_scale_factor
:
Tuple
,
resnet_eps
:
float
,
resnet_act_fn
:
str
,
resolution_idx
:
Optional
[
int
]
=
None
,
transformer_layers_per_block
:
int
=
1
,
num_attention_heads
:
Optional
[
int
]
=
None
,
resnet_groups
:
Optional
[
int
]
=
None
,
cross_attention_dim
:
Optional
[
int
]
=
None
,
dual_cross_attention
:
bool
=
False
,
use_linear_projection
:
bool
=
False
,
only_cross_attention
:
bool
=
False
,
upcast_attention
:
bool
=
False
,
resnet_time_scale_shift
:
str
=
"default"
,
attention_type
:
str
=
"default"
,
resnet_skip_time_act
:
bool
=
False
,
resnet_out_scale_factor
:
float
=
1.0
,
cross_attention_norm
:
Optional
[
str
]
=
None
,
attention_head_dim
:
Optional
[
int
]
=
None
,
upsample_type
:
Optional
[
str
]
=
None
,
dropout
:
float
=
0.0
,
)
->
nn
.
Module
:
# If attn head dim is not defined, we default it to the number of heads
if
attention_head_dim
is
None
:
logger
.
warn
(
f
"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to
{
num_attention_heads
}
."
)
attention_head_dim
=
num_attention_heads
up_block_type
=
up_block_type
[
7
:]
if
up_block_type
.
startswith
(
"UNetRes"
)
else
up_block_type
if
up_block_type
==
"UpDecoderBlockCausal3D"
:
return
UpDecoderBlockCausal3D
(
num_layers
=
num_layers
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
resolution_idx
=
resolution_idx
,
dropout
=
dropout
,
add_upsample
=
add_upsample
,
upsample_scale_factor
=
upsample_scale_factor
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
resnet_groups
=
resnet_groups
,
resnet_time_scale_shift
=
resnet_time_scale_shift
,
temb_channels
=
temb_channels
,
)
raise
ValueError
(
f
"
{
up_block_type
}
does not exist."
)
class
UNetMidBlockCausal3D
(
nn
.
Module
):
"""
A 3D UNet mid-block [`UNetMidBlockCausal3D`] with multiple residual blocks and optional attention blocks.
"""
def
__init__
(
self
,
in_channels
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
# default, spatial
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
attn_groups
:
Optional
[
int
]
=
None
,
resnet_pre_norm
:
bool
=
True
,
add_attention
:
bool
=
True
,
attention_head_dim
:
int
=
1
,
output_scale_factor
:
float
=
1.0
,
):
super
().
__init__
()
resnet_groups
=
resnet_groups
if
resnet_groups
is
not
None
else
min
(
in_channels
//
4
,
32
)
self
.
add_attention
=
add_attention
if
attn_groups
is
None
:
attn_groups
=
resnet_groups
if
resnet_time_scale_shift
==
"default"
else
None
# there is always at least one resnet
resnets
=
[
ResnetBlockCausal3D
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
]
attentions
=
[]
if
attention_head_dim
is
None
:
logger
.
warn
(
f
"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`:
{
in_channels
}
."
)
attention_head_dim
=
in_channels
for
_
in
range
(
num_layers
):
if
self
.
add_attention
:
attentions
.
append
(
Attention
(
in_channels
,
heads
=
in_channels
//
attention_head_dim
,
dim_head
=
attention_head_dim
,
rescale_output_factor
=
output_scale_factor
,
eps
=
resnet_eps
,
norm_num_groups
=
attn_groups
,
spatial_norm_dim
=
temb_channels
if
resnet_time_scale_shift
==
"spatial"
else
None
,
residual_connection
=
True
,
bias
=
True
,
upcast_softmax
=
True
,
_from_deprecated_attn_block
=
True
,
)
)
else
:
attentions
.
append
(
None
)
resnets
.
append
(
ResnetBlockCausal3D
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
def
forward
(
self
,
hidden_states
:
torch
.
FloatTensor
,
temb
:
Optional
[
torch
.
FloatTensor
]
=
None
)
->
torch
.
FloatTensor
:
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
)
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
[
1
:]):
if
attn
is
not
None
:
B
,
C
,
T
,
H
,
W
=
hidden_states
.
shape
hidden_states
=
rearrange
(
hidden_states
,
"b c f h w -> b (f h w) c"
)
attention_mask
=
prepare_causal_attention_mask
(
T
,
H
*
W
,
hidden_states
.
dtype
,
hidden_states
.
device
,
batch_size
=
B
)
hidden_states
=
attn
(
hidden_states
,
temb
=
temb
,
attention_mask
=
attention_mask
)
hidden_states
=
rearrange
(
hidden_states
,
"b (f h w) c -> b c f h w"
,
f
=
T
,
h
=
H
,
w
=
W
)
hidden_states
=
resnet
(
hidden_states
,
temb
)
return
hidden_states
class
DownEncoderBlockCausal3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
output_scale_factor
:
float
=
1.0
,
add_downsample
:
bool
=
True
,
downsample_stride
:
int
=
2
,
downsample_padding
:
int
=
1
,
):
super
().
__init__
()
resnets
=
[]
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
ResnetBlockCausal3D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
temb_channels
=
None
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
add_downsample
:
self
.
downsamplers
=
nn
.
ModuleList
(
[
DownsampleCausal3D
(
out_channels
,
use_conv
=
True
,
out_channels
=
out_channels
,
padding
=
downsample_padding
,
name
=
"op"
,
stride
=
downsample_stride
,
)
]
)
else
:
self
.
downsamplers
=
None
def
forward
(
self
,
hidden_states
:
torch
.
FloatTensor
,
scale
:
float
=
1.0
)
->
torch
.
FloatTensor
:
for
resnet
in
self
.
resnets
:
hidden_states
=
resnet
(
hidden_states
,
temb
=
None
,
scale
=
scale
)
if
self
.
downsamplers
is
not
None
:
for
downsampler
in
self
.
downsamplers
:
hidden_states
=
downsampler
(
hidden_states
,
scale
)
return
hidden_states
class
UpDecoderBlockCausal3D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
resolution_idx
:
Optional
[
int
]
=
None
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
# default, spatial
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
output_scale_factor
:
float
=
1.0
,
add_upsample
:
bool
=
True
,
upsample_scale_factor
=
(
2
,
2
,
2
),
temb_channels
:
Optional
[
int
]
=
None
,
):
super
().
__init__
()
resnets
=
[]
for
i
in
range
(
num_layers
):
input_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
ResnetBlockCausal3D
(
in_channels
=
input_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
add_upsample
:
self
.
upsamplers
=
nn
.
ModuleList
(
[
UpsampleCausal3D
(
out_channels
,
use_conv
=
True
,
out_channels
=
out_channels
,
upsample_factor
=
upsample_scale_factor
,
)
]
)
else
:
self
.
upsamplers
=
None
self
.
resolution_idx
=
resolution_idx
def
forward
(
self
,
hidden_states
:
torch
.
FloatTensor
,
temb
:
Optional
[
torch
.
FloatTensor
]
=
None
,
scale
:
float
=
1.0
)
->
torch
.
FloatTensor
:
for
resnet
in
self
.
resnets
:
hidden_states
=
resnet
(
hidden_states
,
temb
=
temb
,
scale
=
scale
)
if
self
.
upsamplers
is
not
None
:
for
upsampler
in
self
.
upsamplers
:
hidden_states
=
upsampler
(
hidden_states
)
return
hidden_states
lightx2v/text2v/models/video_encoders/hf/autoencoder_kl_causal_3d/vae.py
0 → 100755
View file @
daf4c74e
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
diffusers.utils
import
BaseOutput
,
is_torch_version
from
diffusers.utils.torch_utils
import
randn_tensor
from
diffusers.models.attention_processor
import
SpatialNorm
from
.unet_causal_3d_blocks
import
(
CausalConv3d
,
UNetMidBlockCausal3D
,
get_down_block3d
,
get_up_block3d
,
)
@
dataclass
class
DecoderOutput
(
BaseOutput
):
r
"""
Output of decoding method.
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
The decoded output sample from the last layer of the model.
"""
sample
:
torch
.
FloatTensor
class
EncoderCausal3D
(
nn
.
Module
):
r
"""
The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation.
"""
def
__init__
(
self
,
in_channels
:
int
=
3
,
out_channels
:
int
=
3
,
down_block_types
:
Tuple
[
str
,
...]
=
(
"DownEncoderBlockCausal3D"
,),
block_out_channels
:
Tuple
[
int
,
...]
=
(
64
,),
layers_per_block
:
int
=
2
,
norm_num_groups
:
int
=
32
,
act_fn
:
str
=
"silu"
,
double_z
:
bool
=
True
,
mid_block_add_attention
=
True
,
time_compression_ratio
:
int
=
4
,
spatial_compression_ratio
:
int
=
8
,
):
super
().
__init__
()
self
.
layers_per_block
=
layers_per_block
self
.
conv_in
=
CausalConv3d
(
in_channels
,
block_out_channels
[
0
],
kernel_size
=
3
,
stride
=
1
)
self
.
mid_block
=
None
self
.
down_blocks
=
nn
.
ModuleList
([])
# down
output_channel
=
block_out_channels
[
0
]
for
i
,
down_block_type
in
enumerate
(
down_block_types
):
input_channel
=
output_channel
output_channel
=
block_out_channels
[
i
]
is_final_block
=
i
==
len
(
block_out_channels
)
-
1
num_spatial_downsample_layers
=
int
(
np
.
log2
(
spatial_compression_ratio
))
num_time_downsample_layers
=
int
(
np
.
log2
(
time_compression_ratio
))
if
time_compression_ratio
==
4
:
add_spatial_downsample
=
bool
(
i
<
num_spatial_downsample_layers
)
add_time_downsample
=
bool
(
i
>=
(
len
(
block_out_channels
)
-
1
-
num_time_downsample_layers
)
and
not
is_final_block
)
else
:
raise
ValueError
(
f
"Unsupported time_compression_ratio:
{
time_compression_ratio
}
."
)
downsample_stride_HW
=
(
2
,
2
)
if
add_spatial_downsample
else
(
1
,
1
)
downsample_stride_T
=
(
2
,)
if
add_time_downsample
else
(
1
,)
downsample_stride
=
tuple
(
downsample_stride_T
+
downsample_stride_HW
)
down_block
=
get_down_block3d
(
down_block_type
,
num_layers
=
self
.
layers_per_block
,
in_channels
=
input_channel
,
out_channels
=
output_channel
,
add_downsample
=
bool
(
add_spatial_downsample
or
add_time_downsample
),
downsample_stride
=
downsample_stride
,
resnet_eps
=
1e-6
,
downsample_padding
=
0
,
resnet_act_fn
=
act_fn
,
resnet_groups
=
norm_num_groups
,
attention_head_dim
=
output_channel
,
temb_channels
=
None
,
)
self
.
down_blocks
.
append
(
down_block
)
# mid
self
.
mid_block
=
UNetMidBlockCausal3D
(
in_channels
=
block_out_channels
[
-
1
],
resnet_eps
=
1e-6
,
resnet_act_fn
=
act_fn
,
output_scale_factor
=
1
,
resnet_time_scale_shift
=
"default"
,
attention_head_dim
=
block_out_channels
[
-
1
],
resnet_groups
=
norm_num_groups
,
temb_channels
=
None
,
add_attention
=
mid_block_add_attention
,
)
# out
self
.
conv_norm_out
=
nn
.
GroupNorm
(
num_channels
=
block_out_channels
[
-
1
],
num_groups
=
norm_num_groups
,
eps
=
1e-6
)
self
.
conv_act
=
nn
.
SiLU
()
conv_out_channels
=
2
*
out_channels
if
double_z
else
out_channels
self
.
conv_out
=
CausalConv3d
(
block_out_channels
[
-
1
],
conv_out_channels
,
kernel_size
=
3
)
def
forward
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
r
"""The forward method of the `EncoderCausal3D` class."""
assert
len
(
sample
.
shape
)
==
5
,
"The input tensor should have 5 dimensions"
sample
=
self
.
conv_in
(
sample
)
# down
for
down_block
in
self
.
down_blocks
:
sample
=
down_block
(
sample
)
# middle
sample
=
self
.
mid_block
(
sample
)
# post-process
sample
=
self
.
conv_norm_out
(
sample
)
sample
=
self
.
conv_act
(
sample
)
sample
=
self
.
conv_out
(
sample
)
return
sample
class
DecoderCausal3D
(
nn
.
Module
):
r
"""
The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample.
"""
def
__init__
(
self
,
in_channels
:
int
=
3
,
out_channels
:
int
=
3
,
up_block_types
:
Tuple
[
str
,
...]
=
(
"UpDecoderBlockCausal3D"
,),
block_out_channels
:
Tuple
[
int
,
...]
=
(
64
,),
layers_per_block
:
int
=
2
,
norm_num_groups
:
int
=
32
,
act_fn
:
str
=
"silu"
,
norm_type
:
str
=
"group"
,
# group, spatial
mid_block_add_attention
=
True
,
time_compression_ratio
:
int
=
4
,
spatial_compression_ratio
:
int
=
8
,
):
super
().
__init__
()
self
.
layers_per_block
=
layers_per_block
self
.
conv_in
=
CausalConv3d
(
in_channels
,
block_out_channels
[
-
1
],
kernel_size
=
3
,
stride
=
1
)
self
.
mid_block
=
None
self
.
up_blocks
=
nn
.
ModuleList
([])
temb_channels
=
in_channels
if
norm_type
==
"spatial"
else
None
# mid
self
.
mid_block
=
UNetMidBlockCausal3D
(
in_channels
=
block_out_channels
[
-
1
],
resnet_eps
=
1e-6
,
resnet_act_fn
=
act_fn
,
output_scale_factor
=
1
,
resnet_time_scale_shift
=
"default"
if
norm_type
==
"group"
else
norm_type
,
attention_head_dim
=
block_out_channels
[
-
1
],
resnet_groups
=
norm_num_groups
,
temb_channels
=
temb_channels
,
add_attention
=
mid_block_add_attention
,
)
# up
reversed_block_out_channels
=
list
(
reversed
(
block_out_channels
))
output_channel
=
reversed_block_out_channels
[
0
]
for
i
,
up_block_type
in
enumerate
(
up_block_types
):
prev_output_channel
=
output_channel
output_channel
=
reversed_block_out_channels
[
i
]
is_final_block
=
i
==
len
(
block_out_channels
)
-
1
num_spatial_upsample_layers
=
int
(
np
.
log2
(
spatial_compression_ratio
))
num_time_upsample_layers
=
int
(
np
.
log2
(
time_compression_ratio
))
if
time_compression_ratio
==
4
:
add_spatial_upsample
=
bool
(
i
<
num_spatial_upsample_layers
)
add_time_upsample
=
bool
(
i
>=
len
(
block_out_channels
)
-
1
-
num_time_upsample_layers
and
not
is_final_block
)
else
:
raise
ValueError
(
f
"Unsupported time_compression_ratio:
{
time_compression_ratio
}
."
)
upsample_scale_factor_HW
=
(
2
,
2
)
if
add_spatial_upsample
else
(
1
,
1
)
upsample_scale_factor_T
=
(
2
,)
if
add_time_upsample
else
(
1
,)
upsample_scale_factor
=
tuple
(
upsample_scale_factor_T
+
upsample_scale_factor_HW
)
up_block
=
get_up_block3d
(
up_block_type
,
num_layers
=
self
.
layers_per_block
+
1
,
in_channels
=
prev_output_channel
,
out_channels
=
output_channel
,
prev_output_channel
=
None
,
add_upsample
=
bool
(
add_spatial_upsample
or
add_time_upsample
),
upsample_scale_factor
=
upsample_scale_factor
,
resnet_eps
=
1e-6
,
resnet_act_fn
=
act_fn
,
resnet_groups
=
norm_num_groups
,
attention_head_dim
=
output_channel
,
temb_channels
=
temb_channels
,
resnet_time_scale_shift
=
norm_type
,
)
self
.
up_blocks
.
append
(
up_block
)
prev_output_channel
=
output_channel
# out
if
norm_type
==
"spatial"
:
self
.
conv_norm_out
=
SpatialNorm
(
block_out_channels
[
0
],
temb_channels
)
else
:
self
.
conv_norm_out
=
nn
.
GroupNorm
(
num_channels
=
block_out_channels
[
0
],
num_groups
=
norm_num_groups
,
eps
=
1e-6
)
self
.
conv_act
=
nn
.
SiLU
()
self
.
conv_out
=
CausalConv3d
(
block_out_channels
[
0
],
out_channels
,
kernel_size
=
3
)
self
.
gradient_checkpointing
=
False
def
forward
(
self
,
sample
:
torch
.
FloatTensor
,
latent_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
)
->
torch
.
FloatTensor
:
r
"""The forward method of the `DecoderCausal3D` class."""
assert
len
(
sample
.
shape
)
==
5
,
"The input tensor should have 5 dimensions."
sample
=
self
.
conv_in
(
sample
)
upscale_dtype
=
next
(
iter
(
self
.
up_blocks
.
parameters
())).
dtype
if
self
.
training
and
self
.
gradient_checkpointing
:
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
return
module
(
*
inputs
)
return
custom_forward
if
is_torch_version
(
">="
,
"1.11.0"
):
# middle
sample
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
self
.
mid_block
),
sample
,
latent_embeds
,
use_reentrant
=
False
,
)
sample
=
sample
.
to
(
upscale_dtype
)
# up
for
up_block
in
self
.
up_blocks
:
sample
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
up_block
),
sample
,
latent_embeds
,
use_reentrant
=
False
,
)
else
:
# middle
sample
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
self
.
mid_block
),
sample
,
latent_embeds
)
sample
=
sample
.
to
(
upscale_dtype
)
# up
for
up_block
in
self
.
up_blocks
:
sample
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
up_block
),
sample
,
latent_embeds
)
else
:
# middle
sample
=
self
.
mid_block
(
sample
,
latent_embeds
)
sample
=
sample
.
to
(
upscale_dtype
)
# up
for
up_block
in
self
.
up_blocks
:
sample
=
up_block
(
sample
,
latent_embeds
)
# post-process
if
latent_embeds
is
None
:
sample
=
self
.
conv_norm_out
(
sample
)
else
:
sample
=
self
.
conv_norm_out
(
sample
,
latent_embeds
)
sample
=
self
.
conv_act
(
sample
)
sample
=
self
.
conv_out
(
sample
)
return
sample
class
DiagonalGaussianDistribution
(
object
):
def
__init__
(
self
,
parameters
:
torch
.
Tensor
,
deterministic
:
bool
=
False
):
if
parameters
.
ndim
==
3
:
dim
=
2
# (B, L, C)
elif
parameters
.
ndim
==
5
or
parameters
.
ndim
==
4
:
dim
=
1
# (B, C, T, H ,W) / (B, C, H, W)
else
:
raise
NotImplementedError
self
.
parameters
=
parameters
self
.
mean
,
self
.
logvar
=
torch
.
chunk
(
parameters
,
2
,
dim
=
dim
)
self
.
logvar
=
torch
.
clamp
(
self
.
logvar
,
-
30.0
,
20.0
)
self
.
deterministic
=
deterministic
self
.
std
=
torch
.
exp
(
0.5
*
self
.
logvar
)
self
.
var
=
torch
.
exp
(
self
.
logvar
)
if
self
.
deterministic
:
self
.
var
=
self
.
std
=
torch
.
zeros_like
(
self
.
mean
,
device
=
self
.
parameters
.
device
,
dtype
=
self
.
parameters
.
dtype
)
def
sample
(
self
,
generator
:
Optional
[
torch
.
Generator
]
=
None
)
->
torch
.
FloatTensor
:
# make sure sample is on the same device as the parameters and has same dtype
sample
=
randn_tensor
(
self
.
mean
.
shape
,
generator
=
generator
,
device
=
self
.
parameters
.
device
,
dtype
=
self
.
parameters
.
dtype
,
)
x
=
self
.
mean
+
self
.
std
*
sample
return
x
def
kl
(
self
,
other
:
"DiagonalGaussianDistribution"
=
None
)
->
torch
.
Tensor
:
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.0
])
else
:
reduce_dim
=
list
(
range
(
1
,
self
.
mean
.
ndim
))
if
other
is
None
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
,
2
)
+
self
.
var
-
1.0
-
self
.
logvar
,
dim
=
reduce_dim
,
)
else
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
-
other
.
mean
,
2
)
/
other
.
var
+
self
.
var
/
other
.
var
-
1.0
-
self
.
logvar
+
other
.
logvar
,
dim
=
reduce_dim
,
)
def
nll
(
self
,
sample
:
torch
.
Tensor
,
dims
:
Tuple
[
int
,
...]
=
[
1
,
2
,
3
])
->
torch
.
Tensor
:
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.0
])
logtwopi
=
np
.
log
(
2.0
*
np
.
pi
)
return
0.5
*
torch
.
sum
(
logtwopi
+
self
.
logvar
+
torch
.
pow
(
sample
-
self
.
mean
,
2
)
/
self
.
var
,
dim
=
dims
,
)
def
mode
(
self
)
->
torch
.
Tensor
:
return
self
.
mean
lightx2v/text2v/models/video_encoders/hf/wan/__init__.py
0 → 100755
View file @
daf4c74e
lightx2v/text2v/models/video_encoders/hf/wan/vae.py
0 → 100755
View file @
daf4c74e
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import
logging
import
torch
import
torch.cuda.amp
as
amp
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
__all__
=
[
"WanVAE"
,
]
CACHE_T
=
2
class
CausalConv3d
(
nn
.
Conv3d
):
"""
Causal 3d convolusion.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_padding
=
(
self
.
padding
[
2
],
self
.
padding
[
2
],
self
.
padding
[
1
],
self
.
padding
[
1
],
2
*
self
.
padding
[
0
],
0
,
)
self
.
padding
=
(
0
,
0
,
0
)
def
forward
(
self
,
x
,
cache_x
=
None
):
padding
=
list
(
self
.
_padding
)
if
cache_x
is
not
None
and
self
.
_padding
[
4
]
>
0
:
cache_x
=
cache_x
.
to
(
x
.
device
)
x
=
torch
.
cat
([
cache_x
,
x
],
dim
=
2
)
padding
[
4
]
-=
cache_x
.
shape
[
2
]
x
=
F
.
pad
(
x
,
padding
)
return
super
().
forward
(
x
)
class
RMS_norm
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
channel_first
=
True
,
images
=
True
,
bias
=
False
):
super
().
__init__
()
broadcastable_dims
=
(
1
,
1
,
1
)
if
not
images
else
(
1
,
1
)
shape
=
(
dim
,
*
broadcastable_dims
)
if
channel_first
else
(
dim
,)
self
.
channel_first
=
channel_first
self
.
scale
=
dim
**
0.5
self
.
gamma
=
nn
.
Parameter
(
torch
.
ones
(
shape
))
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
shape
))
if
bias
else
0.0
def
forward
(
self
,
x
):
return
(
F
.
normalize
(
x
,
dim
=
(
1
if
self
.
channel_first
else
-
1
))
*
self
.
scale
*
self
.
gamma
+
self
.
bias
)
class
Upsample
(
nn
.
Upsample
):
def
forward
(
self
,
x
):
"""
Fix bfloat16 support for nearest neighbor interpolation.
"""
return
super
().
forward
(
x
.
float
()).
type_as
(
x
)
class
Resample
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
mode
):
assert
mode
in
(
"none"
,
"upsample2d"
,
"upsample3d"
,
"downsample2d"
,
"downsample3d"
,
)
super
().
__init__
()
self
.
dim
=
dim
self
.
mode
=
mode
# layers
if
mode
==
"upsample2d"
:
self
.
resample
=
nn
.
Sequential
(
Upsample
(
scale_factor
=
(
2.0
,
2.0
),
mode
=
"nearest-exact"
),
nn
.
Conv2d
(
dim
,
dim
//
2
,
3
,
padding
=
1
),
)
elif
mode
==
"upsample3d"
:
self
.
resample
=
nn
.
Sequential
(
Upsample
(
scale_factor
=
(
2.0
,
2.0
),
mode
=
"nearest-exact"
),
nn
.
Conv2d
(
dim
,
dim
//
2
,
3
,
padding
=
1
),
)
self
.
time_conv
=
CausalConv3d
(
dim
,
dim
*
2
,
(
3
,
1
,
1
),
padding
=
(
1
,
0
,
0
))
elif
mode
==
"downsample2d"
:
self
.
resample
=
nn
.
Sequential
(
nn
.
ZeroPad2d
((
0
,
1
,
0
,
1
)),
nn
.
Conv2d
(
dim
,
dim
,
3
,
stride
=
(
2
,
2
))
)
elif
mode
==
"downsample3d"
:
self
.
resample
=
nn
.
Sequential
(
nn
.
ZeroPad2d
((
0
,
1
,
0
,
1
)),
nn
.
Conv2d
(
dim
,
dim
,
3
,
stride
=
(
2
,
2
))
)
self
.
time_conv
=
CausalConv3d
(
dim
,
dim
,
(
3
,
1
,
1
),
stride
=
(
2
,
1
,
1
),
padding
=
(
0
,
0
,
0
)
)
else
:
self
.
resample
=
nn
.
Identity
()
def
forward
(
self
,
x
,
feat_cache
=
None
,
feat_idx
=
[
0
]):
b
,
c
,
t
,
h
,
w
=
x
.
size
()
if
self
.
mode
==
"upsample3d"
:
if
feat_cache
is
not
None
:
idx
=
feat_idx
[
0
]
if
feat_cache
[
idx
]
is
None
:
feat_cache
[
idx
]
=
"Rep"
feat_idx
[
0
]
+=
1
else
:
cache_x
=
x
[:,
:,
-
CACHE_T
:,
:,
:].
clone
()
if
(
cache_x
.
shape
[
2
]
<
2
and
feat_cache
[
idx
]
is
not
None
and
feat_cache
[
idx
]
!=
"Rep"
):
# cache last frame of last two chunk
cache_x
=
torch
.
cat
(
[
feat_cache
[
idx
][:,
:,
-
1
,
:,
:]
.
unsqueeze
(
2
)
.
to
(
cache_x
.
device
),
cache_x
,
],
dim
=
2
,
)
if
(
cache_x
.
shape
[
2
]
<
2
and
feat_cache
[
idx
]
is
not
None
and
feat_cache
[
idx
]
==
"Rep"
):
cache_x
=
torch
.
cat
(
[
torch
.
zeros_like
(
cache_x
).
to
(
cache_x
.
device
),
cache_x
],
dim
=
2
,
)
if
feat_cache
[
idx
]
==
"Rep"
:
x
=
self
.
time_conv
(
x
)
else
:
x
=
self
.
time_conv
(
x
,
feat_cache
[
idx
])
feat_cache
[
idx
]
=
cache_x
feat_idx
[
0
]
+=
1
x
=
x
.
reshape
(
b
,
2
,
c
,
t
,
h
,
w
)
x
=
torch
.
stack
((
x
[:,
0
,
:,
:,
:,
:],
x
[:,
1
,
:,
:,
:,
:]),
3
)
x
=
x
.
reshape
(
b
,
c
,
t
*
2
,
h
,
w
)
t
=
x
.
shape
[
2
]
x
=
rearrange
(
x
,
"b c t h w -> (b t) c h w"
)
x
=
self
.
resample
(
x
)
x
=
rearrange
(
x
,
"(b t) c h w -> b c t h w"
,
t
=
t
)
if
self
.
mode
==
"downsample3d"
:
if
feat_cache
is
not
None
:
idx
=
feat_idx
[
0
]
if
feat_cache
[
idx
]
is
None
:
feat_cache
[
idx
]
=
x
.
clone
()
feat_idx
[
0
]
+=
1
else
:
cache_x
=
x
[:,
:,
-
1
:,
:,
:].
clone
()
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
# # cache last frame of last two chunk
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x
=
self
.
time_conv
(
torch
.
cat
([
feat_cache
[
idx
][:,
:,
-
1
:,
:,
:],
x
],
2
)
)
feat_cache
[
idx
]
=
cache_x
feat_idx
[
0
]
+=
1
return
x
def
init_weight
(
self
,
conv
):
conv_weight
=
conv
.
weight
nn
.
init
.
zeros_
(
conv_weight
)
c1
,
c2
,
t
,
h
,
w
=
conv_weight
.
size
()
one_matrix
=
torch
.
eye
(
c1
,
c2
)
init_matrix
=
one_matrix
nn
.
init
.
zeros_
(
conv_weight
)
# conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
conv_weight
.
data
[:,
:,
1
,
0
,
0
]
=
init_matrix
# * 0.5
conv
.
weight
.
data
.
copy_
(
conv_weight
)
nn
.
init
.
zeros_
(
conv
.
bias
.
data
)
def
init_weight2
(
self
,
conv
):
conv_weight
=
conv
.
weight
.
data
nn
.
init
.
zeros_
(
conv_weight
)
c1
,
c2
,
t
,
h
,
w
=
conv_weight
.
size
()
init_matrix
=
torch
.
eye
(
c1
//
2
,
c2
)
# init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
conv_weight
[:
c1
//
2
,
:,
-
1
,
0
,
0
]
=
init_matrix
conv_weight
[
c1
//
2
:,
:,
-
1
,
0
,
0
]
=
init_matrix
conv
.
weight
.
data
.
copy_
(
conv_weight
)
nn
.
init
.
zeros_
(
conv
.
bias
.
data
)
class
ResidualBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
out_dim
,
dropout
=
0.0
):
super
().
__init__
()
self
.
in_dim
=
in_dim
self
.
out_dim
=
out_dim
# layers
self
.
residual
=
nn
.
Sequential
(
RMS_norm
(
in_dim
,
images
=
False
),
nn
.
SiLU
(),
CausalConv3d
(
in_dim
,
out_dim
,
3
,
padding
=
1
),
RMS_norm
(
out_dim
,
images
=
False
),
nn
.
SiLU
(),
nn
.
Dropout
(
dropout
),
CausalConv3d
(
out_dim
,
out_dim
,
3
,
padding
=
1
),
)
self
.
shortcut
=
(
CausalConv3d
(
in_dim
,
out_dim
,
1
)
if
in_dim
!=
out_dim
else
nn
.
Identity
()
)
def
forward
(
self
,
x
,
feat_cache
=
None
,
feat_idx
=
[
0
]):
h
=
self
.
shortcut
(
x
)
for
layer
in
self
.
residual
:
if
isinstance
(
layer
,
CausalConv3d
)
and
feat_cache
is
not
None
:
idx
=
feat_idx
[
0
]
cache_x
=
x
[:,
:,
-
CACHE_T
:,
:,
:].
clone
()
if
cache_x
.
shape
[
2
]
<
2
and
feat_cache
[
idx
]
is
not
None
:
# cache last frame of last two chunk
cache_x
=
torch
.
cat
(
[
feat_cache
[
idx
][:,
:,
-
1
,
:,
:]
.
unsqueeze
(
2
)
.
to
(
cache_x
.
device
),
cache_x
,
],
dim
=
2
,
)
x
=
layer
(
x
,
feat_cache
[
idx
])
feat_cache
[
idx
]
=
cache_x
feat_idx
[
0
]
+=
1
else
:
x
=
layer
(
x
)
return
x
+
h
class
AttentionBlock
(
nn
.
Module
):
"""
Causal self-attention with a single head.
"""
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
dim
=
dim
# layers
self
.
norm
=
RMS_norm
(
dim
)
self
.
to_qkv
=
nn
.
Conv2d
(
dim
,
dim
*
3
,
1
)
self
.
proj
=
nn
.
Conv2d
(
dim
,
dim
,
1
)
# zero out the last layer params
nn
.
init
.
zeros_
(
self
.
proj
.
weight
)
def
forward
(
self
,
x
):
identity
=
x
b
,
c
,
t
,
h
,
w
=
x
.
size
()
x
=
rearrange
(
x
,
"b c t h w -> (b t) c h w"
)
x
=
self
.
norm
(
x
)
# compute query, key, value
q
,
k
,
v
=
(
self
.
to_qkv
(
x
)
.
reshape
(
b
*
t
,
1
,
c
*
3
,
-
1
)
.
permute
(
0
,
1
,
3
,
2
)
.
contiguous
()
.
chunk
(
3
,
dim
=-
1
)
)
# apply attention
x
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
)
x
=
x
.
squeeze
(
1
).
permute
(
0
,
2
,
1
).
reshape
(
b
*
t
,
c
,
h
,
w
)
# output
x
=
self
.
proj
(
x
)
x
=
rearrange
(
x
,
"(b t) c h w-> b c t h w"
,
t
=
t
)
return
x
+
identity
class
Encoder3d
(
nn
.
Module
):
def
__init__
(
self
,
dim
=
128
,
z_dim
=
4
,
dim_mult
=
[
1
,
2
,
4
,
4
],
num_res_blocks
=
2
,
attn_scales
=
[],
temperal_downsample
=
[
True
,
True
,
False
],
dropout
=
0.0
,
):
super
().
__init__
()
self
.
dim
=
dim
self
.
z_dim
=
z_dim
self
.
dim_mult
=
dim_mult
self
.
num_res_blocks
=
num_res_blocks
self
.
attn_scales
=
attn_scales
self
.
temperal_downsample
=
temperal_downsample
# dimensions
dims
=
[
dim
*
u
for
u
in
[
1
]
+
dim_mult
]
scale
=
1.0
# init block
self
.
conv1
=
CausalConv3d
(
3
,
dims
[
0
],
3
,
padding
=
1
)
# downsample blocks
downsamples
=
[]
for
i
,
(
in_dim
,
out_dim
)
in
enumerate
(
zip
(
dims
[:
-
1
],
dims
[
1
:])):
# residual (+attention) blocks
for
_
in
range
(
num_res_blocks
):
downsamples
.
append
(
ResidualBlock
(
in_dim
,
out_dim
,
dropout
))
if
scale
in
attn_scales
:
downsamples
.
append
(
AttentionBlock
(
out_dim
))
in_dim
=
out_dim
# downsample block
if
i
!=
len
(
dim_mult
)
-
1
:
mode
=
"downsample3d"
if
temperal_downsample
[
i
]
else
"downsample2d"
downsamples
.
append
(
Resample
(
out_dim
,
mode
=
mode
))
scale
/=
2.0
self
.
downsamples
=
nn
.
Sequential
(
*
downsamples
)
# middle blocks
self
.
middle
=
nn
.
Sequential
(
ResidualBlock
(
out_dim
,
out_dim
,
dropout
),
AttentionBlock
(
out_dim
),
ResidualBlock
(
out_dim
,
out_dim
,
dropout
),
)
# output blocks
self
.
head
=
nn
.
Sequential
(
RMS_norm
(
out_dim
,
images
=
False
),
nn
.
SiLU
(),
CausalConv3d
(
out_dim
,
z_dim
,
3
,
padding
=
1
),
)
def
forward
(
self
,
x
,
feat_cache
=
None
,
feat_idx
=
[
0
]):
if
feat_cache
is
not
None
:
idx
=
feat_idx
[
0
]
cache_x
=
x
[:,
:,
-
CACHE_T
:,
:,
:].
clone
()
if
cache_x
.
shape
[
2
]
<
2
and
feat_cache
[
idx
]
is
not
None
:
# cache last frame of last two chunk
cache_x
=
torch
.
cat
(
[
feat_cache
[
idx
][:,
:,
-
1
,
:,
:].
unsqueeze
(
2
).
to
(
cache_x
.
device
),
cache_x
,
],
dim
=
2
,
)
x
=
self
.
conv1
(
x
,
feat_cache
[
idx
])
feat_cache
[
idx
]
=
cache_x
feat_idx
[
0
]
+=
1
else
:
x
=
self
.
conv1
(
x
)
## downsamples
for
layer
in
self
.
downsamples
:
if
feat_cache
is
not
None
:
x
=
layer
(
x
,
feat_cache
,
feat_idx
)
else
:
x
=
layer
(
x
)
## middle
for
layer
in
self
.
middle
:
if
isinstance
(
layer
,
ResidualBlock
)
and
feat_cache
is
not
None
:
x
=
layer
(
x
,
feat_cache
,
feat_idx
)
else
:
x
=
layer
(
x
)
## head
for
layer
in
self
.
head
:
if
isinstance
(
layer
,
CausalConv3d
)
and
feat_cache
is
not
None
:
idx
=
feat_idx
[
0
]
cache_x
=
x
[:,
:,
-
CACHE_T
:,
:,
:].
clone
()
if
cache_x
.
shape
[
2
]
<
2
and
feat_cache
[
idx
]
is
not
None
:
# cache last frame of last two chunk
cache_x
=
torch
.
cat
(
[
feat_cache
[
idx
][:,
:,
-
1
,
:,
:]
.
unsqueeze
(
2
)
.
to
(
cache_x
.
device
),
cache_x
,
],
dim
=
2
,
)
x
=
layer
(
x
,
feat_cache
[
idx
])
feat_cache
[
idx
]
=
cache_x
feat_idx
[
0
]
+=
1
else
:
x
=
layer
(
x
)
return
x
class
Decoder3d
(
nn
.
Module
):
def
__init__
(
self
,
dim
=
128
,
z_dim
=
4
,
dim_mult
=
[
1
,
2
,
4
,
4
],
num_res_blocks
=
2
,
attn_scales
=
[],
temperal_upsample
=
[
False
,
True
,
True
],
dropout
=
0.0
,
):
super
().
__init__
()
self
.
dim
=
dim
self
.
z_dim
=
z_dim
self
.
dim_mult
=
dim_mult
self
.
num_res_blocks
=
num_res_blocks
self
.
attn_scales
=
attn_scales
self
.
temperal_upsample
=
temperal_upsample
# dimensions
dims
=
[
dim
*
u
for
u
in
[
dim_mult
[
-
1
]]
+
dim_mult
[::
-
1
]]
scale
=
1.0
/
2
**
(
len
(
dim_mult
)
-
2
)
# init block
self
.
conv1
=
CausalConv3d
(
z_dim
,
dims
[
0
],
3
,
padding
=
1
)
# middle blocks
self
.
middle
=
nn
.
Sequential
(
ResidualBlock
(
dims
[
0
],
dims
[
0
],
dropout
),
AttentionBlock
(
dims
[
0
]),
ResidualBlock
(
dims
[
0
],
dims
[
0
],
dropout
),
)
# upsample blocks
upsamples
=
[]
for
i
,
(
in_dim
,
out_dim
)
in
enumerate
(
zip
(
dims
[:
-
1
],
dims
[
1
:])):
# residual (+attention) blocks
if
i
==
1
or
i
==
2
or
i
==
3
:
in_dim
=
in_dim
//
2
for
_
in
range
(
num_res_blocks
+
1
):
upsamples
.
append
(
ResidualBlock
(
in_dim
,
out_dim
,
dropout
))
if
scale
in
attn_scales
:
upsamples
.
append
(
AttentionBlock
(
out_dim
))
in_dim
=
out_dim
# upsample block
if
i
!=
len
(
dim_mult
)
-
1
:
mode
=
"upsample3d"
if
temperal_upsample
[
i
]
else
"upsample2d"
upsamples
.
append
(
Resample
(
out_dim
,
mode
=
mode
))
scale
*=
2.0
self
.
upsamples
=
nn
.
Sequential
(
*
upsamples
)
# output blocks
self
.
head
=
nn
.
Sequential
(
RMS_norm
(
out_dim
,
images
=
False
),
nn
.
SiLU
(),
CausalConv3d
(
out_dim
,
3
,
3
,
padding
=
1
),
)
def
forward
(
self
,
x
,
feat_cache
=
None
,
feat_idx
=
[
0
]):
## conv1
if
feat_cache
is
not
None
:
idx
=
feat_idx
[
0
]
cache_x
=
x
[:,
:,
-
CACHE_T
:,
:,
:].
clone
()
if
cache_x
.
shape
[
2
]
<
2
and
feat_cache
[
idx
]
is
not
None
:
# cache last frame of last two chunk
cache_x
=
torch
.
cat
(
[
feat_cache
[
idx
][:,
:,
-
1
,
:,
:].
unsqueeze
(
2
).
to
(
cache_x
.
device
),
cache_x
,
],
dim
=
2
,
)
x
=
self
.
conv1
(
x
,
feat_cache
[
idx
])
feat_cache
[
idx
]
=
cache_x
feat_idx
[
0
]
+=
1
else
:
x
=
self
.
conv1
(
x
)
## middle
for
layer
in
self
.
middle
:
if
isinstance
(
layer
,
ResidualBlock
)
and
feat_cache
is
not
None
:
x
=
layer
(
x
,
feat_cache
,
feat_idx
)
else
:
x
=
layer
(
x
)
## upsamples
for
layer
in
self
.
upsamples
:
if
feat_cache
is
not
None
:
x
=
layer
(
x
,
feat_cache
,
feat_idx
)
else
:
x
=
layer
(
x
)
## head
for
layer
in
self
.
head
:
if
isinstance
(
layer
,
CausalConv3d
)
and
feat_cache
is
not
None
:
idx
=
feat_idx
[
0
]
cache_x
=
x
[:,
:,
-
CACHE_T
:,
:,
:].
clone
()
if
cache_x
.
shape
[
2
]
<
2
and
feat_cache
[
idx
]
is
not
None
:
# cache last frame of last two chunk
cache_x
=
torch
.
cat
(
[
feat_cache
[
idx
][:,
:,
-
1
,
:,
:]
.
unsqueeze
(
2
)
.
to
(
cache_x
.
device
),
cache_x
,
],
dim
=
2
,
)
x
=
layer
(
x
,
feat_cache
[
idx
])
feat_cache
[
idx
]
=
cache_x
feat_idx
[
0
]
+=
1
else
:
x
=
layer
(
x
)
return
x
def
count_conv3d
(
model
):
count
=
0
for
m
in
model
.
modules
():
if
isinstance
(
m
,
CausalConv3d
):
count
+=
1
return
count
class
WanVAE_
(
nn
.
Module
):
def
__init__
(
self
,
dim
=
128
,
z_dim
=
4
,
dim_mult
=
[
1
,
2
,
4
,
4
],
num_res_blocks
=
2
,
attn_scales
=
[],
temperal_downsample
=
[
True
,
True
,
False
],
dropout
=
0.0
,
):
super
().
__init__
()
self
.
dim
=
dim
self
.
z_dim
=
z_dim
self
.
dim_mult
=
dim_mult
self
.
num_res_blocks
=
num_res_blocks
self
.
attn_scales
=
attn_scales
self
.
temperal_downsample
=
temperal_downsample
self
.
temperal_upsample
=
temperal_downsample
[::
-
1
]
# modules
self
.
encoder
=
Encoder3d
(
dim
,
z_dim
*
2
,
dim_mult
,
num_res_blocks
,
attn_scales
,
self
.
temperal_downsample
,
dropout
,
)
self
.
conv1
=
CausalConv3d
(
z_dim
*
2
,
z_dim
*
2
,
1
)
self
.
conv2
=
CausalConv3d
(
z_dim
,
z_dim
,
1
)
self
.
decoder
=
Decoder3d
(
dim
,
z_dim
,
dim_mult
,
num_res_blocks
,
attn_scales
,
self
.
temperal_upsample
,
dropout
,
)
def
forward
(
self
,
x
):
mu
,
log_var
=
self
.
encode
(
x
)
z
=
self
.
reparameterize
(
mu
,
log_var
)
x_recon
=
self
.
decode
(
z
)
return
x_recon
,
mu
,
log_var
def
encode
(
self
,
x
,
scale
):
self
.
clear_cache
()
## cache
t
=
x
.
shape
[
2
]
iter_
=
1
+
(
t
-
1
)
//
4
## 对encode输入的x,按时间拆分为1、4、4、4....
for
i
in
range
(
iter_
):
self
.
_enc_conv_idx
=
[
0
]
if
i
==
0
:
out
=
self
.
encoder
(
x
[:,
:,
:
1
,
:,
:],
feat_cache
=
self
.
_enc_feat_map
,
feat_idx
=
self
.
_enc_conv_idx
,
)
else
:
out_
=
self
.
encoder
(
x
[:,
:,
1
+
4
*
(
i
-
1
)
:
1
+
4
*
i
,
:,
:],
feat_cache
=
self
.
_enc_feat_map
,
feat_idx
=
self
.
_enc_conv_idx
,
)
out
=
torch
.
cat
([
out
,
out_
],
2
)
mu
,
log_var
=
self
.
conv1
(
out
).
chunk
(
2
,
dim
=
1
)
if
isinstance
(
scale
[
0
],
torch
.
Tensor
):
mu
=
(
mu
-
scale
[
0
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
))
*
scale
[
1
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
)
else
:
mu
=
(
mu
-
scale
[
0
])
*
scale
[
1
]
self
.
clear_cache
()
return
mu
def
decode
(
self
,
z
,
scale
):
self
.
clear_cache
()
# z: [b,c,t,h,w]
if
isinstance
(
scale
[
0
],
torch
.
Tensor
):
z
=
z
/
scale
[
1
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
)
+
scale
[
0
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
)
else
:
z
=
z
/
scale
[
1
]
+
scale
[
0
]
iter_
=
z
.
shape
[
2
]
x
=
self
.
conv2
(
z
)
for
i
in
range
(
iter_
):
self
.
_conv_idx
=
[
0
]
if
i
==
0
:
out
=
self
.
decoder
(
x
[:,
:,
i
:
i
+
1
,
:,
:],
feat_cache
=
self
.
_feat_map
,
feat_idx
=
self
.
_conv_idx
,
)
else
:
out_
=
self
.
decoder
(
x
[:,
:,
i
:
i
+
1
,
:,
:],
feat_cache
=
self
.
_feat_map
,
feat_idx
=
self
.
_conv_idx
,
)
out
=
torch
.
cat
([
out
,
out_
],
2
)
self
.
clear_cache
()
return
out
def
reparameterize
(
self
,
mu
,
log_var
):
std
=
torch
.
exp
(
0.5
*
log_var
)
eps
=
torch
.
randn_like
(
std
)
return
eps
*
std
+
mu
def
sample
(
self
,
imgs
,
deterministic
=
False
):
mu
,
log_var
=
self
.
encode
(
imgs
)
if
deterministic
:
return
mu
std
=
torch
.
exp
(
0.5
*
log_var
.
clamp
(
-
30.0
,
20.0
))
return
mu
+
std
*
torch
.
randn_like
(
std
)
def
clear_cache
(
self
):
self
.
_conv_num
=
count_conv3d
(
self
.
decoder
)
self
.
_conv_idx
=
[
0
]
self
.
_feat_map
=
[
None
]
*
self
.
_conv_num
# cache encode
self
.
_enc_conv_num
=
count_conv3d
(
self
.
encoder
)
self
.
_enc_conv_idx
=
[
0
]
self
.
_enc_feat_map
=
[
None
]
*
self
.
_enc_conv_num
def
_video_vae
(
pretrained_path
=
None
,
z_dim
=
None
,
device
=
"cpu"
,
**
kwargs
):
"""
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
"""
# params
cfg
=
dict
(
dim
=
96
,
z_dim
=
z_dim
,
dim_mult
=
[
1
,
2
,
4
,
4
],
num_res_blocks
=
2
,
attn_scales
=
[],
temperal_downsample
=
[
False
,
True
,
True
],
dropout
=
0.0
,
)
cfg
.
update
(
**
kwargs
)
# init model
with
torch
.
device
(
"meta"
):
model
=
WanVAE_
(
**
cfg
)
# load checkpoint
logging
.
info
(
f
"loading
{
pretrained_path
}
"
)
model
.
load_state_dict
(
torch
.
load
(
pretrained_path
,
map_location
=
device
,
weights_only
=
True
),
assign
=
True
)
return
model
class
WanVAE
:
def
__init__
(
self
,
z_dim
=
16
,
vae_pth
=
"cache/vae_step_411000.pth"
,
dtype
=
torch
.
float
,
device
=
"cuda"
,
):
self
.
dtype
=
dtype
self
.
device
=
device
mean
=
[
-
0.7571
,
-
0.7089
,
-
0.9113
,
0.1075
,
-
0.1745
,
0.9653
,
-
0.1517
,
1.5508
,
0.4134
,
-
0.0715
,
0.5517
,
-
0.3632
,
-
0.1922
,
-
0.9497
,
0.2503
,
-
0.2921
,
]
std
=
[
2.8184
,
1.4541
,
2.3275
,
2.6558
,
1.2196
,
1.7708
,
2.6052
,
2.0743
,
3.2687
,
2.1526
,
2.8652
,
1.5579
,
1.6382
,
1.1253
,
2.8251
,
1.9160
,
]
self
.
mean
=
torch
.
tensor
(
mean
,
dtype
=
dtype
,
device
=
device
)
self
.
std
=
torch
.
tensor
(
std
,
dtype
=
dtype
,
device
=
device
)
self
.
scale
=
[
self
.
mean
,
1.0
/
self
.
std
]
# init model
self
.
model
=
(
_video_vae
(
pretrained_path
=
vae_pth
,
z_dim
=
z_dim
,
)
.
eval
()
.
requires_grad_
(
False
)
.
to
(
device
)
)
def
encode
(
self
,
videos
):
"""
videos: A list of videos each with shape [C, T, H, W].
"""
return
[
self
.
model
.
encode
(
u
.
unsqueeze
(
0
),
self
.
scale
).
float
().
squeeze
(
0
)
for
u
in
videos
]
def
decode
(
self
,
zs
,
generator
,
args
):
return
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
lightx2v/text2v/models/video_encoders/trt/__init__.py
0 → 100644
View file @
daf4c74e
lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/model.py
0 → 100755
View file @
daf4c74e
import
os
import
torch
from
lightx2v.text2v.models.video_encoders.hf.autoencoder_kl_causal_3d.autoencoder_kl_causal_3d
import
AutoencoderKLCausal3D
from
lightx2v.text2v.models.video_encoders.trt.autoencoder_kl_causal_3d
import
trt_vae_infer
class
VideoEncoderKLCausal3DModel
():
def
__init__
(
self
,
model_path
,
dtype
,
device
):
self
.
model_path
=
model_path
self
.
dtype
=
dtype
self
.
device
=
device
self
.
load
()
def
load
(
self
):
self
.
vae_path
=
os
.
path
.
join
(
self
.
model_path
,
'hunyuan-video-t2v-720p/vae'
)
config
=
AutoencoderKLCausal3D
.
load_config
(
self
.
vae_path
)
self
.
model
=
AutoencoderKLCausal3D
.
from_config
(
config
)
ckpt
=
torch
.
load
(
os
.
path
.
join
(
self
.
vae_path
,
'pytorch_model.pt'
),
map_location
=
'cpu'
,
weights_only
=
True
)
self
.
model
.
load_state_dict
(
ckpt
)
self
.
model
=
self
.
model
.
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
self
.
model
.
requires_grad_
(
False
)
self
.
model
.
eval
()
trt_decoder
=
trt_vae_infer
.
HyVaeTrtModelInfer
(
engine_path
=
os
.
path
.
join
(
self
.
vae_path
,
"vae_decoder.engine"
))
self
.
model
.
decoder
=
trt_decoder
def
decode
(
self
,
latents
,
generator
):
latents
=
latents
/
self
.
model
.
config
.
scaling_factor
latents
=
latents
.
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
self
.
model
.
enable_tiling
()
image
=
self
.
model
.
decode
(
latents
,
return_dict
=
False
,
generator
=
generator
)[
0
]
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
image
.
cpu
().
float
()
return
image
if
__name__
==
"__main__"
:
vae_model
=
VideoEncoderKLCausal3DModel
(
"/mnt/nvme1/yongyang/models/hy/ckpts"
,
dtype
=
torch
.
float16
,
device
=
torch
.
device
(
"cuda"
))
lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/trt_vae_infer.py
0 → 100644
View file @
daf4c74e
import
os
from
pathlib
import
Path
from
subprocess
import
Popen
import
numpy
as
np
import
torch
import
tensorrt
as
trt
from
cuda
import
cudart
import
torch.nn
as
nn
from
loguru
import
logger
from
lightx2v.common.backend_infer.trt
import
common
TRT_LOGGER
=
trt
.
Logger
(
trt
.
Logger
.
INFO
)
class
HyVaeTrtModelInfer
(
nn
.
Module
):
"""
Implements inference for the TensorRT engine.
"""
def
__init__
(
self
,
engine_path
):
"""
:param engine_path: The path to the serialized engine to load from disk.
"""
# Load TRT engine
if
not
Path
(
engine_path
).
exists
():
# dir_name = str(Path(engine_path).parents)
# onnx_path = self.export_to_onnx(decoder, dir_name)
# self.convert_to_trt_engine(onnx_path, engine_path)
raise
FileNotFoundError
(
f
"VAE tensorrt engine `
{
str
(
engine_path
)
}
` not exists."
)
self
.
logger
=
trt
.
Logger
(
trt
.
Logger
.
ERROR
)
with
open
(
engine_path
,
"rb"
)
as
f
,
trt
.
Runtime
(
self
.
logger
)
as
runtime
:
assert
runtime
self
.
engine
=
runtime
.
deserialize_cuda_engine
(
f
.
read
())
assert
self
.
engine
self
.
context
=
self
.
engine
.
create_execution_context
()
assert
self
.
context
logger
.
info
(
f
"Loaded VAE tensorrt engine from `
{
engine_path
}
`"
)
def
alloc
(
self
,
shape_dict
):
"""
Setup I/O bindings
"""
self
.
inputs
=
[]
self
.
outputs
=
[]
self
.
allocations
=
[]
for
i
in
range
(
self
.
engine
.
num_io_tensors
):
name
=
self
.
engine
.
get_tensor_name
(
i
)
is_input
=
False
if
self
.
engine
.
get_tensor_mode
(
name
)
==
trt
.
TensorIOMode
.
INPUT
:
is_input
=
True
dtype
=
self
.
engine
.
get_tensor_dtype
(
name
)
# shape = self.engine.get_tensor_shape(name)
shape
=
shape_dict
[
name
]
if
is_input
:
self
.
context
.
set_input_shape
(
name
,
shape
)
self
.
batch_size
=
shape
[
0
]
size
=
np
.
dtype
(
trt
.
nptype
(
dtype
)).
itemsize
for
s
in
shape
:
size
*=
s
allocation
=
common
.
cuda_call
(
cudart
.
cudaMalloc
(
size
))
binding
=
{
"index"
:
i
,
"name"
:
name
,
"dtype"
:
np
.
dtype
(
trt
.
nptype
(
dtype
)),
"shape"
:
list
(
shape
),
"allocation"
:
allocation
,
}
self
.
allocations
.
append
(
allocation
)
if
is_input
:
self
.
inputs
.
append
(
binding
)
else
:
self
.
outputs
.
append
(
binding
)
assert
self
.
batch_size
>
0
assert
len
(
self
.
inputs
)
>
0
assert
len
(
self
.
outputs
)
>
0
assert
len
(
self
.
allocations
)
>
0
def
input_spec
(
self
):
"""
Get the specs for the input tensor of the network. Useful to prepare memory allocations.
:return: Two items, the shape of the input tensor and its (numpy) datatype.
"""
return
self
.
inputs
[
0
][
"shape"
],
self
.
inputs
[
0
][
"dtype"
]
def
output_spec
(
self
):
"""
Get the specs for the output tensor of the network. Useful to prepare memory allocations.
:return: Two items, the shape of the output tensor and its (numpy) datatype.
"""
return
self
.
outputs
[
0
][
"shape"
],
self
.
outputs
[
0
][
"dtype"
]
def
__call__
(
self
,
batch
,
top
=
1
):
"""
Execute inference
"""
# Prepare the output data
device
=
batch
.
device
dtype
=
batch
.
dtype
batch
=
batch
.
cpu
().
numpy
()
def
get_output_shape
(
shp
):
b
,
c
,
t
,
h
,
w
=
shp
out
=
(
b
,
3
,
4
*
(
t
-
1
)
+
1
,
h
*
8
,
w
*
8
)
return
out
shp_dict
=
{
"inp"
:
batch
.
shape
,
"out"
:
get_output_shape
(
batch
.
shape
)}
self
.
alloc
(
shp_dict
)
output
=
np
.
zeros
(
*
self
.
output_spec
())
# Process I/O and execute the network
common
.
memcpy_host_to_device
(
self
.
inputs
[
0
][
"allocation"
],
np
.
ascontiguousarray
(
batch
)
)
self
.
context
.
execute_v2
(
self
.
allocations
)
common
.
memcpy_device_to_host
(
output
,
self
.
outputs
[
0
][
"allocation"
])
output
=
torch
.
from_numpy
(
output
).
to
(
device
).
type
(
dtype
)
return
output
@
staticmethod
def
export_to_onnx
(
decoder
:
torch
.
nn
.
Module
,
model_dir
):
logger
.
info
(
"Start to do VAE onnx exporting."
)
device
=
next
(
decoder
.
parameters
())[
0
].
device
example_inp
=
torch
.
rand
(
1
,
16
,
17
,
32
,
32
).
to
(
device
).
type
(
next
(
decoder
.
parameters
())[
0
].
dtype
)
out_path
=
str
(
Path
(
str
(
model_dir
))
/
"vae_decoder.onnx"
)
torch
.
onnx
.
export
(
decoder
.
eval
().
half
(),
example_inp
.
half
(),
out_path
,
input_names
=
[
'inp'
],
output_names
=
[
'out'
],
opset_version
=
14
,
dynamic_axes
=
{
"inp"
:
{
1
:
"c1"
,
2
:
"c2"
,
3
:
"c3"
,
4
:
"c4"
},
"out"
:
{
1
:
"c1"
,
2
:
"c2"
,
3
:
"c3"
,
4
:
"c4"
}
}
)
# onnx_ori = onnx.load(out_path)
os
.
system
(
f
"onnxsim
{
out_path
}
{
out_path
}
"
)
# onnx_opt, check = simplify(onnx_ori)
# assert check, f"Simplified ONNX model({out_path}) could not be validated."
# onnx.save(onnx_opt, out_path)
logger
.
info
(
"Finish VAE onnx exporting."
)
return
out_path
@
staticmethod
def
convert_to_trt_engine
(
onnx_path
,
engine_path
):
logger
.
info
(
"Start to convert VAE ONNX to tensorrt engine."
)
cmd
=
(
"trtexec "
f
"--onnx=
{
onnx_path
}
"
f
"--saveEngine=
{
engine_path
}
"
"--allowWeightStreaming "
"--stronglyTyped "
"--fp16 "
"--weightStreamingBudget=100 "
"--minShapes=inp:1x16x9x18x16 "
"--optShapes=inp:1x16x17x32x16 "
"--maxShapes=inp:1x16x17x32x32 "
)
p
=
Popen
(
cmd
,
shell
=
True
)
p
.
wait
()
if
not
Path
(
engine_path
).
exists
():
raise
RuntimeError
(
f
"Convert vae onnx(
{
onnx_path
}
) to tensorrt engine failed."
)
logger
.
info
(
"Finish VAE tensorrt converting."
)
return
engine_path
\ No newline at end of file
lightx2v/utils/__init__.py
0 → 100755
View file @
daf4c74e
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment